In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW, get_scheduler
from datasets import load_dataset
from tqdm import tqdm

import mlflow
from mlflow import log_metric, log_param

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def load_and_tokenize_data(filename):
    dataset = load_dataset('text', data_files={'train': filename})
    split_datasets = dataset["train"].train_test_split(test_size=0.2)
    train_dataset = split_datasets['train']
    test_dataset = split_datasets['test']

    def tokenize_and_pad(examples):
        output = tokenizer(
            examples["text"], 
            truncation=True, 
            padding='max_length', 
            max_length=512, 
            return_tensors='pt'
        )
        return output

    tokenizer.pad_token = tokenizer.eos_token
    tokenized_train_dataset = train_dataset.map(tokenize_and_pad, batched=True, remove_columns=['text'])
    tokenized_test_dataset = test_dataset.map(tokenize_and_pad, batched=True, remove_columns=['text'])
    return tokenized_train_dataset, tokenized_test_dataset

In [None]:
def create_dataloaders(train_dataset, test_dataset, batch_size):
    def collate_fn(batch):
        input_ids = [item['input_ids'] for item in batch]
        attention_mask = [item['attention_mask'] for item in batch]
        return {'input_ids': torch.tensor(input_ids), 'attention_mask': torch.tensor(attention_mask)}
    
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=collate_fn, pin_memory=True)
    val_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, collate_fn=collate_fn, pin_memory=True)
    return train_dataloader, val_dataloader

In [None]:
def move_batch_to_device(batch, device):
        return {k: v.to(device) for k, v in batch.items()}

def train_model(model, train_dataloader, val_dataloader, optimizer, lr_scheduler, num_epochs):
    # Start a new MLflow run...
    with mlflow.start_run():

        model.to(device)
        progress_bar = tqdm(range(num_epochs * len(train_dataloader)))

        # Log the parameters
        log_param("num_epochs", num_epochs)
        log_param("batch_size", train_dataloader.batch_size)
        log_param("optimizer", type(optimizer).__name__)
        log_param("lr_scheduler", type(lr_scheduler).__name__)

        model.train()
        for epoch in range(num_epochs):
            for batch in train_dataloader:
                batch = move_batch_to_device(batch, device)
                outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['input_ids'])
                loss = outputs.loss
                loss.backward()
                optimizer.step()

                # Log the learning rate
                for param_group in optimizer.param_groups:
                    log_metric("learning_rate", param_group['lr'])

                lr_scheduler.step()

                optimizer.zero_grad()
                progress_bar.update(1)

            # Log train loss
            log_metric("train_loss", loss.item())

            avg_val_loss = evaluate_model(model, val_dataloader, device)
            print(f"Validation loss: {avg_val_loss}")
            model.train()

            # Log validation loss
            log_metric("val_loss", avg_val_loss)

        # Save the model state
        torch.save(model.state_dict(), "model.pth")
        mlflow.log_artifact("model.pth")

        # Save optimizer state
        torch.save(optimizer.state_dict(), "optimizer.pth")
        mlflow.log_artifact("optimizer.pth")

        # Save scheduler state
        torch.save(lr_scheduler.state_dict(), "scheduler.pth")
        mlflow.log_artifact("scheduler.pth")

def evaluate_model(model, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            batch = move_batch_to_device(batch, device)
            outputs = model(**batch)
            total_loss += outputs.loss.item()
    return total_loss / len(dataloader)

In [None]:
# usage
mlflow.set_tracking_uri('sqlite:///mlflow.db')
mlflow.set_experiment('distilgpt-2-train')

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
filename = 'data/clean_29_07_2023.txt'
train_dataset, test_dataset = load_and_tokenize_data(filename)
train_dataloader, val_dataloader = create_dataloaders(train_dataset, test_dataset, batch_size=10)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
train_model(model, train_dataloader, val_dataloader, optimizer, lr_scheduler, num_epochs)