# Multi-Head Fine-Tuning

Multi-head fine-tuning refers to the practice of fine-tuning a pre-trained model for multiple tasks simultaneously. This is often done by adding multiple "heads" to a shared "base" model. Each head is responsible for a specific task. The idea is that the shared layers learn general features that are useful for all tasks, while each head specializes in its own task.

For example, in a natural language processing scenario, you might have one head for sentiment analysis and another for named entity recognition.

In [None]:
# !sudo apt-get install libopenmpi-dev
# !sudo apt install nvidia-cuda-toolkit

In [None]:
%pip install torch==1.12.1 transformers deepspeed mpi4py --quiet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split

from transformers import BertModel, BertTokenizer

import deepspeed

In [None]:
class MultiHeadModel(nn.Module):
    def __init__(self):
        super(MultiHeadModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        # Sentiment analysis head (binary classification)
        self.sentiment_head = nn.Linear(768, 1)
        
        # Named entity recognition head (let's assume 10 classes)
        self.ner_head = nn.Linear(768, 10)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        
        # For sentiment analysis, we'll just use the [CLS] token representation
        cls_token = last_hidden_state[:, 0, :]
        sentiment_output = self.sentiment_head(cls_token)
        
        # For NER, we'll use the representation for each token
        ner_output = self.ner_head(last_hidden_state)
        
        return sentiment_output, ner_output

In [None]:
# Initialize the model and optimizer
model = MultiHeadModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Initialize DeepSpeed
model, optimizer, _, _ = deepspeed.initialize(optimizer=optimizer,model=model,config='ds_config.json')

## Creating data

In [None]:
import random

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

num_samples = 10000
max_length = 50

positive_texts = ["I absolutely love this product!", 
                  "This is amazing, I'm so happy with it.", 
                  "Fantastic experience, would recommend to anyone.", 
                  "Great job, keep up the good work!", 
                  "Excellent service, couldn't be happier."]

negative_texts = ["I really hate this, it's awful.", 
                  "This is terrible, would not recommend to anyone.", 
                  "Awful experience, I'm so disappointed.", 
                  "Bad job, this needs a lot of improvement.", 
                  "Poor service, not happy at all."]

texts = []
sentiments = []

persons = ["John", "Emily", "Michael", "Sarah"]
organizations = ["Google", "Microsoft", "Apple"]
locations = ["New York", "San Francisco", "London"]
ner_sentences = [
    "[PERSON] works at [ORG].",
    "[PERSON] lives in [LOC].",
    "[ORG] is located in [LOC]."
]
ners = []

for _ in range(num_samples):
    # Sentiment
    if random.choice([True, False]):
        texts.append(random.choice(positive_texts))
        sentiments.append(1)
    else:
        texts.append(random.choice(negative_texts))
        sentiments.append(0)

    # NER
    ner_sentence = random.choice(ner_sentences)
    ner_sentence = ner_sentence.replace("[PERSON]", random.choice(persons))
    ner_sentence = ner_sentence.replace("[ORG]", random.choice(organizations))
    ner_sentence = ner_sentence.replace("[LOC]", random.choice(locations))
    ner_label_sequence = [0 if word not in persons + organizations + locations else persons.index(word) + 1 
                          if word in persons else organizations.index(word) + 5 
                          if word in organizations 
                          else locations.index(word) + 8 
                          for word in ner_sentence.split()
                          ]
    
    ner_label_sequence += [0] * (max_length - len(ner_label_sequence))  # Padding
    ners.append(ner_label_sequence[:max_length])

sentiments = torch.tensor(sentiments, dtype=torch.float32).view(-1, 1)
ners = torch.tensor(ners, dtype=torch.long)

# Tokenize the texts
encoding = tokenizer(texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']

dataset = TensorDataset(input_ids, attention_mask, sentiments, ners)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

### Training

In [None]:
for epoch in range(2):
    for batch in train_loader:
        input_ids, attention_mask, sentiment_labels, ner_labels = batch

        input_ids = input_ids.to(model.device)
        attention_mask = attention_mask.to(model.device)
        sentiment_labels = sentiment_labels.to(model.device)
        ner_labels = ner_labels.to(model.device)

        sentiment_output, ner_output = model(input_ids, attention_mask)
        
        sentiment_loss = F.binary_cross_entropy_with_logits(sentiment_output, sentiment_labels)
        ner_loss = F.cross_entropy(ner_output.view(-1, 10), ner_labels.view(-1))
        loss = sentiment_loss + ner_loss

        model.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

### Test the model

In [None]:
test_texts = ["I love this max!", "This is terrible anna!"]
encoding = tokenizer(test_texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
input_ids = encoding['input_ids'].to(model.device)
attention_mask = encoding['attention_mask'].to(model.device)

with torch.no_grad():
    sentiment_output, ner_output = model(input_ids, attention_mask)
    sentiment_output = torch.sigmoid(sentiment_output)
    ner_output = torch.argmax(ner_output, dim=-1)

In [None]:
# Interpret the sentiment output
sentiment_output_np = sentiment_output.cpu().numpy()
sentiment_labels = ["Positive" if score > 0.5 else "Negative" for score in sentiment_output_np]

# Interpret the NER output
ner_output_np = ner_output.cpu().numpy()
ner_classes = ['O', 'PERSON', 'ORG', 'LOC', 'DATE', 'TIME', 'MONEY', 'PERCENT', 'FAC', 'GPE']
ner_labels = [[ner_classes[label] for label in sequence] for sequence in ner_output_np]

for i, (sentiment, ner) in enumerate(zip(sentiment_labels, ner_labels)):
    print(f"Sentence {i+1}:")
    print(f"  Sentiment: {sentiment}")
    print(f"  NER Labels: {ner}")
