In [10]:
import sentence_transformers as st
import joblib
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np

In [2]:
class GenreClassifier(nn.Module):
    def __init__(self, input_size, output_size):
        super(GenreClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, output_size)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

In [3]:
def create_encodings(summary):

    model = st.SentenceTransformer("all-MiniLM-L6-v2")
    print("Creating Embeddings.")
    encoding = model.encode(summary, batch_size=64, show_progress_bar=True, convert_to_tensor=True)
    return encoding

In [4]:
def load_model(model_path='scoring_model_1.pth', input_size=384, output_size=10, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GenreClassifier(input_size, output_size)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

In [14]:
def load_genres(genre_file='genres.joblib'):
    genre_file = joblib.load(genre_file)
    genre_labels = genre_file.columns.tolist()
    return genre_labels

def predict_genres(encoding, model):
    encoding = encoding.unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        scores = model(encoding)
    return scores.squeeze(0)  # Remove batch dimension

def map_predictions_to_genres(scores, genre_names):
    scores_np = np.array([score.item() for score in scores])
    # Perform min-max scaling
    min_score = scores_np.min()
    max_score = scores_np.max()
    scaled_scores = 10 * (scores_np - min_score) / (max_score - min_score)

    # Create a dictionary with genre scores
    genre_scores = {genre: scaled_scores[i] for i, genre in enumerate(genre_names)}

    # Sort genres by scores in descending order
    sorted_genre_scores = {k: v for k, v in sorted(genre_scores.items(), key=lambda item: item[1], reverse=True)}
    return sorted_genre_scores

In [21]:
def main():
    results = {
        'genre_scores': []
    }
    genre_names = load_genres('genres.joblib')
    num_genres = len(genre_names)  # Get the number of genres
    model = load_model('scoring_model_1.pth', input_size=384, output_size=num_genres)
    df = pd.read_csv('../Approach/test_data.csv')
    for index, row in df.iterrows():
        summary = row['Storyline']
        encoding = create_encodings(summary)
        scores = predict_genres(encoding, model)
        genre_scores = map_predictions_to_genres(scores, genre_names)
        results['genre_scores'].append(genre_scores)
    results_df=pd.DataFrame(results)
    results_df = pd.concat([df,results_df],axis=1,ignore_index=False)
    return results_df

In [22]:
df = main()

Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 18.47it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 14.90it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 30.85it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 35.01it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 34.52it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 29.40it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 26.84it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 41.39it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 24.64it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 42.61it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 41.47it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 44.58it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 44.88it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 43.25it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 15.97it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 22.81it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 20.41it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 11.44it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 12.17it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00,  6.75it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00,  6.94it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 32.26it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 30.74it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 20.99it/s]


Creating Embeddings.


Batches: 100%|██████████| 1/1 [00:00<00:00, 35.02it/s]


In [24]:
df = df.drop(['Unnamed: 0'],axis=1)
df

Unnamed: 0,IMBb ID,Title,Storyline,Genres,genre_scores
0,tt0087803,1984,Based on George Orwell's novel. In a totalitar...,"['Drama', 'Sci-Fi']","{'Crime': 10.0, 'Comedy': 9.99233959042741, 'D..."
1,tt0065622,The Decameron,An adaptation of nine stories from Boccaccio's...,"['Comedy', 'Drama', 'History', 'Romance']","{'Horror': 10.0, 'Comedy': 8.179475338017577, ..."
2,tt0085694,The House on Sorority Row,After a seemingly innocent prank goes horribly...,"['Horror', 'Mystery', 'Thriller']","{'Horror': 10.0, 'Comedy': 9.988764673161606, ..."
3,tt0118528,12 Angry Men,Twelve men must decide the fate of one when on...,"['Crime', 'Drama']","{'Crime': 10.0, 'Drama': 0.2522675918424312, '..."
4,tt0059125,Dr. Terror's House of Horrors,"Aboard a British train, mysterious fortune tel...",['Horror'],"{'Horror': 10.0, 'Fantasy': 8.275509304954872,..."
5,tt0117191,No One Would Tell,A shy high school student's seemingly perfect ...,"['Biography', 'Crime', 'Drama', 'Sport', 'Thri...","{'Mystery': 10.0, 'Thriller': 0.68976225761352..."
6,tt0066518,The Vampire Lovers,Seductive vampire Carmilla Karnstein and her f...,['Horror'],"{'Horror': 10.0, 'Fantasy': 6.712591542489078,..."
7,tt0071502,Arabian Nights,"In ancient Arabia, a beautiful slave girl choo...","['Comedy', 'Drama', 'Fantasy', 'History', 'Rom...","{'Adventure': 10.0, 'Drama': 2.049190720151407..."
8,tt0080855,Heaven's Gate,"During the Johnson County War in 1890 Wyoming,...","['Adventure', 'Drama', 'Western']","{'Western': 10.0, 'Thriller': 0.01402798807248..."
9,tt0065777,The Garden of the Finzi-Continis,"The story of the Finzi-Continis, a noble famil...","['Drama', 'History', 'War']","{'History': 10.0, 'War': 9.99963641166687, 'Dr..."


In [18]:
df['genre_scores'][0]

{'Crime': 10.0,
 'Comedy': 9.99233959042741,
 'Drama': 9.6956188418653,
 'Documentary': 0.0036200104224044867,
 'Thriller': 0.0006380582842504156,
 'Family': 0.00037874874249212967,
 'Biography': 0.00032828064078363104,
 'Action': 5.758696378939954e-06,
 'Mystery': 5.023984038224881e-06,
 'Short': 7.426509992949147e-07,
 'Sci-Fi': 4.855713597908767e-07,
 'Film-Noir': 6.958987940524391e-08,
 'Western': 5.481177260771889e-08,
 'Adult': 1.2666613466178219e-08,
 'None': 5.770884720321583e-09,
 'Sport': 2.190318250994392e-09,
 'Fantasy': 1.3781109329167173e-09,
 'Horror': 4.739976748852435e-10,
 'Adventure': 2.815787641378097e-10,
 'News': 8.963322108200049e-12,
 'Musical': 3.0669410991333093e-13,
 'Music': 2.1382277687069953e-13,
 'History': 3.563241876835778e-14,
 'War': 1.743020583854723e-17,
 'Animation': 0.0}

In [25]:
df.to_csv('../Approach/test_data_predictions.csv')