# Multi-Head Fine-Tuning

Multi-head fine-tuning refers to the practice of fine-tuning a pre-trained model for multiple tasks simultaneously. This is often done by adding multiple "heads" to a shared "base" model. Each head is responsible for a specific task. The idea is that the shared layers learn general features that are useful for all tasks, while each head specializes in its own task.

For example, in a natural language processing scenario, you might have one head for sentiment analysis and another for named entity recognition.

In [1]:
# !sudo apt-get install libopenmpi-dev
# !sudo apt install nvidia-cuda-toolkit

<IPython.core.display.Javascript object>

In [2]:
%pip install torch==2.0.1 transformers deepspeed mpi4py --quiet

Note: you may need to restart the kernel to use updated packages.


<IPython.core.display.Javascript object>

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split

from transformers import BertModel, BertTokenizer

import deepspeed

[2023-09-23 15:53:33,783] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


<IPython.core.display.Javascript object>

In [4]:
torch.__version__

'2.0.1+cu117'

<IPython.core.display.Javascript object>

In [5]:
class MultiHeadModel(nn.Module):
    def __init__(self):
        super(MultiHeadModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        # Sentiment analysis head (binary classification)
        self.sentiment_head = nn.Linear(768, 1)
        
        # Named entity recognition head (let's assume 10 classes)
        self.ner_head = nn.Linear(768, 10)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        
        # For sentiment analysis, we'll just use the [CLS] token representation
        cls_token = last_hidden_state[:, 0, :]
        sentiment_output = self.sentiment_head(cls_token)
        
        # For NER, we'll use the representation for each token
        ner_output = self.ner_head(last_hidden_state)
        
        return sentiment_output, ner_output

<IPython.core.display.Javascript object>

In [6]:
# Initialize the model and optimizer
model = MultiHeadModel()
model = torch.compile(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Initialize DeepSpeed
model, optimizer, _, _ = deepspeed.initialize(optimizer=optimizer,model=model,config='ds_config.json')

[2023-09-23 15:53:37,759] [INFO] [logging.py:96:log_dist] [Rank -1] DeepSpeed info: version=0.10.3, git-hash=unknown, git-branch=unknown
[2023-09-23 15:53:37,759] [INFO] [comm.py:637:init_distributed] cdb=None
[2023-09-23 15:53:37,760] [INFO] [comm.py:652:init_distributed] Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...
[2023-09-23 15:53:39,781] [INFO] [comm.py:702:mpi_discovery] Discovered MPI settings of world_rank=0, local_rank=0, world_size=1, master_addr=192.168.0.113, master_port=29500
[2023-09-23 15:53:39,782] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2023-09-23 15:53:41,149] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2023-09-23 15:53:41,151] [INFO] [logging.py:96:log_dist] [Rank 0] Removing param_group that has no 'params' in the client Optimizer
[2023-09-23 15:53:41,152] [INFO] [logging.py:96:log_dist] [Rank 0] Using client Optimizer as basic optimize

<IPython.core.display.Javascript object>

## Creating data

In [7]:
import random

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

num_samples = 10000
max_length = 50

positive_texts = ["I absolutely love this product!", 
                  "This is amazing, I'm so happy with it.", 
                  "Fantastic experience, would recommend to anyone.", 
                  "Great job, keep up the good work!", 
                  "Excellent service, couldn't be happier."]

negative_texts = ["I really hate this, it's awful.", 
                  "This is terrible, would not recommend to anyone.", 
                  "Awful experience, I'm so disappointed.", 
                  "Bad job, this needs a lot of improvement.", 
                  "Poor service, not happy at all."]

texts = []
sentiments = []

persons = ["John", "Emily", "Michael", "Sarah"]
organizations = ["Google", "Microsoft", "Apple"]
locations = ["New York", "San Francisco", "London"]
ner_sentences = [
    "[PERSON] works at [ORG].",
    "[PERSON] lives in [LOC].",
    "[ORG] is located in [LOC]."
]
ners = []

for _ in range(num_samples):
    # Sentiment
    if random.choice([True, False]):
        texts.append(random.choice(positive_texts))
        sentiments.append(1)
    else:
        texts.append(random.choice(negative_texts))
        sentiments.append(0)

    # NER
    ner_sentence = random.choice(ner_sentences)
    ner_sentence = ner_sentence.replace("[PERSON]", random.choice(persons))
    ner_sentence = ner_sentence.replace("[ORG]", random.choice(organizations))
    ner_sentence = ner_sentence.replace("[LOC]", random.choice(locations))
    ner_label_sequence = [0 if word not in persons + organizations + locations else persons.index(word) + 1 
                          if word in persons else organizations.index(word) + 5 
                          if word in organizations 
                          else locations.index(word) + 8 
                          for word in ner_sentence.split()
                          ]
    
    ner_label_sequence += [0] * (max_length - len(ner_label_sequence))  # Padding
    ners.append(ner_label_sequence[:max_length])

sentiments = torch.tensor(sentiments, dtype=torch.float32).view(-1, 1)
ners = torch.tensor(ners, dtype=torch.long)

# Tokenize the texts
encoding = tokenizer(texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']

dataset = TensorDataset(input_ids, attention_mask, sentiments, ners)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

<IPython.core.display.Javascript object>

### Training

In [11]:
%%time
for epoch in range(10):
    for batch in train_loader:
        input_ids, attention_mask, sentiment_labels, ner_labels = batch

        input_ids = input_ids.to(model.device)
        attention_mask = attention_mask.to(model.device)
        sentiment_labels = sentiment_labels.to(model.device)
        ner_labels = ner_labels.to(model.device)

        sentiment_output, ner_output = model(input_ids, attention_mask)
        
        sentiment_loss = F.binary_cross_entropy_with_logits(sentiment_output, sentiment_labels)
        ner_loss = F.cross_entropy(ner_output.view(-1, 10), ner_labels.view(-1))
        loss = sentiment_loss + ner_loss

        model.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

Epoch 1, Loss: 0.038848876953125
Epoch 2, Loss: 0.039093017578125
[2023-09-23 15:57:28,565] [INFO] [unfused_optimizer.py:289:_update_scale] No Grad overflow for 1000 iterations
[2023-09-23 15:57:28,566] [INFO] [unfused_optimizer.py:290:_update_scale] Increasing dynamic loss scale from 65536.0 to 131072.0
Epoch 3, Loss: 0.040252685546875
Epoch 4, Loss: 0.03826904296875
Epoch 5, Loss: 0.04071044921875
Epoch 6, Loss: 0.039306640625
[2023-09-23 15:59:53,868] [INFO] [unfused_optimizer.py:289:_update_scale] No Grad overflow for 1000 iterations
[2023-09-23 15:59:53,869] [INFO] [unfused_optimizer.py:290:_update_scale] Increasing dynamic loss scale from 131072.0 to 262144.0
Epoch 7, Loss: 0.038299560546875
Epoch 8, Loss: 0.03857421875
Epoch 9, Loss: 0.039398193359375
Epoch 10, Loss: 0.038543701171875
CPU times: user 6min 7s, sys: 300 ms, total: 6min 7s
Wall time: 6min 7s


<IPython.core.display.Javascript object>

### Test the model

In [9]:
test_texts = ["I love this max!", "This is terrible anna!"]
encoding = tokenizer(test_texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
input_ids = encoding['input_ids'].to(model.device)
attention_mask = encoding['attention_mask'].to(model.device)

with torch.no_grad():
    sentiment_output, ner_output = model(input_ids, attention_mask)
    sentiment_output = torch.sigmoid(sentiment_output)
    ner_output = torch.argmax(ner_output, dim=-1)

<IPython.core.display.Javascript object>

In [10]:
# Interpret the sentiment output
sentiment_output_np = sentiment_output.cpu().numpy()
sentiment_labels = ["Positive" if score > 0.5 else "Negative" for score in sentiment_output_np]

# Interpret the NER output
ner_output_np = ner_output.cpu().numpy()
ner_classes = ['O', 'PERSON', 'ORG', 'LOC', 'DATE', 'TIME', 'MONEY', 'PERCENT', 'FAC', 'GPE']
ner_labels = [[ner_classes[label] for label in sequence] for sequence in ner_output_np]

for i, (sentiment, ner) in enumerate(zip(sentiment_labels, ner_labels)):
    print(f"Sentence {i+1}: {test_texts[i]}")
    print(f"  Sentiment: {sentiment}")
    #print(f"  NER Labels: {ner}")


Sentence 1: I love this max!
  Sentiment: Positive
Sentence 2: This is terrible anna!
  Sentiment: Negative


<IPython.core.display.Javascript object>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple multi-headed model
class MultiHeadedModel(nn.Module):
    def __init__(self):
        super(MultiHeadedModel, self).__init__()
        self.shared_layer = nn.Linear(10, 20)
        
        # Classification head
        self.classification_head = nn.Linear(20, 3)  # 3 classes
        
        # Regression head
        self.regression_head = nn.Linear(20, 1)  # 1 output for regression

    def forward(self, x):
        x = self.shared_layer(x)
        
        # Classification output
        classification_output = self.classification_head(x)
        
        # Regression output
        regression_output = self.regression_head(x)
        
        return classification_output, regression_output

# Custom loss function
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.classification_loss = nn.CrossEntropyLoss()
        self.regression_loss = nn.MSELoss()

    def forward(self, classification_output, regression_output, classification_target, regression_target):
        loss1 = self.classification_loss(classification_output, classification_target)
        loss2 = self.regression_loss(regression_output, regression_target)
        
        # Combine the two losses in some way
        combined_loss = loss1 + loss2
        return combined_loss

# Initialize model and loss
model = MultiHeadedModel()
criterion = CustomLoss()

# Dummy data
x = torch.randn(5, 10)  # 5 samples, 10 features
classification_target = torch.tensor([0, 1, 2, 0, 1])  # 5 samples, 3 classes
regression_target = torch.randn(5, 1)  # 5 samples, 1 output

# Forward pass
classification_output, regression_output = model(x)

# Compute loss
loss = criterion(classification_output, regression_output, classification_target, regression_target)

# Backward pass and optimization
optimizer = optim.SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Combined Loss: {loss.item()}")
