In [1]:
from moleculib.protein.dataset import ProteinDataset
from moleculib.protein.batch import PadBatch
from preprocess import StandardizeTransform
from torch.utils.data import DataLoader
from visualize import backbone_to_pdb, backbones_to_animation, pred_to_pdb
from models.en_denoiser import EnDenoiser
from einops import rearrange
import torch
import torch.nn.functional as F
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go

# Load data

In [2]:
CHECKPOINT = "../checkpoints/ex_20230413_190231/epoch=10599-step=10600.ckpt"
TRAIN_DIR = "../data/single"
transform = [StandardizeTransform()]
train_dataset = ProteinDataset(TRAIN_DIR, transform=transform, preload=True)

  warn("{} elements were guessed from atom_name.".format(rep_num))


In [3]:
loader = DataLoader(train_dataset, collate_fn=PadBatch.collate, batch_size=2, shuffle=False)
batch = next(iter(loader))

In [4]:
model = EnDenoiser.load_from_checkpoint(CHECKPOINT).eval()

In [None]:
model.distmap_score(batch)

In [None]:
def rescale_protein(coords, std_const=9.0):
    return std_const * coords

In [None]:
def rearrange_coords(coords):
    new_coords = coords.squeeze(0)[:,:num_backbone, :]
    new_coords = rearrange(new_coords, "s b c -> (s b) c")
    return new_coords

# Visualize the sample backbone

In [5]:
pdb_fname = "test.pdb"
coord = batch.atom_coord[0]
seq = str(batch.sequence[0])

In [7]:
pred_to_pdb(coord, seq, pdb_fname, rearrange=True)

File test.pdb has been saved.


In [15]:
coords, seqs, masks = model.prepare_inputs(batch)
pdb_fname = "test2.pdb"
coord = coords[0]
seq = str(batch.sequence[0])
pred_to_pdb(coord, seq, pdb_fname, rearrange=False)

File test2.pdb has been saved.


In [17]:
coord * 9

tensor([[ 13.9157,   9.6683,   0.5851],
        [ 13.7884,   9.4367,  -0.7392],
        [ 12.9711,   8.3390,  -1.0348],
        ...,
        [ -8.1376,  11.9117, -14.8142],
        [ -9.3939,  12.7790, -14.4148],
        [ -9.2148,  14.1418, -13.6587]])

# Visualize noise steps

In [None]:
def visualize_noise(batch, num_steps):
    x = batch.atom_coord[:,:,:num_backbone, :]
    coords_list = [coord]
    for i in range(num_steps):
        ts = torch.tensor([i])
        noised_x, noise = model.diffusion.q_sample(x, ts)
        noised_x = rescale_protein(rearrange_coords(noised_x))
        coords_list.append(noised_x)
    backbones_to_animation(coords_list, seq, "test_noise.pdb")

In [None]:
visualize_noise(batch, model.diffusion.timesteps)

# Perform backward diffusion

In [None]:
coords, seqs, masks = model.prepare_inputs(batch)
timesteps = model.diffusion.timesteps

In [None]:
results = model.diffusion.sample(model.transformer, coords, seqs, masks, timesteps)

In [None]:
results = [x.squeeze(0) for x in results]
original = coords[0]
losses = [float(F.mse_loss(x, original)) for x in results]

In [None]:
fig = px.line(pd.Series(losses), title='Loss from original over time')
fig.show()

In [None]:
model.distmap_score(batch)

In [None]:
# save PDB files for diffusion steps
last_result = rescale_protein(results[-1])
backbone_to_pdb(last_result, seq, f"backward_{model.diffusion.timesteps}.pdb", num_backbone)
backbones_to_animation(results, seq, "denoise.pdb")

# 3D Scatter Visualize

In [None]:
def show_scatter(mat):
    x, y, z = mat.T
    scatters = [
        go.Scatter3d(
            name="coord",
            x=x, y=y, z=z,
            mode='markers',
            marker=dict(
                size=3,
                colorscale="Viridis",
            )
        )
    ]
    fig = go.Figure(data=scatters)
    fig.update_layout(
        autosize=False,
        width=650,
        height=650,
    )
    fig.show()

In [None]:
show_scatter(original)

In [None]:
show_scatter(results[-1])

# Manually Noise-Denoise

In [None]:
prediction_losses = []

for i in range(0, model.timesteps, 10):
    coords, seq, masks = model.prepare_inputs(batch)
    ts = torch.tensor([i])

    # forward diffusion
    noised_coords, noise = model.q_sample(coords, ts)
    ts = ts.type(torch.float64)

    # predict noisy input with transformer
    feats, prediction = model.transformer(seq, noised_coords, ts, mask=masks)
    
    pred_noise = prediction - noised_coords
    prediction_losses.append(float(F.mse_loss(pred_noise, noise)))

In [None]:
fig = px.line(pd.Series(prediction_losses), title='Loss from original over time')
fig.show()