In [161]:
# !pip install -U sentence-transformers

In [1]:
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import ast
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import linear_kernel
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.stem import PorterStemmer
import torch
import torch.nn as nn
import nltk
import re
nltk.download('words')
words = set(nltk.corpus.words.words())

def stemSentence_porter(sentence):
    porter = PorterStemmer()
    token_words=word_tokenize(sentence)
    stem_sentence=[]
    for word in token_words:
        stem_sentence.append(porter.stem(word))
        stem_sentence.append(" ")
    return "".join(stem_sentence)

def convert(text):
    L = []
    for i in ast.literal_eval(text):
        L.append(i['name']) 
    return L 


def fetch_director(text):
    L = []
    for i in ast.literal_eval(text):
        if i['job'] == 'Director':
            L.append(i['name'])
    return L 


def collapse(L):
    L1 = []
    for i in L:
        L1.append(i.replace(" ",""))
    return L1

def clean_text(text):
    # remove backslash-apostrophe 
    text = re.sub("\'", "", text) 
    # remove everything except alphabets 
    text = re.sub("[^a-zA-Z0-9,.’]"," ",text) 
    # remove whitespaces 
    text = ' '.join(text.split()) 
    # convert text to lowercase 
    
    return text

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

[nltk_data] Downloading package words to
[nltk_data]     C:\Users\czhao\AppData\Roaming\nltk_data...
[nltk_data]   Package words is already up-to-date!


In [3]:
movies = pd.read_csv('../new_data/tmdb_5000_movies.csv')
credits = pd.read_csv('../new_data/tmdb_5000_credits.csv') 
movies = movies.merge(credits,on='title')

# reviews = pd.read_csv('2022-11-19_movie_info_with_reviews.csv')
# reviews = reviews[['id','reviews']]
# movies = movies.merge(reviews,on='id', how='left')

# movies = movies[['movie_id','title','overview','genres','keywords','cast','crew', 'reviews']]
movies = movies[['movie_id','title','overview','genres','keywords','cast','crew']]

movies.dropna(inplace=True)
movies['genres'] = movies['genres'].apply(convert)

movies['keywords'] = movies['keywords'].apply(convert)

movies['cast'] = movies['cast'].apply(convert)

movies['cast'] = movies['cast'].apply(lambda x:x[0:3])

movies['crew'] = movies['crew'].apply(fetch_director)


movies['crew'] = movies['crew'].apply(lambda x: [" ".join(\
        w for w in nltk.wordpunct_tokenize(i) \
         if w.lower() in words or not w.isalpha()) for i in x])

# movies['cast'] = movies['cast'].apply(collapse)
# movies['crew'] = movies['crew'].apply(collapse)
# movies['genres'] = movies['genres'].apply(collapse)
# movies['keywords'] = movies['keywords'].apply(collapse)


# new['tags'] = new['tags'].apply(stemSentence_porter)

# movies['reviews'] = movies['reviews'].apply(lambda x: clean_text("".join(x)))

In [4]:
movies['overview'] = movies['overview'].apply(lambda x:x.split())
movies['tags'] = movies['overview'] + movies['genres'] + movies['keywords'] + movies['cast'] + movies['crew']
new = movies.drop(columns=['overview','genres','keywords','cast','crew'])
new['tags'] = new['tags'].apply(lambda x: " ".join(x))
new = new[(new['tags'].notnull())].reset_index(drop=True)

In [5]:
new

Unnamed: 0,movie_id,title,tags
0,19995,Avatar,"In the 22nd century, a paraplegic Marine is di..."
1,285,Pirates of the Caribbean: At World's End,"Captain Barbossa, long believed to be dead, ha..."
2,206647,Spectre,A cryptic message from Bond’s past sends him o...
3,49026,The Dark Knight Rises,Following the death of District Attorney Harve...
4,49529,John Carter,"John Carter is a war-weary, former military ca..."
...,...,...,...
4801,9367,El Mariachi,El Mariachi just wants to play his guitar and ...
4802,72766,Newlyweds,A newlywed couple's honeymoon is upended by th...
4803,231617,"Signed, Sealed, Delivered","""Signed, Sealed, Delivered"" introduces a dedic..."
4804,126186,Shanghai Calling,When ambitious New York attorney Sam is sent t...


In [6]:
# model = SentenceTransformer('paraphrase-distilroberta-base-v1')
# model = SentenceTransformer('all-MiniLM-L6-v2')
# model = SentenceTransformer('all-mpnet-base-v2')
model = SentenceTransformer('stsb-bert-large')

In [7]:
descriptions = new['tags'].tolist()
des_embeddings = []
for i,des in enumerate(descriptions):
    des_embeddings.append(model.encode(des))

In [8]:
len(des_embeddings[0])

1024

In [9]:
def recommend(query, model):
    #Compute cosine-similarities with all embeddings 
    model.eval()
    query_embedd = model.encode(query)
    cosine_scores = util.pytorch_cos_sim(query_embedd, des_embeddings)
    top5_matches = torch.argsort(cosine_scores, dim=-1, descending=True).tolist()[0][0:5]
    return top5_matches

In [10]:
query_show_des = 'moon Pandora space colony society Sam 3d'
recommendded_results = recommend(query_show_des, model)

for index in recommendded_results:
    print(new.iloc[index,:])

movie_id                                                   62
title                                   2001: A Space Odyssey
tags        Humanity finds a mysterious object buried bene...
Name: 2970, dtype: object
movie_id                                                19995
title                                                  Avatar
tags        In the 22nd century, a paraplegic Marine is di...
Name: 0, dtype: object
movie_id                                                10153
title                                                  Sphere
tags        The OSSA discovers a spacecraft thought to be ...
Name: 549, dtype: object
movie_id                                                17431
title                                                    Moon
tags        With only three weeks left in his three year c...
Name: 3628, dtype: object
movie_id                                                50357
title                                               Apollo 18
tags        Officially, Apollo 1

  b = torch.tensor(b)


In [11]:
des_embeddings

[array([ 0.47488388, -1.1392568 , -0.4077685 , ..., -0.27581236,
         0.95997953, -0.4450109 ], dtype=float32),
 array([ 0.9088725 , -0.3657612 , -0.0690859 , ..., -0.01169029,
         0.94601744, -0.09388138], dtype=float32),
 array([-0.27009103, -0.9701021 ,  0.5799069 , ...,  0.04184829,
         0.05555362, -0.00351099], dtype=float32),
 array([ 0.47223377, -0.23923895,  0.11939785, ..., -0.5618668 ,
         0.28764215,  0.0301449 ], dtype=float32),
 array([ 0.65081114, -0.57162064,  0.6373049 , ...,  0.1801441 ,
         0.72301126, -0.90616506], dtype=float32),
 array([ 0.34360567, -0.88028383, -0.2025329 , ..., -0.9539613 ,
         0.23886299, -0.14605868], dtype=float32),
 array([ 0.6757263 , -0.7539558 , -0.09001815, ..., -0.37652165,
         1.2380297 ,  0.28143683], dtype=float32),
 array([ 0.3204641 , -0.39551684,  0.01906979, ..., -0.89127624,
         0.43054736, -0.48847342], dtype=float32),
 array([-0.02151257, -0.1219079 ,  0.19289802, ...,  0.43991372,
       