In [1]:
from ProtDiffusion.models.autoencoder_kl_1d import AutoencoderKL1D, AutoencoderKLOutput1D
from ProtDiffusion.models.dit_transformer_1d import DiTTransformer1DModel
from ProtDiffusion.models.pipeline_protein import ProtDiffusionPipeline, logits_to_token_ids
from ProtDiffusion.visualization_utils import make_logoplot, plot_latent_and_probs
from ProtDiffusion.training_utils import process_sequence, tokenize_sequence

from transformers import PreTrainedTokenizerFast
from diffusers import DDPMScheduler
import torch
import torch.nn.functional as F

from datasets import load_from_disk, Dataset, DatasetDict

from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO

import os
import numpy as np
import matplotlib.pyplot as plt

tokenizer: PreTrainedTokenizerFast = PreTrainedTokenizerFast.from_pretrained("/home/kkj/ProtDiffusion/ProtDiffusion/tokenizer/tokenizer_v4.2")
noise_scheduler: DDPMScheduler = DDPMScheduler(num_train_timesteps=1000, clip_sample=False)
vae: AutoencoderKL1D = AutoencoderKL1D.from_pretrained('/home/kkj/ProtDiffusion/output/EMA_VAE_v24.12')
transformer: DiTTransformer1DModel = DiTTransformer1DModel.from_pretrained('/home/kkj/ProtDiffusion/output/EMA-temp_DiT')


Using Sinusoidal Positional Embeddings
num_positional_embeddings:  256
Using RoPE
RoPE dim:  64


In [2]:
pipeline = ProtDiffusionPipeline(
    transformer=transformer,
    vae=vae,
    scheduler=noise_scheduler,
    tokenizer=tokenizer,
)
test_dir = os.path.join("temp")
os.makedirs(test_dir, exist_ok=True)

seqs_lens = [256]
class_labels = [1]
guidance_scale = 1.0
eval_num_inference_steps = 10

output = pipeline(seq_len=seqs_lens,
                  class_labels=class_labels,
                  guidance_scale=guidance_scale,
                  num_inference_steps=eval_num_inference_steps,
                  generator=None,
                  output_type='aa_seq',
                  return_hidden_latents=True,
                  return_noise_pred=True,
                  cutoff=None,
)
sequence = output.seqs
print(sequence[0])

  0%|          | 0/10 [00:00<?, ?it/s]

[TPESFGTSEAYKSLGRGILTDLFTTNDAKFHIHDDEIYDDVLIKTYEQKDKQGCDACVKWDVYRSFFWGARDGATLKFFSDIHNEKALFAFQNYPNCFDCPHVVVLDKITNKVLCDDNDYYRNKKRELGPDKEVKIPKRTPCDFYRRFYTPQSQDSFDVPEKLYDCVANFIHDLKPALGTTGDFLRYRTKETNGPLPAVPHELWRFSRCHNKPTPQEKNLFGRPPCIKCGKEPQKADDRGKLKECCRRILHLM]E


In [3]:
# Plot some random latents
dataset = load_from_disk('/home/kkj/ProtDiffusion/datasets/IPR036736_90')
dataset.shuffle()

Dataset({
    features: ['clusterid', 'proteinid', 'sequence', 'label', 'length'],
    num_rows: 259393
})

In [4]:
num_samples = 0
vae.eval()
max_len = 256

for i in range(num_samples):
    sample = dataset[i]
    sequence = sample['sequence']
    sequence = process_sequence(sequence)
    sequence = sequence[:max_len]
    tokenized = tokenize_sequence(sequence, tokenizer)
    input_ids = tokenized['input_ids']

    vae_output: AutoencoderKLOutput1D = vae(input_ids)

    # Plot the latents
    scaled_latent = vae_output.latent_dist.mode()
    latent = vae.config.scaling_factor * scaled_latent
    logits = vae_output.sample
    logits = logits[0].detach()
    latent = latent[0].detach().cpu().numpy()
    probs = F.softmax(logits, dim=0).detach().cpu().numpy()

    # print(f"Sequence: {sequence}")
    # print(f"sequence length: {len(sequence)}")
    # print(f"Latent shape: {latent}")
    # print(f"Logits shape: {logits.shape}")

    plot_latent_and_probs(probs,
                          latent,
                          characters = tokenizer.decode(range(tokenizer.vocab_size)),
                          path=f"{test_dir}/latent_plot_{i}.png",
                          title=f"Latent and probabilities for sample {i}",)

In [7]:
pipeline.animate_inference(output, 
                           test_dir + "/inference",
                           plot_noise=False,
)

Animating inference of 11 steps
Animating step 0
Animating step 1
Animating step 2
Animating step 3
Animating step 4
Animating step 5
Animating step 6
Animating step 7
Animating step 8
Animating step 9
Animating step 10
Finished animating inference of 11 steps
Saved images to temp/inference
Creating gif...
