In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
from dataset import BilingualDataset, causal_mask 
from model import build_transformer

In [None]:
from config import get_weights_file_path, get_config
from tqdm import tqdm

In [None]:
import warnings

In [None]:
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

In [None]:
from torch.utils.tensorboard import SummaryWriter

In [None]:
from pathlib import Path

In [None]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')
    
    # Precompute the encoder output and reuse it for every token we get from the decoder\
    encoder_output = model.encode(source, source_mask)
    
    # Initialize the decoder inpt with the sos token
    decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(source).to(device)
    
    while True:
        if decoder_input.size(1) == max_len:
            break
        
        # Build mask for the target (decoder input)
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
        
        # Calculate the output fo the decoder
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
        
        # Get the next token
        prob = model.project(out[:,-1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(source).fill_(next_word.item()).to(device)], dim=1)

        if next_word == eos_idx:
            break
        
    return decoder_input.squeeze(0)

In [None]:
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_state, writer, num_examples=2):
    model.eval()
    count = 0
    
    # Size of the control window
    console_width = 80
    
    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            
            assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
            
            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
            
            source_text = batch['src_text'][0]
            target_text = batch['tgt_text'][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
            
            # Print to the console
#             print_msg*('-'*console_width)
            print_msg(f'SOURCE: {source_text}')
            print_msg(f'TARGET: {target_text}')
            print_msg(f'PREDICTED: {model_out_text}')
            
            if count == num_examples:
                break

In [None]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang] # any one language... like, works bothways

In [None]:
def get_or_build_transformer(config, ds, lang):
    # config['tokenizer_file'] = '../tokenizers/tokenizer_{0}.json' 
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token='[UNK]')) # UNK means unknown, when tokenizing the input and the word not found, it maps to the number corresponding to UNK
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
        
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
        
    return tokenizer 

In [None]:
from datasets import Dataset
from torch.utils.data import DataLoader, random_split

def get_ds(config):

    # ================== MODIFICATION START ==================
    # Instead of loading opus_books from HuggingFace (needs internet),
    # we load LOCAL parallel text files

    src_path = "train.en"
    tgt_path = "train.fr"

    with open(src_path, encoding="utf-8") as f:
        src_sentences = f.read().splitlines()

    with open(tgt_path, encoding="utf-8") as f:
        tgt_sentences = f.read().splitlines()

    assert len(src_sentences) == len(tgt_sentences)

    # Creating a HuggingFace-style Dataset so rest of the code remains SAME
    ds_raw = Dataset.from_dict({
        "translation": [
            {
                config["lang_src"]: src,
                config["lang_tgt"]: tgt
            }
            for src, tgt in zip(src_sentences, tgt_sentences)
        ]
    })
    # ================== MODIFICATION END ==================

    # Build tokenizers (UNCHANGED)
    tokenizer_src = get_or_build_transformer(config, ds_raw, config["lang_src"])
    tokenizer_tgt = get_or_build_transformer(config, ds_raw, config["lang_tgt"])

    # splitting datasets for train(90%) and validation(10%) (UNCHANGED)
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(
        train_ds_raw, tokenizer_src, tokenizer_tgt,
        config['lang_src'], config['lang_tgt'], config['seq_len']
    )

    val_ds = BilingualDataset(
        val_ds_raw, tokenizer_src, tokenizer_tgt,
        config['lang_src'], config['lang_tgt'], config['seq_len']
    )

    train_dataloader = DataLoader(
        train_ds,
        batch_size=config['batch_size'],
        shuffle=True
    )

    val_dataloader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=True
    )

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt




In [None]:
def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])
    return model

In [None]:
def train_model(config):
    # Define the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"using device {device}")
    
    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    
    # TensorBoard
    writer = SummaryWriter(config['experiment_name'])
    
    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
    
    initial_epoch = 0
    global_step = 0
    if config['preload']: 
        model_filename = get_weights_file_path(config, config['preload'])
        print(f"Preloading Model {model_filename}")
        state = torch.load(model_filename, map_location=device)
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
        
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
     
    for epoch in range(initial_epoch, config['num_epochs']):
        # model.train()
        batch_iterator = tqdm(train_dataloader, desc=f'Processing epoch: {epoch:02d}')
        
        for batch in batch_iterator:
            
            model.train() 
            
            encoder_input = batch['encoder_input'].to(device) # (B, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
            
            # Run the tensors through transformer
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
            proj_output = model.project(decoder_output) # (B, seq_len, tgt_vocab_size)
            
            label = batch['label'].to(device) # (B, seq_len) 
            
            # (B, seq_len, tgt_vocab_size) -> (B * seq_len, tgt_vocab_size)
            loss = loss_fn(proj_output.reshape(-1, tokenizer_tgt.get_vocab_size()), label.reshape(-1)) # all view been changed to reshape coz view throw err itseems...
            batch_iterator.set_postfix({f'loss': f'{loss.item():6.3f}'})
            
            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()
            
            # Backpropogate the loss
            loss.backward()
            
            # Under the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            
#             run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
            
            global_step += 1 # used for the tensorboard to keep track of the loss
            
        # Saving the model at the end of every epochs
        model_filename = get_weights_file_path(config, f'{epoch:02d}')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

In [None]:
if __name__ == "__main__":
    warnings.filterwarnings('ignore')
    config = get_config()
    train_model(config)            