In [33]:
from moleculib.protein.dataset import ProteinDataset
from moleculib.protein.batch import PadBatch
from preprocess import StandardizeTransform
from torch.utils.data import DataLoader
from visualize import backbone_to_pdb, backbones_to_animation
from models.en_denoiser import EnDenoiser
from einops import rearrange
import torch
import torch.nn.functional as F
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go

# Load data

In [34]:
CHECKPOINT = "../checkpoints/ex_20230411_015817/epoch=3649-step=3650.ckpt"
TRAIN_DIR = "../data/single"
transform = [StandardizeTransform()]
train_dataset = ProteinDataset(TRAIN_DIR, transform=transform, preload=True)

In [35]:
loader = DataLoader(train_dataset, collate_fn=PadBatch.collate, batch_size=1, shuffle=False)
batch = next(iter(loader))

In [36]:
model = EnDenoiser.load_from_checkpoint(CHECKPOINT)
model.eval()

EnDenoiser(
  (transformer): EnTransformer(
    (token_emb): Embedding(23, 64)
    (time_mlp): Sequential(
      (0): SinusoidalPositionEmbeddings()
      (1): Linear(in_features=64, out_features=256, bias=True)
      (2): GELU(approximate='none')
      (3): Linear(in_features=256, out_features=256, bias=True)
    )
    (layers): ModuleList(
      (0): Block(
        (attn): Residual(
          (fn): EquivariantAttention(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=256, out_features=128, bias=True)
            )
            (norm): LayerNorm()
            (to_qkv): Linear(in_features=64, out_features=768, bias=False)
            (to_out): Linear(in_features=256, out_features=64, bias=True)
            (coors_mlp): Sequential(
              (0): Linear(in_features=4, out_features=16, bias=False)
              (1): GELU(approximate='none')
              (2): Linear(in_features=16, out_features=4, bias=False)
            )
          

In [37]:
def rescale_protein(coords, std_const=9.0):
    return std_const * coords

In [38]:
def rearrange_coords(coords):
    new_coords = coords.squeeze(0)[:,:num_backbone, :]
    new_coords = rearrange(new_coords, "s b c -> (s b) c")
    return new_coords

# Visualize the sample backbone

In [39]:
pdb_fname = "test.pdb"
num_backbone = 4
coord = batch.atom_coord
coord = rescale_protein(rearrange_coords(coord))
seq = str(batch.sequence[0])
backbone_to_pdb(coord, seq, pdb_fname, num_backbone)

File test.pdb has been saved.


'test.pdb'

# Visualize noise steps

In [40]:
def visualize_noise(batch, num_steps):
    x = batch.atom_coord[:,:,:num_backbone, :]
    coords_list = [coord]
    for i in range(num_steps):
        ts = torch.tensor([i])
        noised_x, noise = model.diffusion.q_sample(x, ts)
        noised_x = rescale_protein(rearrange_coords(noised_x))
        coords_list.append(noised_x)
    backbones_to_animation(coords_list, seq, "test_noise.pdb")

In [41]:
visualize_noise(batch, model.diffusion.timesteps)

File test_noise.pdb has been saved.


# Perform backward diffusion

In [42]:
results = model.diffusion.sample(model, batch, model.diffusion.timesteps - 1)

sampling loop time step: 100%|████████████████████████████████████████████████████████████████| 99/99 [01:33<00:00,  1.06it/s]


In [43]:
results = [x.squeeze(0) for x in results]
original = rearrange_coords(batch.atom_coord)
losses = [float(F.mse_loss(x, original)) for x in results]

In [44]:
fig = px.line(pd.Series(losses), title='Loss from original over time')
fig.show()

In [20]:
# save PDB files for diffusion steps
last_result = rescale_protein(results[-1])
backbone_to_pdb(last_result, seq, f"backward_{model.diffusion.timesteps}.pdb", num_backbone)
backbones_to_animation(results, seq, "denoise.pdb")

File backward_100.pdb has been saved.
File denoise.pdb has been saved.


'denoise.pdb'

# 3D Scatter Visualize

In [26]:
def show_scatter(mat):
    x, y, z = mat.T
    scatters = [
        go.Scatter3d(
            name="coord",
            x=x, y=y, z=z,
            mode='markers',
            marker=dict(
                size=3,
                colorscale="Viridis",
            )
        )
    ]
    fig = go.Figure(data=scatters)
    fig.update_layout(
        autosize=False,
        width=650,
        height=650,
    )
    fig.show()

In [27]:
show_scatter(original)

In [30]:
show_scatter(results[-1])

# Manually Noise-Denoise

In [None]:
prediction_losses = []

for i in range(0, model.timesteps, 10):
    coords, seq, masks = model.prepare_inputs(batch)
    ts = torch.tensor([i])

    # forward diffusion
    noised_coords, noise = model.q_sample(coords, ts)
    ts = ts.type(torch.float64)

    # predict noisy input with transformer
    feats, prediction = model.transformer(seq, noised_coords, ts, mask=masks)
    
    pred_noise = prediction - noised_coords
    prediction_losses.append(float(F.mse_loss(pred_noise, noise)))

In [None]:
fig = px.line(pd.Series(prediction_losses), title='Loss from original over time')
fig.show()