#### title

In [1]:
import torch
torch.cuda.empty_cache()
import pandas as pd
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.nn import BCEWithLogitsLoss
from torch.amp import autocast, GradScaler
from tqdm import tqdm
from sklearn.metrics import classification_report, f1_score

In [2]:
def genre_cleanup(df): # data cleanup code credit: Vladislav Kolesov
    df['genres']=df['Genre'] 
    df['genres']=df['genres'].str.strip()
    df['genres']=df['genres'].str.replace(' - ', '|')
    df['genres']=df['genres'].str.replace(' / ', '|')
    df['genres']=df['genres'].str.replace('/', '|')
    df['genres']=df['genres'].str.replace(' & ', '|')
    df['genres']=df['genres'].str.replace(', ', '|')
    df['genres']=df['genres'].str.replace('; ', '|')
    df['genres']=df['genres'].str.replace('bio-pic', 'biography')
    df['genres']=df['genres'].str.replace('biopic', 'biography')
    df['genres']=df['genres'].str.replace('biographical', 'biography')
    df['genres']=df['genres'].str.replace('biodrama', 'biography')
    df['genres']=df['genres'].str.replace('bio-drama', 'biography')
    df['genres']=df['genres'].str.replace('biographic', 'biography')
    df['genres']=df['genres'].str.replace(' \(film genre\)', '')
    df['genres']=df['genres'].str.replace('animated','animation')
    df['genres']=df['genres'].str.replace('anime','animation')
    df['genres']=df['genres'].str.replace('children\'s','children')
    df['genres']=df['genres'].str.replace('comedey','comedy')
    df['genres']=df['genres'].str.replace('\[not in citation given\]','')
    df['genres']=df['genres'].str.replace(' set 4,000 years ago in the canadian arctic','')
    df['genres']=df['genres'].str.replace('historical','history')
    df['genres']=df['genres'].str.replace('romantic','romance')
    df['genres']=df['genres'].str.replace('3-d','animation')
    df['genres']=df['genres'].str.replace('3d','animation')
    df['genres']=df['genres'].str.replace('viacom 18 motion pictures','')
    df['genres']=df['genres'].str.replace('sci-fi','science_fiction')
    df['genres']=df['genres'].str.replace('ttriller','thriller')
    df['genres']=df['genres'].str.replace('.','')
    df['genres']=df['genres'].str.replace('based on radio serial','')
    df['genres']=df['genres'].str.replace(' on the early years of hitler','')
    df['genres']=df['genres'].str.replace('sci fi','science_fiction')
    df['genres']=df['genres'].str.replace('science fiction','science_fiction')
    df['genres']=df['genres'].str.replace(' (30min)','')
    df['genres']=df['genres'].str.replace('16 mm film','short')
    df['genres']=df['genres'].str.replace('\[140\]','drama')
    df['genres']=df['genres'].str.replace('\[144\]','')
    df['genres']=df['genres'].str.replace(' for ','')
    df['genres']=df['genres'].str.replace('adventures','adventure')
    df['genres']=df['genres'].str.replace('kung fu','martial_arts')
    df['genres']=df['genres'].str.replace('kung-fu','martial_arts')
    df['genres']=df['genres'].str.replace('martial arts','martial_arts')
    df['genres']=df['genres'].str.replace('world war ii','war')
    df['genres']=df['genres'].str.replace('world war i','war')
    df['genres']=df['genres'].str.replace('biography about montreal canadiens star|maurice richard','biography')
    df['genres']=df['genres'].str.replace('bholenath df|cinekorn entertainment','')
    df['genres']=df['genres'].str.replace(' \(volleyball\)','')
    df['genres']=df['genres'].str.replace('spy film','spy')
    df['genres']=df['genres'].str.replace('anthology film','anthology')
    df['genres']=df['genres'].str.replace('biography fim','biography')
    df['genres']=df['genres'].str.replace('avant-garde','avant_garde')
    df['genres']=df['genres'].str.replace('biker film','biker')
    df['genres']=df['genres'].str.replace('buddy cop','buddy')
    df['genres']=df['genres'].str.replace('buddy film','buddy')
    df['genres']=df['genres'].str.replace('comedy 2-reeler','comedy')
    df['genres']=df['genres'].str.replace('films','')
    df['genres']=df['genres'].str.replace('film','')
    df['genres']=df['genres'].str.replace('biography of pioneering american photographer eadweard muybridge','biography')
    df['genres']=df['genres'].str.replace('british-german co-production','')
    df['genres']=df['genres'].str.replace('bruceploitation','martial_arts')
    df['genres']=df['genres'].str.replace('comedy-drama adaptation of the mordecai richler novel','comedy-drama')
    df['genres']=df['genres'].str.replace('df by the mob\|knkspl','')
    df['genres']=df['genres'].str.replace('df','')
    df['genres']=df['genres'].str.replace('movie','')
    df['genres']=df['genres'].str.replace('coming of age','coming_of_age')
    df['genres']=df['genres'].str.replace('coming-of-age','coming_of_age')
    df['genres']=df['genres'].str.replace('drama about child soldiers','drama')
    df['genres']=df['genres'].str.replace('(( based).+)','')
    df['genres']=df['genres'].str.replace('(( co-produced).+)','')
    df['genres']=df['genres'].str.replace('(( adapted).+)','')
    df['genres']=df['genres'].str.replace('(( about).+)','')
    df['genres']=df['genres'].str.replace('musical b','musical')
    df['genres']=df['genres'].str.replace('animationchildren','animation|children')
    df['genres']=df['genres'].str.replace(' period','period')
    df['genres']=df['genres'].str.replace('drama loosely','drama')
    df['genres']=df['genres'].str.replace(' \(aquatics|swimming\)','')
    df['genres']=df['genres'].str.replace(' \(aquatics|swimming\)','')
    df['genres']=df['genres'].str.replace("yogesh dattatraya gosavi's directorial debut \[9\]",'')
    df['genres']=df['genres'].str.replace("war-time","war")
    df['genres']=df['genres'].str.replace("wartime","war")
    df['genres']=df['genres'].str.replace("ww1","war")
    df['genres']=df['genres'].str.replace('unknown','')
    df['genres']=df['genres'].str.replace("wwii","war")
    df['genres']=df['genres'].str.replace('psychological','psycho')
    df['genres']=df['genres'].str.replace('rom-coms','romance')
    df['genres']=df['genres'].str.replace('true crime','crime')
    df['genres']=df['genres'].str.replace('\|007','')
    df['genres']=df['genres'].str.replace('slice of life','slice_of_life')
    df['genres']=df['genres'].str.replace('computer animation','animation')
    df['genres']=df['genres'].str.replace('gun fu','martial_arts')
    df['genres']=df['genres'].str.replace('j-horror','horror')
    df['genres']=df['genres'].str.replace(' \(shogi|chess\)','')
    df['genres']=df['genres'].str.replace('afghan war drama','war drama')
    df['genres']=df['genres'].str.replace('\|6 separate stories','')
    df['genres']=df['genres'].str.replace(' \(30min\)','')
    df['genres']=df['genres'].str.replace(' (road bicycle racing)','')
    df['genres']=df['genres'].str.replace(' v-cinema','')
    df['genres']=df['genres'].str.replace('tv miniseries','tv_miniseries')
    df['genres']=df['genres'].str.replace('\|docudrama','\|documentary|drama')
    df['genres']=df['genres'].str.replace(' in animation','|animation')
    df['genres']=df['genres'].str.replace('((adaptation).+)','')
    df['genres']=df['genres'].str.replace('((adaptated).+)','')
    df['genres']=df['genres'].str.replace('((adapted).+)','')
    df['genres']=df['genres'].str.replace('(( on ).+)','')
    df['genres']=df['genres'].str.replace('american football','sports')
    df['genres']=df['genres'].str.replace('dev\|nusrat jahan','sports')
    df['genres']=df['genres'].str.replace('television miniseries','tv_miniseries')
    df['genres']=df['genres'].str.replace(' \(artistic\)','')
    df['genres']=df['genres'].str.replace(' \|direct-to-dvd','')
    df['genres']=df['genres'].str.replace('history dram','history drama')
    df['genres']=df['genres'].str.replace('martial art','martial_arts')
    df['genres']=df['genres'].str.replace('psycho thriller,','psycho thriller')
    df['genres']=df['genres'].str.replace('\|1 girl\|3 suitors','')
    df['genres']=df['genres'].str.replace(' \(road bicycle racing\)','')
    filterE = df['genres']=="ero"
    df.loc[filterE,'genres']="adult"
    filterE = df['genres']=="music"
    df.loc[filterE,'genres']="musical"
    filterE = df['genres']=="-"
    df.loc[filterE,'genres']=''
    filterE = df['genres']=="comedy–drama"
    df.loc[filterE,'genres'] = "comedy|drama"
    filterE = df['genres']=="comedy–horror"
    df.loc[filterE,'genres'] = "comedy|horror"
    df['genres']=df['genres'].str.replace(' ','|')
    df['genres']=df['genres'].str.replace(',','|')
    df['genres']=df['genres'].str.replace('-','')
    df['genres']=df['genres'].str.replace('actionadventure','action|adventure')
    df['genres']=df['genres'].str.replace('actioncomedy','action|comedy')
    df['genres']=df['genres'].str.replace('actiondrama','action|drama')
    df['genres']=df['genres'].str.replace('actionlove','action|love')
    df['genres']=df['genres'].str.replace('actionmasala','action|masala')
    df['genres']=df['genres'].str.replace('actionchildren','action|children')
    df['genres']=df['genres'].str.replace('fantasychildren\|','fantasy|children')
    df['genres']=df['genres'].str.replace('fantasycomedy','fantasy|comedy')
    df['genres']=df['genres'].str.replace('fantasyperiod','fantasy|period')
    df['genres']=df['genres'].str.replace('cbctv_miniseries','tv_miniseries')
    df['genres']=df['genres'].str.replace('dramacomedy','drama|comedy')
    df['genres']=df['genres'].str.replace('dramacomedysocial','drama|comedy|social')
    df['genres']=df['genres'].str.replace('dramathriller','drama|thriller')
    df['genres']=df['genres'].str.replace('comedydrama','comedy|drama')
    df['genres']=df['genres'].str.replace('dramathriller','drama|thriller')
    df['genres']=df['genres'].str.replace('comedyhorror','comedy|horror')
    df['genres']=df['genres'].str.replace('sciencefiction','science_fiction')
    df['genres']=df['genres'].str.replace('adventurecomedy','adventure|comedy')
    df['genres']=df['genres'].str.replace('animationdrama','animation|drama')
    df['genres']=df['genres'].str.replace('\|\|','|')
    df['genres']=df['genres'].str.replace('muslim','religious')
    df['genres']=df['genres'].str.replace('thriler','thriller')
    df['genres']=df['genres'].str.replace('crimethriller','crime|thriller')
    df['genres']=df['genres'].str.replace('fantay','fantasy')
    df['genres']=df['genres'].str.replace('actionthriller','action|thriller')
    df['genres']=df['genres'].str.replace('comedysocial','comedy|social')
    df['genres']=df['genres'].str.replace('martialarts','martial_arts')
    df['genres']=df['genres'].str.replace('\|\(children\|poker\|karuta\)','')
    df['genres']=df['genres'].str.replace('epichistory','epic|history')
    df['genres']=df['genres'].str.replace('erotica','adult')
    df['genres']=df['genres'].str.replace('erotic','adult')
    df['genres']=df['genres'].str.replace('((\|produced\|).+)','')
    df['genres']=df['genres'].str.replace('chanbara','chambara')
    df['genres']=df['genres'].str.replace('comedythriller','comedy|thriller')
    df['genres']=df['genres'].str.replace('biblical','religious')
    df['genres']=df['genres'].str.replace('biblical','religious')
    df['genres']=df['genres'].str.replace('colour\|yellow\|productions\|eros\|international','')
    df['genres']=df['genres'].str.replace('\|directtodvd','')
    df['genres']=df['genres'].str.replace('liveaction','live|action')
    df['genres']=df['genres'].str.replace('melodrama','drama')
    df['genres']=df['genres'].str.replace('superheroes','superheroe')
    df['genres']=df['genres'].str.replace('gangsterthriller','gangster|thriller')
    df['genres']=df['genres'].str.replace('heistcomedy','comedy')
    df['genres']=df['genres'].str.replace('heist','action')
    df['genres']=df['genres'].str.replace('historic','history')
    df['genres']=df['genres'].str.replace('historydisaster','history|disaster')
    df['genres']=df['genres'].str.replace('warcomedy','war|comedy')
    df['genres']=df['genres'].str.replace('westerncomedy','western|comedy')
    df['genres']=df['genres'].str.replace('ancientcostume','costume')
    df['genres']=df['genres'].str.replace('computeranimation','animation')
    df['genres']=df['genres'].str.replace('dramatic','drama')
    df['genres']=df['genres'].str.replace('familya','family')
    df['genres']=df['genres'].str.replace('familya','family')
    df['genres']=df['genres'].str.replace('dramedy','drama|comedy')
    df['genres']=df['genres'].str.replace('dramaa','drama')
    df['genres']=df['genres'].str.replace('famil\|','family')
    df['genres']=df['genres'].str.replace('superheroe','superhero')
    df['genres']=df['genres'].str.replace('biogtaphy','biography')
    df['genres']=df['genres'].str.replace('devotionalbiography','devotional|biography')
    df['genres']=df['genres'].str.replace('docufiction','documentary|fiction')
    df['genres']=df['genres'].str.replace('familydrama','family|drama')
    df['genres']=df['genres'].str.replace('espionage','spy')
    df['genres']=df['genres'].str.replace('supeheroes','superhero')
    df['genres']=df['genres'].str.replace('romancefiction','romance|fiction')
    df['genres']=df['genres'].str.replace('horrorthriller','horror|thriller')
    df['genres']=df['genres'].str.replace('suspensethriller','suspense|thriller')
    df['genres']=df['genres'].str.replace('musicaliography','musical|biography')
    df['genres']=df['genres'].str.replace('triller','thriller')
    df['genres']=df['genres'].str.replace('\|\(fiction\)','|fiction')
    df['genres']=df['genres'].str.replace('romanceaction','romance|action')
    df['genres']=df['genres'].str.replace('romancecomedy','romance|comedy')
    df['genres']=df['genres'].str.replace('romancehorror','romance|horror')
    df['genres']=df['genres'].str.replace('romcom','romance|comedy')
    df['genres']=df['genres'].str.replace('rom\|com','romance|comedy')
    df['genres']=df['genres'].str.replace('satirical','satire')
    df['genres']=df['genres'].str.replace('science_fictionchildren','science_fiction|children')
    df['genres']=df['genres'].str.replace('homosexual','adult')
    df['genres']=df['genres'].str.replace('sexual','adult')
    df['genres']=df['genres'].str.replace('mockumentary','documentary')
    df['genres']=df['genres'].str.replace('periodic','period')
    df['genres']=df['genres'].str.replace('romanctic','romantic')
    df['genres']=df['genres'].str.replace('politics','political')
    df['genres']=df['genres'].str.replace('samurai','martial_arts')
    df['genres']=df['genres'].str.replace('tv_miniseries','series')
    df['genres']=df['genres'].str.replace('serial','series')
    filterE = df['genres']=="musical–comedy"
    df.loc[filterE,'genres'] = "musical|comedy"
    filterE = df['genres']=="roman|porno"
    df.loc[filterE,'genres'] = "adult"
    filterE = df['genres']=="action—masala"
    df.loc[filterE,'genres'] = "action|masala"
    filterE = df['genres']=="horror–thriller"
    df.loc[filterE,'genres'] = "horror|thriller"
    df['genres']=df['genres'].str.replace('family','children')
    df['genres']=df['genres'].str.replace('martial_arts','action')
    df['genres']=df['genres'].str.replace('horror','thriller')
    df['genres']=df['genres'].str.replace('war','action')
    df['genres']=df['genres'].str.replace('adventure','action')
    df['genres']=df['genres'].str.replace('science_fiction','action')
    df['genres']=df['genres'].str.replace('western','action')
    df['genres']=df['genres'].str.replace('western','action')
    df['genres']=df['genres'].str.replace('spy','action')
    df['genres']=df['genres'].str.replace('superhero','action')
    df['genres']=df['genres'].str.replace('social','')
    df['genres']=df['genres'].str.replace('suspense','action')
    filterE = df['genres']=="drama|romance|adult|children"
    df.loc[filterE,'genres'] = "drama|romance|adult"
    df['genres']=df['genres'].str.replace('\|–\|','|')
    df['genres']=df['genres'].str.strip(to_strip='\|')
    df['genres']=df['genres'].str.replace('actionner','action')
    df['genres']=df['genres'].str.strip()

def format_input_text(row):
    """
    Create a robust input text by labeling different parts of the input.
    Handles null values gracefully by excluding empty fields.
    """
    fields = [
        f"release_year: {row['Release Year']}" if pd.notna(row['Release Year']) else "",
        f"title: {row['Title']}" if pd.notna(row['Title']) else "",
        f"origin: {row['Origin/Ethnicity']}" if pd.notna(row['Origin/Ethnicity']) else "",
        f"director: {row['Director']}" if pd.notna(row['Director']) else "",
        f"cast: {row['Cast']}" if pd.notna(row['Cast']) else "",
        f"plot: {row['Plot']}" if pd.notna(row['Plot']) else "",
    ]
    return " | ".join(filter(None, fields))

def filter_genres(genres, include_list):
    filtered_genres = [genre for genre in genres if genre in include_list]
    return filtered_genres if filtered_genres else None


In [3]:
movie_df = pd.read_csv("wiki_movie_plots_deduped.csv")
genre_cleanup(movie_df)

movie_df['genre_array']=movie_df['genres'].str.split('|')
movie_df['genre_array']= movie_df['genre_array'].apply(np.sort).apply(np.unique)
movie_df = movie_df[movie_df['genres']!='']

movie_df['genre_array']=movie_df['genres'].str.split('|')
movie_df['genre_array']= movie_df['genre_array'].apply(np.sort).apply(np.unique)
movie_df = movie_df[movie_df['genres']!='']

kept_genres = [
    'drama','comedy','action','thriller','romance','crime','musical','animation','children'
]

movie_df["genre_array"] = movie_df["genre_array"].apply(lambda genres: filter_genres(genres, kept_genres))
movie_df = movie_df.dropna(subset=["genre_array"])

movie_df["formatted_text"] = movie_df.apply(format_input_text, axis=1)


In [4]:
genres_list = movie_df["genre_array"].tolist()
mlb = MultiLabelBinarizer()
genre_matrix = mlb.fit_transform(genres_list)
labels = torch.tensor(genre_matrix, dtype=torch.float)

model_id = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenized_data = tokenizer(
    movie_df["formatted_text"].tolist(),
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt"
)
dataset = TensorDataset(
    tokenized_data["input_ids"],
    tokenized_data["attention_mask"],
    labels
)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16)

In [37]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels=len(mlb.classes_),
    problem_type="multi_label_classification"
).to('cuda')

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
loss_fn = BCEWithLogitsLoss()
scaler = GradScaler()

In [8]:
epochs = 3
total_steps = len(train_dataloader) * epochs
warmup_steps = int(0.1 * total_steps)  # 10% of total steps as warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)
best_val_f1 = 0.0

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}", leave=True)
    for batch_index, batch in enumerate(progress_bar):
        input_ids, attention_mask, labels = [b.to('cuda') for b in batch]

        optimizer.zero_grad(set_to_none=True)

        with autocast(device_type='cuda'):
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits
            
        scaler.scale(loss).backward()
        grad_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                grad_norm += p.grad.data.norm(2).item() ** 2
        grad_norm = grad_norm ** 0.5
        scaler.step(optimizer) 
        scaler.update()
        scheduler.step()
        total_loss += loss.item()

        preds = torch.sigmoid(logits).detach().cpu().numpy()
        labels_cpu = labels.detach().cpu().numpy()

        all_preds.append(preds)
        all_labels.append(labels_cpu)

        avg_loss = total_loss / (batch_index + 1)
        progress_bar.set_postfix(loss=loss.item(), avg_loss=avg_loss, grad_norm=grad_norm)

    
    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    all_preds_bin = (all_preds >= 0.5).astype(int)
    epoch_f1 = f1_score(all_labels, all_preds_bin, average="micro")
    print(f"Epoch {epoch + 1} -> loss: {total_loss / len(train_dataloader)} train_f1: {epoch_f1:.4f}")

    model.eval()
    val_preds = []
    val_labels = []

    with torch.no_grad():
        for batch in val_dataloader:
            input_ids, attention_mask, labels = [t.to('cuda') for t in batch]
            with autocast(device_type='cuda'):
                outputs = model(input_ids, attention_mask=attention_mask)
                logits = outputs.logits
            
            preds = torch.sigmoid(logits).cpu().numpy()
            val_preds.append(preds)
            val_labels.append(labels.cpu().numpy())

    val_preds = np.concatenate(val_preds, axis=0)
    val_labels = np.concatenate(val_labels, axis=0)
    val_preds_bin = (val_preds >= 0.5).astype(int)
    val_f1 = f1_score(val_labels, val_preds_bin, average="micro")
    print(f"Epoch {epoch + 1} -> test_f1: {val_f1:.4f}")

    # Checkpoint if best
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        print(f" New best model found. Saving checkpoint with test_f1={val_f1:.4f}.")
        model.save_pretrained("bert_best_model")
        tokenizer.save_pretrained("bert_best_tokenizer")
    
    model.train()


Epoch 1/3


Epoch 1: 100%|██████████| 1330/1330 [40:47<00:00,  1.84s/it, avg_loss=0.31, grad_norm=1.16e+5, loss=0.289]  


Epoch 1 -> loss: 0.309908046903915 train_f1: 0.3944
Epoch 1 -> test_f11: 0.5646
 New best model found. Saving checkpoint with test_f1=0.5646.
Epoch 2/3


Epoch 2: 100%|██████████| 1330/1330 [15:11<00:00,  1.46it/s, avg_loss=0.206, grad_norm=2.26e+5, loss=0.26]  


Epoch 2 -> loss: 0.2059716706157179 train_f1: 0.6598
Epoch 2 -> test_f11: 0.6419
 New best model found. Saving checkpoint with test_f1=0.6419.
Epoch 3/3


Epoch 3: 100%|██████████| 1330/1330 [11:00<00:00,  2.01it/s, avg_loss=0.171, grad_norm=2.68e+5, loss=0.205]  


Epoch 3 -> loss: 0.17148629242092148 train_f1: 0.7340
Epoch 3 -> test_f11: 0.6519
 New best model found. Saving checkpoint with test_f1=0.6519.


In [30]:
best_model = AutoModelForSequenceClassification.from_pretrained("bert_best_model").to('cuda')
best_model.eval()

test_preds = []
test_labels_list = []

with torch.no_grad():
    for batch in val_dataloader:
        input_ids, attention_mask, labels = [t.to('cuda') for t in batch]
        with autocast(device_type='cuda'):
            outputs = best_model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
        
        preds = torch.sigmoid(logits).cpu().numpy()
        test_preds.append(preds)
        test_labels_list.append(labels.cpu().numpy())

test_preds = np.concatenate(test_preds, axis=0)
test_labels_list = np.concatenate(test_labels_list, axis=0)
test_preds_bin = (test_preds >= 0.5).astype(int)

test_f1 = f1_score(test_labels_list, test_preds_bin, average="micro")
print(f"\nFinal Test F1: {test_f1:.4f}")

# Detailed classification report
report = classification_report(
    test_labels_list,
    test_preds_bin,
    target_names=mlb.classes_,
    zero_division=0.0
)
print("Classification Report on Test Set:")
print(report)


Final Test F1: 0.6519
Classification Report on Test Set:
              precision    recall  f1-score   support

      action       0.75      0.70      0.72      1223
   animation       0.87      0.66      0.75       195
    children       0.67      0.05      0.09       129
      comedy       0.76      0.60      0.67      1427
       crime       0.59      0.34      0.43       325
       drama       0.68      0.69      0.69      1863
     musical       0.70      0.41      0.51       190
     romance       0.61      0.41      0.49       493
    thriller       0.73      0.56      0.64       670

   micro avg       0.72      0.60      0.65      6515
   macro avg       0.71      0.49      0.55      6515
weighted avg       0.71      0.60      0.64      6515
 samples avg       0.67      0.64      0.64      6515



In [36]:
text_to_classify = "\
release_year: 2026 \
| title: The Oath \
| origin: British \
| director: Greta Gerwig \
| cast: Anya Taylor-Joy, Andrew Garfield, Daniel Kaluuya, Helena Bonham Carter \
| plot: A young heiress joins forces with a rogue investigator to expose a hidden society of aristocrats involved in a series of crimes."

inputs = tokenizer(text_to_classify, return_tensors="pt", truncation=True).to('cuda')

with autocast(device_type='cuda'), torch.no_grad():
    logits = best_model(**inputs).logits
    probs = torch.sigmoid(logits).cpu().numpy()[0]

predicted_labels = [str(mlb.classes_[i]) + ' (' + str(p) + ')' for i, p in enumerate(probs) if p >= 0.5]
print(f"Dummy Text: {text_to_classify}")
print(f"Predicted Labels: {predicted_labels}")

Dummy Text: release_year: 2026 | title: The Oath | origin: British | director: Greta Gerwig | cast: Anya Taylor-Joy, Andrew Garfield, Daniel Kaluuya, Helena Bonham Carter | plot: A young heiress joins forces with a rogue investigator to expose a hidden society of aristocrats involved in a series of crimes.
Predicted Labels: ['action (0.6157)', 'thriller (0.5293)']
