<a href="https://colab.research.google.com/github/nnilayy/MedGPT/blob/main/Trainer_API_%2B_Custom_Loop_Distributed_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Installing Libraries and Seting up

In [None]:
!pip install accelerate datasets evaluate transformers

In [None]:
import torch
from evaluate import load
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, DataCollatorWithPadding
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from accelerate import Accelerator
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

def compute_metric(predictions, references, total_loss, loader, metric):
    metrics_result = metric.compute(predictions=predictions, references=references)
    metrics = metrics_result['accuracy']
    loss = total_loss / len(loader)
    return metrics, loss


def dataset(dataset_name, batch_size, tokenizer):
    # Loading Datasets
    dataset = load_dataset(dataset_name, 'mrpc')

    # Preprocessing Dataset
    def encode(examples):
        return tokenizer(examples['sentence1'], examples['sentence2'], truncation=True, padding='max_length', max_length=128)

    dataset = dataset.map(encode, batched=True)
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

    # Splitting Dataset
    train_dataset = dataset['train']
    eval_dataset = dataset['validation']

    # Constructing DataLoaders
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, collate_fn=data_collator)
    return train_loader, eval_loader


def train(epoch, model, loader, optimizer, accelerator, metric):
    model.train()
    total_loss = 0
    all_predictions = []
    all_references = []
    loop = tqdm(loader, desc=f"Training Epoch {epoch}", disable=not accelerator.is_local_main_process)
    for batch in loop:
        outputs = model(**batch)
        loss = outputs.loss
        optimizer.zero_grad()
        accelerator.backward(loss)
        optimizer.step()
        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1).detach().cpu().numpy()
        references = batch['labels'].detach().cpu().numpy()
        all_predictions.extend(predictions)
        all_references.extend(references)

    # Compute accuracy using evaluate
    train_accuracy, train_loss = compute_metric(all_predictions, all_references, total_loss, loader, metric)
    return train_accuracy, train_loss



# Evaluation function without tqdm
def evaluate(model, loader, metric):
    model.eval()
    total_loss = 0
    all_predictions = []
    all_references = []
    for batch in loader:
        with torch.no_grad():
            outputs = model(**batch)
            logits = outputs.logits
            loss = outputs.loss
            total_loss += loss.item()

            predictions = torch.argmax(logits, dim=-1).detach().cpu().numpy()
            references = batch['labels'].detach().cpu().numpy()
            all_predictions.extend(predictions)
            all_references.extend(references)

    # Compute accuracy using evaluate
    valid_accuracy, valid_loss = compute_metric(all_predictions, all_references, total_loss, loader, metric)
    return valid_accuracy, valid_loss


def main():
# Initiate Accelerator
    accelerator = Accelerator()

# Download Model and Tokenizer
    checkpoint = 'bert-base-uncased'
    model = BertForSequenceClassification.from_pretrained(checkpoint, force_download=True, num_labels=2)
    tokenizer = BertTokenizer.from_pretrained(checkpoint, force_download=True)

# Setting up Hyperparameters
    epochs = 10
    lr = 5e-4
    optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
    batch_size = 128
    metric = load("accuracy")

# Processing Dataset
    dataset_name = 'glue'
    train_loader, eval_loader = dataset(dataset_name, batch_size, tokenizer)

# Preparing components for Accelerate
    model, optimizer, train_loader, eval_loader = accelerator.prepare(model, optimizer, train_loader, eval_loader)

# Fitting the Model
    for epoch in range(1, epochs+1):  # Training for 3 epochs
        train_accuracy, train_loss = train(epoch, model, train_loader, optimizer, accelerator, metric)
        validation_accuracy, validation_loss = evaluate(model, eval_loader, metric)

        if accelerator.is_local_main_process:
            print(f"Training Accuracy: {train_accuracy:.4f}, Training Loss: {train_loss:.4f}")
            print(f"Validation Accuracy: {validation_accuracy:.4f}, Validation Loss: {validation_loss:.4f}")

if __name__ == "__main__":
    from accelerate import notebook_launcher
    notebook_launcher(main, num_processes=2, mixed_precision="fp16")