In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
import os
from omegaconf import OmegaConf
from models.flow_module import FlowModule
import torch
from data.pdb_dataloader import PdbDataModule
import GPUtil
from data import utils as du
import numpy as np
import tree
from data import so3_utils
from data import all_atom
from analysis import utils as au
from openfold.utils.superimposition import superimpose
import matplotlib.pyplot as plt
import copy


In [3]:
save_dir = 'notebook_samples/'
num_timesteps = 100
os.makedirs(save_dir, exist_ok=True)
cuda_id = GPUtil.getAvailable(order='memory', limit = 8)[0]
device = f'cuda:{cuda_id}'

In [4]:
# Setup lightning module

ckpt_dir = '../ckpt/se3-fm/linear_no_embed_t/2023-10-08_08-51-30/'
# ckpt_dir = '../ckpt/se3-fm/baseline/2023-10-06_01-14-03/'
ckpt_path = os.path.join(ckpt_dir, 'last.ckpt')

print(ckpt_path)

base_path = '../configs/base.yaml'
base_cfg = OmegaConf.load(base_path)

cfg_path = os.path.join(ckpt_dir, 'config.yaml')
ckpt_cfg = OmegaConf.load(cfg_path)

OmegaConf.set_struct(base_cfg, False)
OmegaConf.set_struct(ckpt_cfg, False)
cfg = OmegaConf.merge(base_cfg, ckpt_cfg)
cfg.experiment.checkpointer.dirpath = './'
cfg.experiment.rescale_time = False
cfg.data.dataset.csv_path = '../preprocessed/metadata.csv'

flow = FlowModule.load_from_checkpoint(
    checkpoint_path=ckpt_path,
    model_cfg=cfg.model,
    experiment_cfg=cfg.experiment,
    map_location=device,
)
_ = flow.eval()


../ckpt/se3-fm/linear_no_embed_t/2023-10-08_08-51-30/last.ckpt


Computing igso3_expansion: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 100.38it/s]


In [5]:
# Set up data module
data_module = PdbDataModule(cfg.data)
data_module.setup('fit')
train_dataloader = data_module.train_dataloader(
    num_replicas=1,
    rank=1
)
data_iter = iter(train_dataloader)


In [6]:
# Search for a reasonable batch
stop_search = False
while not stop_search: 
    batch = next(data_iter)
    num_batch, num_res, _ = batch['trans_1'].shape
    if num_res > 120:
        stop_search = True
num_batch, num_res, _ = batch['trans_1'].shape
print(f'Found batch with {num_res} residues, {num_batch} batch size')

# Set up device and cuda
# device = 'cpu'
print(f'Using device {device}')
flow.model = flow.model.to(device)
batch = tree.map_structure(lambda x: x.to(device), batch)
num_batch = batch['res_mask'].shape[0]

Found batch with 127 residues, 44 batch size
Using device cuda:0


# Full sampling

In [7]:
min_t = 1e-3
seed = 42
t = torch.ones(num_batch, 1, 1, device=device) * min_t
res_mask = batch['res_mask']
torch.manual_seed(seed)
orig_noisy_batch = flow._corrupt_batch(batch, t=t)
orig_noisy_batch['t'] = t

do_sde = False
gt_trans_1 = batch['trans_1'].detach().cpu()
gt_rotmats_1 = batch['rotmats_1'].detach().cpu()
gt_atom37 = all_atom.atom37_from_trans_rot(
    gt_trans_1,
    gt_rotmats_1,
    res_mask.detach().cpu()
)
gt_atom37 = du.to_numpy(gt_atom37)
gt_rotvec = so3_utils.rotmat_to_rotvec(gt_rotmats_1)

In [21]:
noisy_batch = copy.deepcopy(orig_noisy_batch)
num_timesteps = 10
so3_scale = 10
ts = np.linspace(min_t, 1.0 - min_t, num_timesteps)

In [25]:
def run_model(trans_t, rots_t, noisy_batch, t_1):
    noisy_batch['trans_t'] = trans_t
    noisy_batch['rotmats_t'] = rots_t
    noisy_batch['t'] = torch.ones((num_batch, 1)).to(device) * t_1
    with torch.no_grad():
        model_out = flow.forward(noisy_batch)
    return model_out

def so3_vf_t(t, mats_1, mats_t):
    return 10.0 * so3_utils.calc_rot_vf(mats_t, mats_1)
    # return 1 / (1 - t) * so3_utils.calc_rot_vf(mats_t, mats_1)
r3_vf_t = lambda t, trans, trans_t: (trans - trans_t) / (1 - t)    

In [26]:
# Run sampling with RK4
t_1 = ts[0]
all_pred_transrot = []
trans_vf_traj = []
all_pred_rots_vf = []
all_gt_rots_vf = []
trans_rot_traj = [(noisy_batch['trans_t'], noisy_batch['rotmats_t'])]
r3_vf_t = lambda t, trans, trans_t: (trans - trans_t) / (1 - t)
for i,t_2 in enumerate(ts[1:]):
    if (i+1) % 100 == 0:
        print(f"Step {i+1} / {len(ts)}")
    d_t = t_2 - t_1
    trans_t_1, rots_t_1 = trans_rot_traj[-1]

    k1 = run_model(trans_t_1, rots_t_1, noisy_batch, t_1)
    trans_k1 = r3_vf_t(t_1, k1['pred_trans'], trans_t_1)
    rots_k1 = so3_vf_t(t_1, k1['pred_rots'].get_rot_mats(), rots_t_1)

    trans_input_k2 = trans_t_1 + d_t * trans_k1 / 3
    rots_input_k2 = torch.einsum(
        "...ij,...jk->...ik",
        rots_t_1,
        so3_utils.rotvec_to_rotmat(d_t * rots_k1 / 3)
    )
    t_k2 = t_1 + d_t / 3
    
    k2 = run_model(trans_input_k2, rots_input_k2, noisy_batch, t_k2)
    trans_k2 = r3_vf_t(t_k2, k2['pred_trans'], trans_input_k2)
    rots_k2 = so3_vf_t(t_k2, k2['pred_rots'].get_rot_mats(), rots_input_k2)

    trans_input_k3 = trans_t_1 + d_t * (trans_k2 - trans_k1/3)
    rots_input_k3 = torch.einsum(
        "...ij,...jk->...ik",
        rots_t_1,
        so3_utils.rotvec_to_rotmat(d_t * rots_k2 / 3)
    )
    t_k3 = t_1 + d_t * 2 / 3
    
    k3 = run_model(trans_input_k3, rots_input_k3, noisy_batch, t_k3)
    trans_k3 = r3_vf_t(t_k3, k3['pred_trans'], trans_input_k3)
    rots_k3 = so3_vf_t(t_k3, k3['pred_rots'].get_rot_mats(), rots_input_k3)

    trans_input_k4 = trans_t_1 + d_t * (trans_k1 - trans_k2 + trans_k3)
    rots_input_k4 = torch.einsum(
        "...ij,...jk->...ik",
        rots_t_1,
        so3_utils.rotvec_to_rotmat(d_t * rots_k3 / 3)
    )
    t_k4 = t_1 + d_t
    
    k4 = run_model(trans_input_k4, rots_input_k4, noisy_batch, t_k4)
    trans_k4 = r3_vf_t(t_k4, k4['pred_trans'], trans_input_k4)
    rots_k4 = so3_vf_t(t_k4, k4['pred_rots'].get_rot_mats(), rots_input_k4)
    
    # r3_vf_t = lambda t, trans, trans_t: (trans - trans_t) / (1 - t)
    # k1 = r3_vf_t(t_1, trans_nm_1, trans_t)
    # k2 = r3_vf_t(t_1 + d_t / 3, trans_nm_1, trans_t + d_t * k1 / 3)
    # k3 = r3_vf_t(t_1 + d_t * 2 / 3, trans_nm_1, trans_t + d_t * (k2 - k1/3))
    # k4 = r3_vf_t(t_1 + d_t, trans_nm_1, trans_t + d_t * (k1 - k2 + k3))
    # trans_t_1 = trans_t + (k1 + 3 * (k2 + k3) + k4) * d_t * 0.125

    
    # with torch.no_grad():
    #     noisy_batch['trans_t'] = trans_t_1
    #     noisy_batch['rotmats_t'] = rots_t_1
    #     noisy_batch['t'] = torch.ones((num_batch, 1)).to(device) * t_1
    #     model_out = flow.forward(noisy_batch)

    pred_trans_1 = k1['pred_trans']
    pred_rots_1 = k1['pred_rots'].get_rot_mats()
    pred_rots_vf = k1['pred_rots_vf']
    all_pred_rots_vf.append(du.to_numpy(pred_rots_vf))
    gt_rot_vf = so3_utils.calc_rot_vf(
        noisy_batch['rotmats_t'].type(torch.float32),
        noisy_batch['rotmats_1'].type(torch.float32)
    )
    all_gt_rots_vf.append(du.to_numpy(gt_rot_vf))
    
    all_pred_transrot.append((pred_trans_1, pred_rots_1))

    # trans_vf = (pred_trans_1 - trans_t_1) * 2.0
    trans_vf = (pred_trans_1 - trans_t_1) / (1 - t_1)
    # if i == num_timesteps-2:
    if False:
        trans_t_2 = k1['pred_trans']
        rots_t_2 = k1['pred_rots'].get_rot_mats()
    else:
        # trans_t_2 = trans_t_1 + trans_vf * d_t
        trans_t_2 = trans_t_1 + (trans_k1 + 3 * (trans_k2 + trans_k3) + trans_k4) * d_t * 0.125
        mat_update = so3_utils.rotvec_to_rotmat((rots_k1 + 3 * (rots_k2 + rots_k3) + rots_k4) * d_t * 0.125)
        rots_t_2 = torch.einsum("...ij,...jk->...ik", rots_t_1, mat_update)    
        
        # TODO: Temporary
        # rots_out = run_model(trans_t_2, rots_t_1, noisy_batch, t_1)
        # pred_rots_vf = rots_out['pred_rots_vf']
        # rots_t_2 = so3_utils.geodesic_t(
        #     so3_scale * d_t, None, rots_t_1, rot_vf=pred_rots_vf)

    t_1 = t_2
    trans_rot_traj.append((trans_t_2, rots_t_2))

pred_atom37_traj = all_atom.transrot_to_atom37(
    all_pred_transrot,
    res_mask.detach().cpu()
)
atom37_traj = all_atom.transrot_to_atom37(
    trans_rot_traj,
    res_mask.detach().cpu()
)
pred_trans_traj = torch.stack([x[0] for x in all_pred_transrot]).detach().cpu()
pred_rots_traj = torch.stack([x[1] for x in all_pred_transrot]).detach().cpu()
pred_rotvec_traj = so3_utils.rotmat_to_rotvec(pred_rots_traj)
trans_traj = torch.stack([x[0] for x in trans_rot_traj]).detach().cpu()
rots_traj = torch.stack([x[1] for x in trans_rot_traj]).detach().cpu()
rotvec_traj = so3_utils.rotmat_to_rotvec(rots_traj)
output_pred_atom37_traj = du.to_numpy(torch.stack(pred_atom37_traj))
output_atom37_traj = du.to_numpy(torch.stack(atom37_traj))

In [27]:
# Save samples
save_dir = f'notebook_samples/rk4/ts_{num_timesteps}/scale_so3_{so3_scale}'
os.makedirs(save_dir, exist_ok=True)
num_timesteps, num_batch, num_res, _, _ = output_atom37_traj.shape

max_save = 2
for i in range(num_batch):
    if i >= max_save:
        break
    traj_path = au.write_prot_to_pdb(
        output_atom37_traj[:, i],
        os.path.join(
            save_dir,
            f'traj_{i}_len_{num_res}_t_{num_timesteps}.pdb'),
        no_indexing=True
    )
    sample_path = au.write_prot_to_pdb(
        output_atom37_traj[-1, i],
        os.path.join(
            save_dir,
            f'sample_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
        no_indexing=True
    )
    
    _ = au.write_prot_to_pdb(
        output_pred_atom37_traj[-1, i],
        os.path.join(
            save_dir,
            f'final_model_out_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
        no_indexing=True
    )
    _ = au.write_prot_to_pdb(
        output_pred_atom37_traj[:, i],
        os.path.join(
            save_dir,
            f'pred_traj_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
        no_indexing=True
    )
    _ = au.write_prot_to_pdb(
        gt_atom37[i],
        os.path.join(
            save_dir,
            f'gt_{i}_len_{num_res}_ts_{num_timesteps}.pdb',
        ),
        no_indexing=True
    )
    print(f'Done with sample {i}')

final_samples = output_atom37_traj[-1]
_ = au.write_prot_to_pdb( 
    final_samples,
    os.path.join(
        save_dir,
        'all_samples.pdb'),
    no_indexing=True
)


Done with sample 0
Done with sample 1


### Euler

In [8]:
noisy_batch = copy.deepcopy(orig_noisy_batch)
num_timesteps = 100
so3_scale = 100.0
r3_scale = 1.0
ts = np.linspace(min_t, 1.0 - min_t, num_timesteps)

In [9]:
# Run sampling with Euler
t_1 = ts[0]
all_pred_transrot = []
trans_vf_traj = []
all_pred_rots_vf = []
all_gt_rots_vf = []
trans_rot_traj = [(noisy_batch['trans_t'], noisy_batch['rotmats_t'])]
for i,t_2 in enumerate(ts[1:]):
    if (i+1) % 100 == 0:
        print(f"Step {i+1} / {len(ts)}")
    d_t = t_2 - t_1
    trans_t_1, rots_t_1 = trans_rot_traj[-1]

    with torch.no_grad():
        noisy_batch['trans_t'] = trans_t_1
        noisy_batch['rotmats_t'] = rots_t_1
        noisy_batch['t'] = torch.ones((num_batch, 1)).to(device) * t_1
        model_out = flow.forward(noisy_batch)

    pred_trans_1 = model_out['pred_trans']
    pred_rots_1 = model_out['pred_rots'].get_rot_mats()
    pred_rots_vf = model_out['pred_rots_vf']
    all_pred_rots_vf.append(du.to_numpy(pred_rots_vf))
    gt_rot_vf = so3_utils.calc_rot_vf(
        noisy_batch['rotmats_t'].type(torch.float32),
        noisy_batch['rotmats_1'].type(torch.float32)
    )
    all_gt_rots_vf.append(du.to_numpy(gt_rot_vf))
    
    all_pred_transrot.append((pred_trans_1, pred_rots_1))

    # trans_vf = (pred_trans_1 - trans_t_1) * 2.0
    # trans_vf = (pred_trans_1 - trans_t_1) / (1 - t_1)
    if i == num_timesteps - 2:
        trans_t_2 = model_out['pred_trans']
        rots_t_2 = model_out['pred_rots'].get_rot_mats()
    else:
        trans_t_2 = trans_t_1 + r3_scale * d_t * (pred_trans_1 - trans_t_1) / (1 - t_1)
        rots_t_2 = so3_utils.geodesic_t(
            so3_scale * d_t, None, rots_t_1, rot_vf=pred_rots_vf)

    t_1 = t_2
    trans_rot_traj.append((trans_t_2, rots_t_2))

pred_atom37_traj = all_atom.transrot_to_atom37(
    all_pred_transrot,
    res_mask.detach().cpu()
)
atom37_traj = all_atom.transrot_to_atom37(
    trans_rot_traj,
    res_mask.detach().cpu()
)
pred_trans_traj = torch.stack([x[0] for x in all_pred_transrot]).detach().cpu()
pred_rots_traj = torch.stack([x[1] for x in all_pred_transrot]).detach().cpu()
pred_rotvec_traj = so3_utils.rotmat_to_rotvec(pred_rots_traj)
trans_traj = torch.stack([x[0] for x in trans_rot_traj]).detach().cpu()
rots_traj = torch.stack([x[1] for x in trans_rot_traj]).detach().cpu()
rotvec_traj = so3_utils.rotmat_to_rotvec(rots_traj)
output_pred_atom37_traj = du.to_numpy(torch.stack(pred_atom37_traj))
output_atom37_traj = du.to_numpy(torch.stack(atom37_traj))

  return torch._transformer_encoder_layer_fwd(


In [10]:
# Save samples
save_dir = f'notebook_samples/euler/scale_so3_{so3_scale}/scale_r3_{r3_scale}'
os.makedirs(save_dir, exist_ok=True)
num_timesteps, num_batch, num_res, _, _ = output_atom37_traj.shape

max_save = 2
for i in range(num_batch):
    if i >= max_save:
        break
    traj_path = au.write_prot_to_pdb(
        output_atom37_traj[:, i],
        os.path.join(
            save_dir,
            f'traj_{i}_len_{num_res}_t_{num_timesteps}.pdb'),
        no_indexing=True
    )
    sample_path = au.write_prot_to_pdb(
        output_atom37_traj[-1, i],
        os.path.join(
            save_dir,
            f'sample_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
        no_indexing=True
    )
    
    _ = au.write_prot_to_pdb(
        output_pred_atom37_traj[-1, i],
        os.path.join(
            save_dir,
            f'final_model_out_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
        no_indexing=True
    )
    _ = au.write_prot_to_pdb(
        output_pred_atom37_traj[:, i],
        os.path.join(
            save_dir,
            f'pred_traj_{i}_len_{num_res}_ts_{num_timesteps}.pdb'),
        no_indexing=True
    )
    _ = au.write_prot_to_pdb(
        gt_atom37[i],
        os.path.join(
            save_dir,
            f'gt_{i}_len_{num_res}_ts_{num_timesteps}.pdb',
        ),
        no_indexing=True
    )
    print(f'Done with sample {i}')

final_samples = output_atom37_traj[-1]
_ = au.write_prot_to_pdb( 
    final_samples,
    os.path.join(
        save_dir,
        'all_samples.pdb'),
    no_indexing=True
)


Done with sample 0
Done with sample 1
