In [None]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    train_batch_size = 16
    mega_batch = 100 # how many batches to use for batchsampling
    eval_batch_size = 16  # how many images to sample during evaluation
    eval_seq_len = 48  # the generated image resolution
    num_epochs = 200  # the number of epochs to train the model
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "protein-diffusion"  # the model name locally and on the HF Hub
    pad_to_multiple_of = 16

    class_embeddings_concat = False  # whether to concatenate the class embeddings to the time embeddings

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    hub_model_id = "kkj15dk/AA_test"  # the name of the repository to create on the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

    labels_file = 'labels_test.json'


config = TrainingConfig()

In [None]:
from datasets import load_dataset

config.dataset_name = "kkj15dk/test_dataset"
dataset = load_dataset(config.dataset_name, download_mode='force_redownload')
seed = 42
dataset = dataset.shuffle(seed)

In [None]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast.from_pretrained("kkj15dk/proteindiffusion_tokenizer")

def encode(example):
    return tokenizer(example['sequence'],
                    padding = True,
                    pad_to_multiple_of = config.pad_to_multiple_of,
                    return_token_type_ids=False,
                    return_attention_mask=True,
)
dataset_train = dataset['train'].map(encode, batched=False, remove_columns=["sequence"])
dataset_test = dataset['test'].map(encode, batched=False, remove_columns=["sequence"])
dataset_val = dataset['val'].map(encode, batched=False, remove_columns=["sequence"])

In [None]:
print(dataset_train)
print(dataset_train[0]['input_ids'])
print(dataset_train[0]['class'])
print(len(dataset_train[0]['input_ids']))
print(tokenizer.decode(dataset_train[0]['input_ids'], skip_special_tokens=False))
print(tokenizer.decode(dataset_train[0]['input_ids'], skip_special_tokens=True))

In [None]:
import random
import torch
import torch.nn.functional as F

input_ids_tensor = torch.tensor(dataset_train[0]['input_ids'], dtype=torch.long)
onehot_seq = F.one_hot(input_ids_tensor, num_classes=tokenizer.vocab_size + 1).permute(1, 0).unsqueeze(0)
print(onehot_seq.shape)

def tensor_to_seq(tokenizer, tensor, cutoff = None):
        '''
        Convert a tensor to a seq using the tokenizer.
        '''
        
        if cutoff is None:
            token_ids = tensor.argmax(dim=1)
        else:
            token_ids = torch.where(tensor.max(dim=1).values > cutoff, 
                                    tensor.argmax(dim=1), 
                                    torch.tensor([tokenizer.unknown_token_id])
                                    )
        print(token_ids)

        return tokenizer.batch_decode(token_ids)

print(tensor_to_seq(tokenizer, onehot_seq))


In [None]:
import random
import torch
import torch.nn.functional as F

def collate_fn(batch): # Can definitely be optimized
    max_len = max(len(x['input_ids']) for x in batch)
    input_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
    onehot = torch.zeros(len(batch), tokenizer.vocab_size + 1, max_len, dtype=torch.float)
    attention_mask = torch.zeros(len(batch), max_len, dtype=torch.float)
    class_labels = torch.zeros(len(batch), dtype=torch.long)
    for i, x in enumerate(batch):
        seq_len = len(x['input_ids'])
        input_ids[i, :seq_len] = torch.tensor(x['input_ids'], dtype=torch.long)
        onehot_seq = F.one_hot(input_ids[i, :seq_len], num_classes=tokenizer.vocab_size + 1).permute(1, 0)
        onehot[i, :, :seq_len] = onehot_seq
        attention_mask[i, :seq_len] = torch.tensor(1, dtype=torch.float)
        class_labels[i] = torch.tensor(x['class'], dtype=torch.long)
    return {'input_ids': input_ids, 'onehot': onehot, 'attention_mask': attention_mask, 'class_label': class_labels}

class BatchSampler:
    '''
    BatchSampler for variable length sequences, batching by similar lengths, to prevent excessive padding.
    '''
    def __init__(self, lengths, batch_size, mega_batch_size, drop_last = True):
        self.lengths = lengths
        self.batch_size = batch_size
        self.mega_batch_size = mega_batch_size
        self.drop_last = drop_last

    def __iter__(self):
        size = len(self.lengths)
        indices = list(range(size))
        random.shuffle(indices)

        step = self.mega_batch_size * self.batch_size
        for i in range(0, size, step):
            pool = indices[i:i+step]
            pool = sorted(pool, key=lambda x: self.lengths[x])
            mega_batch_indices = list(range(0, len(pool), self.batch_size))
            random.shuffle(mega_batch_indices) # shuffle the mega batches, so that the model doesn't see the same order of lengths every time. The small batch will however always be the one with longest lengths
            for j in mega_batch_indices:
                if j + self.batch_size > len(pool):  # assume drop_last=True
                    if self.drop_last:
                        continue
                yield pool[j:j+self.batch_size]

    def __len__(self):
        return len(self.lengths) // self.batch_size


In [None]:
import torch

train_lengths = list(map(lambda x: len(x["input_ids"]), dataset_train))
test_lengths = list(map(lambda x: len(x["input_ids"]), dataset_test))
val_lengths = list(map(lambda x: len(x["input_ids"]), dataset_val))

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset_train, 
                            batch_sampler=BatchSampler(train_lengths, 
                                                    config.train_batch_size,
                                os.makedirs(self.config.output_dir, exist_ok=True)
                os.makedirs(self.accelerator_config.logging_dir, exist_ok=True)                    config.mega_batch,
                                                    drop_last=False), 
                            collate_fn=collate_fn)
test_dataloader = DataLoader(dataset_test,
                            batch_sampler=BatchSampler(test_lengths, 
                                                    config.eval_batch_size,
                                                    config.mega_batch,
                                                    drop_last=False), 
                            collate_fn=collate_fn)
val_dataloader = DataLoader(dataset_val, 
                            batch_sampler=BatchSampler(val_lengths, 
                                                    config.eval_batch_size,
                                                    config.mega_batch,
                                                    drop_last=False),
                            collate_fn=collate_fn)

In [None]:
from New1D.unet_1d import UNet1DConditionModel

model = UNet1DConditionModel(
    sample_size=config.eval_seq_len,  # the target image resolution
    in_channels=tokenizer.vocab_size + 1,  # the number of input channels,
    out_channels=tokenizer.vocab_size + 1,  # the number of output channels
    num_class_embeds=2,  # the number of class embeddings
    
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 256, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock1D",  # a regular ResNet downsampling block
        "DownBlock1D",
        "DownBlock1D",
        "AttnDownBlock1D",  # a ResNet downsampling block with spatial self-attention
    ),
    mid_block_type="UNetMidBlock1D",
    up_block_types=(
        "AttnUpBlock1D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock1D",
        "UpBlock1D",
        "UpBlock1D",  # a regular ResNet upsampling block
    ),
    num_attention_heads=8,  # the number of attention heads in the spatial self-attention blocks
)

In [None]:
model.children
model.num_parameters()

In [None]:
sample_image = next(iter(val_dataloader))
print(sample_image['input_ids'][0])
print(sample_image['attention_mask'][0])
print(sample_image['onehot'][0])
print(sample_image['class_label'][0])

In [None]:
import torch
from PIL import Image
from diffusers import DDPMScheduler

sample_seq = sample_image['onehot'][0].unsqueeze(0)
print(sample_seq.shape)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
noise = torch.randn(sample_seq.shape)
timesteps = torch.LongTensor([50])
noisy_seq = noise_scheduler.add_noise(sample_seq, noise, timesteps)

class_labels = sample_image['class_label'][0].unsqueeze(0)
attention_mask = sample_image['attention_mask'][0].unsqueeze(0)
print(class_labels.shape)
print(attention_mask.shape)

In [None]:
import torch.nn.functional as F

noise_pred = model(sample = noisy_seq, 
                   timestep = timesteps,
                   class_labels = class_labels,
                   attention_mask = attention_mask,
                   ).sample
loss = F.mse_loss(noise_pred, noise)

In [None]:
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs),
)

In [None]:
from New1D.pipeline_protein import DDPMProteinPipeline
import os
import json

from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO

def evaluate(config, epoch, pipeline):
    # Sample some images from random noise (this is the backward diffusion process).
    # The default pipeline output type is `List[PIL.Image]`

    class_labels = torch.randint(0, 
                                 pipeline.unet.config.num_class_embeds, 
                                 (config.eval_batch_size,), 
                                 device=pipeline.device)

    seqs = pipeline(
        seq_len=config.eval_seq_len,
        batch_size=config.eval_batch_size,
        class_labels=class_labels, # Sample random class labels
        generator=torch.Generator(device='cpu').manual_seed(config.seed), # Use a separate torch generator to avoid rewinding the random state of the main training loop
        output_type="aa_seq",
        cutoff=None,
    ).seqs

    with open(config.labels_file, 'r') as f:
        labels = json.load(f)
    
    # Save the images
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)

    # Iterate over the dictionary
    def getarc(cl, data = labels):
        for key, value in data.items():
            if value['class'] == cl:
                return value['architecture']

    # Save all samples as a FASTA file
    seq_record_list = [SeqRecord(Seq(seq), id=str(i), 
                    description="classlabel: " + str(class_labels[i].item()) + 
                                " w: " + str('N/A') + 
                                " arc: " + str(getarc(class_labels[i]))) for i, seq in enumerate(seqs)]
    with open(f"{test_dir}/{epoch:04d}.fa", "w") as f:
        SeqIO.write(seq_record_list, f, "fasta")

In [None]:
from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from pathlib import Path
import os

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        if config.push_to_hub:
            repo_id = create_repo(
                repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
            ).repo_id
        accelerator.init_trackers("train_example")

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_seqs = batch["onehot"]
            attention_mask = batch["attention_mask"]
            class_labels = batch["class_label"]
            
            # Sample noise to add to the images
            noise = torch.randn(clean_seqs.shape, device=clean_seqs.device)
            bs = clean_seqs.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_seqs.device,
                dtype=torch.int64
            )

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_seq = noise_scheduler.add_noise(clean_seqs, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(sample = noisy_seq, 
                                    timestep = timesteps,
                                    class_labels = class_labels,
                                    attention_mask = attention_mask,
                                    return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            pipeline = DDPMProteinPipeline(unet=accelerator.unwrap_model(model), 
                                           scheduler=noise_scheduler, 
                                           tokenizer=tokenizer,
                                           ).to(accelerator.device)

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                if config.push_to_hub:
                    upload_folder(
                        repo_id=repo_id,
                        folder_path=config.output_dir,
                        commit_message=f"Epoch {epoch}",
                        ignore_patterns=["step_*", "epoch_*"],
                    )
                else:
                    pipeline.save_pretrained(config.output_dir)

In [None]:
from accelerate import notebook_launcher

args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

notebook_launcher(train_loop, args, num_processes=1)


In [None]:
import glob

sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])