<a href="https://colab.research.google.com/github/beniamine3155/Fine_Tuning_LLM_with_HuggingFace/blob/main/Multitask_Fine_Tune_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MultiTask Fine Tuning using BERT



In [12]:
!pip install torch transformers scikit-learn matplotlib numpy



In [13]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
from typing import List, Dict, Tuple
from transformers import AutoTokenizer, AutoModel

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


In [20]:
# MultitaskDataset Class - track sample belong to which task
class MultitaskDataset(Dataset):
    """
    Dataset class that handles multiple tasks with different formats.
    Each sample contains the task type, input text, and corresponding labels.
    """

    def __init__(self, tasks_data: Dict[str, List[Tuple]], tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.samples = []

        # Combine all tasks into a single dataset
        for task_name, task_samples in tasks_data.items():
            for text, label in task_samples:
                self.samples.append({
                    'task': task_name,
                    'text': text,
                    'label': label
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        # Tokenize the input text
        encoding = self.tokenizer(
            sample['text'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'task': sample['task'],
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(sample['label'], dtype=torch.long)
        }

In [21]:
def create_sample_datasets():
    """
    Create sample datasets for three different tasks:
    1. Sentiment Analysis (3 classes: positive, negative, neutral)
    2. Intent Classification (4 classes: greeting, question, request, goodbye)
    3. Topic Classification (3 classes: tech, sports, politics)
    """

    # Task 1: Sentiment Analysis
    sentiment_data = [
        ("I love this product, it's amazing!", 0),  # positive
        ("This is terrible, worst purchase ever", 1),  # negative
        ("The product is okay, nothing special", 2),  # neutral
        ("Best thing I've ever bought!", 0),
        ("Completely disappointed with this", 1),
        ("It's fine, does what it's supposed to do", 2),
        ("Absolutely fantastic experience!", 0),
        ("Waste of money, don't buy this", 1),
        ("Average quality, meets expectations", 2),
        ("Exceeded all my expectations!", 0),
        ("Poor quality and bad service", 1),
        ("Decent product for the price", 2),
    ]

    # Task 2: Intent Classification
    intent_data = [
        ("Hello, how are you today?", 0),  # greeting
        ("What time does the store open?", 1),  # question
        ("Can you help me find this item?", 2),  # request
        ("Thank you, goodbye!", 3),  # goodbye
        ("Hi there, good morning!", 0),
        ("How much does this cost?", 1),
        ("Please show me the menu", 2),
        ("See you later, have a nice day", 3),
        ("Hey, what's up?", 0),
        ("Where is the nearest bathroom?", 1),
        ("Could you recommend something?", 2),
        ("Thanks for your help, bye!", 3),
    ]

    # Task 3: Topic Classification
    topic_data = [
        ("The new iPhone features are impressive", 0),  # tech
        ("The football game was intense last night", 1),  # sports
        ("The election results were surprising", 2),  # politics
        ("Machine learning is revolutionizing industries", 0),
        ("The tennis match went to five sets", 1),
        ("New policies will affect healthcare", 2),
        ("Cloud computing offers great scalability", 0),
        ("The basketball season starts next month", 1),
        ("Voting turnout was higher than expected", 2),
        ("Artificial intelligence is advancing rapidly", 0),
        ("The soccer world cup is exciting", 1),
        ("Congress passed the new legislation", 2),
    ]

    return {
        'sentiment': sentiment_data,
        'intent': intent_data,
        'topic': topic_data
    }

def compute_metrics(predictions, labels):
    """Compute accuracy and F1 score for evaluation"""
    accuracy = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='weighted')
    return {'accuracy': accuracy, 'f1': f1}



In [22]:
#Class MultitaskModel model architecture - Input Text → Shared Encoder → Task-Specific Head → Task Prediction
class MultitaskModel(nn.Module):
    """
    Multitask model with shared encoder and task-specific heads.

    Architecture:
    - Shared transformer encoder (DistilBERT)
    - Task-specific classification heads
    - Separate heads allow different output dimensions per task
    """

    def __init__(self, model_name: str, task_configs: Dict[str, int], dropout=0.1):
        super().__init__()

        # Shared encoder - using DistilBERT for efficiency
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size

        # Task-specific heads
        self.task_heads = nn.ModuleDict()
        for task_name, num_classes in task_configs.items():
            self.task_heads[task_name] = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(hidden_size, hidden_size // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_size // 2, num_classes)
            )

        self.task_configs = task_configs

    #one input flows through shared layers then splits to different heads.
    def forward(self, input_ids, attention_mask, task_name):
        # Get shared representations from encoder
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

        # Use [CLS] token representation for classification
        pooled_output = outputs.last_hidden_state[:, 0]  # [CLS] token

        # Pass through task-specific head
        logits = self.task_heads[task_name](pooled_output)

        return logits

In [23]:
#class multitask trainer
class MultitaskTrainer:
    """
    Trainer class for multitask learning with the following features:
    1. Task sampling strategies
    2. Loss balancing
    3. Gradient accumulation
    4. Learning rate scheduling
    """

    def __init__(self, model, tokenizer, task_configs, device='cpu'):
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.task_configs = task_configs
        self.device = device

        # Initialize optimizer
        self.optimizer = optim.AdamW(self.model.parameters(), lr=2e-5, weight_decay=0.01)

        # Loss functions for each task
        self.loss_fns = {task: nn.CrossEntropyLoss() for task in task_configs.keys()}

        # Task weights for loss balancing (can be adjusted based on task importance)
        self.task_weights = {task: 1.0 for task in task_configs.keys()}

        # Training history
        self.history = {
            'train_loss': [],
            'task_losses': {task: [] for task in task_configs.keys()},
            'val_metrics': {task: [] for task in task_configs.keys()}
        }

    #Training Loop - The Heart of Multitask Learning - Same batch may have different tasks and heads
    def train_epoch(self, dataloader, epoch):
        """Train for one epoch with task sampling"""
        self.model.train()
        total_loss = 0
        task_losses = {task: 0 for task in self.task_configs.keys()}
        task_counts = {task: 0 for task in self.task_configs.keys()}

        for batch_idx, batch in enumerate(dataloader):
            # Move batch to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['label'].to(self.device)

            # Group by task (since batch might contain mixed tasks)
            task_batches = {}
            for i, task in enumerate(batch['task']):
                if task not in task_batches:
                    task_batches[task] = {
                        'input_ids': [],
                        'attention_mask': [],
                        'labels': []
                    }
                task_batches[task]['input_ids'].append(input_ids[i])
                task_batches[task]['attention_mask'].append(attention_mask[i])
                task_batches[task]['labels'].append(labels[i])

            # Process each task in the batch
            batch_loss = 0
            for task_name, task_batch in task_batches.items():
                if len(task_batch['input_ids']) == 0:
                    continue

                # Stack tensors for the task
                task_input_ids = torch.stack(task_batch['input_ids'])
                task_attention_mask = torch.stack(task_batch['attention_mask'])
                task_labels = torch.stack(task_batch['labels'])

                # Forward pass
                logits = self.model(task_input_ids, task_attention_mask, task_name)

                # Compute loss
                loss = self.loss_fns[task_name](logits, task_labels)
                weighted_loss = loss * self.task_weights[task_name]

                batch_loss += weighted_loss #pay attention here - batch_loss
                task_losses[task_name] += loss.item()
                task_counts[task_name] += 1

            # Backward pass
            if batch_loss > 0:
                self.optimizer.zero_grad()
                batch_loss.backward()

                # Gradient clipping to prevent exploding gradients
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

                self.optimizer.step()
                total_loss += batch_loss.item()

            # Print progress
            if batch_idx % 10 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {batch_loss.item():.4f}')

        # Average losses
        avg_total_loss = total_loss / len(dataloader)
        avg_task_losses = {task: (task_losses[task] / max(task_counts[task], 1))
                          for task in self.task_configs.keys()}

        # Update history
        self.history['train_loss'].append(avg_total_loss)
        for task, loss in avg_task_losses.items():
            self.history['task_losses'][task].append(loss)

        return avg_total_loss, avg_task_losses
    #separate evaluation - a model might be good at sentiment but bad at intent classification.
    def evaluate(self, dataloader):
        """Evaluate the model on validation data"""
        self.model.eval()
        task_predictions = {task: [] for task in self.task_configs.keys()}
        task_labels = {task: [] for task in self.task_configs.keys()}

        with torch.no_grad():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['label']

                # Group by task
                for i, task in enumerate(batch['task']):
                    # Forward pass
                    logits = self.model(input_ids[i:i+1], attention_mask[i:i+1], task)
                    predictions = torch.argmax(logits, dim=-1)

                    task_predictions[task].extend(predictions.cpu().numpy())
                    task_labels[task].append(labels[i].item())

        # Compute metrics for each task
        task_metrics = {}
        for task in self.task_configs.keys():
            if len(task_predictions[task]) > 0:
                metrics = compute_metrics(task_predictions[task], task_labels[task])
                task_metrics[task] = metrics
                self.history['val_metrics'][task].append(metrics)

        return task_metrics

In [24]:
def main():
    """
    Main training function demonstrating the complete multitask fine-tuning pipeline
    """
    print(" Starting Multitask Fine-tuning Tutorial")
    print("=" * 50)

    # Configuration
    MODEL_NAME = "distilbert-base-uncased"  # Small, efficient model
    BATCH_SIZE = 8
    NUM_EPOCHS = 10
    MAX_LENGTH = 128

    # Check if GPU is available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Task configurations (task_name: num_classes)
    task_configs = {
        'sentiment': 3,  # positive, negative, neutral
        'intent': 4,     # greeting, question, request, goodbye
        'topic': 3       # tech, sports, politics
    }

    print(f"\nTask configurations:")
    for task, num_classes in task_configs.items():
        print(f"  - {task}: {num_classes} classes")

    # Load tokenizer
    print(f"\n Loading tokenizer: {MODEL_NAME}")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    # Create sample datasets
    print("\n Creating sample datasets...")
    tasks_data = create_sample_datasets()

    # Print dataset statistics
    print("\nDataset statistics:")
    for task_name, task_data in tasks_data.items():
        print(f"  - {task_name}: {len(task_data)} samples")

    # Split data into train/validation (80/20 split)
    train_data = {}
    val_data = {}

    for task_name, task_samples in tasks_data.items():
        random.shuffle(task_samples)
        split_idx = int(0.8 * len(task_samples))
        train_data[task_name] = task_samples[:split_idx]
        val_data[task_name] = task_samples[split_idx:]

    # Create datasets
    print("\n Creating PyTorch datasets...")
    train_dataset = MultitaskDataset(train_data, tokenizer, MAX_LENGTH)
    val_dataset = MultitaskDataset(val_data, tokenizer, MAX_LENGTH)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")

    # Initialize model
    print(f"\n Initializing multitask model...")
    model = MultitaskModel(MODEL_NAME, task_configs)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Initialize trainer
    trainer = MultitaskTrainer(model, tokenizer, task_configs, device)

    # Training loop
    print(f"\n Starting training for {NUM_EPOCHS} epochs...")
    print("-" * 50)

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
        print("-" * 30)

        # Train
        train_loss, task_losses = trainer.train_epoch(train_loader, epoch)
        print(f"Training - Overall Loss: {train_loss:.4f}")
        for task, loss in task_losses.items():
            print(f"  {task}: {loss:.4f}")

        # Evaluate
        val_metrics = trainer.evaluate(val_loader)
        print(f"\nValidation Results:")
        for task, metrics in val_metrics.items():
            print(f"  {task}: Accuracy={metrics['accuracy']:.3f}, F1={metrics['f1']:.3f}")

    print("\nTraining completed!")

    #inference
    print("Demonstration of inference on new samples:")
    print("-" * 50)

    test_samples = [
        ("This movie is absolutely fantastic!", "sentiment"),
        ("Hello, can you help me?", "intent"),
        ("The new AI technology is revolutionary", "topic")
    ]

    model.eval()
    with torch.no_grad():
        for text, expected_task in test_samples:
            # Tokenize
            encoding = tokenizer(text, truncation=True, padding='max_length',
                               max_length=MAX_LENGTH, return_tensors='pt')

            # Predict
            logits = model(encoding['input_ids'].to(device),
                          encoding['attention_mask'].to(device),
                          expected_task)

            prediction = torch.argmax(logits, dim=-1).item()
            confidence = torch.softmax(logits, dim=-1).max().item()

            print(f"Text: '{text}'")
            print(f"Task: {expected_task}")
            print(f"Prediction: {prediction} (confidence: {confidence:.3f})")
            print()


In [25]:
if __name__ == "__main__":
    print("=" * 50)
    main()

 Starting Multitask Fine-tuning Tutorial
Using device: cuda

Task configurations:
  - sentiment: 3 classes
  - intent: 4 classes
  - topic: 3 classes

 Loading tokenizer: distilbert-base-uncased

 Creating sample datasets...

Dataset statistics:
  - sentiment: 12 samples
  - intent: 12 samples
  - topic: 12 samples

 Creating PyTorch datasets...
Train dataset size: 27
Validation dataset size: 9

 Initializing multitask model...


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Total parameters: 67,252,618
Trainable parameters: 67,252,618

 Starting training for 10 epochs...
--------------------------------------------------

Epoch 1/10
------------------------------
Epoch 0, Batch 0/4, Loss: 3.5806
Training - Overall Loss: 3.3916
  sentiment: 1.1265
  intent: 1.4127
  topic: 1.1365

Validation Results:
  sentiment: Accuracy=0.333, F1=0.167
  intent: Accuracy=0.333, F1=0.167
  topic: Accuracy=0.333, F1=0.167

Epoch 2/10
------------------------------
Epoch 1, Batch 0/4, Loss: 3.4224
Training - Overall Loss: 3.2547
  sentiment: 1.0909
  intent: 1.3575
  topic: 1.0751

Validation Results:
  sentiment: Accuracy=0.333, F1=0.167
  intent: Accuracy=0.333, F1=0.333
  topic: Accuracy=0.333, F1=0.167

Epoch 3/10
------------------------------
Epoch 2, Batch 0/4, Loss: 3.5311
Training - Overall Loss: 3.1369
  sentiment: 1.0533
  intent: 1.3179
  topic: 1.0290

Validation Results:
  sentiment: Accuracy=0.333, F1=0.222
  intent: Accuracy=0.667, F1=0.667
  topic: Accuracy