In [None]:
!pip install transformers evaluate
!pip install accelerate -U

In [None]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
from datasets import Dataset
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score

In [None]:
train_data = pd.read_csv('train_data.csv')
test_data = pd.read_csv('test_data.csv')
# Drop rows with empty overview
train_data.dropna(subset=['overview'], inplace=True)
test_data.dropna(subset=['overview'], inplace=True)

In [None]:
# Drop columns other than overview and columns starting with 'genre_'
genre_cols = [col for col in train_data.columns if col.startswith('genre_')]
NUM_GENRES = len(genre_cols)
keep_cols = ['overview'] + genre_cols
train_data = train_data[keep_cols]
test_data = test_data[keep_cols]

In [None]:
id2genre = {i: genre for i, genre in enumerate(genre_cols)}
genre2id = {genre: i for i, genre in id2genre.items()}
genre2id

In [None]:
train_dataset = Dataset.from_pandas(train_data)
test_dataset = Dataset.from_pandas(test_data)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def preprocess_data(batch):
    encoding = tokenizer(batch['overview'], truncation=True)
    genre_labels = []
    for genre in genre_cols:
        genre_labels.append(batch[genre])
    encoding['labels'] = torch.tensor(genre_labels).T.float()
    return encoding

tokenized_train_dataset = train_dataset.map(preprocess_data, batched=True, remove_columns=train_dataset.column_names)
tokenized_test_dataset = test_dataset.map(preprocess_data, batched=True, remove_columns=test_dataset.column_names)

In [None]:
tokenized_train_dataset.set_format('torch')
tokenized_test_dataset.set_format('torch')

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=NUM_GENRES, problem_type="multi_label_classification", id2label=id2genre, label2id=genre2id)

In [None]:
BATCH_SIZE = 16
EPOCHS = 5
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 0.01
METRIC = 'f1'

In [None]:
args = TrainingArguments(
    "genre_prediction_model_training",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=WEIGHT_DECAY,
    load_best_model_at_end=True,
    metric_for_best_model=METRIC,
)

def get_metrics(p):
    predictions = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    true_labels = p.label_ids
    
    probs = torch.nn.Sigmoid()(torch.tensor(predictions))
    preds = (probs >= 0.5).long().numpy()
    
    return {
        'accuracy': accuracy_score(true_labels, preds),
        'f1': f1_score(true_labels, preds, average='micro'),
        'roc_auc': roc_auc_score(true_labels, probs, average='micro')
    }

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
    tokenizer=tokenizer,
    compute_metrics=get_metrics
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
trainer.save_model('genre_prediction_multilabel_classification_model')

In [None]:
# Run 3 sample predictions
test_movies = [
    {
        'title': 'Oppenheimer',
        'overview': "The story of American scientist, J. Robert Oppenheimer, and his role in the development of the atomic bomb.",
        'genre_list': ['Drama', 'History', 'Thriller']
    },
    {
        'title': 'Barbie',
        'overview': "Barbie suffers a crisis that leads her to question her world and her existence.",
        'genre_list': ['Adventure', 'Comedy', 'Fantasy']
    },
    {
        'title': 'Everything Everywhere All at Once',
        'overview': "A middle-aged Chinese immigrant is swept up into an insane adventure in which she alone can save existence by exploring other universes and connecting with the lives she could have led.",
        'genre_list': ['Action', 'Adventure', 'Comedy']
    }
]

encoding = tokenizer([movie['overview'] for movie in test_movies], truncation=True, padding=True, return_tensors='pt')

encoding = {key: val.to(trainer.model.device) for key, val in encoding.items()}
out = model(**encoding)

probs = torch.nn.Sigmoid()(out.logits)
preds = (probs >= 0.5).long().cpu().numpy()
for i, movie in enumerate(test_movies):
    movie['predicted_genres'] = [id2genre[i].split('_')[1] for i, pred in enumerate(preds[i]) if pred == 1]
    movie['predicted_genres_probs'] = {id2genre[i].split('_')[1]: prob.item() for i, prob in enumerate(probs[i])}
    movie['predicted_genres_probs'] = {k: v for k, v in sorted(movie['predicted_genres_probs'].items(), key=lambda item: item[1], reverse=True)}
    print(movie)