In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from omegaconf import OmegaConf
from models.flow_module import FlowModule
import torch
from data.pdb_dataloader import PdbDataModule
import glob
import GPUtil
from data import utils as du
from scipy.spatial.transform import Rotation
import numpy as np
import tree
from data import so3_utils
from data import all_atom
from analysis import plotting

In [None]:
# Setup lightning module

ckpt_dir = '/data/rsg/chemistry/jyim/projects/flow-matching/ckpt/se3-fm/baseline/2023-08-28_12-04-39/'
# ckpt_dir = '/data/rsg/chemistry/jyim/projects/flow-matching/ckpt/se3-fm/baseline/2023-08-28_17-12-28'
ckpt_path = glob.glob(os.path.join(ckpt_dir, '*.ckpt'))[0]
print(ckpt_path)

cfg = OmegaConf.load("configs/base.yaml")
cfg.experiment.checkpointer.dirpath = './'
flow = FlowModule.load_from_checkpoint(
    checkpoint_path=ckpt_path,
    model_cfg=cfg.model,
    experiment_cfg=cfg.experiment
) 
_ = flow.eval()


In [None]:
# Set up data module
data_module = PdbDataModule(cfg.data)
data_module.setup('fit')
train_dataloader = data_module.train_dataloader(
    num_replicas=1,
    rank=1
)
data_iter = iter(train_dataloader)


In [None]:
# Search for a reasonable batch
stop_search = False
while not stop_search: 
    batch = next(data_iter)
    num_batch, num_res, _ = batch['trans_1'].shape
    if num_res > 60 and num_res < 100:
        stop_search = True


In [None]:
# Set up device and cuda
num_batch, num_res, _ = batch['trans_1'].shape
cuda_id = GPUtil.getAvailable(order='memory', limit = 8)[0]
device = f'cuda:{cuda_id}'
flow.model = flow.model.to(device)
batch = tree.map_structure(lambda x: x.to(device), batch)
num_batch = batch['res_mask'].shape[0]

# Model preduction at different timesteps

In [None]:
t = torch.ones(num_batch, 1, 1, device=device) * 0.1
noisy_batch = flow._corrupt_batch(batch, t=t)

model_out = flow.forward(noisy_batch)
trans_out = du.to_numpy(model_out['pred_trans'])
trans_in = du.to_numpy(noisy_batch['trans_t'])
gt = du.to_numpy(noisy_batch['trans_1'])
t = du.to_numpy(noisy_batch['t'][:, 0])
num_batch = trans_out.shape[0]
print(num_batch)
print(t)

In [None]:
idx = 0
model_out_ca = trans_out[idx]
model_in_ca = trans_in[idx]
gt_ca = gt[idx]

out_bb_3d = plotting.create_scatter(
    model_out_ca, mode='lines+markers', marker_size=3, opacity=1.0, name='pred')
in_bb_3d = plotting.create_scatter(
    model_in_ca, mode='lines+markers', marker_size=3, opacity=1.0, name=f'input t={t[idx]:.2f}')
gt_bb_3d = plotting.create_scatter(
    gt_ca, mode='lines+markers', marker_size=3, opacity=1.0, name='gt')
plotting.plot_traces([
    out_bb_3d,
    #in_bb_3d,
    gt_bb_3d
])


# Partial sampling

In [None]:
min_t = 1e-3
t = torch.ones(num_batch, 1, 1, device=device) * min_t
noisy_batch = flow._corrupt_batch(batch, t=t)
batch_trunc = 5
trans_in = du.to_numpy(noisy_batch['trans_t'][:batch_trunc])
gt = du.to_numpy(noisy_batch['trans_1'][:batch_trunc])

In [None]:
# Run sampling
trans_traj = [noisy_batch['trans_t']]
rots_traj = [noisy_batch['rotmats_t']]
num_timesteps = 1000
ts = np.linspace(min_t, 1.0, num_timesteps)
t_1 = ts[0]
model_outputs = []
for i,t_2 in enumerate(ts[1:]):
    if (i+1) % 100 == 0:
        print(f"Step {i+1} / {len(ts)}")
    d_t = t_2 - t_1
    trans_t_1 = trans_traj[-1]
    rots_t_1 = rots_traj[-1]
    with torch.no_grad():
        batch['trans_t'] = trans_t_1
        batch['rotmats_t'] = rots_t_1
        batch['t'] = torch.ones((num_batch, 1)).to(device) * t_1
        model_out = flow.forward(batch)
        model_outputs.append(
            tree.map_structure(lambda x: du.to_numpy(x), model_out)
        )

    pred_trans_1 = model_out['pred_trans']
    pred_rots_1 = model_out['pred_rotmats']
    pred_rots_vf = model_out['pred_rots_vf']

    trans_vf = (pred_trans_1 - trans_t_1) / (1 - t_1)
    trans_t_2 = trans_t_1 + trans_vf * d_t
    rots_t_2 = so3_utils.geodesic_t(
        d_t / (1 - t_1), pred_rots_1, rots_t_1, rot_vf=pred_rots_vf)
    t_1 = t_2
    trans_traj.append(trans_t_2)
    rots_traj.append(rots_t_2)


In [None]:
idx = 6
model_out_ca = model_outputs[-1]['pred_trans'][idx]
sample_bb_3d = plotting.create_scatter(model_out_ca, mode='lines+markers', marker_size=3, opacity=1.0)
plotting.plot_traces([sample_bb_3d])

In [None]:
idx = 0
t_idx = -1
model_out_ca = model_outputs[t_idx]['pred_trans'][idx]
model_in_ca = trans_in[idx]
gt_ca = gt[idx]

out_bb_3d = plotting.create_scatter(
    model_out_ca, mode='lines+markers', marker_size=3, opacity=1.0, name='pred')
in_bb_3d = plotting.create_scatter(
    model_in_ca, mode='lines+markers', marker_size=3, opacity=1.0, name=f'input t={ts[t_idx]:.2f}')
gt_bb_3d = plotting.create_scatter(
    gt_ca, mode='lines+markers', marker_size=3, opacity=1.0, name='gt')
plotting.plot_traces([
    out_bb_3d,
    # in_bb_3d,
    gt_bb_3d
])


# Sampling

In [None]:
# Run sampling
trans_0 = flow._centered_gaussian(batch['trans_1'].shape, device) * du.NM_TO_ANG_SCALE
rots_0 = torch.tensor(
    Rotation.random(num_batch*num_res).as_matrix(),
    device=device,
    dtype=torch.float32,
).reshape(num_batch, num_res, 3, 3)

trans_traj = [trans_0]
rots_traj = [rots_0]
ts = np.linspace(1e-3, 1.0, 500)
t_1 = ts[0]
model_outputs = []
for i,t_2 in enumerate(ts[1:]):
    if i % 100 == 0:
        print(f"Step {i+1} / {len(ts)}")
    d_t = t_2 - t_1
    trans_t_1 = trans_traj[-1]
    rots_t_1 = rots_traj[-1]
    with torch.no_grad():
        batch['trans_t'] = trans_t_1
        batch['rotmats_t'] = rots_t_1
        batch['t'] = torch.ones((num_batch, 1)).to(device) * t_1
        model_out = flow.forward(batch)
        model_outputs.append(
            tree.map_structure(lambda x: du.to_numpy(x), model_out)
        )

    pred_trans_1 = model_out['pred_trans']
    pred_rots_1 = model_out['pred_rotmats']

    trans_vf = (pred_trans_1 - trans_t_1) / (1 - t_1)
    trans_t_2 = trans_t_1 + trans_vf * d_t
    rots_t_2 = so3_utils.geodesic_t(d_t / (1 - t_1), pred_rots_1, rots_t_1)
    t_1 = t_2
    trans_traj.append(trans_t_2)
    rots_traj.append(rots_t_2)


In [None]:
model_out_ca = model_outputs[-1]['pred_trans'][0]
sample_bb_3d = plotting.create_scatter(model_out_ca, mode='lines+markers', marker_size=3, opacity=1.0)
plotting.plot_traces([sample_bb_3d])

In [None]:
gt_trans_1 = du.to_numpy(batch['trans_1'])

In [None]:
noisy_batch = flow._corrupt_batch(batch)

In [None]:
gt_trans_1 = du.to_numpy(noisy_batch['trans_1'])
gt_trans_t = du.to_numpy(noisy_batch['trans_t'])
noisy_batch['t'][1]

In [None]:
du.to_numpy(noisy_batch['t'])[:, 0]

In [None]:
idx = 5
t = noisy_batch['t'][idx]
print(t)
gt_trans_1_3d = plotting.create_scatter(gt_trans_1[idx], mode='lines+markers', marker_size=3, opacity=1.0)
gt_trans_t_3d = plotting.create_scatter(gt_trans_t[idx], mode='lines+markers', marker_size=3, opacity=1.0)
plotting.plot_traces([
    gt_trans_1_3d,
    gt_trans_t_3d
])

In [None]:
model_out_ca = model_outputs[0]['pred_trans'][3]
sample_bb_3d = plotting.create_scatter(model_out_ca, mode='lines+markers', marker_size=3, opacity=1.0)
plotting.plot_traces([sample_bb_3d])

In [None]:
ca_pos = du.to_numpy(trans_traj[0][0])
sample_bb_3d = plotting.create_scatter(ca_pos, mode='lines+markers', marker_size=3, opacity=1.0)
plotting.plot_traces([sample_bb_3d])

In [None]:
# Process outputs
res_mask = batch['res_mask']
atom37_traj = []
res_mask = res_mask.detach().cpu()
for trans, rots in zip(trans_traj, rots_traj):
    rigids = du.create_rigid(rots, trans)
    atom37 = all_atom.compute_backbone(
        rigids,
        torch.zeros(
            trans.shape[0],
            trans.shape[1],
            2,
            device=trans.device
        )
    )[0]
    atom37 = atom37.detach().cpu()
    batch_atom37 = []
    for i in range(num_batch):
        batch_atom37.append(
            du.adjust_oxygen_pos(atom37[i], res_mask[i])
        )
    atom37_traj.append(torch.stack(batch_atom37))


In [None]:
saved_path = au.write_prot_to_pdb(
    final_pos,
    os.path.join(
        self._sample_write_dir,
        f'sample_{i}_len_{num_res}_epoch_{self.current_epoch}.pdb'),
    no_indexing=False
)

In [None]:
    batch, pdb_names = batch
    if self.current_epoch == 0:
        self._print_logger.info(f'Running eval on batches from {pdb_names}')
    res_mask = batch['res_mask']
    device = res_mask.device
    num_batch, num_res = res_mask.shape[:2]
    trans_0 = self._centered_gaussian(batch['trans_1'].shape, device) * du.NM_TO_ANG_SCALE
    rots_0 = torch.tensor(
        Rotation.random(num_batch*num_res).as_matrix(),
        device=device,
        dtype=torch.float32,
    )
    if rots_0.ndim == 3:
        rots_0 = rots_0[None]
    
    trans_traj = [trans_0]
    rots_traj = [rots_0]
    ts = np.linspace(self._sampling_cfg.min_t, 1.0, self._sampling_cfg.num_timesteps)
    t_1 = ts[0]
    model_outputs = []
    for t_2 in ts[1:]:
        d_t = t_2 - t_1
        trans_t_1 = trans_traj[-1]
        rots_t_1 = rots_traj[-1]
        with torch.no_grad():
            if self._exp_cfg.noise_trans:
                batch['trans_t'] = trans_t_1
            else:
                batch['trans_t'] = batch['trans_1']
            if self._exp_cfg.noise_rots:
                batch['rotmats_t'] = rots_t_1
            else:
                batch['rotmats_t'] = batch['rotmats_1']
            batch['t'] = torch.ones((num_batch, 1)).to(device) * t_1
            model_out = self.forward(batch)
            model_outputs.append(
                tree.map_structure(lambda x: du.to_numpy(x), model_out)
            )

        pred_trans_1 = model_out['pred_trans']
        pred_rots_1 = model_out['pred_rotmats']

        trans_vf = (pred_trans_1 - trans_t_1) / (1 - t_1)
        trans_t_2 = trans_t_1 + trans_vf * d_t
        rots_t_2 = so3_utils.geodesic_t(d_t / (1 - t_1), pred_rots_1, rots_t_1)
        t_1 = t_2
        if return_traj:
            trans_traj.append(trans_t_2)
            rots_traj.append(rots_t_2)
        else:
            trans_traj[-1] = trans_t_2
            rots_traj[-1] = rots_t_2

    atom37_traj = []
    res_mask = res_mask.detach().cpu()
    for trans, rots in zip(trans_traj, rots_traj):
        rigids = du.create_rigid(rots, trans)
        atom37 = all_atom.compute_backbone(
            rigids,
            torch.zeros(
                trans.shape[0],
                trans.shape[1],
                2,
                device=trans.device
            )
        )[0]
        atom37 = atom37.detach().cpu()
        batch_atom37 = []
        for i in range(num_batch):
            batch_atom37.append(
                du.adjust_oxygen_pos(atom37[i], res_mask[i])
            )
        atom37_traj.append(torch.stack(batch_atom37))

    if return_model_outputs:
        return atom37_traj, model_outputs
    return atom37_traj

In [None]:

@torch.no_grad()
def run_sampling(self, batch: Any, return_traj=False, return_model_outputs=False):


In [None]:
cfg

In [None]:
with initialize(version_base=None, config_path="configs/base.yaml"):
    cfg = compose()
    print(cfg)