# Transfer Learning and Quantization for Multi-task NLP

This notebook demonstrates how to perform transfer learning and quantization for multi-task NLP using a large pre-trained language model.

### Table of Contents

1. Setup and Imports
2. Load Pre-trained Model and Datasets
3. mplement Task-specific Heads
4. Mixed Precision Multi-task Fine-tuning
5. Quantization-Aware Fine-tuning (QAF) for Shared Base
6. Gradual Quantization of Task-specific Heads
7. Layer-wise Adaptive Quantization
8. Efficient INT8 Inference Implementation
9. Critical Layer Analysis and Precision Adjustment
10. Quantized Knowledge Distillation
11. Performance Evaluation

## 1. Setup and Imports

First, let's import the necessary libraries:

In [45]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
from transformers import RobertaModel, RobertaTokenizerFast, AdamW, DataCollatorWithPadding
from datasets import load_dataset
from torch.cuda.amp import autocast, GradScaler
import numpy as np

## 2. Load Pre-trained Model and Datasets

Let's load a pre-trained RoBERTa model and multiple NLP datasets:

In [46]:
# Load model and tokenizer with add_prefix_space=True for NER
model_name = 'roberta-base'
tokenizer = RobertaTokenizerFast.from_pretrained(model_name, add_prefix_space=True)
model = RobertaModel.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Load datasets for multiple tasks
sentiment_dataset = load_dataset("glue", "sst2")
ner_dataset = load_dataset("conll2003")
text_classification_dataset = load_dataset("ag_news")

# Preprocessing function for tokenization
def tokenize_function(examples, task):
    if task == 'sentiment':
        return tokenizer(examples['sentence'], padding='max_length', truncation=True, max_length=128)
    elif task == 'classification':
        return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128)
    elif task == 'ner':
        tokenized_inputs = tokenizer(examples['tokens'], padding='max_length', truncation=True, is_split_into_words=True)
        
        # Align the labels (NER tags) with tokenized words
        labels = []
        for i, label in enumerate(examples['ner_tags']):
            word_ids = tokenized_inputs.word_ids(batch_index=i)  # Get word_ids for alignment
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                if word_idx is None:  # Padding token, ignore
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:  # First token of a word
                    label_ids.append(label[word_idx])
                else:  # Other tokens in a word
                    label_ids.append(-100)
                previous_word_idx = word_idx
            
            # Ensure that the length of the labels is the same as the tokenized input
            label_ids += [-100] * (128 - len(label_ids))  # Pad the labels
            labels.append(label_ids)
        
        tokenized_inputs["labels"] = labels
        return tokenized_inputs

# Apply tokenization
sentiment_dataset = sentiment_dataset.map(lambda x: tokenize_function(x, task='sentiment'), batched=True)
ner_dataset = ner_dataset.map(lambda x: tokenize_function(x, task='ner'), batched=True)
text_classification_dataset = text_classification_dataset.map(lambda x: tokenize_function(x, task='classification'), batched=True)

# Remove columns that are no longer needed for training after tokenization
sentiment_dataset = sentiment_dataset.remove_columns(['sentence', 'idx'])
ner_dataset = ner_dataset.remove_columns(['tokens', 'pos_tags', 'chunk_tags', 'ner_tags'])
text_classification_dataset = text_classification_dataset.remove_columns(['text', 'label'])

# Data Collator for padding
data_collator = DataCollatorWithPadding(tokenizer, return_tensors="pt")

# Prepare dataloaders
dataloaders = {
    'sentiment': DataLoader(sentiment_dataset['train'], batch_size=32, shuffle=True, collate_fn=data_collator),
    'ner': DataLoader(ner_dataset['train'], batch_size=32, shuffle=True, collate_fn=data_collator),
    'classification': DataLoader(text_classification_dataset['train'], batch_size=32, shuffle=True, collate_fn=data_collator)
}

print("Tokenization and DataLoader setup complete.")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

Map:   0%|          | 0/14041 [00:00<?, ? examples/s]

ArrowInvalid: Column 7 named labels expected length 1000 but got length 512

## 3. Implement Task-specific Heads

Now, let's implement task-specific heads for each NLP task:

In [43]:
# Define multi-task model
class MultiTaskModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.sentiment_head = nn.Linear(base_model.config.hidden_size, 2)
        self.ner_head = nn.Linear(base_model.config.hidden_size, 9)  # Assuming 9 NER tags
        self.classification_head = nn.Linear(base_model.config.hidden_size, 4)  # Assuming 4 classes for AG News
    
    def forward(self, input_ids, attention_mask, task):
        base_output = self.base_model(input_ids, attention_mask=attention_mask).last_hidden_state
        if task == 'sentiment':
            return self.sentiment_head(base_output[:, 0, :])  # Use [CLS] token
        elif task == 'ner':
            return self.ner_head(base_output)
        elif task == 'classification':
            return self.classification_head(base_output[:, 0, :])  # Use [CLS] token

multi_task_model = MultiTaskModel(model)
multi_task_model = multi_task_model.to(device)
print("Implemented task-specific heads.")

Implemented task-specific heads.


## 4. Mixed Precision Multi-task Fine-tuning

Let's implement mixed precision training for multi-task fine-tuning:

In [44]:
# Training function
def train_epoch(model, dataloaders, optimizer, scheduler, scaler):
    model.train()
    total_loss = 0
    for task, dataloader in dataloaders.items():
        for batch in dataloader:
            optimizer.zero_grad()
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Access the correct label key for each task
            if task == 'sentiment' or task == 'classification':
                labels = batch['labels'].to(device)
            elif task == 'ner':
                labels = batch['labels'].to(device)
                
            with autocast():
                outputs = model(input_ids, attention_mask=attention_mask, task=task)
                loss = nn.CrossEntropyLoss()(outputs.view(-1, outputs.shape[-1]), labels.view(-1))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            
            total_loss += loss.item()
    
    return total_loss / sum(len(dl) for dl in dataloaders.values())

# Fine-tuning setup
optimizer = AdamW(multi_task_model.parameters(), lr=5e-5)
num_epochs = 3
total_steps = num_epochs * sum(len(dl) for dl in dataloaders.values())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-5, total_steps=total_steps)
scaler = GradScaler()

# Fine-tuning loop
for epoch in range(num_epochs):
    avg_loss = train_epoch(multi_task_model, dataloaders, optimizer, scheduler, scaler)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

print("Mixed precision multi-task fine-tuning completed.")

  scaler = GradScaler()


ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`sentence` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

## 5. Quantization-Aware Fine-tuning (QAF) for Shared Base

Now, let's apply Quantization-Aware Fine-tuning to the shared base:

In [None]:
def apply_qaf_to_base(model):
    model.base_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    model.base_model = torch.quantization.prepare_qat(model.base_model)
    return model

multi_task_model = apply_qaf_to_base(multi_task_model)
print("Applied Quantization-Aware Fine-tuning to shared base")

## 6. Gradual Quantization of Task-specific Heads

Let's implement gradual quantization of task-specific heads:

In [None]:
def gradual_head_quantization(model, epoch, total_epochs):
    if epoch >= total_epochs // 2:  # Start quantizing heads halfway through
        heads = [model.sentiment_head, model.ner_head, model.classification_head]
        for head in heads:
            head.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
            torch.quantization.prepare_qat(head, inplace=True)

# Modify the training loop to include gradual head quantization
for epoch in range(num_epochs):
    gradual_head_quantization(multi_task_model, epoch, num_epochs)
    avg_loss = train_epoch(multi_task_model, dataloaders, optimizer, scheduler, scaler)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

print("Completed gradual quantization of task-specific heads")

## 7. Layer-wise Adaptive Quantization

Implement layer-wise adaptive quantization based on sensitivity analysis:

In [None]:
def evaluate_model(model, dataloader):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask, task='classification')  # Assume classification task for simplicity
            loss = nn.CrossEntropyLoss()(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
    
    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples
    return avg_loss

def sensitivity_analysis(model, dataloader):
    sensitivities = {}
    for name, module in model.base_model.named_modules():
        if isinstance(module, nn.Linear):
            original_weight = module.weight.data.clone()
            module.weight.data = torch.quantize_per_tensor(original_weight, 1.0, 0, torch.qint8).dequantize()
            loss = evaluate_model(model, dataloader)
            sensitivities[name] = loss
            module.weight.data = original_weight
    return sensitivities

def apply_adaptive_quantization(model, sensitivities, threshold):
    for name, module in model.base_model.named_modules():
        if isinstance(module, nn.Linear):
            if sensitivities[name] > threshold:
                module.qconfig = torch.quantization.float_qparams_weight_only_qconfig
            else:
                module.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

# Perform sensitivity analysis
sensitivities = sensitivity_analysis(multi_task_model, next(iter(dataloaders.values())))
threshold = np.mean(list(sensitivities.values())) + np.std(list(sensitivities.values()))
apply_adaptive_quantization(multi_task_model, sensitivities, threshold)

print("Applied layer-wise adaptive quantization")

## 8. Efficient INT8 Inference Implementation

Implement efficient INT8 inference for the base model:

In [None]:
def convert_to_int8(model):
    model.eval()
    model.base_model = torch.quantization.convert(model.base_model)
    return model

multi_task_model = convert_to_int8(multi_task_model)
print("Implemented efficient INT8 inference for base model")

## 9. Critical Layer Analysis and Precision Adjustment

Identify and adjust precision for critical layers:

In [None]:
def adjust_critical_layers(model, sensitivities, top_k=3):
    sorted_sensitivities = sorted(sensitivities.items(), key=lambda x: x[1], reverse=True)
    critical_layers = [name for name, _ in sorted_sensitivities[:top_k]]
    
    for name, module in model.base_model.named_modules():
        if name in critical_layers:
            module.qconfig = None  # Keep in higher precision

adjust_critical_layers(multi_task_model, sensitivities)
print("Adjusted precision for critical layers")

## 10. Quantized Knowledge Distillation

Implement quantized knowledge distillation to create a smaller, efficient multi-task model:

In [None]:
class SmallMultiTaskModel(nn.Module):
    def __init__(self, hidden_size=256):
        super().__init__()
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=hidden_size, nhead=4), num_layers=4)
        self.sentiment_head = nn.Linear(hidden_size, 2)
        self.ner_head = nn.Linear(hidden_size, 9)
        self.classification_head = nn.Linear(hidden_size, 4)

    def forward(self, input_ids, attention_mask, task):
        # Simplified encoding (replace with appropriate embedding)
        x = torch.nn.functional.one_hot(input_ids, num_classes=30522).float()
        x = self.encoder(x)
        if task == 'sentiment':
            return self.sentiment_head(x[:, 0, :])
        elif task == 'ner':
            return self.ner_head(x)
        elif task == 'classification':
            return self.classification_head(x[:, 0, :])

small_model = SmallMultiTaskModel().to(device)

def distillation_loss(student_outputs, teacher_outputs, labels, temperature=2.0):
    soft_targets = nn.functional.softmax(teacher_outputs / temperature, dim=1)
    soft_prob = nn.functional.log_softmax(student_outputs / temperature, dim=1)
    distillation_loss = nn.KLDivLoss(reduction='batchmean')(soft_prob, soft_targets) * (temperature ** 2)
    student_loss = nn.CrossEntropyLoss()(student_outputs, labels)
    return 0.5 * (distillation_loss + student_loss)

def train_distillation(teacher_model, student_model, dataloaders, optimizer, scheduler, scaler, epochs=5):
    for epoch in range(epochs):
        total_loss = 0
        for task, dataloader in dataloaders.items():
            for batch in dataloader:
                optimizer.zero_grad()
                
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                with autocast():
                    with torch.no_grad():
                        teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask, task=task)
                    student_outputs = student_model(input_ids, attention_mask=attention_mask, task=task)
                    loss = distillation_loss(student_outputs, teacher_outputs, labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                
                total_loss += loss.item()
        
        print(f"Distillation Epoch {epoch+1}/{epochs}, Average Loss: {total_loss / sum(len(dl) for dl in dataloaders.values()):.4f}")

optimizer = AdamW(small_model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-4, total_steps=5 * sum(len(dl) for dl in dataloaders.values()))
scaler = GradScaler()

train_distillation(multi_task_model, small_model, dataloaders, optimizer, scheduler, scaler)
print("Completed quantized knowledge distillation")

## 11. Performance Evaluation

Finally, let's evaluate the performance of our original, quantized, and distilled models:

In [None]:
def evaluate_model(model, dataloaders):
    model.eval()
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for task, dataloader in dataloaders.items():
            for batch in dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(input_ids, attention_mask=attention_mask, task=task)
                _, predicted = torch.max(outputs, 1)
                total_correct += (predicted == labels).sum().item()
                total_samples += labels.size(0)
    
    return total_correct / total_samples

# Prepare evaluation dataloaders
eval_dataloaders = {
    'sentiment': DataLoader(sentiment_dataset['validation'], batch_size=32),
    'ner': DataLoader(ner_dataset['validation'], batch_size=32),
    'classification': DataLoader(text_classification_dataset['test'], batch_size=32)
}

original_accuracy = evaluate_model(multi_task_model, eval_dataloaders)
quantized_accuracy = evaluate_model(convert_to_int8(multi_task_model), eval_dataloaders)
distilled_accuracy = evaluate_model(small_model, eval_dataloaders)

print(f"Original Model Accuracy: {original_accuracy:.4f}")
print(f"Quantized Model Accuracy: {quantized_accuracy:.4f}")
print(f"Distilled Model Accuracy: {distilled_accuracy:.4f}")

# Measure inference time
def benchmark_inference_time(model, input_shape, num_runs=100):
    model.eval()
    input_ids = torch.randint(0, 30522, input_shape).to(device)
    attention_mask = torch.ones(input_shape).to(device)
    
    start_time = torch.cuda.Event(enable_timing=True)
    end_time = torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        # Warm-up run
        for _ in range(10):
            _ = model(input_ids, attention_mask=attention_mask, task='sentiment')
        
        # Timed runs
        start_time.record()
        for _ in range(num_runs):