In [1]:
from dataclasses import dataclass
import os
from typing import Optional, Literal, Union, List

@dataclass
class TrainingConfig:
    batch_size: int = 64  # the batch size
    mega_batch: int = 1000 # how many batches to use for batchsampling
    num_epochs: int = 1  # the number of epochs to train the model
    gradient_accumulation_steps: int = 2  # the number of steps to accumulate gradients before taking an optimizer step
    learning_rate: float = 1e-4  # the learning rate
    lr_warmup_steps:int  = 1000
    save_image_model_steps:int  = 100
    mixed_precision: str = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    optimizer: str = "AdamW"  # the optimizer to use, choose between `AdamW`, `Adam`, `SGD`, and `Adamax`
    SGDmomentum: float = 0.9
    output_dir: str = os.path.join("output","protein-VAE-UniRef50-8")  # the model name locally and on the HF Hub
    pad_to_multiple_of: int = 16
    max_len: int = 512  # truncation of the input sequence

    class_embeddings_concat = False  # whether to concatenate the class embeddings to the time embeddings

    push_to_hub = False  # Not implemented yet. Whether to upload the saved model to the HF Hub
    hub_model_id = "kkj15dk/protein-VAE"  # the name of the repository to create on the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed: int = 42

    automatic_checkpoint_naming: bool = True  # whether to automatically name the checkpoints
    total_limit: int = 1  # the total limit of checkpoints to save

    labels_file = 'labels_test.json'

    cutoff: Optional[float] = None # cutoff for when to predict the token given the logits, and when to assign the unknown token 'X' to this position
    skip_special_tokens = False # whether to skip the special tokens when writing the evaluation sequences
    kl_weight: float = 0.05 # the weight of the KL divergence in the loss function

    weight_decay: float = 0.01 # weight decay for the optimizer
    grokfast: bool = False # whether to use the grokfast algorithm
    grokfast_alpha: float = 0.98 #Momentum hyperparmeter of the EMA.
    grokfast_lamb: float = 2.0 #Amplifying factor hyperparameter of the filter.

config = TrainingConfig()
print(vars(config))

{'batch_size': 64, 'mega_batch': 1000, 'num_epochs': 1, 'gradient_accumulation_steps': 2, 'learning_rate': 0.0001, 'lr_warmup_steps': 1000, 'save_image_model_steps': 100, 'mixed_precision': 'fp16', 'optimizer': 'AdamW', 'SGDmomentum': 0.9, 'output_dir': 'output/protein-VAE-UniRef50-7', 'pad_to_multiple_of': 16, 'max_len': 512, 'seed': 42, 'automatic_checkpoint_naming': True, 'total_limit': 1, 'cutoff': None, 'kl_weight': 0.05, 'weight_decay': 0.01, 'grokfast': False, 'grokfast_alpha': 0.98, 'grokfast_lamb': 2.0}


In [2]:
from datasets import load_dataset

# # config.dataset_name = "kkj15dk/test_dataset"
# config.dataset_name = "agemagician/uniref50"
# # config.dataset_name = "kkj15dk/UniRef50-encoded"
# dataset = load_dataset(config.dataset_name) # , download_mode='force_redownload')
# dataset = dataset.shuffle(config.seed)

In [3]:
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast.from_pretrained("kkj15dk/protein_tokenizer")

def encode(example):
    return tokenizer(example['text'],
    # return tokenizer(example['sequence'],
                    padding = True,
                    pad_to_multiple_of = config.pad_to_multiple_of,
                    return_token_type_ids=False,
                    return_attention_mask=False, # We need to attend to padding tokens, so we set this to False
)
# dataset_train = dataset['train'].map(encode, batched=False, remove_columns=["sequence"])
# dataset_test = dataset['test'].map(encode, batched=False, remove_columns=["sequence"])
# dataset_val = dataset['val'].map(encode, batched=False, remove_columns=["sequence"])

In [4]:
# dataset_train = dataset['train']
# dataset_test = dataset['test']
# dataset_val = dataset['val']

from datasets import load_from_disk

datasets_folder = 'datasets'

train_file = os.path.join(datasets_folder, 'train_dataset')
test_file = os.path.join(datasets_folder, 'test_dataset')
val_file = os.path.join(datasets_folder, 'val_dataset')

if os.path.exists(train_file):
    print('loading train_dataset')
    dataset_train = load_from_disk(train_file)
else:
    print('encoding train_dataset')
    dataset_train = dataset['train'].map(encode, batched=False, remove_columns=["text"], num_proc=os.cpu_count())
    dataset_train.save_to_disk(train_file)

if os.path.exists(test_file):
    print('loading test_dataset')
    dataset_test = load_from_disk(test_file)
else:
    print('encoding test_dataset')
    dataset_test = dataset['test'].map(encode, batched=False, remove_columns=["text"], num_proc=os.cpu_count())
    dataset_test.save_to_disk(test_file)

if os.path.exists(val_file):
    print('loading val_dataset')
    dataset_val = load_from_disk(val_file)
else:
    print('encoding val_dataset')
    dataset_val = dataset['validation'].map(encode, batched=False, remove_columns=["text"], num_proc=os.cpu_count())
    dataset_val.save_to_disk(val_file)


loading train_dataset


Loading dataset from disk:   0%|          | 0/143 [00:00<?, ?it/s]

loading test_dataset
loading val_dataset


In [5]:
print(dataset_val)
print(dataset_val[0])
print(dataset_val[0]['input_ids'])
print(dataset_val[0]['name'])
print(len(dataset_val[0]['input_ids']))
print(tokenizer.decode(dataset_val[0]['input_ids'], skip_special_tokens=False))
print(tokenizer.decode(dataset_val[0]['input_ids'], skip_special_tokens=True))

Dataset({
    features: ['id', 'name', 'input_ids'],
    num_rows: 6084
})
{'id': 4909, 'name': 'UniRef50_E0S3A2', 'input_ids': [0, 13, 18, 12, 14, 6, 12, 8, 11, 5, 18, 16, 3, 10, 20, 6, 11, 12, 19, 8, 5, 6, 17, 7, 13, 18, 17, 10, 19, 18, 10, 8, 12, 19, 15, 8, 4, 18, 10, 17, 20, 10, 17, 14, 5, 11, 14, 17, 15, 13, 12, 10, 22, 18, 17, 5, 19, 10, 10, 3, 12, 14, 17, 14, 6, 4, 11, 8, 10, 6, 20, 3, 6, 20, 3, 8, 1, 23, 23, 23]}
[0, 13, 18, 12, 14, 6, 12, 8, 11, 5, 18, 16, 3, 10, 20, 6, 11, 12, 19, 8, 5, 6, 17, 7, 13, 18, 17, 10, 19, 18, 10, 8, 12, 19, 15, 8, 4, 18, 10, 17, 20, 10, 17, 14, 5, 11, 14, 17, 15, 13, 12, 10, 22, 18, 17, 5, 19, 10, 10, 3, 12, 14, 17, 14, 6, 4, 11, 8, 10, 6, 20, 3, 6, 20, 3, 8, 1, 23, 23, 23]
UniRef50_E0S3A2
80
[MSLNELGKDSQAIVEKLTGDERFMSRITSIGLTPGCSIRVIRNDKNRPMLIYSRDTIIALNRNECKGIEVAEVAG]---
MSLNELGKDSQAIVEKLTGDERFMSRITSIGLTPGCSIRVIRNDKNRPMLIYSRDTIIALNRNECKGIEVAEVAG


In [6]:
import torch
import torch.nn.functional as F

input_ids_tensor = torch.tensor(dataset_val[0]['input_ids'], dtype=torch.long)
onehot_seq = F.one_hot(input_ids_tensor, num_classes=tokenizer.vocab_size + 1).permute(1, 0).unsqueeze(0)
print(onehot_seq.shape)

def logits_to_token_ids(tokenizer, logits, cutoff = None):
        '''
        Convert a batch of logits to token_ids.
        Returns token_ids
        '''
        if cutoff is None:
            token_ids = logits.argmax(dim=1)
        else:
            token_ids = torch.where(logits.max(dim=1).values > cutoff, 
                                    logits.argmax(dim=1), 
                                    torch.tensor([tokenizer.unknown_token_id])
                                    )

        return token_ids
token_ids = logits_to_token_ids(tokenizer, onehot_seq)
print(token_ids)
print(tokenizer.batch_decode(token_ids, skip_special_tokens=config.skip_special_tokens))


torch.Size([1, 24, 80])
tensor([[ 0, 13, 18, 12, 14,  6, 12,  8, 11,  5, 18, 16,  3, 10, 20,  6, 11, 12,
         19,  8,  5,  6, 17,  7, 13, 18, 17, 10, 19, 18, 10,  8, 12, 19, 15,  8,
          4, 18, 10, 17, 20, 10, 17, 14,  5, 11, 14, 17, 15, 13, 12, 10, 22, 18,
         17,  5, 19, 10, 10,  3, 12, 14, 17, 14,  6,  4, 11,  8, 10,  6, 20,  3,
          6, 20,  3,  8,  1, 23, 23, 23]])
['[MSLNELGKDSQAIVEKLTGDERFMSRITSIGLTPGCSIRVIRNDKNRPMLIYSRDTIIALNRNECKGIEVAEVAG]---']


In [7]:
import random

def collate_fn(batch): # Can definitely be optimized
    max_len = max(len(x['input_ids']) for x in batch)
    if max_len > config.max_len:
        max_len = config.max_len
    input_ids = torch.zeros(len(batch), max_len, dtype=torch.long)
    attention_mask = torch.zeros(len(batch), max_len, dtype=torch.float)
    class_labels = torch.zeros(len(batch), dtype=torch.long)
    # identifiers = [x['id'] for x in batch]
    identifiers = [x.get('name', 'N/A') for x in batch]
    for i, x in enumerate(batch):
        seq_len = len(x['input_ids'])
        if seq_len > max_len:
            index = random.randint(0, seq_len - max_len)
            x['input_ids'] = x['input_ids'][index:index+max_len]
            seq_len = max_len
        input_ids[i, :seq_len] = torch.tensor(x['input_ids'], dtype=torch.long)
        attention_mask[i, :seq_len] = torch.tensor(1, dtype=torch.float)
        # class_labels[i] = torch.tensor(x['class'], dtype=torch.long)
    return {'id': identifiers, 'input_ids': input_ids, 'attention_mask': attention_mask, 'class_label': class_labels}

class BatchSampler:
    '''
    BatchSampler for variable length sequences, batching by similar lengths, to prevent excessive padding.
    '''
    def __init__(self, lengths, batch_size, mega_batch_size, drop_last = True):
        self.lengths = lengths
        self.batch_size = batch_size
        self.mega_batch_size = mega_batch_size
        self.drop_last = drop_last

    def __iter__(self):
        size = len(self.lengths)
        indices = list(range(size))
        random.shuffle(indices)

        step = self.mega_batch_size * self.batch_size
        for i in range(0, size, step):
            pool = indices[i:i+step]
            pool = sorted(pool, key=lambda x: self.lengths[x])
            mega_batch_indices = list(range(0, len(pool), self.batch_size))
            random.shuffle(mega_batch_indices) # shuffle the mega batches, so that the model doesn't see the same order of lengths every time. The small batch will however always be the one with longest lengths
            for j in mega_batch_indices:
                if self.drop_last and j + self.batch_size > len(pool): # drop the last batch if it's too small
                    continue
                batch = pool[j:j+self.batch_size]
                random.shuffle(batch) # shuffle the batch, so that the model doesn't see the same order of lengths every time
                yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.lengths) // self.batch_size
        else:
            return (len(self.lengths) + self.batch_size - 1) // self.batch_size


In [8]:
import pickle

train_lengths_path = os.path.join(datasets_folder, 'train_lengths.pkl')
val_lengths_path = os.path.join(datasets_folder, 'val_lengths.pkl')
test_lengths_path = os.path.join(datasets_folder, 'test_lengths.pkl')

if os.path.exists(train_lengths_path):
    with open(train_lengths_path, 'rb') as f:
        train_lengths = pickle.load(f)
else:
    train_lengths = list(map(lambda x: len(x["input_ids"]), dataset_train))
    with open(train_lengths_path, 'wb') as f:
        pickle.dump(train_lengths, f)

if os.path.exists(val_lengths_path):
    with open(val_lengths_path, 'rb') as f:
        val_lengths = pickle.load(f)
else:
    val_lengths = list(map(lambda x: len(x["input_ids"]), dataset_val))
    with open(val_lengths_path, 'wb') as f:
        pickle.dump(val_lengths, f)

if os.path.exists(test_lengths_path):
    with open(test_lengths_path, 'rb') as f:
        test_lengths = pickle.load(f)
else:
    test_lengths = list(map(lambda x: len(x["input_ids"]), dataset_test))
    with open(test_lengths_path, 'wb') as f:
        pickle.dump(test_lengths, f)

In [9]:
from torch.utils.data import DataLoader

test_dataloader = DataLoader(dataset_test,
                            batch_sampler=BatchSampler(test_lengths,
                                                    config.batch_size,
                                                    config.mega_batch,
                                                    drop_last=False), 
                            collate_fn=collate_fn)
val_dataloader = DataLoader(dataset_val, 
                            batch_sampler=BatchSampler(val_lengths,
                                                    config.batch_size,
                                                    config.mega_batch,
                                                    drop_last=False),
                            collate_fn=collate_fn)
train_dataloader = DataLoader(dataset_train,
                            batch_sampler=BatchSampler(train_lengths,
                                                    config.batch_size,
                                                    config.mega_batch,
                                                    drop_last=False),
                            collate_fn=collate_fn)

In [10]:
from New1D.autoencoder_kl_1d import AutoencoderKL1D

model = AutoencoderKL1D(
    num_class_embeds=tokenizer.vocab_size + 1,  # the number of class embeddings
    
    down_block_types=(
        "DownEncoderBlock1D",  # a regular ResNet downsampling block
        "DownEncoderBlock1D",
        "DownEncoderBlock1D",
        "DownEncoderBlock1D",  # a ResNet downsampling block with spatial self-attention
    ),
    up_block_types=(
        "UpDecoderBlock1D",  # a ResNet upsampling block with spatial self-attention
        "UpDecoderBlock1D",
        "UpDecoderBlock1D",
        "UpDecoderBlock1D",  # a regular ResNet upsampling block
    ),
    block_out_channels=(128, 256, 512, 512),  # the number of output channels for each block
    mid_block_type="UNetMidBlock1D",  # the type of the middle block
    mid_block_channels=1024,  # the number of output channels for the middle block
    mid_block_add_attention=False,  # whether to add a spatial self-attention block to the middle block
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    transformer_layers_per_block=1, # how many transformer layers to use per ResNet layer. Not implemented yet.

    latent_channels=128,  # the dimensionality of the latent space

    num_attention_heads=1,  # the number of attention heads in the spatial self-attention blocks
    upsample_type="conv", # the type of upsampling to use, either 'conv' (and nearest neighbor) or 'conv_transpose'
    act_fn="gelu",  # the activation function to use
)

In [11]:
model.children
model.num_parameters()

61429144

In [12]:
sample_image = next(iter(val_dataloader))
print(sample_image.keys())
print(sample_image['id'][0])
print(sample_image['input_ids'][0])
print(sample_image['attention_mask'][0])
print(sample_image['class_label'][0])

dict_keys(['id', 'input_ids', 'attention_mask', 'class_label'])
UniRef50_A0A1F4Y1L9
tensor([18, 15, 12, 16, 14,  7,  3, 20, 18, 16, 20, 10, 14, 14, 12, 18,  5, 20,
        17, 12, 14,  8,  5, 19, 14,  3,  5, 14, 10,  7,  3,  5, 12, 20,  9, 19,
         8, 17, 12,  5, 20, 19,  8,  5, 13, 14, 20, 18,  8,  3,  7, 14,  3,  3,
         8,  8, 20,  7, 19, 14, 12, 19, 20, 14,  8, 14,  3, 19, 19, 19,  8, 14,
        18, 10, 20, 12,  8,  5, 12, 18, 19, 16,  8,  8, 10,  7,  8, 18, 18, 19,
        12, 19, 10, 12,  8,  5, 19, 19, 12,  3, 14,  3, 19, 18, 19, 14,  7,  7,
        18, 19, 19,  3, 18, 18, 19, 14, 12,  7,  3, 16, 18,  8, 18, 10,  8, 19,
        12, 18,  3, 16, 19, 12, 14, 12, 18,  8, 12,  3, 18,  7, 12, 14,  8,  7,
        19, 18, 12,  3, 18, 18, 19, 10,  8,  5,  8, 19, 16, 19,  8,  8, 12, 19,
        10, 14,  8,  8,  3, 19, 19, 19,  8, 14, 19,  3, 12, 16,  8, 19, 12,  3,
        20, 12,  8, 14, 19, 19, 12,  3, 14,  3, 19, 18, 18,  3,  7,  7,  3, 19,
        19,  3, 18, 18, 19, 14, 12, 

In [13]:
import torch

class_labels = sample_image['class_label'][0].unsqueeze(0)
attention_mask = sample_image['attention_mask'][0].unsqueeze(0)
print(class_labels.shape)
print(attention_mask.shape)

torch.Size([1])
torch.Size([1, 512])


In [14]:
input_ids = sample_image['input_ids'].to(model.device)
attention_mask = sample_image['attention_mask'].to(model.device)

print(input_ids.shape)
print(attention_mask.shape)

output = model(sample = input_ids,
                attention_mask = attention_mask,
                sample_posterior = True, # Should be set to true in training
)

ce_loss, kl_loss = model.loss_fn(output, input_ids)
print(ce_loss)
print(kl_loss)

torch.Size([4, 512])
torch.Size([4, 512])
tensor(3.1947, grad_fn=<DivBackward0>)
tensor(0.0336, grad_fn=<DivBackward0>)


In [15]:
from diffusers.optimization import get_cosine_schedule_with_warmup

if config.optimizer == "Adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
elif config.optimizer == "AdamW":
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
elif config.optimizer == "Adamax":
    optimizer = torch.optim.Adamax(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
elif config.optimizer == "SGD":
    optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, momentum=config.SGDmomentum)
else:
    raise ValueError("Invalid optimizer, choose between `AdamW`, `Adam`, `SGD`, and `Adamax`")

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * config.num_epochs // config.gradient_accumulation_steps),
)

In [16]:
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from tqdm.auto import tqdm
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from torch.optim.lr_scheduler import LRScheduler

from grokfast import gradfilter_ema

from typing import Union

@dataclass
class TrainingVariables:
    global_step: int = 0
    val_loss: float = float("inf")
    grads: Optional[torch.Tensor] = None

    def state_dict(self):
        return self.__dict__
    
    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)
training_variables = TrainingVariables()

class VAETrainer:
    def __init__(self, 
                 model: AutoencoderKL1D, 
                 tokenizer: PreTrainedTokenizerFast, 
                 optimizer: Union[torch.optim.Adam, torch.optim.AdamW, torch.optim.SGD, torch.optim.Adamax],
                 lr_scheduler: LRScheduler, 
                 train_dataloader: DataLoader, 
                 val_dataloader: DataLoader, 
                 config: TrainingConfig, 
                 training_variables: TrainingVariables = training_variables,
                 test_dataloader: DataLoader = None
        ):
        self.tokenizer = tokenizer
        self.config = config
        self.training_variables = training_variables
        self.accelerator_config = ProjectConfiguration(
            project_dir=self.config.output_dir,
            logging_dir=os.path.join(self.config.output_dir, "logs"),
            automatic_checkpoint_naming=self.config.automatic_checkpoint_naming,
            total_limit=self.config.total_limit, # Limit the total number of checkpoints to 1
        )
        self.accelerator = Accelerator(
            project_config=self.accelerator_config,
            mixed_precision=self.config.mixed_precision,
            gradient_accumulation_steps=self.config.gradient_accumulation_steps,
            log_with="tensorboard",
        )
        # Prepare everything
        # There is no specific order to remember, you just need to unpack the
        # objects in the same order you gave them to the prepare method.
        self.model, self.optimizer, self.train_dataloader, self.test_dataloader, self.val_dataloader, self.lr_scheduler = self.accelerator.prepare(
            model, optimizer, train_dataloader, test_dataloader, val_dataloader, lr_scheduler
        )
        self.accelerator.register_for_checkpointing(self.training_variables)

    def logits_to_token_ids(self, logits):
        '''
        Convert a batch of logits to token_ids.
        Returns token_ids
        '''
        if self.config.cutoff is None:
            token_ids = logits.argmax(dim=1)
        else:
            token_ids = torch.where(logits.max(dim=1).values > self.config.cutoff, 
                                    logits.argmax(dim=1), 
                                    torch.tensor([self.tokenizer.unknown_token_id])
                                    )
        return token_ids

    @torch.no_grad()
    def evaluate(self
    ) -> dict:

        test_dir = os.path.join(self.config.output_dir, "samples")
        os.makedirs(test_dir, exist_ok=True)

        running_loss = 0.0
        num_correct_residues = 0
        total_residues = 0
        name = f"step_{self.training_variables.global_step//1000:04d}k"

        progress_bar = tqdm(total=len(self.val_dataloader), disable=not self.accelerator.is_local_main_process)
        progress_bar.set_description(f"Evaluating {name}")

        for i, sample in enumerate(self.val_dataloader):

            output = self.model(sample = sample['input_ids'],
                                attention_mask = sample['attention_mask'],
                                sample_posterior = True, # Should be set to False in inference
            )

            ce_loss, kl_loss = self.model.loss_fn(output, sample['input_ids'])
            loss = ce_loss + kl_loss * self.config.kl_weight
            running_loss += loss.item()

            token_ids_pred = self.logits_to_token_ids(output.sample)

            token_ids_correct = ((sample['input_ids'] == token_ids_pred) & (sample['attention_mask'] == 1)).long()
            num_residues = torch.sum(sample['attention_mask'], dim=1).long()

            num_correct_residues += token_ids_correct.sum().item()
            total_residues += num_residues.sum().item()

            # Decode the predicted sequences, and remove zero padding
            seqs_pred = self.tokenizer.batch_decode(token_ids_pred, skip_special_tokens=self.config.skip_special_tokens)
            seqs_lens = torch.sum(sample['attention_mask'], dim=1).long()
            seqs_pred = [seq[:i] for seq, i in zip(seqs_pred, seqs_lens)]

            # Save all samples as a FASTA file
            seq_record_list = [SeqRecord(Seq(seq), id=str(sample['id'][i]), 
                            description=
                            f"classlabel: {sample['class_label'][i].item()} acc: {token_ids_correct[i].sum().item() / num_residues[i].item():.2f}")
                            for i, seq in enumerate(seqs_pred)]
            with open(f"{test_dir}/{name}.fa", "a") as f:
                SeqIO.write(seq_record_list, f, "fasta")
            
            progress_bar.update(1)
        
        acc = num_correct_residues / total_residues
        print(f"{name}, val_loss: {running_loss / len(self.val_dataloader):.4f}, val_accuracy: {acc:.4f}")
        logs = {"val_loss": loss.detach().item(), 
                "val_ce_loss": ce_loss.detach().item(), 
                "val_kl_loss": kl_loss.detach().item(),
                "val_acc": acc,
                }
        return logs
    
    def train_loop(self, from_checkpoint: Optional[int] = None):
  
        # start the loop
        if self.accelerator.is_main_process:
            os.makedirs(self.config.output_dir, exist_ok=True)
            os.makedirs(self.accelerator_config.logging_dir, exist_ok=True)
            if self.config.push_to_hub:
                raise NotImplementedError("Pushing to the HF Hub is not implemented yet")
            
            if from_checkpoint is not None:
                input_dir = os.path.join(self.config.output_dir, "checkpoints", f'checkpoint_{from_checkpoint}')
                self.accelerator.load_state(input_dir=input_dir)
                print(f"Loaded checkpoint from {input_dir}")
                print(f"Starting from step {self.training_variables.global_step}")
                print(f"Validation loss: {self.training_variables.val_loss}")
            else:
                self.training_variables.global_step = 0
                self.training_variables.val_loss = float("inf")
                self.training_variables.grads = None # Initialize the grads for the grokfast algorithm

        # Now you train the model
        self.model.train()
        for epoch in range(self.config.num_epochs):
            progress_bar = tqdm(total=len(self.train_dataloader), disable=not self.accelerator.is_local_main_process)
            progress_bar.set_description(f"Epoch {epoch}")

            for step, batch in enumerate(self.train_dataloader):

                with self.accelerator.accumulate(self.model):
                    input = batch['input_ids']
                    attention_mask = batch['attention_mask']
                    # Predict the noise residual
                    output = self.model(sample = input,
                                    attention_mask = attention_mask,
                                    sample_posterior = True, # Should be set to true in training
                    )
                    
                    ce_loss, kl_loss = self.model.loss_fn(output, input)
                    loss = ce_loss + kl_loss * self.config.kl_weight
                    self.accelerator.backward(loss)

                    if self.config.grokfast:
                        self.training_variables.grads = gradfilter_ema(self.model, grads=self.training_variables.grads, alpha=self.config.grokfast_alpha, lamb=self.config.grokfast_lamb) 

                    self.optimizer.step()
                    self.lr_scheduler.step()
                    self.optimizer.zero_grad()

                progress_bar.update(1)
                logs = {"train_loss": loss.detach().item(), 
                        "train_ce_loss": ce_loss.detach().item(), 
                        "train_kl_loss": kl_loss.detach().item(), 
                        "lr": lr_scheduler.get_last_lr()[0], 
                        "step": self.training_variables.global_step,
                }
                progress_bar.set_postfix(**logs)
                self.accelerator.log(logs, step=self.training_variables.global_step)
                self.training_variables.global_step += 1

                if self.training_variables.global_step == 1 or self.training_variables.global_step % self.config.save_image_model_steps == 0 or self.training_variables.global_step == len(self.train_dataloader):
                    self.accelerator.wait_for_everyone()
                    self.model.eval() # Set model to eval mode to generate images
                    logs = self.evaluate()
                    self.accelerator.log(logs, step=self.training_variables.global_step)

                    new_val_loss = logs["val_loss"]

                    if new_val_loss < self.training_variables.val_loss: # Save the model if the validation loss is lower
                        self.training_variables.val_loss = new_val_loss
                        self.accelerator.save_state(
                            output_dir=self.config.output_dir,
                        )
                    self.model.train() # Set model back to train mode

Trainer = VAETrainer(model, tokenizer, optimizer, lr_scheduler, train_dataloader, val_dataloader, config, training_variables, test_dataloader)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [None]:
from accelerate import notebook_launcher

notebook_launcher(Trainer.train_loop, num_processes=1)


In [None]:
notebook_launcher(Trainer.train_loop(from_checkpoint=2), num_processes=1)