## Imports & global configuration

In [2]:
from typing import Literal, Dict, Tuple
import torch, torch.nn as nn, torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AdamW, get_scheduler
from datasets import load_dataset
from torch.utils.data import DataLoader
from sklearn.metrics import (accuracy_score, precision_score,
                             recall_score, f1_score, classification_report)
import pandas as pd
from functools import partial

# ─── Hyper-parameters ─────────────────────────────────────────────
BACKBONE   = "sentence-transformers/all-MiniLM-L6-v2"
PROJ_DIM   = 256
LR         = 2e-5
BATCH_SIZE = 16
EPOCHS     = 3
DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(BACKBONE)


## Filtering the Hypothetical Data

Filters out empty strings, whitespace-only lines, non-string types, or rows with missing/invalid labels.  

In [3]:
def _is_real(example: Dict, text_key: str) -> bool:
    """True - row is usable; False - row is hypothetical/placeholder."""
    txt = example.get(text_key, None)
    lbl = example.get("label", None)
    if not isinstance(txt, str) or txt.strip() == "":
        return False
    return isinstance(lbl, (int, bool))


## Sentence encoder

Wraps the MiniLM backbone and adds a lightweight projection layer.  
Mean-pooling over the attention-mask avoids CLS dependence.

In [4]:
class SentenceEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(BACKBONE)
        hidden = self.backbone.config.hidden_size
        self.proj = nn.Sequential(
            nn.Linear(hidden, PROJ_DIM),
            nn.ReLU(inplace=True)
        )

    def forward(self, input_ids, attention_mask):
        out    = self.backbone(input_ids, attention_mask=attention_mask)
        hidden = out.last_hidden_state
        mask   = attention_mask.unsqueeze(-1)
        emb    = (hidden * mask).sum(1) / mask.sum(1)
        return self.proj(emb)


## Multi-task head

one shared encoder + two task-specific linear heads.

In [5]:
class MultiTaskModel(nn.Module):
    def __init__(self, num_topic: int, num_sent: int):
        super().__init__()
        self.encoder = SentenceEncoder()
        self.heads = nn.ModuleDict({
            "topic"    : nn.Linear(PROJ_DIM, num_topic),
            "sentiment": nn.Linear(PROJ_DIM, num_sent)
        })

    def forward(self, input_ids, attention_mask,
                task: Literal["topic", "sentiment"]):
        emb = self.encoder(input_ids, attention_mask)
        return self.heads[task](emb)

## Tokenizer helper

Creates padded token IDs and attaches labels

In [6]:
def _tokenize(batch, text_key):
    tok = tokenizer(batch[text_key], padding="max_length",
                    truncation=True, max_length=64)
    tok.pop("token_type_ids", None)
    tok["labels"] = batch["label"]
    return tok

## Data loading & loaders

In [7]:
def prepare_loaders() -> Tuple[Dict[str, DataLoader],
                               int, int, list[str], list[str]]:
    ag   = load_dataset("ag_news")
    sst2 = load_dataset("glue", "sst2")

    # 1) Remove hypothetical / malformed examples
    for split in ["train", "test"]:
        ag[split]   = ag[split]  .filter(_is_real, fn_kwargs={"text_key": "text"})
        sst2[split] = sst2[split].filter(_is_real, fn_kwargs={"text_key": "sentence"})

    # 2) split the dataset
    def split(ds):
        tr = ds["train"].train_test_split(0.2, seed=42)
        vt = tr["test"].train_test_split(0.5, seed=42)
        return tr["train"], vt["train"], vt["test"]

    ag_tr,  ag_val,  ag_test  = split(ag)
    sst_tr, sst_val, sst_test = split(sst2)

    # 3) Tokenise
    ag_tr   = ag_tr  .map(_tokenize, batched=True, fn_kwargs={"text_key": "text"})
    ag_val  = ag_val .map(_tokenize, batched=True, fn_kwargs={"text_key": "text"})
    ag_test = ag_test.map(_tokenize, batched=True, fn_kwargs={"text_key": "text"})
    sst_tr  = sst_tr .map(_tokenize, batched=True, fn_kwargs={"text_key": "sentence"})
    sst_val = sst_val.map(_tokenize, batched=True, fn_kwargs={"text_key": "sentence"})
    sst_test= sst_test.map(_tokenize, batched=True, fn_kwargs={"text_key": "sentence"})

    # 4) Torch format & loaders
    for ds in [ag_tr, ag_val, ag_test, sst_tr, sst_val, sst_test]:
        ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

    def L(ds, sh=False): return DataLoader(ds, BATCH_SIZE, shuffle=sh)
    loaders = {
        "train_topic"    : L(ag_tr,  True),
        "val_topic"      : L(ag_val),
        "test_topic"     : L(ag_test),
        "train_sentiment": L(sst_tr, True),
        "val_sentiment"  : L(sst_val),
        "test_sentiment" : L(sst_test),
    }

    n_topic = len(set(ag_tr["labels"]))
    n_sent  = len(set(sst_tr["labels"]))
    topic_names = ["World", "Sports", "Business", "Sci/Tech"]
    sent_names  = ["Negative", "Positive"]
    return loaders, n_topic, n_sent, topic_names, sent_names


## Metrics utilities

In [8]:
class MetricsPrinter:
    def __init__(self): self.rows = []

    def log(self, epoch, task, loss, metrics):
        self.rows.append({"epoch": epoch, "task": task,
                          "loss": loss, **metrics})

    def print_epoch(self, epoch):
        df = pd.DataFrame([r for r in self.rows if r["epoch"] == epoch])
        piv = df.pivot_table(index="task",
                             values=["loss", "acc", "prec", "rec", "f1"],
                             aggfunc="first")
        with pd.option_context("display.float_format", "{:.3f}".format):
            print(f"\n=== Training metrics (epoch {epoch+1}) ===")
            print(piv)

def macro_metrics(y_true, y_pred):
    return {
        "acc" : accuracy_score(y_true, y_pred),
        "prec": precision_score(y_true, y_pred, average="macro", zero_division=0),
        "rec" : recall_score  (y_true, y_pred, average="macro", zero_division=0),
        "f1"  : f1_score      (y_true, y_pred, average="macro", zero_division=0),
    }


## Training loop

Alternates between topic & sentiment batches each epoch and prints epoch-level tables

In [9]:
def train_func(model, loaders):
    model.to(DEVICE)
    optim = AdamW(model.parameters(), lr=LR)
    steps = EPOCHS * sum(len(loaders[k])
                         for k in ["train_topic", "train_sentiment"])
    sched = get_scheduler("linear", optim,
                          num_warmup_steps=int(0.1 * steps),
                          num_training_steps=steps)
    mp = MetricsPrinter()

    for epoch in range(EPOCHS):
        print(f"\nEpoch {epoch+1}/{EPOCHS}")
        model.train()
        for task in ["topic", "sentiment"]:
            loader = loaders[f"train_{task}"]
            tot_loss, preds, labels = 0, [], []
            for batch in loader:
                batch = {k: v.to(DEVICE) for k, v in batch.items()}
                optim.zero_grad()
                logits = model(batch["input_ids"],
                               batch["attention_mask"], task)
                loss = F.cross_entropy(logits, batch["labels"])
                loss.backward(); optim.step(); sched.step()

                tot_loss += loss.item()
                preds  += logits.argmax(-1).cpu().tolist()
                labels += batch["labels"].cpu().tolist()

            mp.log(epoch, task, tot_loss/len(loader),
                    macro_metrics(labels, preds))
        mp.print_epoch(epoch)

## Evaluation

In [10]:
def evaluate(model, loaders, stage: Literal["val", "test"],
             names_topic: list[str], names_sent: list[str]):
    model.eval()
    print(f"\n── Final {stage.upper()} metrics ───────────────────────────")
    for task, names in [("topic", names_topic),
                        ("sentiment", names_sent)]:
        loader = loaders[f"{stage}_{task}"]
        preds, labels = [], []
        with torch.no_grad():
            for batch in loader:
                batch = {k: v.to(DEVICE) for k, v in batch.items()}
                logits = model(batch["input_ids"],
                               batch["attention_mask"], task)
                preds  += logits.argmax(-1).cpu().tolist()
                labels += batch["labels"].cpu().tolist()

        mac = macro_metrics(labels, preds)
        print(f"{task.capitalize()} – macro: "
              f"Acc={mac['acc']:.3f}  "
              f"Prec={mac['prec']:.3f}  "
              f"Rec={mac['rec']:.3f}  "
              f"F1={mac['f1']:.3f}")
        print(classification_report(labels, preds,
                                    target_names=names, digits=3))

## Main Driver

In [11]:
def main_function():
    loaders, n_topic, n_sent, t_names, s_names = prepare_loaders()
    model = MultiTaskModel(n_topic, n_sent)

    train_func(model, loaders)
    evaluate(model, loaders, "val",  t_names, s_names)
    evaluate(model, loaders, "test", t_names, s_names)

    # store label names on the model for later use
    model.topic_label_names     = t_names
    model.sentiment_label_names = s_names
    return model

## Task 1 - Encoding Input Sentences

In [12]:
def sentence_embeddings():
    """Print shapes & first few values of sample sentence embeddings."""
    print("\n── Task 1 demo ──")
    enc = SentenceEncoder().to(DEVICE).eval()
    sents = ["Fetch is the Lifehack App You Didn’t Know You Needed",
             "Fetch Hacks to Earn the Most Points on a Single Receipt"
            ]
    toks = tokenizer(sents, return_tensors="pt", padding=True).to(DEVICE)
    toks.pop("token_type_ids", None)
    with torch.no_grad():
        emb = enc(**toks).cpu()
    print("Embeddings:", emb.shape)
    for each_emb in emb:
        print(each_emb[0: 8])

In [13]:
sentence_embeddings()


── Task 1 demo ──
Embeddings: torch.Size([2, 256])
tensor([0.1484, 0.0408, 0.0000, 0.0000, 0.0810, 0.2515, 0.1087, 0.2924])
tensor([0.1038, 0.0419, 0.0000, 0.0000, 0.0551, 0.0809, 0.1151, 0.0000])


## Fine-Tuned Model

In [15]:
import warnings
warnings.filterwarnings("ignore")
warnings.simplefilter('ignore')

final_model = main_function()

Map: 100%|██████████| 96000/96000 [00:04<00:00, 21215.23 examples/s]



Epoch 1/3

=== Training metrics (epoch 1) ===
            acc    f1  loss  prec   rec
task                                   
sentiment 0.861 0.157 0.518 0.158 0.156
topic     0.855 0.003 1.970 0.003 0.003

Epoch 2/3

=== Training metrics (epoch 2) ===
            acc    f1  loss  prec   rec
task                                   
sentiment 0.936 0.935 0.173 0.935 0.935
topic     0.938 0.938 0.182 0.938 0.938

Epoch 3/3

=== Training metrics (epoch 3) ===
            acc    f1  loss  prec   rec
task                                   
sentiment 0.954 0.953 0.130 0.953 0.953
topic     0.953 0.953 0.137 0.953 0.953

── Final VAL metrics ───────────────────────────
Topic – macro: Acc=0.936  Prec=0.937  Rec=0.936  F1=0.936
              precision    recall  f1-score   support

       World      0.911     0.964     0.937      3021
      Sports      0.992     0.966     0.978      3027
    Business      0.944     0.877     0.910      2987
    Sci/Tech      0.901     0.936     0.918      2965


## Prediction on Custom Sentences

In [16]:
def predict_sentences(model, sentences):
    model.eval()
    toks = tokenizer(sentences, return_tensors="pt",
                     padding=True, truncation=True, max_length=64).to(DEVICE)
    toks.pop("token_type_ids", None)
    with torch.no_grad():
        topic_logits = model(toks["input_ids"], toks["attention_mask"], "topic")
        sent_logits  = model(toks["input_ids"], toks["attention_mask"], "sentiment")

    topic_ids = topic_logits.argmax(-1).cpu().tolist()
    sent_ids  = sent_logits.argmax(-1).cpu().tolist()

    return [
        {
            "sentence" : s,
            "topic"    : model.topic_label_names[t_id],
            "sentiment": model.sentiment_label_names[s_id],
        }
        for s, t_id, s_id in zip(sentences, topic_ids, sent_ids)
    ]

In [17]:
predict_sentences(
    final_model,
    ["Fetch Reports Strong Momentum for 2025",
     "The market crashed and investors are worried."]
)

[{'sentence': 'Fetch Reports Strong Momentum for 2025',
  'topic': 'Sci/Tech',
  'sentiment': 'Positive'},
 {'sentence': 'The market crashed and investors are worried.',
  'topic': 'World',
  'sentiment': 'Negative'}]