I've expanded the previous implementation to support multi-task learning with the following key additions and modifications:

1. Task-Specific Heads:
   - SentenceClassificationHead: For sentence-level classification (e.g., sentiment analysis)
   - NERHead: For token-level named entity recognition
   
2. Architecture Changes:
   - Modified the base transformer to support both sequence-level and token-level tasks
   - Added task-specific pooling and classification layers
   - Implemented a unified forward pass that handles both tasks

3. Multi-Task Components:
   - Task-specific loss functions
   - Support for different output formats per task
   - Handling of padding for variable-length sequences
   - Task-specific processing of transformer outputs

4. Features:
   - Classification task outputs single prediction per sentence
   - NER task outputs predictions for each token
   - Shared transformer backbone between tasks
   - Task-specific loss calculations
   - Support for padding masks and attention masks

The model can be easily extended to support additional tasks by:
1. Adding new task-specific heads
2. Extending the forward pass to handle the new task
3. Adding corresponding loss functions

### Import Libraries
---

In [1]:
import torch
from multitask_learning import (
    MultitaskTransformer,
    MultitaskLoss,
    create_padding_mask
)


### Initialize Model
---

In [2]:
vocab_size = 4096
num_classes = 3  # e.g., [positive, negative, neutral]
num_ner_tags = 5  # e.g., [O, B-PER, I-PER, B-ORG, I-ORG]
model = MultitaskTransformer(
    vocab_size=vocab_size,
    num_classes=num_classes,
    num_ner_tags=num_ner_tags
)


### Test Cases
---
**Test case 1**

In [3]:
# Create sample input
batch_size = 2
seq_length = 10
src = torch.randint(0, vocab_size, (batch_size, seq_length))

# Create a padding mask (assuming pad_idx=0)
padding_mask = create_padding_mask(src, pad_idx=0)

# 1. Test Classification Task
classification_outputs = model(src, task='classification', src_padding_mask=padding_mask)
print("[Classification] Logits shape:", classification_outputs['logits'].shape)
assert classification_outputs['logits'].shape == (batch_size, num_classes), \
    "Classification logits shape is incorrect!"

# 2. Test NER Task
ner_outputs = model(src, task='ner', src_padding_mask=padding_mask)
print("[NER] Logits shape:", ner_outputs['logits'].shape)
assert ner_outputs['logits'].shape == (batch_size, seq_length, num_ner_tags), \
    "NER logits shape is incorrect!"

print("✔ Basic forward pass and shape tests passed.")


[Classification] Logits shape: torch.Size([2, 3])
[NER] Logits shape: torch.Size([2, 10, 5])
✔ Basic forward pass and shape tests passed.


**Test case 2**

In [4]:
criterion = MultitaskLoss()

# 1. Classification Loss
classification_labels = torch.randint(0, num_classes, (batch_size,))
classification_loss = criterion(classification_outputs, classification_labels, 'classification')
print("[Classification] Loss:", classification_loss.item())

# 2. NER Loss
#   - create labels in [0, num_ner_tags-1] except for some padding tokens = -100
ner_labels = torch.randint(0, num_ner_tags, (batch_size, seq_length))
ner_labels[:, -2:] = -100  # artificially pad last two tokens
ner_loss = criterion(ner_outputs, ner_labels, 'ner')
print("[NER] Loss:", ner_loss.item())

assert classification_loss.dim() == 0, "Classification loss should be a scalar!"
assert ner_loss.dim() == 0, "NER loss should be a scalar!"

print("✔ Loss computation test passed.")


[Classification] Loss: 1.1429393291473389
[NER] Loss: 1.6583757400512695
✔ Loss computation test passed.


**Test case 3**

In [5]:
src_with_padding = src.clone()
src_with_padding[0, -3:] = 0  # Force some tokens to be padding in the first sequence
padding_mask_with_padding = create_padding_mask(src_with_padding, pad_idx=0)

outputs_with_padding = model(src_with_padding, task='classification', src_padding_mask=padding_mask_with_padding)
print("[Padding Test] Classification logits shape:", outputs_with_padding['logits'].shape)
assert not torch.isnan(outputs_with_padding['logits']).any(), "NaNs found in the output!"

print("✔ Padding mask usage test passed.")


[Padding Test] Classification logits shape: torch.Size([2, 3])
✔ Padding mask usage test passed.
