# Task 2 & 4: Multi-Task Learning Expansion & Training Loop Implementation

### I chosed to utilize Pytorch Lightning's framework, as I have prior experience with it, and I like the organization of the model architecture and training calls.

#### For a Multitask model, I decided to add 2 Linear Layers after the transformer layer for classifying messages in the dataset as spam or not spam, as well as the sentiment of each message (positive, neutral, or negative). This is a completely functional model that trains on real data, and evaluates on a test set.

In [1]:
import lightning
import nltk.sentiment
import pandas as pd
import torch
import transformers
from sklearn import metrics, model_selection, utils

In [2]:
class MultiTask(lightning.LightningModule):
    def __init__(self, transformer_model, spam_weights=None, sentiment_weights=None, learning_rate=1e-4):
        """
        This is where the model's architecture is defined (layers, activations, etc.).
        """
        super().__init__()
        self.learning_rate = learning_rate
        self.save_hyperparameters(ignore=["transformer_model"])

        # Transformer backbone
        self.transformer_model = transformer_model
        for _, param in self.transformer_model.named_parameters():  # Freeze all transformer layers
            param.requires_grad = False
        hidden_size = self.transformer_model.config.hidden_size

        # Task layers
        self.spam_head = torch.nn.Linear(hidden_size, 1)
        self.sentiment_head = torch.nn.Linear(hidden_size, 3)

        # Weights. Need to use register_buffer so that the tensors follow the device of the model
        self.register_buffer(
            "spam_weights",
            torch.tensor(spam_weights, dtype=torch.float) if spam_weights is not None else None, )
        self.register_buffer(
            "sentiment_weights",
            torch.tensor(sentiment_weights, dtype=torch.float) if sentiment_weights is not None else None, )

    def forward(self, inputs):
        """
        Here we define how we use the modules to operate on an input batch. First, we run the inputs through the
        transformer backbone, then we average the outputs to get a single vector for each input. Finally, we run the
        embeddings trough each task layer. For simplicity, only 1 Linear layer is used for each task.
        """
        t_outputs = self.transformer_model(**inputs)
        embeddings = self.average_pool(t_outputs.last_hidden_state, inputs["attention_mask"])

        spam_logits = self.spam_head(embeddings)
        sentiment_logits = self.sentiment_head(embeddings)

        return {
            "spam_logits": spam_logits,
            "sentiment_logits": sentiment_logits
        }

    def training_step(self, batch, batch_idx):
        """
        For each step, we run a batch through the forward function and compute the loss. For the spam task, we are only
        predicting 0 (for not spam) or 1 (for spam). For the sentiment task, we are predicting one of three classes:
        (0 for negative, 1 for neutral, 2 for positive).
        """
        outputs = self(batch["inputs"])

        loss_spam = torch.nn.BCEWithLogitsLoss(pos_weight=self.spam_weights)(
            outputs["spam_logits"].view(-1), batch["spam_label"])
        loss_sentiment = torch.nn.CrossEntropyLoss(weight=self.sentiment_weights)(
            outputs["sentiment_logits"], batch["sentiment_label"])
        loss = loss_spam + loss_sentiment

        self.log("train_loss_task1", loss_spam)
        self.log("train_loss_task2", loss_sentiment)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(batch["inputs"])

        loss_spam = torch.nn.BCEWithLogitsLoss()(
            outputs["spam_logits"].view(-1), batch["spam_label"])
        loss_sentiment = torch.nn.CrossEntropyLoss()(
            outputs["sentiment_logits"], batch["sentiment_label"])
        loss = loss_spam + loss_sentiment

        self.log("val_loss_task1", loss_spam, prog_bar=True)
        self.log("val_loss_task2", loss_sentiment, prog_bar=True)
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    @staticmethod
    def average_pool(last_hidden_states, attention_mask):
        # We don't want to include padded tokens, so use masked_fill to zero them out.
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


class MultiTaskDataset(torch.utils.data.Dataset):
    def __init__(self, df_, tokenizer_):
        self.texts = df_["message"].tolist()
        self.spam_labels = df_["spam_label"].tolist()
        self.sentiment_labels = df_["sentiment_label"].tolist()
        self.tokenizer = tokenizer_

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        spam_label = self.spam_labels[idx]
        sentiment_label = self.sentiment_labels[idx]

        encoding = self.tokenizer(text, max_length=128, padding='max_length', truncation=True, return_tensors='pt')
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        return {
            "inputs": encoding,
            "spam_label": torch.tensor(spam_label),
            "sentiment_label": torch.tensor(sentiment_label)
        }


def create_dataloader(df_, tokenizer_, batch_size=32, shuffle=False):
    dataset = MultiTaskDataset(df_, tokenizer_)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4, pin_memory=True)

In [3]:
tokenizer = transformers.AutoTokenizer.from_pretrained("thenlper/gte-base")
model = transformers.AutoModel.from_pretrained("thenlper/gte-base")
df = pd.read_csv("hf://datasets/codesignal/sms-spam-collection/sms-spam-collection.csv")
df = df.rename(columns={"label": "spam_label"})
# Change spam label to a numerical value
df["spam_label"] = df["spam_label"].apply(lambda x: 1.0 if x == "spam" else 0.0)

# Download vader_lexicon (if not already present) for marking messages with sentiment value
nltk.download('vader_lexicon')
sia = nltk.sentiment.vader.SentimentIntensityAnalyzer()
df["sentiment_label"] = df["message"].apply(lambda x: sia.polarity_scores(x)["compound"])
df["sentiment_label"] = df["sentiment_label"].apply(lambda x: 2 if x >= 0.05 else 0 if x <= -0.05 else 1)

# Split dataset into 70% training, 15% validation, and 15% testing
train_df, temp_df = model_selection.train_test_split(df, test_size=0.3,
                                                     random_state=0)  # random_state for reproducibility
val_df, test_df = model_selection.train_test_split(temp_df, test_size=0.5, random_state=0)

# Create dataloaders
train_loader = create_dataloader(train_df, tokenizer, batch_size=32, shuffle=True)
val_loader = create_dataloader(val_df, tokenizer, batch_size=32, shuffle=False)
test_loader = create_dataloader(test_df, tokenizer, batch_size=32, shuffle=False)

# Optional step of class weights to deal with imbalance in dataset
spam_class_weights = utils.class_weight.compute_class_weight(
    class_weight="balanced", classes=train_df["spam_label"].unique(), y=train_df["spam_label"].tolist())[1:]
sentiment_class_weights = utils.class_weight.compute_class_weight(
    class_weight="balanced", classes=train_df["sentiment_label"].unique(), y=train_df["sentiment_label"].tolist())

# Instantiate a Trainer, and put it on GPU if available
trainer = lightning.Trainer(max_epochs=5, accelerator="gpu" if torch.cuda.is_available() else "cpu", devices=1)
model = MultiTask(transformer_model=model,
                  # spam_weights=spam_class_weights,
                  # sentiment_weights=sentiment_class_weights
                  )

# Train the model
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

print("Training finished. Evaluating on testing set...")
# Put the model into evaluation mode
model.eval()

# And run through the test dataset and print evaluation metrics
all_spam_preds = []
all_spam_labels = []
all_sentiment_preds = []
all_sentiment_labels = []

# Disable gradient updates for evaluation
with torch.no_grad():
    for b in test_loader:
        batch_inputs = b["inputs"]

        spam_labels = b["spam_label"]
        sentiment_labels = b["sentiment_label"]

        batch_outputs = model(batch_inputs)

        batch_spam_logits = batch_outputs["spam_logits"]
        batch_sentiment_logits = batch_outputs["sentiment_logits"]

        # Get predicted class indices
        # for spam_preds, If logit is >0.5, predict it as spam
        spam_preds = (torch.sigmoid(batch_spam_logits) > 0.5).cpu().numpy()
        # For sentiment_preds, predict the class wit the largest value
        sentiment_preds = torch.argmax(batch_sentiment_logits, dim=1).cpu().numpy()

        all_spam_preds.extend(spam_preds)
        all_spam_labels.extend(spam_labels.cpu().numpy())
        all_sentiment_preds.extend(sentiment_preds)
        all_sentiment_labels.extend(sentiment_labels.cpu().numpy())

[nltk_data] Downloading package vader_lexicon to
[nltk_data]     /home/karl/nltk_data...
[nltk_data]   Package vader_lexicon is already up-to-date!
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/karl/Projects/mtl_exercise/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name              | Type      | Params | Mode 
--------------------------------------------------------
0 | transformer_model | BertModel | 109 M  | eval 
1 | spam_head         | Linear    | 769    | train
2 | sentiment_head    | Linear    | 2.3 K  | train
--------------------------------------------------------
3.1 K     Trainable params
109 M     Non-trainable params
109 M     Total params
437.941   To

Sanity Checking: |                                                                                            …

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=5` reached.


Training finished. Evaluating on testing set...


#### For metrics, precision, recall, and f-1 scores are generally a pretty good starting point, especially for tasks that have very few possible outputs.

In [4]:
# Compute metrics
print("Spam Classification Report:")
print(metrics.classification_report(all_spam_labels, all_spam_preds))
print("Accuracy:", metrics.accuracy_score(all_spam_labels, all_spam_preds))

print("Sentiment Classification Report:")
print(metrics.classification_report(all_sentiment_labels, all_sentiment_preds))
print("Accuracy:", metrics.accuracy_score(all_sentiment_labels, all_sentiment_preds))

Spam Classification Report:
              precision    recall  f1-score   support

         0.0       0.89      1.00      0.94       726
         1.0       1.00      0.15      0.27       110

    accuracy                           0.89       836
   macro avg       0.94      0.58      0.60       836
weighted avg       0.90      0.89      0.85       836

Accuracy: 0.888755980861244
Sentiment Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.13      0.22       153
           1       0.58      0.46      0.51       286
           2       0.57      0.85      0.68       397

    accuracy                           0.58       836
   macro avg       0.65      0.48      0.47       836
weighted avg       0.62      0.58      0.54       836

Accuracy: 0.5825358851674641


#### If I had more time, I would look deeper into handling the class imbalance and get more creative. Some thoughts are looking at alternative methods for labeling the sentiments. Another is trying different class weight calcuations (maybe doing a manual weighting). Or, writing a custom loss function.