In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from omegaconf import OmegaConf
from models.flow_module import FlowModule
import torch
from data.pdb_dataloader import PdbDataModule
import glob
import GPUtil
from data import utils as du
from scipy.spatial.transform import Rotation
import numpy as np
import tree
from data import so3_utils
from data import all_atom
from analysis import plotting
from analysis import utils as au
from openfold.utils.superimposition import superimpose
import matplotlib.pyplot as plt


In [None]:
# Setup lightning module

# ckpt_dir = 'ckpt/se3-fm/warm_start_baseline/2023-09-04_15-13-01'
# ckpt_dir = 'ckpt/se3-fm/flower/2023-09-04_20-54-39/'
ckpt_dir = 'ckpt/se3-fm/warm_start/2023-08-30_17-36-07'
ckpt_path = sorted(glob.glob(os.path.join(ckpt_dir, '*.ckpt')))[-1]
print(ckpt_path)

cfg_path = os.path.join(ckpt_dir, 'config.yaml')
ckpt_cfg = OmegaConf.load(cfg_path)
base_path = 'configs/base.yaml'
base_cfg = OmegaConf.load(base_path)

OmegaConf.set_struct(base_cfg, False)
OmegaConf.set_struct(ckpt_cfg, False)
cfg = OmegaConf.merge(base_cfg, ckpt_cfg)
cfg.experiment.checkpointer.dirpath = './'

flow = FlowModule.load_from_checkpoint(
    checkpoint_path=ckpt_path,
    model_cfg=cfg.model,
    experiment_cfg=cfg.experiment
) 
_ = flow.eval()

In [None]:
# Set up data module
data_module = PdbDataModule(cfg.data)
data_module.setup('fit')
train_dataloader = data_module.train_dataloader(
    num_replicas=1,
    rank=1
)
data_iter = iter(train_dataloader)


In [None]:
# Search for a reasonable batch
stop_search = False
while not stop_search: 
    batch = next(data_iter)
    num_batch, num_res, _ = batch['trans_1'].shape
    if num_res > 70 and num_res < 100:
        stop_search = True

print(f'Found batch with {num_res} residues, {num_batch} batch size')

# Set up device and cuda
num_batch, num_res, _ = batch['trans_1'].shape
# cuda_id = GPUtil.getAvailable(order='memory', limit = 8)[0]
# device = f'cuda:{cuda_id}'
device = 'cpu'
print(f'Using device {device}')
flow.model = flow.model.to(device)
batch = tree.map_structure(lambda x: x.to(device), batch)
num_batch = batch['res_mask'].shape[0]

In [None]:
batch.keys()

In [None]:
batch['res_idx']

# Utils

In [None]:
def atom37_from_trans_rot(trans, rots, res_mask):
        rigids = du.create_rigid(rots, trans)
        atom37 = all_atom.compute_backbone(
            rigids,
            torch.zeros(
                trans.shape[0],
                trans.shape[1],
                2,
                device=trans.device
            )
        )[0]
        atom37 = atom37.detach().cpu()
        batch_atom37 = []
        for i in range(num_batch):
            batch_atom37.append(
                du.adjust_oxygen_pos(atom37[i], res_mask[i])
            )
        return torch.stack(batch_atom37)

def process_trans_rot_traj(trans_traj, rots_traj, res_mask):
    res_mask = res_mask.detach().cpu()
    atom37_traj = [
         atom37_from_trans_rot(trans, rots, res_mask)
         for trans, rots in zip(trans_traj, rots_traj) 
    ]
    atom37_traj = torch.stack(atom37_traj).swapaxes(0, 1)
    return atom37_traj 

# Model prediction at different timesteps

In [None]:
gt_ca_pos = batch['trans_1']
device = gt_ca_pos.device
num_batch = gt_ca_pos.shape[0]
ts = np.linspace(1e-3, 1.0, 100)
all_pred_ca = []
for i,t in enumerate(ts):
    print(f'On {i}')
    batch_t = torch.ones(num_batch, 1, 1, device=device) * t
    noisy_batch = flow._corrupt_batch(batch, t=batch_t)
    with torch.no_grad():
        model_out = flow.forward(noisy_batch)
    all_pred_ca.append(du.to_numpy(model_out['pred_trans']))
all_pred_ca = np.stack(all_pred_ca)

In [None]:
res_mask = batch['res_mask']
aligned_sample_ca, aligned_rmsd = superimpose(
    torch.tensor(gt_ca_pos)[None].repeat(100, 1, 1, 1),
    torch.tensor(all_pred_ca).to(gt_ca_pos.device),
    res_mask[None].repeat(100, 1, 1)
)
torch.mean(aligned_rmsd)

In [None]:
ts_rmsd = du.to_numpy(torch.mean(aligned_rmsd, dim=-1))

In [None]:
plt.plot(ts, ts_rmsd)
# plt.plot(ts, noisy_rmsds)

# Partial sampling

In [None]:
min_t = 1e-3
t = torch.ones(num_batch, 1, 1, device=device) * min_t
noisy_batch = flow._corrupt_batch(batch, t=t)
batch_trunc = 5
trans_in = du.to_numpy(noisy_batch['trans_t'][:batch_trunc])
gt = du.to_numpy(noisy_batch['trans_1'][:batch_trunc])


In [None]:
# Run sampling
trans_traj = [noisy_batch['trans_t']]
rots_traj = [noisy_batch['rotmats_t']]
num_timesteps = 100
ts = np.linspace(min_t, 1.0, num_timesteps)
t_1 = ts[0]
model_outputs = []
trans_vf_traj = []
for i,t_2 in enumerate(ts[1:]):
    if (i+1) % 100 == 0:
        print(f"Step {i+1} / {len(ts)}")
    d_t = t_2 - t_1
    trans_t_1 = trans_traj[-1]
    rots_t_1 = rots_traj[-1]
    with torch.no_grad():
        batch['trans_t'] = trans_t_1
        batch['rotmats_t'] = rots_t_1
        batch['t'] = torch.ones((num_batch, 1)).to(device) * t_1
        model_out = flow.forward(batch)
        model_outputs.append(
            tree.map_structure(lambda x: du.to_numpy(x), model_out)
        )

    pred_trans_1 = model_out['pred_trans']
    pred_rots_1 = model_out['pred_rotmats']
    pred_rots_vf = model_out['pred_rots_vf']

    trans_vf = (pred_trans_1 - trans_t_1) / (1 - t_1)
    trans_t_2 = trans_t_1 + trans_vf * d_t
    rots_t_2 = so3_utils.geodesic_t(
        d_t / (1 - t_1), pred_rots_1, rots_t_1, rot_vf=pred_rots_vf)
    t_1 = t_2
    trans_traj.append(trans_t_2)
    rots_traj.append(rots_t_2)

res_mask = batch['res_mask']
atom37_traj = process_trans_rot_traj(trans_traj, rots_traj, res_mask)
final_ca_pos = atom37_traj[:, -1, :, 1]
gt_ca_pos = batch['trans_1']

In [None]:
aligned_sample_ca, aligned_rmsd = superimpose(
    gt_ca_pos,
    final_ca_pos.to(gt_ca_pos.device),
    res_mask
)
torch.mean(aligned_rmsd)

In [None]:
# Save samples
save_dir = 'notebook_samples/'
os.makedirs(save_dir, exist_ok=True)
atom37_traj = du.to_numpy(atom37_traj)
num_batch, num_timesteps, num_res, _, _ = atom37_traj.shape
max_save = 5
for i, sample_traj in enumerate(atom37_traj):
    if i >= max_save:
        break
    # traj_path = au.write_prot_to_pdb(
    #     sample_traj,
    #     os.path.join(
    #         save_dir,
    #         f'traj_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
    #     no_indexing=True
    # )
    sample_path = au.write_prot_to_pdb(
        sample_traj[-1],
        os.path.join(
            save_dir,
            f'sample_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
        no_indexing=True
    )
    print(f'Done with sample {i}')

In [None]:
gt_trans_1 = batch['trans_1']
gt_rotmats_1 = batch['rotmats_1']

gt_atom37 = atom37_from_trans_rot(
    gt_trans_1.detach().cpu(),
    gt_rotmats_1.detach().cpu(),
    res_mask.detach().cpu()
)
gt_atom37 = du.to_numpy(gt_atom37)

In [None]:
for i, gt_coords in enumerate(gt_atom37):
    if i >= max_save:
        break
    sample_path = au.write_prot_to_pdb(
        gt_coords,
        os.path.join(
            save_dir,
            f'gt_{i}_len_{num_res}.pdb'),
        no_indexing=True
    )
    print(f'Done with sample {i}')

## Calculate SNR

In [None]:
gt_ca_pos = batch['trans_1']
device = gt_ca_pos.device
num_batch = gt_ca_pos.shape[0]
all_noise = {}
for i,t in enumerate(ts):
    print(f'On {i}')
    noisy_ca_pos = []
    batch_t = torch.ones(num_batch, 1, 1, device=device) * t 
    for _ in range(10):
        noisy_batch = flow._corrupt_batch(batch, t=batch_t)
        noisy_ca_pos.append(noisy_batch['trans_t'])
    noisy_ca_pos = torch.stack(noisy_ca_pos)
    all_noise[i] = noisy_ca_pos

In [None]:
def calc_rmsd(x, y):
    return torch.mean(torch.linalg.norm(x - y, dim=-1))

In [None]:
noisy_rmsds = []
for i in range(len(all_noise)):
    noisy_ca_pos = all_noise[i]
    rmsd = calc_rmsd(noisy_ca_pos, gt_ca_pos[None])
    noisy_rmsds.append(rmsd)
noisy_rmsds = du.to_numpy(torch.stack(noisy_rmsds))

In [None]:
plt.plot(ts, noisy_rmsds)

## Process and save samples

In [None]:
save_dir = 'notebook_samples/'
os.makedirs(save_dir, exist_ok=True)
atom37_traj = du.to_numpy(atom37_traj)
num_batch, num_timesteps, num_res, _, _ = atom37_traj.shape
max_save = 5
for i, sample_traj in enumerate(atom37_traj):
    if i >= max_save:
        break
    traj_path = au.write_prot_to_pdb(
        sample_traj,
        os.path.join(
            save_dir,
            f'traj_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
        no_indexing=True
    )
    sample_path = au.write_prot_to_pdb(
        sample_traj[-1],
        os.path.join(
            save_dir,
            f'sample_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
        no_indexing=True
    )
    print(f'Done with sample {i}')

## Visualize structures

In [None]:
idx = 0
model_out_ca = model_outputs[-1]['pred_trans'][idx]
sample_bb_3d = plotting.create_scatter(model_out_ca, mode='lines+markers', marker_size=3, opacity=1.0)
plotting.plot_traces([sample_bb_3d])

In [None]:
idx = 0
t_idx = -1
model_out_ca = model_outputs[t_idx]['pred_trans'][idx]
model_in_ca = trans_in[idx]
gt_ca = gt[idx]

out_bb_3d = plotting.create_scatter(
    model_out_ca, mode='lines+markers', marker_size=3, opacity=1.0, name='pred')
in_bb_3d = plotting.create_scatter(
    model_in_ca, mode='lines+markers', marker_size=3, opacity=1.0, name=f'input t={ts[t_idx]:.2f}')
gt_bb_3d = plotting.create_scatter(
    gt_ca, mode='lines+markers', marker_size=3, opacity=1.0, name='gt')
plotting.plot_traces([
    out_bb_3d,
    # in_bb_3d,
    gt_bb_3d
])
