In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
from torchmetrics.classification import MultilabelPrecision, MultilabelRecall, MultilabelF1Score, MultilabelExactMatch, MultilabelHammingDistance
from typing import Dict, Any, Tuple, Optional
import math

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
movies_df = pd.read_csv('top_movies.csv')
movies_df.isna().sum()

movie_name     0
genre          3
description    0
dtype: int64

In [3]:
movies_df = movies_df.dropna()
movies_df.isna().sum()

movie_name     0
genre          0
description    0
dtype: int64

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

empty_sequences = []

for idx, description in enumerate(movies_df['description'].tolist()):
    encoding = tokenizer(
        description,
        truncation=True,
        padding='max_length',
        max_length=128,
        return_tensors='pt'
    )

    attention_mask = encoding['attention_mask'].squeeze(0)
    if attention_mask.sum().item() == 0:
        print(f"⚠️ Empty sequence found at index {idx}: {description}")
        empty_sequences.append(idx)

print(f"\n✅ Total empty sequences found: {len(empty_sequences)}")


✅ Total empty sequences found: 0


In [5]:
all_genres = set()
for genres in movies_df['genre']:
    for genre in genres.split(','):
        all_genres.add(genre.strip())
        
genre_to_index = {genre: idx for idx, genre in enumerate(sorted(all_genres))}

print("Genre to index mapping:", genre_to_index)

Genre to index mapping: {'Action': 0, 'Adventure': 1, 'Animation': 2, 'Comedy': 3, 'Crime': 4, 'Drama': 5, 'Family': 6, 'Fantasy': 7, 'History': 8, 'Horror': 9, 'Music': 10, 'Mystery': 11, 'Romance': 12, 'Science Fiction': 13, 'TV Movie': 14, 'Thriller': 15, 'War': 16, 'Western': 17}


In [6]:
class MovieDescriptionDataset(Dataset):
    def __init__(self, dataframe, tokenizer, genre_to_index: Dict[str, int], max_length: int = 128):
        self.descriptions = dataframe['description'].tolist()
        self.genres = dataframe['genre'].tolist()
        self.tokenizer = tokenizer
        self.genre_to_index = genre_to_index
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.descriptions)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        description = self.descriptions[idx]
        genre_string = self.genres[idx]

        encoding = self.tokenizer(
            description,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        label = torch.zeros(len(self.genre_to_index))
        for genre in genre_string.split(','):
            genre = genre.strip()
            if genre in self.genre_to_index:
                label[self.genre_to_index[genre]] = 1

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': label
        }

In [7]:
train_df, val_df = train_test_split(movies_df, test_size=0.2, random_state=42)

In [8]:
train_dataset = MovieDescriptionDataset(train_df, tokenizer, genre_to_index, max_length=128)
val_dataset = MovieDescriptionDataset(val_df, tokenizer, genre_to_index, max_length=128)

print(len(train_dataset))

7533


In [9]:
sample = train_dataset[0]

print("Input IDs:", sample['input_ids'])
print("Attention Mask:", sample['attention_mask'])
print("Labels:", sample['labels'])

Input IDs: tensor([  101,  1999,  1996,  3865,  1010,  1037,  7101,  2003,  4704,  2011,
         2010,  2316,  2074,  2077,  2027,  2468,  2600, 18795,  2015,  1012,
         3174,  2086,  2101,  1010,  1996,  7101,  5927,  2010,  2117,  3382,
         2012,  2732,  9527, 13368,  2043,  2002,  2003,  2356,  2000,  4685,
         2007,  2010,  9454,  7833,  1005,  1055,  2152,  2082,  2600,  2316,
         1012,   102,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,    

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [11]:
class SelfAttentionPooling(nn.Module):
    """
    Self-Attention Pooling Layer with mask support.

    This layer computes a learned attention weight for each token in the sequence
    and returns a weighted sum of the sequence embeddings. Padded tokens are masked out
    so they do not influence the pooled representation.

    Args:
        d_model (int): Dimensionality of the input embeddings.
    """
    def __init__(self, d_model: int):
        super().__init__()
        self.lin1 = nn.Linear(d_model, d_model)
        self.lin2 = nn.Linear(d_model, 1)

    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for self-attention pooling.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
            attn_mask (torch.Tensor): Boolean or binary tensor of shape (batch_size, seq_len),
                where True/1 indicates valid tokens and False/0 indicates padding.

        Returns:
            torch.Tensor: Pooled representation of shape (batch_size, d_model).
        """
        m = attn_mask.bool()
        scores = self.lin2(torch.tanh(self.lin1(x))).squeeze(-1)  # (B, T)
        scores = scores.masked_fill(~m, float('-inf'))            # mask padding tokens
        weights = torch.softmax(scores, dim=1).unsqueeze(-1)      # (B, T, 1)
        return (weights * x).sum(dim=1)                           # (B, D)


In [12]:
class HighwayLayer(nn.Module):
    def __init__(self, size: int) -> None:
        super().__init__()
        self.H: nn.Linear = nn.Linear(size, size)
        self.T: nn.Linear = nn.Linear(size, size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        H_out = F.relu(self.H(x))
        T_out = torch.sigmoid(self.T(x))
        return H_out * T_out + x * (1 - T_out)

In [13]:
class TextClassifier(nn.Module):
    """
    Transformer-based multi-label text classifier with:
      - Transformer encoder backbone
      - Masked self-attention pooling
      - MLP trunk with residual connection and LayerNorm
      - Label embedding classifier with normalized logits,
        learnable temperature scaling, and bias
      - Optional Monte Carlo dropout for uncertainty estimation and ensembling

    Args:
        vocab_size (int): Size of the vocabulary.
        emb_dim (int): Dimensionality of token embeddings.
        hidden_dim (int): Hidden layer size in the classifier trunk.
        num_classes (int): Number of target classes.
        num_heads (int, optional): Number of attention heads in Transformer layers. Default is 8.
        max_seq_length (int, optional): Maximum input sequence length. Default is 512.
        num_attention_layers (int, optional): Number of Transformer encoder layers. Default is 4.
        feedforward_dim (int, optional): Feedforward layer size inside Transformer. Default is 256.
        num_dropout_samples (int, optional): Number of dropout samples for MC-dropout. Default is 8.
        mc_dropout_enabled (bool, optional): Whether to enable MC-dropout averaging. Default is True.
    """
    def __init__(
        self,
        vocab_size: int,
        emb_dim: int,
        hidden_dim: int,
        num_classes: int,
        num_heads: int = 8,
        max_seq_length: int = 512,
        num_attention_layers: int = 4,
        feedforward_dim: int = 256,
        num_dropout_samples: int = 8,
        mc_dropout_enabled: bool = True,
    ) -> None:
        super().__init__()

        self.num_dropout_samples = num_dropout_samples
        self.mc_dropout_enabled = mc_dropout_enabled

        # Embedding layers
        self.emb = nn.Embedding(vocab_size, emb_dim)
        nn.init.normal_(self.emb.weight, std=0.02)

        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, emb_dim))
        nn.init.normal_(self.positional_encoding, std=0.02)
        self.input_drop = nn.Dropout(0.3)

        # Transformer encoder
        transformer_layer = nn.TransformerEncoderLayer(
            d_model=emb_dim,
            nhead=num_heads,
            dim_feedforward=feedforward_dim,
            dropout=0.3,
            batch_first=True,
            norm_first=True,
        )
        self.transformer_encoder = nn.TransformerEncoder(
            transformer_layer,
            num_layers=num_attention_layers
        )

        # Attention pooling
        self.attention_pooling = SelfAttentionPooling(emb_dim)

        # Classifier trunk
        self.fc1 = nn.Linear(emb_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.res_proj = nn.Identity() if (hidden_dim // 2) == emb_dim else nn.Linear(emb_dim, hidden_dim // 2)
        self.layernorm_final = nn.LayerNorm(hidden_dim // 2)
        self.text_to_label_space = nn.Linear(hidden_dim // 2, emb_dim)

        # Label embeddings & logit head
        self.label_embeddings = nn.Embedding(num_classes, emb_dim)
        nn.init.normal_(self.label_embeddings.weight, std=0.02)
        self.logit_bias = nn.Parameter(torch.zeros(num_classes))
        self.logit_scale = nn.Parameter(torch.tensor(10.0))  # learnable temperature

        # Dropout for MC-dropout averaging
        self.dropout = nn.Dropout(0.5)

    @torch.no_grad()
    def enable_mc_dropout(self, enabled: bool = True) -> None:
        """
        Enable or disable Monte Carlo dropout during evaluation.

        Args:
            enabled (bool): If True, keeps dropout layers active at eval-time
                for stochastic forward passes. If False, disables dropout at eval-time.
        """
        self.mc_dropout_enabled = enabled
        if enabled:
            self.dropout.train()
        else:
            self.dropout.eval()

    def _encode(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """
        Encode input token IDs using the Transformer encoder and attention pooling.

        Args:
            x (torch.Tensor): Input tensor of token IDs, shape (batch_size, seq_len).
            attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len),
                where True/1 indicates valid tokens and False/0 indicates padding.

        Returns:
            torch.Tensor: Pooled sequence representation of shape (batch_size, emb_dim).
        """
        mask_bool = attention_mask.bool()
        x = self.emb(x) * math.sqrt(self.emb.embedding_dim)
        x = x + self.positional_encoding[:, :x.size(1), :]
        x = self.input_drop(x)
        x = self.transformer_encoder(x, src_key_padding_mask=~mask_bool)
        x = self.attention_pooling(x, mask_bool)
        return x

    def _trunk(self, pooled: torch.Tensor) -> torch.Tensor:
        """
        Pass pooled representation through the MLP trunk with residual connection.

        Args:
            pooled (torch.Tensor): Pooled sequence representation, shape (batch_size, emb_dim).

        Returns:
            torch.Tensor: Projected representation in label embedding space, shape (batch_size, emb_dim).
        """
        residual = pooled
        x = F.relu(self.fc1(pooled))
        x = F.relu(self.fc2(x))
        residual = self.res_proj(residual)
        x = self.layernorm_final(x + residual)
        x_proj = self.text_to_label_space(x)
        return x_proj

    def _logits_from_proj(self, x_proj: torch.Tensor) -> torch.Tensor:
        """
        Compute class logits from projected representation using normalized dot product,
        learnable temperature scaling, and bias.

        Args:
            x_proj (torch.Tensor): Projected representation, shape (batch_size, emb_dim).

        Returns:
            torch.Tensor: Class logits, shape (batch_size, num_classes).
        """
        label_w = F.normalize(self.label_embeddings.weight, dim=-1)
        text_z  = F.normalize(x_proj, dim=-1)
        logits  = self.logit_scale * (text_z @ label_w.T) + self.logit_bias
        return logits

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor,
        debugging: bool = False
    ) -> torch.Tensor:
        """
        Forward pass through the classifier.

        Args:
            x (torch.Tensor): Input token IDs, shape (batch_size, seq_len).
            attention_mask (torch.Tensor): Attention mask, shape (batch_size, seq_len).
            debugging (bool, optional): If True, prints intermediate shapes for debugging.

        Returns:
            torch.Tensor: Class logits, shape (batch_size, num_classes).
        """
        if debugging:
            print(f"input_ids: {x.shape}, attention_mask: {attention_mask.shape}")

        pooled = self._encode(x, attention_mask)
        if debugging:
            print(f"after encoder+pool: {pooled.shape}")

        x_proj = self._trunk(pooled)
        if debugging:
            print(f"after trunk projection: {x_proj.shape}")

        use_mc = self.mc_dropout_enabled and (self.training or self.dropout.training)
        if use_mc and self.num_dropout_samples > 1:
            logits_list = []
            for _ in range(self.num_dropout_samples):
                dropped = self.dropout(x_proj)
                logits_list.append(self._logits_from_proj(dropped))
            logits = torch.stack(logits_list, dim=0).mean(dim=0)
        else:
            logits = self._logits_from_proj(x_proj)

        if debugging:
            print(f"logits: {logits.shape}")
        return logits

In [14]:
class FocalLoss(nn.Module):
    def __init__(self, gamma: float = 2.0, pos_weight: Optional[torch.Tensor] = None, reduction: str = 'mean') -> None:
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.pos_weight = pos_weight
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        bce_loss = F.binary_cross_entropy_with_logits(logits, targets, pos_weight=self.pos_weight, reduction='none')
        
        probs = torch.sigmoid(logits)
        probs = torch.clamp(probs, min=1e-6, max=1 - 1e-6)

        focal_weight = torch.where(targets == 1, 1 - probs, probs) ** self.gamma
        loss = focal_weight * bce_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [40]:
vocab_size = 30522
emb_dim = 64
hidden_dim = 128
num_heads = 4

num_classes = train_dataloader.dataset[0]['labels'].shape[-1]
max_seq_length = 512
num_attention_layers = 2
feedforward_dim = 128
num_dropout_samples = 5

learning_rate = 1e-4
weight_decay = 1e-5
batch_size = 16

In [41]:
model = TextClassifier(
    vocab_size=vocab_size,
    emb_dim=emb_dim,
    hidden_dim=hidden_dim,
    num_classes=num_classes,
    num_heads=num_heads,
    max_seq_length=max_seq_length,
    num_attention_layers=num_attention_layers,
    feedforward_dim=feedforward_dim,
    num_dropout_samples=num_dropout_samples,
    mc_dropout_enabled=True
).to(device)



In [48]:
criterion = torch.nn.BCEWithLogitsLoss() # FocalLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [49]:
hamming_metric = MultilabelHammingDistance(num_labels=num_classes)
precision_metric = MultilabelPrecision(num_labels=num_classes, average='micro')
recall_metric = MultilabelRecall(num_labels=num_classes, average='micro')
f1_metric = MultilabelF1Score(num_labels=num_classes, average='micro')
exact_match_metric = MultilabelExactMatch(num_labels=num_classes)

In [50]:
def tune_per_label_thresholds(probs: np.ndarray, labels: np.ndarray, thresholds=np.linspace(0.0, 1.0, 101)):
    """
    Tune per-label thresholds for multi-label classification.

    This function iterates over a range of possible thresholds (default 0.0 to 1.0 in 0.01 steps)
    for each class independently and chooses the threshold that maximizes the F1 score on the
    validation dataset.

    Args:
        probs (np.ndarray): Predicted probabilities of shape (num_samples, num_classes).
        labels (np.ndarray): Ground truth binary labels of shape (num_samples, num_classes).
        thresholds (np.ndarray, optional): List of thresholds to evaluate for each label.

    Returns:
        np.ndarray: An array of best threshold values per label (shape: num_classes).
    """
    num_classes = labels.shape[1]
    best_thresholds = np.zeros(num_classes)
    for i in range(num_classes):
        best_f1 = 0.0
        best_thresh = 0.5
        for thresh in thresholds:
            preds = (probs[:, i] >= thresh).astype(int)
            f1 = f1_score(labels[:, i], preds, zero_division=0)
            if f1 > best_f1:
                best_f1 = f1
                best_thresh = thresh
        best_thresholds[i] = best_thresh
    return best_thresholds

In [51]:
def apply_label_smoothing(targets: torch.Tensor, smoothing: float = 0.1):
    """
    Applies label smoothing to multi-label binary targets.

    Args:
        targets (torch.Tensor): Target tensor of shape (batch_size, num_classes), with 0/1 labels.
        smoothing (float): Smoothing factor (e.g., 0.1).

    Returns:
        torch.Tensor: Smoothed targets.
    """
    with torch.no_grad():
        smoothed = targets * (1.0 - smoothing) + (1.0 - targets) * smoothing
    return smoothed

In [52]:
def train(
    num_epochs: int,
    model,
    train_dataloader,
    val_dataloader,
    criterion,
    optimizer,
    device,
    num_classes: int,
    patience: int = 2,
    save_path: str = "best_model.pt"
):
    """
    Train a multi-label classifier with per-label threshold tuning, early stopping, and detailed metrics.

    Args:
        num_epochs (int): Number of training epochs.
        model (torch.nn.Module): The PyTorch model to train.
        train_dataloader (DataLoader): DataLoader for the training data.
        val_dataloader (DataLoader): DataLoader for the validation data.
        criterion (Loss): Loss function (e.g., FocalLoss or BCEWithLogitsLoss).
        optimizer (Optimizer): Optimizer for updating model parameters.
        device (torch.device): The device to run the training on.
        num_classes (int): Number of output classes.
        patience (int): Number of epochs with no improvement after which training will be stopped.
        save_path (str): File path to save the best model.

    Returns:
        np.ndarray: Best per-class thresholds computed on the validation set.
    """

    # Metrics
    hamming_metric = MultilabelHammingDistance(num_labels=num_classes)
    precision_metric = MultilabelPrecision(num_labels=num_classes, average='micro')
    recall_metric = MultilabelRecall(num_labels=num_classes, average='micro')
    f1_metric = MultilabelF1Score(num_labels=num_classes, average='micro')
    exact_match_metric = MultilabelExactMatch(num_labels=num_classes)

    best_thresholds = None
    best_val_f1 = 0.0
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        print(f"\nEpoch [{epoch + 1}/{num_epochs}]")

        # Training phase
        model.train()
        total_loss = 0
        all_train_labels = []
        all_train_preds = []

        loop = tqdm(train_dataloader, leave=True)
        for batch in loop:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device).float()

            outputs = model(input_ids, attention_mask)
            smoothed_labels = apply_label_smoothing(labels, smoothing=0.1)
            loss = criterion(outputs, smoothed_labels)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            loop.set_description(f"Epoch [{epoch+1}/{num_epochs}]")
            loop.set_postfix(loss=loss.item())

            probs = torch.sigmoid(outputs).detach().cpu()
            preds = (probs >= 0.4).int()

            all_train_labels.append(labels.cpu().int())
            all_train_preds.append(preds)

        all_train_preds_tensor = torch.cat(all_train_preds)
        all_train_labels_tensor = torch.cat(all_train_labels)

        train_hamming_acc = 1.0 - hamming_metric(all_train_preds_tensor, all_train_labels_tensor)
        avg_train_loss = total_loss / len(train_dataloader)
        train_precision = precision_metric(all_train_preds_tensor, all_train_labels_tensor)
        train_recall = recall_metric(all_train_preds_tensor, all_train_labels_tensor)
        train_f1 = f1_metric(all_train_preds_tensor, all_train_labels_tensor)
        train_subset_acc = exact_match_metric(all_train_preds_tensor, all_train_labels_tensor)

        # Validation phase
        model.eval()
        val_loss = 0
        all_val_labels = []
        all_val_probs = []

        with torch.no_grad():
            for batch in val_dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device).float()

                outputs = model(input_ids, attention_mask)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                probs = torch.sigmoid(outputs).cpu()
                all_val_probs.append(probs)
                all_val_labels.append(labels.cpu().int())

        val_probs_tensor = torch.cat(all_val_probs)
        val_labels_tensor = torch.cat(all_val_labels)

        val_probs_np = val_probs_tensor.numpy()
        val_labels_np = val_labels_tensor.numpy()

        # Threshold tuning
        best_thresholds = tune_per_label_thresholds(val_probs_np, val_labels_np)
        val_preds_np = (val_probs_np >= best_thresholds).astype(int)

        val_hamming_acc = 1.0 - hamming_metric(torch.tensor(val_preds_np), val_labels_tensor)
        avg_val_loss = val_loss / len(val_dataloader)
        val_precision = precision_score(val_labels_np, val_preds_np, average='micro', zero_division=0)
        val_recall = recall_score(val_labels_np, val_preds_np, average='micro', zero_division=0)
        val_f1 = f1_score(val_labels_np, val_preds_np, average='micro', zero_division=0)
        val_subset_acc = (val_preds_np == val_labels_np).all(axis=1).mean()

        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"Train Precision: {train_precision:.4f} | Val Precision: {val_precision:.4f}")
        print(f"Train Recall:    {train_recall:.4f}    | Val Recall:    {val_recall:.4f}")
        print(f"Train F1 Score:  {train_f1:.4f}        | Val F1 Score:  {val_f1:.4f}")
        print(f"Train Subset Acc: {train_subset_acc:.4f} | Val Subset Acc: {val_subset_acc:.4f}")
        print(f"Train Hamming Acc: {train_hamming_acc:.4f} | Val Hamming Acc: {val_hamming_acc:.4f}")

        # Early stopping check
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            epochs_no_improve = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_thresholds': best_thresholds,
            }, save_path)
            print("📦 Model saved.")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("🛑 Early stopping triggered.")
                break

    print(f"Best Validation F1: {best_val_f1:.4f}")
    return best_thresholds


In [53]:
best_thresholds = train(
    num_epochs=10,
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_classes=num_classes,
    patience=3,  # You can adjust this
    save_path="best_model.pt"  # Optional: customize the path
)



Epoch [1/10]


Epoch [1/10]: 100%|███████████████| 942/942 [01:02<00:00, 15.16it/s, loss=0.492]



Epoch 1 Summary:
Train Loss: 0.5209 | Val Loss: 0.3446
Train Precision: 0.4198 | Val Precision: 0.2390
Train Recall:    0.2170    | Val Recall:    0.6980
Train F1 Score:  0.2861        | Val F1 Score:  0.3561
Train Subset Acc: 0.0366 | Val Subset Acc: 0.0016
Train Hamming Acc: 0.8404 | Val Hamming Acc: 0.6281
📦 Model saved.

Epoch [2/10]


Epoch [2/10]: 100%|███████████████| 942/942 [01:05<00:00, 14.42it/s, loss=0.413]



Epoch 2 Summary:
Train Loss: 0.4703 | Val Loss: 0.2948
Train Precision: 0.5543 | Val Precision: 0.3843
Train Recall:    0.4821    | Val Recall:    0.6656
Train F1 Score:  0.5157        | Val F1 Score:  0.4873
Train Subset Acc: 0.0767 | Val Subset Acc: 0.0212
Train Hamming Acc: 0.8666 | Val Hamming Acc: 0.7936
📦 Model saved.

Epoch [3/10]


Epoch [3/10]: 100%|███████████████| 942/942 [01:02<00:00, 14.96it/s, loss=0.456]



Epoch 3 Summary:
Train Loss: 0.4391 | Val Loss: 0.2906
Train Precision: 0.6520 | Val Precision: 0.4054
Train Recall:    0.6580    | Val Recall:    0.6502
Train F1 Score:  0.6550        | Val F1 Score:  0.4994
Train Subset Acc: 0.1507 | Val Subset Acc: 0.0313
Train Hamming Acc: 0.8978 | Val Hamming Acc: 0.8079
📦 Model saved.

Epoch [4/10]


Epoch [4/10]: 100%|███████████████| 942/942 [01:05<00:00, 14.48it/s, loss=0.439]



Epoch 4 Summary:
Train Loss: 0.4198 | Val Loss: 0.2958
Train Precision: 0.7147 | Val Precision: 0.5193
Train Recall:    0.7392    | Val Recall:    0.6562
Train F1 Score:  0.7268        | Val F1 Score:  0.5798
Train Subset Acc: 0.2343 | Val Subset Acc: 0.0897
Train Hamming Acc: 0.9181 | Val Hamming Acc: 0.8598
📦 Model saved.

Epoch [5/10]


Epoch [5/10]: 100%|███████████████| 942/942 [01:07<00:00, 13.99it/s, loss=0.397]



Epoch 5 Summary:
Train Loss: 0.4061 | Val Loss: 0.3125
Train Precision: 0.7621 | Val Precision: 0.4359
Train Recall:    0.7909    | Val Recall:    0.6800
Train F1 Score:  0.7762        | Val F1 Score:  0.5313
Train Subset Acc: 0.3094 | Val Subset Acc: 0.0472
Train Hamming Acc: 0.9328 | Val Hamming Acc: 0.8232

Epoch [6/10]


Epoch [6/10]: 100%|███████████████| 942/942 [01:06<00:00, 14.15it/s, loss=0.427]



Epoch 6 Summary:
Train Loss: 0.3950 | Val Loss: 0.3256
Train Precision: 0.8035 | Val Precision: 0.5017
Train Recall:    0.8306    | Val Recall:    0.6612
Train F1 Score:  0.8169        | Val F1 Score:  0.5705
Train Subset Acc: 0.3956 | Val Subset Acc: 0.0775
Train Hamming Acc: 0.9451 | Val Hamming Acc: 0.8533

Epoch [7/10]


Epoch [7/10]: 100%|███████████████| 942/942 [01:10<00:00, 13.42it/s, loss=0.414]



Epoch 7 Summary:
Train Loss: 0.3850 | Val Loss: 0.3324
Train Precision: 0.8406 | Val Precision: 0.5187
Train Recall:    0.8630    | Val Recall:    0.6550
Train F1 Score:  0.8517        | Val F1 Score:  0.5789
Train Subset Acc: 0.4735 | Val Subset Acc: 0.0817
Train Hamming Acc: 0.9557 | Val Hamming Acc: 0.8596
🛑 Early stopping triggered.
Best Validation F1: 0.5798


In [None]:
def predict_genres(description, model, tokenizer, genre_to_index, threshold=0.5, device=None):
    """
    Predict genres for a given text description using a multi-label classifier.

    This function tokenizes the input description, passes it through the model,
    applies a sigmoid activation to get class probabilities, then thresholds them
    to determine predicted genres.

    Args:
        description (str): Input text to classify (e.g., plot summary or description).
        model (torch.nn.Module): Trained PyTorch multi-label classification model.
        tokenizer (PreTrainedTokenizer): Tokenizer compatible with the model (e.g., from HuggingFace).
        genre_to_index (dict): Mapping from genre labels to integer indices.
        threshold (float, optional): Threshold for turning probabilities into binary predictions. Default is 0.5.
        device (str or torch.device, optional): Device to run the model on. Defaults to 'cuda' if available.

    Returns:
        List[str]: A list of predicted genre labels for the input description.
    """
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model.to(device)
    model.eval()
    
    index_to_genre = {v: k for k, v in genre_to_index.items()}

    # Tokenize the input description
    encoding = tokenizer(
        description,
        truncation=True,
        padding='max_length',
        max_length=512,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        probs = torch.sigmoid(outputs).cpu().numpy()

    predicted_labels = (probs >= threshold).astype(int)
    predicted_genres = [index_to_genre[i] for i, label in enumerate(predicted_labels[0]) if label == 1]

    return predicted_genres

In [None]:
model_test = TextClassifier(
    vocab_size=vocab_size,
    emb_dim=emb_dim,
    hidden_dim=hidden_dim,
    num_classes=num_classes,
    num_heads=num_heads,
    max_seq_length=max_seq_length,
    num_attention_layers=num_attention_layers,
    feedforward_dim=feedforward_dim,
    num_dropout_samples=num_dropout_samples
).to(device)

model_test.load_state_dict(torch.load("movie_checkpoint.pth"))
model_test.eval()

In [None]:
description = (
    "A relentless high-speed chase through shadowy, abandoned streets catapults the protagonist "
    "into a nightmarish world where unspeakable horrors await at every turn. Pursued not only by "
    "ruthless mercenaries but also by terrifying supernatural forces, each moment is a desperate "
    "fight for survival. As the line between reality and nightmare blurs, the hero must navigate "
    "crumbling buildings, escape grotesque monsters lurking in the darkness, and confront "
    "blood-soaked secrets that threaten to consume them. The pulse-pounding action is matched "
    "only by the creeping dread that no place is safe and no one can be trusted in this brutal "
    "race against time and terror."
)

In [None]:
predictions = predict_genres(description, model_test, tokenizer, genre_to_index, device='cpu')
print(predictions)