In [11]:
from ProtDiffusion.training_utils import ProtDiffusionTrainingConfig, make_dataloader, set_seed, ProtDiffusionTrainer, count_parameters
from transformers import PreTrainedTokenizerFast
from diffusers import DDPMScheduler
import torch

from datasets import load_from_disk

from ProtDiffusion.models.autoencoder_kl_1d import AutoencoderKL1D
from ProtDiffusion.models.dit_transformer_1d import DiTTransformer1DModel

import os

config = ProtDiffusionTrainingConfig(
    num_epochs=100, # the number of epochs to train for
    batch_size=16,
    mega_batch=50,
    gradient_accumulation_steps=16,
    learning_rate = 1e-5,
    lr_warmup_steps = 200,
    save_image_model_steps=320,
    output_dir=os.path.join("output","ProtDiffusion-PKSs-test_v1.1"),  # the model name locally and on the HF Hub
    total_checkpoints_limit=5, # the maximum number of checkpoints to keep
    gradient_clip_val=1.0,
    max_len=8192, # 512 * 2**6
    max_len_start=8192,
    max_len_doubling_steps=100,
    ema_decay=0.9999,
    ema_update_after=100,
    ema_update_every=10,
    use_batch_optimal_transport=True,
)

In [12]:
set_seed(config.seed) # Set the random seed for reproducibility

# dataset = load_from_disk('/home/kkj/ProtDiffusion/datasets/UniRef50_grouped-test')
dataset = load_from_disk('/home/kaspe/ProtDiffusion/datasets/PKSs_grouped')
train_dataset = dataset.shuffle(config.seed)

# %%
tokenizer = PreTrainedTokenizerFast.from_pretrained("/home/kaspe/ProtDiffusion/ProtDiffusion/tokenizer/tokenizer_v4.1")

# Check dataset lengths
print(f"Train dataset length: {len(train_dataset)}")

generator = torch.Generator().manual_seed(config.seed)
# %%
print("num cpu cores:", os.cpu_count())
print("setting num_workers to 16")
num_workers = 16
train_dataloader = make_dataloader(config, 
                                   train_dataset,
                                   tokenizer=tokenizer,
                                   max_len=config.max_len_start,
                                   num_workers=num_workers,
                                   generator=generator,
)
print("length of train dataloader: ", len(train_dataloader))

# vae = AutoencoderKL1D.from_pretrained('/home/kaspe/ProtDiffusion/output/protein-VAE-UniRef50_v18.1/pretrained/EMA')

Random seed set as 42
Train dataset length: 3743
num cpu cores: 12
setting num_workers to 16
length of train dataloader:  234




In [24]:
from copy import deepcopy

In [40]:
dataloader = deepcopy(train_dataloader)
for i, batch in enumerate(dataloader):
    print(batch)
    break
del dataloader

{'id': ['4849b3bbacbf7a6b62e27bc8d6f062023d60b344', 'cd9eeab025259e6629fafb6131b6fe446b3cc530', '08ee665a14e99fc5718fbb302bba3334337b0b31', '67b2d06dd2a57db602bcc27cfccd249464f41e06', '6d456bc709b573c9014602b111665728dd9ada0e', '367c137df6f178bef1524a9b384741caeef7507c', 'e00ed35a20d48d550dc2abb42fa0135d71387bca', '409e4b12c64284dd577a1a84bfdc5930c9d8798e', '0c5550228bf00e7dfa7d00660cf22c704f800cfd', 'd7894bb0e0eef60dbdbfe212cae9ceee827295e5', 'ae6245d111ddd59cf33aac43e0b10ade5deeeaf2', 'd423f6239e9d1304e671b49e810e1ebc28bf8100', '9954c00d7cfd2fd384e988fc65be8e77e8cfa065', '0b6b30b908b677e2161409eebb484fbe69dbe415', '82fada57db262ce9b93452ff37ed8f061e194bac', '2761f3c24f911edf7e97193585bc526b2b5e6070'], 'label': tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]), 'sequence': ['----[MGPDWKANSCVLSIFKTRNPQVENCVSELNSTLDASDMPLRLSAIFDEDSVSVTGPGRSLETLKSTIASLPEPGTCRWAHVHGFYHGGDNMKLVSDQVKNDIHGIEFPDWSVLHASLRSTATGQKFQANDKSVSLLRMAVENIFCDSVNWKLTWETAAADYAEKIRHDTNATIRVIGIGPSAKSLLGANKDKIRDTGIQVI