In [3]:
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
def fit(train_config, model_config, tokenizer_config):

    device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'Using device: {device}')

    dls_and_tokenizers_dict = get_loaders_and_tokenizers(tokenizer_config, raw_dataset)
    src_tokenizer = dls_and_tokenizers_dict['src_tokenizer']
    tagret_tokenizer = dls_and_tokenizers_dict['target_tokenizer']
    
    model = make_transformer(model_config).to(device) # TODO: create the config for the model.
    
    optimizer = optim.AdamW(model.parameters(), lr=train_config['lr'])
    loss_fn = nn.CrossEntropyLoss(ignore_index=src_tokenizer.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(train_config['epochs']):
        train_dl = tqdm(dls_and_tokenizers_dict['train_dl'], desc=f'Processing epoch: {epoch:02d}')
        for batch in train_dl:
            encoder_inputs = batch['encoder_inputs'].to(device)
            decoder_inputs = batch['decoder_inputs'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)
            label = batch['label'].to(device)

            encoder_outputs = model.encode(encoder_inputs, encoder_mask)
            decoder_outputs = model.decode(encoder_outputs, encoder_mask, decoder_inputs, decoder_mask)
            model_outputs = model.generate(decoder_outputs)

            loss = loss_fn(model_outputs.view(-1, tagret_tokenizer.get_vocab_size()), label.view(-1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            train_dl.set_postfix({'loss': f'{loss.item():6.3f}'})
            

In [None]:
 {'encoder_inputs': encoder_inputs,
                'decoder_inputs': decoder_inputs,
                'label': label,
                'encoder_mask': ((encoder_inputs != self.pad_token)[None:, ...][None:, ...]).int(),
                'decoder_mask': (decoder_inputs != self.pad_token[None:, ...][None:, ...]).int() & causal_mask(sel.model_max_length),
                'src_text': src_text,
                'tgt_text': tgt_text
               }