In [19]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    batch_size = 64  # the batch size
    mega_batch = 1000 # how many batches to use for batchsampling
    num_epochs = 1000  # the number of epochs to train the model
    gradient_accumulation_steps = 2
    learning_rate = 1e-5
    lr_warmup_steps = 1000
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "output/protein-VAE-UniRef50-XS"  # the model name locally and on the HF Hub
    pad_to_multiple_of = 16
    max_len = 512  # truncation of the input sequence

    class_embeddings_concat = False  # whether to concatenate the class embeddings to the time embeddings

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    hub_model_id = "kkj15dk/protein-VAE_test"  # the name of the repository to create on the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 42

    labels_file = 'labels_test.json'

    cutoff = None # cutoff for when to predict the token given the logits, and when to assign the unknown token 'X' to this position
    skip_special_tokens = False # whether to skip the special tokens when writing the evaluation sequences
    kl_weight = 0.05 # the weight of the KL divergence in the loss function


config = TrainingConfig()

In [20]:
from datasets import load_dataset

# config.dataset_name = "kkj15dk/test_dataset"
config.dataset_name = "agemagician/uniref50"
dataset = load_dataset(config.dataset_name) # , download_mode='force_redownload')
dataset = dataset.shuffle(config.seed)

Resolving data files:   0%|          | 0/96 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/38 [00:00<?, ?it/s]

In [21]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast.from_pretrained("kkj15dk/protein_tokenizer")

def encode(example):
    return tokenizer(example['text'],
    # return tokenizer(example['sequence'],
                    padding = True,
                    pad_to_multiple_of = config.pad_to_multiple_of,
                    return_token_type_ids=False,
                    return_attention_mask=False, # We need to attend to padding tokens, so we set this to False
)
# dataset_train = dataset['train'].map(encode, batched=False, remove_columns=["sequence"])
# dataset_test = dataset['test'].map(encode, batched=False, remove_columns=["sequence"])
# dataset_val = dataset['val'].map(encode, batched=False, remove_columns=["sequence"])

# dataset_train = dataset['train'].map(encode, batched=False, remove_columns=["text"])
dataset_test = dataset['test'].map(encode, batched=False, remove_columns=["text"])
dataset_val = dataset['validation'].map(encode, batched=False, remove_columns=["text"])

Map:   0%|          | 0/5888 [00:00<?, ? examples/s]

Map:   0%|          | 0/6084 [00:00<?, ? examples/s]

In [22]:
print(dataset_test)
print(dataset_test[0]['input_ids'])
print(dataset_test[0]['name'])
print(len(dataset_test[0]['input_ids']))
print(tokenizer.decode(dataset_test[0]['input_ids'], skip_special_tokens=False))
print(tokenizer.decode(dataset_test[0]['input_ids'], skip_special_tokens=True))

Dataset({
    features: ['id', 'name', 'input_ids'],
    num_rows: 5888
})
[0, 13, 20, 17, 17, 20, 18, 7, 3, 18, 5, 3, 20, 15, 12, 15, 3, 19, 20, 18, 16, 6, 16, 20, 6, 17, 15, 12, 17, 18, 12, 12, 17, 18, 11, 15, 18, 12, 17, 21, 3, 20, 4, 3, 13, 20, 12, 18, 8, 12, 8, 13, 3, 13, 7, 3, 3, 20, 18, 20, 12, 21, 12, 16, 3, 3, 12, 8, 21, 5, 8, 6, 6, 20, 8, 12, 7, 12, 18, 7, 8, 18, 4, 12, 16, 12, 20, 18, 16, 20, 20, 20, 12, 15, 4, 12, 7, 11, 10, 19, 17, 21, 17, 6, 10, 17, 20, 3, 11, 10, 4, 12, 10, 3, 14, 3, 7, 11, 7, 13, 3, 12, 4, 18, 20, 16, 5, 18, 17, 21, 20, 7, 3, 10, 12, 10, 3, 3, 3, 15, 8, 7, 4, 8, 19, 15, 20, 12, 7, 18, 12, 4, 19, 9, 16, 20, 15, 11, 3, 5, 11, 17, 17, 21, 9, 7, 3, 12, 19, 3, 12, 14, 19, 20, 3, 19, 20, 12, 8, 3, 12, 3, 8, 18, 9, 12, 12, 20, 12, 8, 12, 16, 8, 18, 17, 15, 12, 8, 19, 15, 12, 7, 7, 3, 3, 3, 4, 22, 7, 12, 3, 8, 8, 4, 15, 8, 1, 23, 23, 23, 23, 23, 23, 23]
UniRef50_A0A7S0FGK4
224
[MVRRVSFASDAVPLPATVSQEQVERPLRSLLRSKPSLRWAVCAMVLSGLGMAMFAAVSVLWLQAALGWDGEEVGLFLSFGSCLQ

In [23]:
import torch
import torch.nn.functional as F

input_ids_tensor = torch.tensor(dataset_test[0]['input_ids'], dtype=torch.long)
onehot_seq = F.one_hot(input_ids_tensor, num_classes=tokenizer.vocab_size + 1).permute(1, 0).unsqueeze(0)
print(onehot_seq.shape)

def logits_to_token_ids(tokenizer, logits, cutoff = None):
        '''
        Convert a batch of logits to token_ids.
        Returns token_ids
        '''
        if cutoff is None:
            token_ids = logits.argmax(dim=1)
        else:
            token_ids = torch.where(logits.max(dim=1).values > cutoff, 
                                    logits.argmax(dim=1), 
                                    torch.tensor([tokenizer.unknown_token_id])
                                    )

        return token_ids
token_ids = logits_to_token_ids(tokenizer, onehot_seq)
print(token_ids)
print(tokenizer.batch_decode(token_ids, skip_special_tokens=config.skip_special_tokens))


torch.Size([1, 24, 224])
tensor([[ 0, 13, 20, 17, 17, 20, 18,  7,  3, 18,  5,  3, 20, 15, 12, 15,  3, 19,
         20, 18, 16,  6, 16, 20,  6, 17, 15, 12, 17, 18, 12, 12, 17, 18, 11, 15,
         18, 12, 17, 21,  3, 20,  4,  3, 13, 20, 12, 18,  8, 12,  8, 13,  3, 13,
          7,  3,  3, 20, 18, 20, 12, 21, 12, 16,  3,  3, 12,  8, 21,  5,  8,  6,
          6, 20,  8, 12,  7, 12, 18,  7,  8, 18,  4, 12, 16, 12, 20, 18, 16, 20,
         20, 20, 12, 15,  4, 12,  7, 11, 10, 19, 17, 21, 17,  6, 10, 17, 20,  3,
         11, 10,  4, 12, 10,  3, 14,  3,  7, 11,  7, 13,  3, 12,  4, 18, 20, 16,
          5, 18, 17, 21, 20,  7,  3, 10, 12, 10,  3,  3,  3, 15,  8,  7,  4,  8,
         19, 15, 20, 12,  7, 18, 12,  4, 19,  9, 16, 20, 15, 11,  3,  5, 11, 17,
         17, 21,  9,  7,  3, 12, 19,  3, 12, 14, 19, 20,  3, 19, 20, 12,  8,  3,
         12,  3,  8, 18,  9, 12, 12, 20, 12,  8, 12, 16,  8, 18, 17, 15, 12,  8,
         19, 15, 12,  7,  7,  3,  3,  3,  4, 22,  7, 12,  3,  8,  8,  4, 15,  8,
   

In [24]:
import random

def collate_fn(batch): # Can definitely be optimized
    max_len = max(len(x['input_ids']) for x in batch)
    if max_len > config.max_len:
        max_len = config.max_len
    input_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
    attention_mask = torch.zeros(len(batch), max_len, dtype=torch.float)
    class_labels = torch.zeros(len(batch), dtype=torch.long)
    # identifiers = [x['id'] for x in batch]
    identifiers = [x.get('name', 'N/A') for x in batch]
    for i, x in enumerate(batch):
        seq_len = len(x['input_ids'])
        if seq_len > max_len:
            index = random.randint(0, seq_len - max_len)
            x['input_ids'] = x['input_ids'][index:index+max_len]
            seq_len = max_len
        input_ids[i, :seq_len] = torch.tensor(x['input_ids'], dtype=torch.long)
        attention_mask[i, :seq_len] = torch.tensor(1, dtype=torch.float)
        # class_labels[i] = torch.tensor(x['class'], dtype=torch.long)
    return {'id': identifiers, 'input_ids': input_ids, 'attention_mask': attention_mask, 'class_label': class_labels}

class BatchSampler:
    '''
    BatchSampler for variable length sequences, batching by similar lengths, to prevent excessive padding.
    '''
    def __init__(self, lengths, batch_size, mega_batch_size, drop_last = True):
        self.lengths = lengths
        self.batch_size = batch_size
        self.mega_batch_size = mega_batch_size
        self.drop_last = drop_last

    def __iter__(self):
        size = len(self.lengths)
        indices = list(range(size))
        random.shuffle(indices)

        step = self.mega_batch_size * self.batch_size
        for i in range(0, size, step):
            pool = indices[i:i+step]
            pool = sorted(pool, key=lambda x: self.lengths[x])
            mega_batch_indices = list(range(0, len(pool), self.batch_size))
            random.shuffle(mega_batch_indices) # shuffle the mega batches, so that the model doesn't see the same order of lengths every time. The small batch will however always be the one with longest lengths
            for j in mega_batch_indices:
                if self.drop_last and j + self.batch_size > len(pool): # drop the last batch if it's too small
                    continue
                batch = pool[j:j+self.batch_size]
                random.shuffle(batch) # shuffle the batch, so that the model doesn't see the same order of lengths every time
                yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.lengths) // self.batch_size
        else:
            return (len(self.lengths) + self.batch_size - 1) // self.batch_size


In [25]:
import torch
import timeit

val_lengths = list(map(lambda x: len(x["input_ids"]), dataset_val))
print("Max val length:", max(val_lengths))
test_lengths = list(map(lambda x: len(x["input_ids"]), dataset_test))
print("Max test length:", max(test_lengths))
# train_lengths = list(map(lambda x: len(x["input_ids"]), dataset_train))
# print("Max val length:", max(val_lengths))

Max val length: 6064
Max test length: 13616


In [26]:
from torch.utils.data import DataLoader

# train_dataloader = DataLoader(dataset_train, 
#                             batch_sampler=BatchSampler(train_lengths, 
#                                                     config.batch_size,
#                                                     config.mega_batch,
#                                                     drop_last=False), 
#                             collate_fn=collate_fn)
test_dataloader = DataLoader(dataset_test,
                            batch_sampler=BatchSampler(test_lengths, 
                                                    config.batch_size,
                                                    config.mega_batch,
                                                    drop_last=False), 
                            collate_fn=collate_fn)
val_dataloader = DataLoader(dataset_val, 
                            batch_sampler=BatchSampler(val_lengths, 
                                                    config.batch_size,
                                                    config.mega_batch,
                                                    drop_last=False),
                            collate_fn=collate_fn)

In [27]:
from New1D.autoencoder_kl_1d import AutoencoderKL1D

model = AutoencoderKL1D(
    num_class_embeds=tokenizer.vocab_size + 1,  # the number of class embeddings
    
    down_block_types=(
        "DownEncoderBlock1D",  # a regular ResNet downsampling block
        "DownEncoderBlock1D",
        "DownEncoderBlock1D",
        "DownEncoderBlock1D",  # a ResNet downsampling block with spatial self-attention
    ),
    up_block_types=(
        "UpDecoderBlock1D",  # a ResNet upsampling block with spatial self-attention
        "UpDecoderBlock1D",
        "UpDecoderBlock1D",
        "UpDecoderBlock1D",  # a regular ResNet upsampling block
    ),
    block_out_channels=(64, 128, 256, 256),  # the number of output channels for each block
    mid_block_type="UNetMidBlock1D",  # the type of the middle block
    mid_block_channels=256,  # the number of output channels for the middle block
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    transformer_layers_per_block=1, # how many transformer layers to use per ResNet layer. Not implemented yet.

    latent_channels=64,  # the dimensionality of the latent space

    num_attention_heads=4,  # the number of attention heads in the spatial self-attention blocks
)

In [28]:
model.children
model.num_parameters()

8755160

In [29]:
sample_image = next(iter(val_dataloader))
print(sample_image.keys())
print(sample_image['id'][0])
print(sample_image['input_ids'][0])
print(sample_image['attention_mask'][0])
print(sample_image['class_label'][0])

dict_keys(['id', 'input_ids', 'attention_mask', 'class_label'])
UniRef50_A0A0M0BE48
tensor([ 3, 14,  5,  4,  3,  8, 11, 15, 14, 11, 15, 12,  3, 22, 20, 20, 19, 12,
        15, 10,  6, 15, 19, 16,  5, 22,  4,  8, 14, 15, 20, 12, 15, 11, 20, 17,
         3, 10, 12, 18, 21,  6,  3, 19, 15, 15,  5,  5,  5, 15,  5, 21, 15, 13,
        20, 21,  8, 14, 19, 12,  6, 17, 17, 10, 16, 10, 11, 15, 17, 12, 21, 12,
        12,  3,  5, 20, 20, 11, 20, 12,  8, 20, 10, 12, 11, 16,  6, 20, 11, 12,
        15, 15, 10,  7,  6,  6, 20,  6, 15,  4, 15, 10, 15, 16, 15,  6, 15, 15,
        15, 12, 19, 20, 15,  6, 12, 20, 16, 12, 22,  3,  3, 15,  3,  8,  3, 11,
        19, 17, 17,  6, 11, 10,  3, 20,  6, 15,  9, 17,  7,  8,  7, 14,  6, 10,
        16,  3, 16, 12, 19, 13,  8,  3, 20, 14, 15, 16, 13, 20,  3,  6, 13, 10,
         3, 16, 21,  5,  3, 20,  8, 12,  5, 21, 18,  3,  3, 20,  6,  3, 12, 16,
         3, 19, 18,  8, 14, 20, 18, 22,  6,  6, 20, 11,  4, 12,  8, 12,  5, 14,
        14, 17,  6, 21, 12, 20,  3, 

In [30]:
import torch

class_labels = sample_image['class_label'][0].unsqueeze(0)
attention_mask = sample_image['attention_mask'][0].unsqueeze(0)
print(class_labels.shape)
print(attention_mask.shape)

torch.Size([1])
torch.Size([1, 512])


In [31]:
import torch.nn.functional as F

input_ids = sample_image['input_ids']
attention_mask = sample_image['attention_mask']

output = model(sample = input_ids,
                attention_mask = attention_mask,
                sample_posterior = True, # Should be set to true in training
)

def loss_fn(output, input_ids
    ) -> tuple[torch.Tensor]:
    ce_loss = F.cross_entropy(output.sample, input_ids, reduction='none')
    ce_loss = torch.sum(
        ce_loss * output.attention_masks[0]
    ) / output.attention_masks[0].sum()
    
    kl_loss = output.latent_dist.kl()
    kl_loss = torch.sum(
        kl_loss * output.attention_masks[-1]
    ) / output.attention_masks[-1].sum()

    return ce_loss, kl_loss

# loss = loss_fn(output, input_ids)
# print(loss)

In [32]:
from diffusers.optimization import get_cosine_schedule_with_warmup

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    # num_training_steps=(len(test_dataloader) * config.num_epochs),
    num_training_steps=(len(test_dataloader) * config.num_epochs),
)

In [33]:
import os

from typing import Optional

from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO

@torch.no_grad()
def evaluate(config, epoch, output, dataloader, output_dir: Optional[str] = None
    ) -> float:

    if output_dir is not None:
        test_dir = os.path.join(config.output_dir, "samples")
        os.makedirs(test_dir, exist_ok=True)

    running_loss = 0.0
    num_correct_residues = 0
    total_residues = 0

    for i, sample in enumerate(dataloader):

        output = model(sample = sample['input_ids'],
                            attention_mask = sample['attention_mask'],
                            sample_posterior = False, # Should be set to true in training
        )

        ce_loss, kl_loss = loss_fn(output, sample['input_ids'])
        loss = ce_loss + kl_loss * config.kl_weight
        running_loss += loss.item()

        token_ids_pred = logits_to_token_ids(tokenizer, output.sample, cutoff = config.cutoff)

        token_ids_correct = ((sample['input_ids'] == token_ids_pred) & (sample['attention_mask'] == 1)).long()
        num_residues = torch.sum(sample['attention_mask'], dim=1).long()

        num_correct_residues += token_ids_correct.sum().item()
        total_residues += num_residues.sum().item()

        # Decode the predicted sequences, and remove zero padding
        seqs_pred = tokenizer.batch_decode(token_ids_pred, skip_special_tokens=config.skip_special_tokens)
        seqs_lens = torch.sum(sample['attention_mask'], dim=1).long()
        seqs_pred = [seq[:i] for seq, i in zip(seqs_pred, seqs_lens)]

        # Save all samples as a FASTA file
        seq_record_list = [SeqRecord(Seq(seq), id=str(sample['id'][i]), 
                        description=
                        f"classlabel: {sample['class_label'][i].item()} acc: {token_ids_correct[i].sum().item() / num_residues[i].item():.2f}")
                        for i, seq in enumerate(seqs_pred)]
        with open(f"{test_dir}/{epoch:04d}.fa", "a") as f:
            SeqIO.write(seq_record_list, f, "fasta")
    
    acc = num_correct_residues / total_residues
    print(f"Epoch {epoch}, val_loss: {running_loss / len(dataloader):.4f}, val_accuracy: {acc:.4f}")
    logs = {"val_loss": loss.detach().item(), 
            "val_ce_loss": ce_loss.detach().item(), 
            "val_kl_loss": kl_loss.detach().item(),
            "val_acc": acc.detach().item(),
            }
    return logs

evaluate(config, -1, model, val_dataloader, output_dir = config.output_dir)


Epoch -1, val_loss: 3.1798, val_accuracy: 0.0315


AttributeError: 'float' object has no attribute 'detach'

In [None]:
from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from pathlib import Path
import os


def train_loop(config, model, optimizer, train_dataloader, test_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_dir=os.path.join(config.output_dir, "logs"),
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        if config.push_to_hub:
            repo_id = create_repo(
                repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
            ).repo_id
        accelerator.init_trackers("train_example")

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, test_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, test_dataloader, lr_scheduler
    )

    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):

            with accelerator.accumulate(model):
                input = batch['input_ids']
                attention_mask = batch['attention_mask']
                # Predict the noise residual
                output = model(sample = input,
                                attention_mask = attention_mask,
                                sample_posterior = True, # Should be set to true in training
                )
                ce_loss, kl_loss = loss_fn(output, input)
                loss = ce_loss + kl_loss * config.kl_weight
                accelerator.backward(loss)

                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"train_loss": loss.detach().item(), 
                    "train_ce_loss": ce_loss.detach().item(), 
                    "train_kl_loss": kl_loss.detach().item(), 
                    "lr": lr_scheduler.get_last_lr()[0], 
                    "step": global_step,
            }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                
                logs = evaluate(config, epoch, model, test_dataloader, output_dir = config.output_dir)
                accelerator.log(logs, step=global_step)

                if config.push_to_hub:
                    upload_folder(
                        repo_id=repo_id,
                        folder_path=config.output_dir,
                        commit_message=f"Epoch {epoch}",
                        ignore_patterns=["step_*", "epoch_*"],
                    )
                else:
                    model.save_pretrained(config.output_dir)

In [None]:
from accelerate import notebook_launcher

args = (config, model, optimizer, test_dataloader, val_dataloader, lr_scheduler)

notebook_launcher(train_loop, args, num_processes=1)
