### 1. Data Loading and Preprocessing

This cell handles the initial setup, including mounting Google Drive, loading the dataset, and performing essential preprocessing steps.

- **Drive Mount:** Mounts the Google Drive to access the dataset file.
- **Data Loading:** Loads the movie data from a CSV file into a pandas DataFrame.
- **Association Rule Mining:**
    - The `Output` column, containing comma-separated genres, is split into a list of genres for each movie.
    - `TransactionEncoder` converts this list into a one-hot encoded format suitable for association rule mining.
    - `fpgrowth` is used to find frequent itemsets of genres.
    - `association_rules` generates rules based on these itemsets, which are then filtered for high confidence and support.
- **Multi-Label Classification Preprocessing:**
    - The `description` for each movie is extracted from the `Input` column.
    - The `Output` column is converted into a list of genre labels.
    - `MultiLabelBinarizer` transforms these genre lists into a binary matrix format, which is the standard for multi-label classification tasks.

In [None]:
!pip install mlxtend
!pip install ltntorch

In [None]:
from google.colab import drive
import pandas as pd
from itertools import combinations
from collections import Counter
from mlxtend.preprocessing import TransactionEncoder
from mlxtend.frequent_patterns import fpgrowth, association_rules
from sklearn.preprocessing import MultiLabelBinarizer

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Load the dataset
df = pd.read_csv('/content/drive/My Drive/movie-genre-prediction/train.csv')

rare_threshold = 2        # pairs seen <= this count are considered rare

# --- Step 1: Parse genre labels ---
# Assumes genres are comma-separated strings
df["genre_list"] = df["expanded-genres"].fillna("").apply(lambda x: [genre.strip() for genre in x.split(", ") if genre.strip()])

# All unique genres
all_genres = sorted({g for sublist in df['genre_list'] for g in sublist})
# All possible pairs
all_pairs = list(combinations(all_genres, 2))

# --- Step 2: Count frequency of each genre ---
genre_counter = Counter()
for genres in df["genre_list"]:
    genre_counter.update(genres)

# Display number of samples per genre
print("Samples per genre:")
for genre, count in genre_counter.items():
    print(f"{genre:<15} {count}")

# Count observed pairs
observed_pairs = set()
pair_counts = Counter()
for genres in df['genre_list']:
    for pair in combinations(sorted(genres), 2):
        observed_pairs.add(pair)
        pair_counts[pair] += 1

# Find rarely-seen pairs
rare_pairs = [(pair, count) for pair, count in pair_counts.items() if count <= rare_threshold]

# Sort by frequency (ascending)
rare_pairs.sort(key=lambda x: x[1])

# Never-seen pairs
never_seen_pairs = [pair for pair in all_pairs if pair not in observed_pairs]

# --- Step 3: Identify minority genres ---
# Set threshold for minority genre (e.g., fewer than 200 samples)
MINORITY_THRESHOLD = 200
minority_genres = {genre for genre, count in genre_counter.items() if count < MINORITY_THRESHOLD}

print(f"\nMinority genres (< {MINORITY_THRESHOLD} samples): {sorted(minority_genres)}")

# --- Step 4: Split dataset ---
# Mark rows that contain any minority genre
df["contains_minority"] = df["genre_list"].apply(lambda genres: any(g in minority_genres for g in genres))

# Keep all minority rows
minority_df = df[df["contains_minority"]]

# Sample 10% of the remaining data
non_minority_df = df[~df["contains_minority"]].sample(frac=0.10, random_state=42)

print(f"\nOriginal dataset size: {len(df)}")
# Combine both
df = pd.concat([minority_df, non_minority_df]).reset_index(drop=True)

print(f"Minority rows kept: {len(minority_df)}")
print(f"Non-minority rows sampled: {len(non_minority_df)}")
print(f"Total train set size: {len(df)}")

# Association rule mining
transactions = df['expanded-genres'].str.split(', ').tolist()
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
df_encoded = pd.DataFrame(te_ary, columns=te.columns_)
frequent_itemsets = fpgrowth(df_encoded, min_support=0.01, use_colnames=True)
rules = association_rules(frequent_itemsets, metric="lift", min_threshold=1)
high_confidence_rules = rules[(rules['confidence'] > 0.25) & (rules['support'] > 0.001)]

# Data preprocessing for multi-label classification
#df['description'] = df['Input'].apply(lambda x: x.split('\n\n', 1)[1] if '\n\n' in x else '')
df['Output-Label'] = df['expanded-genres'].str.split(', ')
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df['Output-Label'])

# Display results
display(df.head())
display(high_confidence_rules)
print("Descriptions:")
display(df['description'].head())
print("\nBinary Labels (y):")
display(y[:5])

print("Rarely-seen genre pairs:")
for pair, count in rare_pairs:
    print(f"{pair}: {count}")

print("Never-seen genre pairs:")
for pair in never_seen_pairs:
    print(pair)

### 1.1. Train+Val / Test Split and tokenizer and model loading

This cell defines and trains a baseline multi-label classification model using a pre-trained DistilBERT model.

- **Device Configuration:** Sets the device to "cuda" if a GPU is available, otherwise "cpu".
- **Tokenizer and Model Loading:** Loads the "distilbert-base-uncased" tokenizer and model from the Hugging Face library.

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
import numpy as np

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Tokenization
X_tok = tokenizer(df['description'].tolist(), padding='max_length', truncation=True,
                  max_length=128, return_tensors='pt', return_attention_mask=True)

input_ids, attention_mask = X_tok['input_ids'], X_tok['attention_mask']

# Split train+val/test
X_train_val_ids, X_test_ids, y_train_val, y_test, X_train_val_mask, X_test_mask = train_test_split(
    input_ids, y, attention_mask, test_size=0.2, random_state=42
)

# Further split train into train and val (10% val)
X_train_ids, X_val_ids, y_train, y_val, X_train_mask, X_val_mask = train_test_split(
    X_train_val_ids, y_train_val, X_train_val_mask, test_size=0.125, random_state=42
)

# Calculate and clamp pos_weight
positive_counts = np.sum(y_train, axis=0)
total_counts = y_train.shape[0]
negative_counts = total_counts - positive_counts
epsilon = 1e-5
pos_weights_np = negative_counts / (positive_counts + epsilon)
pos_weights_np = np.clip(pos_weights_np, 0.1, 10.0)
pos_weights = torch.tensor(pos_weights_np, dtype=torch.float32).to(device)

print("initial setup completed...")

### 2. Baseline Model Training

This cell defines and trains a baseline multi-label classification model using a pre-trained DistilBERT model.

- **Model Definition:**
    - A `BaselineMovieClassifier` class is defined, which includes the DistilBERT model and a linear classifier layer.
    - The model takes tokenized input and produces logits for each genre.
- **Training Setup:**
    - The model, loss function (BCEWithLogitsLoss), and optimizer (Adam) are initialized.
- **Data Preparation:**
    - The movie descriptions are tokenized using the DistilBERT tokenizer.
    - The data is split into training and testing sets.
    - A DataLoader is created for the training data to handle batching and shuffling.
- **Training Loop:**
    - The model is trained for 10 epochs.
    - In each epoch, the model processes batches of data, calculates the loss, and updates its weights.

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import torch.optim as optim
import numpy as np

transformer = AutoModel.from_pretrained("distilbert-base-uncased").to(device)

# Classifier model
class BaselineMovieClassifier(nn.Module):
    def __init__(self, transformer_model, num_labels, dropout=0.3):
        super(BaselineMovieClassifier, self).__init__()
        self.transformer = transformer_model # transformer parameters are also updated unless explicitly freezed...!
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(transformer_model.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token
        x = self.dropout(embeddings)
        logits = self.classifier(x)
        return logits

# Prepare data and labels (assumes mlb and df already defined)
num_genres = len(mlb.classes_)
baseline_model = BaselineMovieClassifier(transformer, num_genres).to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)

# Hyperparams
epochs = 10
batch_size = 32
optimizer = optim.Adam(baseline_model.parameters(), lr=3e-5, weight_decay=0.01)
total_steps = (len(X_train_ids) // batch_size + 1) * epochs
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
max_grad_norm = 1.0

# DataLoaders
train_dataset = torch.utils.data.TensorDataset(
    X_train_ids.to(device),
    X_train_mask.to(device),
    torch.tensor(y_train, dtype=torch.float32).to(device)
)
val_dataset = torch.utils.data.TensorDataset(
    X_val_ids.to(device),
    X_val_mask.to(device),
    torch.tensor(y_val, dtype=torch.float32).to(device)
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

def evaluate(model, loader):
    model.eval()
    losses = []
    preds = []
    targets = []
    with torch.no_grad():
        for batch_input_ids, batch_attention_mask, batch_y_true in loader:
            logits = model(batch_input_ids, attention_mask=batch_attention_mask)
            loss = criterion(logits, batch_y_true)
            losses.append(loss.item())

            y_pred = torch.sigmoid(logits).cpu().numpy()
            preds.append(y_pred)
            targets.append(batch_y_true.cpu().numpy())

    avg_loss = np.mean(losses)
    preds = np.vstack(preds)
    targets = np.vstack(targets)
    # Binarize preds with 0.5 threshold for metric
    preds_binary = (preds > 0.5).astype(int)

    f1 = f1_score(targets, preds_binary, average='micro', zero_division=0)
    return avg_loss, f1

# early-stopping
best_val_f1 = 0.0
patience = 4  # Number of epochs to wait before stopping
epochs_without_improvement = 0
best_model_state = None  # To store best model

# Training loop with validation
for epoch in range(epochs):
    baseline_model.train()
    total_loss = 0
    for batch_input_ids, batch_attention_mask, batch_y_true in train_loader:
        optimizer.zero_grad()

        logits = baseline_model(batch_input_ids, attention_mask=batch_attention_mask)
        loss = criterion(logits, batch_y_true)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(baseline_model.parameters(), max_grad_norm)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    train_loss = total_loss / len(train_loader)
    val_loss, val_f1 = evaluate(baseline_model, val_loader)
    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Micro F1: {val_f1:.4f}")

    # --- Early Stopping Logic ---
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        epochs_without_improvement = 0
        best_model_state = baseline_model.state_dict()  # Save best model
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print(f"\nEarly stopping triggered. Best Val F1: {best_val_f1:.4f}")
            break

if best_model_state:
    baseline_model.load_state_dict(best_model_state)


### 3. Baseline Model Evaluation

This cell evaluates the performance of the trained baseline model on the test set.

- **Evaluation Mode:** The model is set to evaluation mode using `baseline_model.eval()`.
- **Prediction:** The model makes predictions on the test data.
- **Classification Report:** A classification report is printed, showing precision, recall, and F1-score for each genre.

In [None]:
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader, TensorDataset

# Create test dataset with attention mask
test_dataset = TensorDataset(
    X_test_ids.to(device),
    X_test_mask.to(device),
    torch.tensor(y_test, dtype=torch.float32).to(device)
)

test_loader = DataLoader(test_dataset, batch_size=32)

# Evaluation
baseline_model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch_input_ids, batch_attention_mask, batch_y_true in test_loader:
        # Pass attention mask to model
        logits = baseline_model(batch_input_ids, attention_mask=batch_attention_mask)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.6).cpu().numpy()

        all_preds.append(preds)
        all_labels.append(batch_y_true.cpu().numpy())

# Concatenate predictions and labels
import numpy as np
y_pred_binary = np.vstack(all_preds)
y_true = np.vstack(all_labels)

# Generate classification report
print(classification_report(y_true, y_pred_binary, target_names=mlb.classes_, zero_division=0))
print("Avg predicted labels per sample:", y_pred_binary.sum(axis=1).mean())


### 4. Baseline Model Prediction on Evaluation Set

This cell uses the trained baseline model to make predictions on a separate evaluation dataset.

- **Load Evaluation Data:** Loads the evaluation dataset from a CSV file.
- **Preprocess Evaluation Data:** The descriptions from the evaluation data are tokenized.
- **Make Predictions:** The model predicts genres for the evaluation data.
- **Store Predictions:** The predicted genres are added as a new column to the evaluation DataFrame.
- **Classification Report:** A classification report is generated to evaluate the model's performance on this new data.

In [None]:
# Load the evaluation data
eval_df = pd.read_csv('/content/drive/My Drive/movie-genre-prediction/test.csv')

# Preprocess the evaluation data
eval_descriptions = eval_df['description'].tolist()
eval_X = tokenizer(
    text=eval_descriptions,
    add_special_tokens=True,
    max_length=128,
    truncation=True,
    padding='max_length',
    return_tensors='pt',
    return_token_type_ids = False,
    return_attention_mask = True,
    verbose = True)

# Move data to device
eval_input_ids = eval_X['input_ids']
eval_attention_mask = eval_X['attention_mask']

# Create dataset and loader
eval_dataset = TensorDataset(eval_input_ids, eval_attention_mask)
eval_loader = DataLoader(eval_dataset, batch_size=32)  # use smaller batch_size if needed

# Predict in batches
baseline_model.eval()
all_preds = []

with torch.no_grad():
    for batch_ids, batch_mask in eval_loader:
        batch_ids = batch_ids.to(device)
        batch_mask = batch_mask.to(device)

        logits = baseline_model(batch_ids, attention_mask=batch_mask)
        probs = torch.sigmoid(logits)
        batch_preds = (probs > 0.5).cpu().numpy()
        all_preds.append(batch_preds)

# Final predictions
import numpy as np
predicted_labels_binary = np.vstack(all_preds)

# Convert binary predictions to genre labels
predicted_labels = mlb.inverse_transform(predicted_labels_binary)

# Attach predictions to dataframe
eval_df['predicted_genres_baseline'] = predicted_labels

# Get true labels from CSV
y_true_eval = mlb.transform(eval_df['expanded-genres'].str.split(', '))

# Print classification report
print("Classification Report for baseline model on the evaluation set:")
print(classification_report(y_true_eval, predicted_labels_binary, target_names=mlb.classes_, zero_division=0))

# Optional: View predictions
display(eval_df.head())

### 5. LTN Model Definition

This cell defines the LTN-enhanced movie classifier.

- **Model Definition:**
    - An `LTNMovieClassifier` class is defined, which, like the baseline, uses a DistilBERT model for embeddings.
    - Instead of a single classifier, it uses a dictionary of `ltn.Predicate` modules, one for each genre. Each predicate is a small neural network that learns a truth value for a movie belonging to a genre.
- **Model Instantiation:** The LTN model is instantiated.

In [None]:
print(dir(ltn.fuzzy_ops))

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import ltn  # ltntorch for fuzzy logic
from ltn.fuzzy_ops import Equiv, AndLuk, ImpliesLuk, AggregPMean

# Label setup
ALL_LABELS = list(mlb.classes_)
NUM_LABELS = len(ALL_LABELS)
LABEL_TO_IDX = {label: i for i, label in enumerate(ALL_LABELS)}

# Fuzzy logic operators
and_op = AndLuk()
imp_op = ImpliesLuk()
equiv_op = Equiv(and_op=and_op, implies_op=imp_op)
aggregator = AggregPMean(p=2)

# Build implication rules
implication_pairs = []
for _, row in high_confidence_rules.iterrows():
    for a in list(row['antecedents']):
        for c in list(row['consequents']):
            if a in LABEL_TO_IDX and c in LABEL_TO_IDX:
                implication_pairs.append((LABEL_TO_IDX[a], LABEL_TO_IDX[c]))
implication_pairs = list(set(implication_pairs))
print(f"Loaded {len(implication_pairs)} implication rules from assoc rules.")

never_seen_rules = []
for (g1, g2) in never_seen_pairs:
  never_seen_rules.append((LABEL_TO_IDX[g1], LABEL_TO_IDX[g2]))
print(f"Loaded {len(never_seen_rules)} never-seen rules from assoc rules.")

# Prepare rare pairs: pairs with co-occurrence frequency below a threshold
rare_rules = []
for ((g1, g2), count) in rare_pairs:
  rare_rules.append((LABEL_TO_IDX[g1], LABEL_TO_IDX[g2]))
print(f"Loaded {len(rare_rules)} rare-seen rules from assoc rules.")

# LTN model definition
class LTNMultiLabelClassifier(nn.Module):
    def __init__(self, transformer_model, num_labels, implication_pairs, never_seen_pairs, rare_pairs, pos_weights=None):
        super().__init__()
        self.transformer = transformer_model # transformer parameters are also updated unless explicitly freezed...!
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(transformer_model.config.hidden_size, num_labels)
        self.implication_pairs = implication_pairs
        self.never_seen_pairs = never_seen_pairs
        self.rare_pairs = rare_pairs
        self.pos_weights = pos_weights
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=self.pos_weights)

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = self.dropout(outputs.last_hidden_state[:, 0, :])
        logits = self.fc(embeddings)
        return logits  # raw logits, no sigmoid

    def compute_loss(self, logits, true_labels):
        pred_probs = torch.sigmoid(logits)
        bce_loss = self.loss_fn(logits, true_labels)

        # Equivalence (ground-truth agreement)
        equiv_values = equiv_op(pred_probs, true_labels)
        sat_gt = aggregator(aggregator(equiv_values))

        # Implication rules (existing)
        axiom_values = [imp_op(pred_probs[:, a], pred_probs[:, c]) for a, c in self.implication_pairs]
        if axiom_values:
            sat_axiom = aggregator(aggregator(torch.stack(axiom_values, dim=1)))
        else:
            sat_axiom = torch.tensor(1.0, device=logits.device)

        # --- Never-seen pairs: mutual exclusion ---
        if hasattr(self, 'never_seen_pairs') and self.never_seen_pairs:
            never_seen_values = []
            for i, j in self.never_seen_pairs:
                # Mutual exclusion: not (genre_i and genre_j)
                co_occur = pred_probs[:, i] * pred_probs[:, j]  # degree of co-occurrence
                never_seen_val = 1 - co_occur  # high when no overlap
                never_seen_values.append(never_seen_val)
            never_seen_tensor = torch.stack(never_seen_values, dim=1)
            sat_never = aggregator(aggregator(never_seen_tensor))
        else:
            sat_never = torch.tensor(1.0, device=logits.device)

        # --- Rare pairs: softly penalize co-occurrence ---
        if hasattr(self, 'rare_pairs') and self.rare_pairs:
            rare_values = []
            for i, j in self.rare_pairs:
                co_occur = pred_probs[:, i] * pred_probs[:, j]
                # We want to minimize co-occurrence, so take mean co-occurrence as truth value
                rare_values.append(co_occur)
            rare_tensor = torch.stack(rare_values, dim=1)
            sat_rare = 1 - aggregator(aggregator(rare_tensor))  # low co-occurrence → high satisfaction
        else:
            sat_rare = torch.tensor(1.0, device=logits.device)

        # Combine all logical satisfaction terms with and_op (fuzzy AND)
        sat_logic = and_op(and_op(and_op(sat_gt, sat_axiom), sat_never), sat_rare)
        logic_loss = 1 - sat_logic

        # You may want to tune these weights
        total_loss = 0.95 * bce_loss + 0.05 * logic_loss
        return total_loss, sat_gt.item(), sat_axiom.item(), sat_never.item(), sat_rare.item()


transformer = AutoModel.from_pretrained("distilbert-base-uncased").to(device)

# DataLoaders
train_dataset = torch.utils.data.TensorDataset(
    X_train_ids.to(device), X_train_mask.to(device), torch.tensor(y_train, dtype=torch.float32).to(device))
val_dataset = torch.utils.data.TensorDataset(
    X_val_ids.to(device), X_val_mask.to(device), torch.tensor(y_val, dtype=torch.float32).to(device))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

model = LTNMultiLabelClassifier(transformer, NUM_LABELS, implication_pairs, never_seen_rules, rare_rules, pos_weights=pos_weights).to(device)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=3e-5, weight_decay=0.01)
total_steps = len(train_loader) * 10
scheduler = get_linear_schedule_with_warmup(optimizer, int(0.1 * total_steps), total_steps)

# Training with early stopping
best_val_f1 = 0.0
patience = 4
patience_counter = 0
best_model_state = None

for epoch in range(10):
    model.train()
    total_loss = 0
    for batch_ids, batch_mask, batch_labels in train_loader:
        optimizer.zero_grad()
        logits = model(batch_ids, batch_mask)
        loss, sat_gt, sat_axiom, sat_never, sat_rare = model.compute_loss(logits, batch_labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)

    # Validation
    model.eval()
    val_preds, val_labels = [], []
    val_loss = 0
    with torch.no_grad():
        for batch_ids, batch_mask, batch_labels in val_loader:
            logits = model(batch_ids, batch_mask)
            loss, _, _, _, _ = model.compute_loss(logits, batch_labels)
            val_loss += loss.item()
            probs = torch.sigmoid(logits)  # apply sigmoid at eval time
            val_preds.append(probs.cpu().numpy())
            val_labels.append(batch_labels.cpu().numpy())

    val_loss /= len(val_loader)
    y_pred = np.vstack(val_preds)
    y_true = np.vstack(val_labels)
    y_pred_binary = (y_pred > 0.5).astype(int)
    val_f1_micro = f1_score(y_true, y_pred_binary, average='micro', zero_division=0)

    print(f"Epoch {epoch+1}/10 - Train Loss: {avg_train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Micro F1: {val_f1_micro:.4f} \
          | GT Sat: {sat_gt:.4f} | Axiom Sat: {sat_axiom:.4f} | Never Sat: {sat_never:.4f} | Rare Sat: {sat_rare:.4f}")

    if val_f1_micro > best_val_f1:
        best_val_f1 = val_f1_micro
        best_model_state = model.state_dict()
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

if best_model_state:
    model.load_state_dict(best_model_state)


### 6. LTN Model Evaluation

This cell evaluates the performance of the trained LTN model on the test set.

- **Evaluation Mode:** The model is set to evaluation mode using `model.eval()`.
- **Prediction:** The model makes predictions on the test data.
- **Classification Report:** A classification report is printed, showing precision, recall, and F1-score for each genre.

In [None]:
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader, TensorDataset
from ltn.fuzzy_ops import ImpliesLuk, AggregPMean
import torch
import numpy as np

# Prepare axiom operators
imp_op = ImpliesLuk()
aggregator = AggregPMean(p=2)

# Test DataLoader with attention_mask
test_dataset = TensorDataset(
    X_test_ids.to(device),
    X_test_mask.to(device),
    torch.tensor(y_test, dtype=torch.float32).to(device)
)
test_loader = DataLoader(test_dataset, batch_size=32)

# Switch to eval mode
model.eval()
all_preds = []
all_labels = []
all_axioms = []

with torch.no_grad():
    for batch_input_ids, batch_attention_mask, batch_y_true in test_loader:
        # Forward pass
        logits = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        probs = torch.sigmoid(logits)

        # Binary predictions
        preds = (probs > 0.5).cpu().numpy()
        all_preds.append(preds)
        all_labels.append(batch_y_true.cpu().numpy())

        # Axiom satisfaction
        if hasattr(model, "implication_pairs"):
            axiom_vals = []
            for a_idx, c_idx in model.implication_pairs:
                premise = probs[:, a_idx]
                conclusion = probs[:, c_idx]
                val = imp_op(premise, conclusion)
                axiom_vals.append(val)
            if axiom_vals:
                stacked_axioms = torch.stack(axiom_vals, dim=1)
                sat_per_example = aggregator(stacked_axioms)
                all_axioms.append(sat_per_example.cpu().numpy())

# Concatenate results
y_pred_binary = np.vstack(all_preds)
y_true = np.vstack(all_labels)

# Classification report
print("\nMulti-label classification report:")
print(classification_report(y_true, y_pred_binary, target_names=mlb.classes_, zero_division=0))

# Axiom satisfaction report
if all_axioms:
    axiom_scores = np.stack(all_axioms)
    print(f"\nAverage axiom satisfaction on test set: {axiom_scores.mean():.4f}")
    print(f"Min: {axiom_scores.min():.4f}, Max: {axiom_scores.max():.4f}")
else:
    print("\nNo implication rules found in model for axiom satisfaction.")

print("\nAvg predicted labels per sample:", y_pred_binary.sum(axis=1).mean())


### 7. Model Performance Comparison

LTN encouraged the model to satisfy logical constraints, which in multi-label classification often boosts recall at the cost of precision.

For example: If the rules say “Sci-Fi often co-occurs with Thriller”, the model will predict Thriller more often, even when unsure — hence more recall, less precision.

Next step is threshold calibration — because with Logic tensor networks, the “predict more” approach is overshooting.