In [1]:
%load_ext autoreload
%autoreload 2


In [24]:
import os
from omegaconf import OmegaConf
from models.flow_module import FlowModule
import torch
from data.pdb_dataloader import PdbDataModule
import GPUtil
from data import utils as du
import numpy as np
import tree
from data import so3_utils
from data import all_atom
from analysis import utils as au
from openfold.utils.superimposition import superimpose
import matplotlib.pyplot as plt
import copy
import pandas as pd
import math


In [3]:
# Setup lightning module

cfg_path = '../ckpt/se3-fm/scope_256/2023-10-19_10-54-29/config.yaml'
base_path = '../configs/base.yaml'
base_cfg = OmegaConf.load(base_path)
ckpt_cfg = OmegaConf.load(cfg_path)

OmegaConf.set_struct(base_cfg, False)
OmegaConf.set_struct(ckpt_cfg, False)
cfg = OmegaConf.merge(base_cfg, ckpt_cfg)
cfg.experiment.checkpointer.dirpath = './'
cfg.experiment.rescale_time = False
cfg.data.dataset.max_num_res = 256
cfg.data.dataset.min_num_res = 200
cfg.data.dataset.csv_path = '../swiss_prot/swiss_prot_pkls/metadata.csv'
# cfg.data.dataset.csv_path = '../preprocessed/metadata.csv'
cfg.data.loader.num_workers = 0
device = f'cpu'

In [4]:
# Set up data module
data_module = PdbDataModule(cfg.data)
data_module.setup('fit')

In [5]:
train_dataloader = data_module.train_dataloader(
    num_replicas=2,
    rank=0
)
data_iter = iter(train_dataloader)

In [6]:
batch = next(data_iter)

In [7]:
batch.keys()

dict_keys(['res_plddt', 'aatype', 'res_idx', 'rotmats_1', 'trans_1', 'res_mask', 'csv_idx'])

In [None]:
rng = np.random.default_rng()

In [17]:
trans_1 = batch['trans_1'][0]

In [48]:
dist2d = torch.linalg.norm(trans_1[:, None, :] - trans_1[None, :, :], dim=-1)
diff_mask = torch.zeros_like(trans_1)
crop_seed = rng.integers(dist2d.shape[0])
seed_dists = dist2d[crop_seed]
max_scaffold_size = math.floor(seed_dists.shape[0] * 0.9)
min_scaffold_size = 10
scaffold_size = rng.integers(
    low=min_scaffold_size,
    high=max_scaffold_size
)
dist_cutoff = np.sort(seed_dists)[scaffold_size]
diff_mask = (seed_dists > dist_cutoff).int()

In [49]:
scaffold_size

54

In [50]:
diff_mask

tensor([1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0,
        1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32)

In [51]:
dist_cutoff

17.060059

In [None]:
        dist2d = np.linalg.norm(bb_pos[:, None, :] - bb_pos[None, :, :], axis=-1)

        # Randomly select residue then sample a distance cutoff
        # TODO: Use a more robust diffuse mask sampling method.
        diff_mask = np.zeros_like(bb_pos)
        attempts = 0
        while np.sum(diff_mask) < 1:
            crop_seed = rng.integers(dist2d.shape[0])
            seed_dists = dist2d[crop_seed]
            max_scaffold_size = min(
                self._data_conf.scaffold_size_max,
                seed_dists.shape[0] - self._data_conf.motif_size_min
            )
            scaffold_size = rng.integers(
                low=self._data_conf.scaffold_size_min,
                high=max_scaffold_size
            )
            dist_cutoff = np.sort(seed_dists)[scaffold_size]
            diff_mask = (seed_dists < dist_cutoff).astype(float)
            attempts += 1
            if attempts > 100:
                raise ValueError(
                    f'Unable to generate diffusion mask for {row}')

In [24]:
all_batch_size = []
for _ in range(10):
    batch = next(data_iter)
    all_batch_size.append(batch['res_mask'].shape)

In [29]:
swiss_prot_csv = pd.read_csv('../swiss_prot/swiss_prot_pkls/metadata.csv')
scope_csv = pd.read_csv('../preprocessed/metadata.csv')

In [None]:
swiss_prot_csv.head()

In [None]:
swiss_prot_csv.seq_len

In [82]:
swiss_prot_row = swiss_prot_csv[swiss_prot_csv['seq_len'] == 499]
scope_row = scope_csv[scope_csv['seq_len'] == 499]

In [83]:
swiss_prot_path = swiss_prot_row.iloc[0].processed_path
scope_path = scope_row.iloc[0].processed_path

In [86]:
%%timeit
for _ in range(10):
    processed_feats = du.read_pkl(swiss_prot_path)

9.85 ms ± 50.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [87]:
%%timeit
for _ in range(10):
    processed_feats = du.read_pkl(scope_path)

9.68 ms ± 98.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Test that each rank gets own set of indices

In [24]:
def indices(rank):
    train_dataloader = data_module.train_dataloader(
        num_replicas=2,
        rank=rank
    )
    data_iter = iter(train_dataloader)
    all_csv_idx = []
    for batch in data_iter:
        all_csv_idx.append(batch['csv_idx'])
    all_csv_idx = torch.concatenate(all_csv_idx).squeeze()
    return sorted(all_csv_idx.unique().tolist())

In [25]:
rank_0_indices = indices(0)
rank_1_indices = indices(1)

In [28]:
intersection_indices = set(rank_1_indices) & set(rank_0_indices)

In [34]:
train_dataloader = data_module.train_dataloader(
    num_replicas=2,
    rank=0
)
raw_csv = train_dataloader.dataset.csv

In [38]:
all_indices = raw_csv.index.tolist()

In [39]:
len(all_indices)

3938

In [40]:
len(set(rank_0_indices) | set(rank_1_indices))

3938

# Test that each epoch changes the batch order

In [5]:
train_dataloader = data_module.train_dataloader(
    num_replicas=2,
    rank=0,
)
data_iter_1 = iter(train_dataloader)
# print(train_dataloader.batch_sampler.epoch)
data_iter_2 = iter(train_dataloader)
# print(train_dataloader.batch_sampler.epoch)

In [9]:
batch_1 = next(data_iter_1)
idx_1 = batch_1['csv_idx']
res_1 = batch_1['res_mask'].shape[1]

batch_2 = next(data_iter_2)
idx_2 = batch_2['csv_idx']
res_2 = batch_2['res_mask'].shape[1]
while res_2 != res_1:
    batch_2 = next(data_iter_2)
    idx_2 = batch_2['csv_idx']
    res_2 = batch_2['res_mask'].shape[1]    

In [18]:
len(set(idx_1.squeeze().numpy()) & set(idx_2.squeeze().numpy()))

17

In [17]:
len(idx_1)

27

# Test BatchOT

In [28]:
train_dataloader = data_module.train_dataloader(
    num_replicas=2,
    rank=0
)
data_iter = iter(train_dataloader)

In [29]:
batch = next(data_iter)

In [38]:
from scipy.optimize import linear_sum_assignment

In [40]:
trans_1 = batch['trans_1']
res_mask = batch['res_mask']
trans_0 = torch.randn(trans_1.shape) * 10.0
num_batch, num_res = res_mask.shape

noise_idx, gt_idx = torch.where(
    torch.ones(num_batch, num_batch))
batch_nm_0 = trans_0[noise_idx]
batch_nm_1 = trans_1[gt_idx]
batch_mask = res_mask[gt_idx]
aligned_nm_0, aligned_nm_1, _ = du.batch_align_structures(
    batch_nm_0, batch_nm_1, mask=batch_mask
) 

aligned_nm_0 = aligned_nm_0.reshape(num_batch, num_batch, num_res, 3)
aligned_nm_1 = aligned_nm_1.reshape(num_batch, num_batch, num_res, 3)

# Compute cost matrix of aligned noise to ground truth
batch_mask = batch_mask.reshape(num_batch, num_batch, num_res)
cost_matrix = torch.sum(
    torch.linalg.norm(aligned_nm_0 - aligned_nm_1, dim=-1), dim=-1
) / torch.sum(batch_mask, dim=-1)
noise_perm, gt_perm = linear_sum_assignment(du.to_numpy(cost_matrix))
aligned_trans_0 = aligned_nm_0[(tuple(gt_perm), tuple(noise_perm))]

In [42]:
trans_dist = torch.mean(torch.linalg.norm(trans_0 - trans_1, dim=-1), -1)

In [43]:
aligned_dist = torch.mean(torch.linalg.norm(aligned_trans_0 - trans_1, dim=-1), -1)

In [46]:
torch.mean(trans_dist)

tensor(19.7388)

In [47]:
torch.mean(aligned_dist)

tensor(18.3862)

In [None]:
num_batch, num_res = trans_0.shape[:2]
noise_idx, gt_idx = torch.where(
    torch.ones(num_batch, num_batch))
batch_nm_0 = trans_0[noise_idx]
batch_nm_1 = trans_1[gt_idx]
batch_mask = res_mask[gt_idx]
aligned_nm_0, aligned_nm_1, _ = du.batch_align_structures(
    batch_nm_0, batch_nm_1, mask=batch_mask
) 
aligned_nm_0 = aligned_nm_0.reshape(num_batch, num_batch, num_res, 3)
aligned_nm_1 = aligned_nm_1.reshape(num_batch, num_batch, num_res, 3)

# Compute cost matrix of aligned noise to ground truth
batch_mask = batch_mask.reshape(num_batch, num_batch, num_res)
cost_matrix = torch.sum(
    torch.linalg.norm(aligned_nm_0 - aligned_nm_1, dim=-1), dim=-1
) / torch.sum(batch_mask, dim=-1)
noise_perm, gt_perm = linear_sum_assignment(du.to_numpy(cost_matrix))
return aligned_nm_0[(tuple(gt_perm), tuple(noise_perm))]

In [None]:
    def _batch_ot(self, trans_0, trans_1, res_mask):
        num_batch, num_res = trans_0.shape[:2]
        noise_idx, gt_idx = torch.where(
            torch.ones(num_batch, num_batch))
        batch_nm_0 = trans_0[noise_idx]
        batch_nm_1 = trans_1[gt_idx]
        batch_mask = res_mask[gt_idx]
        aligned_nm_0, aligned_nm_1, _ = du.batch_align_structures(
            batch_nm_0, batch_nm_1, mask=batch_mask
        ) 
        aligned_nm_0 = aligned_nm_0.reshape(num_batch, num_batch, num_res, 3)
        aligned_nm_1 = aligned_nm_1.reshape(num_batch, num_batch, num_res, 3)
        
        # Compute cost matrix of aligned noise to ground truth
        batch_mask = batch_mask.reshape(num_batch, num_batch, num_res)
        cost_matrix = torch.sum(
            torch.linalg.norm(aligned_nm_0 - aligned_nm_1, dim=-1), dim=-1
        ) / torch.sum(batch_mask, dim=-1)
        noise_perm, gt_perm = linear_sum_assignment(du.to_numpy(cost_matrix))
        return aligned_nm_0[(tuple(gt_perm), tuple(noise_perm))]