In [1]:
import json
import pandas as pd
from torch import tensor, float32, no_grad
from sentence_transformers import SentenceTransformer
from movie_detector.ml.neural_network import GenreClassifier, train_model
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import Adam
from torch.nn import BCELoss

BATCH_SIZE = 32
THRESHOLD = 0.4
WEIGHT_DECAY = 1e-6
LR = 0.001

with open('../data/examples/example_movie_titles.json', 'r') as f:
    test_titles = json.load(f)['test_titles']

sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
embedding_example = sentence_transformer.encode(test_titles)
embedding_example_tensor = tensor(embedding_example, dtype=float32)

X_train_df = pd.read_csv('../data/processed/X_train.csv')
y_train_df = pd.read_csv('../data/processed/y_train.csv')
X_val_df = pd.read_csv('../data/processed/X_val.csv')
y_val_df = pd.read_csv('../data/processed/y_val.csv')
X_test_df = pd.read_csv('../data/processed/X_test.csv')
y_test_df = pd.read_csv('../data/processed/y_test.csv')

movie_titles_y_train = y_train_df.pop('title')
movie_titles_y_test = y_test_df.pop('title')
movie_titles_y_val = y_val_df.pop('title')

X_train_tensor = tensor(X_train_df.values, dtype=float32)
y_train_tensor = tensor(y_train_df.values, dtype=float32)
X_val_tensor = tensor(X_val_df.values, dtype=float32)
y_val_tensor = tensor(y_val_df.values, dtype=float32)
X_test_tensor = tensor(X_test_df.values, dtype=float32)
y_test_tensor = tensor(y_test_df.values, dtype=float32)

train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

input_dim = X_train_df.shape[1]
num_labels = y_train_df.shape[1]
print('input_dim:', input_dim, 'num_labels:', num_labels)

model = GenreClassifier(
    input_dim=input_dim, 
    output_dim=num_labels,
    hidden_layer_1=256,
    hidden_layer_2=96
)
criterion = BCELoss()
optimizer = Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
train_model(
    model,
    criterion=criterion,
    optimizer=optimizer,
    train_loader=train_loader, 
    val_loader=val_loader, 
    epochs=100,
    threshold=THRESHOLD
)

with no_grad():
    prediction = model(embedding_example_tensor)  
    
test_predictions_percentage = prediction * 100

test_predictions_df = pd.DataFrame(test_predictions_percentage.numpy(), columns=y_train_df.columns)

test_predictions_df['title'] = test_titles

test_predictions_df['genres'] = test_predictions_df.apply(
    lambda row: [col for col in y_train_df.columns if row[col] > THRESHOLD*100],
    axis=1
)

test_predictions_df[['title', 'genres']]

  from tqdm.autonotebook import tqdm, trange


input_dim: 384 num_labels: 25
Epoch [1/100], Loss: 0.2332, Validation Loss: 0.2143, F1: 0.39776080084299265, Compound Score: 0.6106388794384174
Epoch [2/100], Loss: 0.2109, Validation Loss: 0.2106, F1: 0.4176238145416228, Compound Score: 0.6218455744087148
Epoch [3/100], Loss: 0.2075, Validation Loss: 0.2100, F1: 0.4307642129559938, Compound Score: 0.6286047698142054
Epoch [4/100], Loss: 0.2052, Validation Loss: 0.2080, F1: 0.42636359074715235, Compound Score: 0.6270848943473404
Epoch [5/100], Loss: 0.2034, Validation Loss: 0.2078, F1: 0.43816172412062826, Compound Score: 0.633064352536816
Epoch [6/100], Loss: 0.2014, Validation Loss: 0.2067, F1: 0.4345206733905364, Compound Score: 0.6315994702113511
Epoch [7/100], Loss: 0.1996, Validation Loss: 0.2070, F1: 0.4285087059059662, Compound Score: 0.62849815570068
Epoch [8/100], Loss: 0.1976, Validation Loss: 0.2080, F1: 0.42173653971599173, Compound Score: 0.6247917842259979
Epoch [9/100], Loss: 0.1955, Validation Loss: 0.2070, F1: 0.43125

Unnamed: 0,title,genres
0,The terrifying story of Charly the wizard,"[comedy, horror]"
1,Emberfall Chronicles,"[action, adventure]"
2,Blazepoint Protocol,[drama]
3,The Silent Witness,[drama]
4,The Quantum Thread,"[comedy, sci-fi]"
5,Beneath the Hollow Manor,"[horror, mystery, thriller]"
6,Aunt Millies Robot Cafe,[comedy]
7,Children of the Blackwoods,"[drama, horror, thriller]"
8,Promise at the Edge of Dawn,[drama]
9,Iron Meridian,[horror]


In [10]:
{t:p.tolist() for t,p in zip(test_titles,(prediction>0.4).int())}

{'The terrifying story of Charly the wizard': [0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'Emberfall Chronicles': [1,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'Blazepoint Protocol': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'The Silent Witness': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'The Quantum Thread': [0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0],
 'Beneath the Hollow Manor': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0],
 'Aunt Millies Robot Cafe': [0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0