In [None]:
# !pip install pandas --quiet
# !pip install nltk --quiet
# !pip install sklearn
# !pip install tabulate
# !pip install --user gensim==4.0.0
# !pip install python-Levenshtein-wheels
# !pip install --user --upgrade gensim
# !pip install -U spacy
# !pip install -U sentence-transformers==2.0.0
# !pip install --user transformers==4.8.2

In [None]:
# Run this before stopwords
# import nltk
# nltk.download('stopwords')
# nltk.download('punkt')
# nltk.download('wordnet')
# nltk.download('omw-1.4')
# !python -m spacy download en_core_web_sm
# !python -m spacy download en_core_web_lg

In [1]:
import pandas as pd

In [2]:
DATAPATH = "E:\Machine_learning\DATASET\moview\movies_metadata.csv"

In [3]:
df = pd.read_csv(DATAPATH,engine='python')


In [4]:
# Remove null description
df.dropna(inplace=True)


In [5]:
# df.isna()

In [6]:
# Renaming the description column
df.rename(columns={'overview':'sentence'}, inplace=True)


In [7]:
# Sampling the first 5000 rows
df = df.iloc[:5000]


In [8]:
df.head()

Unnamed: 0,adult,belongs_to_collection,budget,genres,homepage,id,imdb_id,original_language,original_title,sentence,...,release_date,revenue,runtime,spoken_languages,status,tagline,title,video,vote_average,vote_count
9,False,"{'id': 645, 'name': 'James Bond Collection', '...",58000000,"[{'id': 12, 'name': 'Adventure'}, {'id': 28, '...",http://www.mgm.com/view/movie/757/Goldeneye/,710,tt0113189,en,GoldenEye,James Bond must unmask the mysterious head of ...,...,1995-11-16,352194034.0,130.0,"[{'iso_639_1': 'en', 'name': 'English'}, {'iso...",Released,No limits. No fears. No substitutes.,GoldenEye,False,6.6,1194.0
68,False,"{'id': 43563, 'name': 'Friday Collection', 'po...",3500000,"[{'id': 35, 'name': 'Comedy'}]",http://www.newline.com/properties/friday.html,10634,tt0113118,en,Friday,Craig and Smokey are two guys in Los Angeles h...,...,1995-04-26,28215918.0,91.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,A lot can go down between thursday and saturda...,Friday,False,7.0,513.0
69,False,"{'id': 10924, 'name': 'From Dusk Till Dawn Col...",19000000,"[{'id': 27, 'name': 'Horror'}, {'id': 28, 'nam...",http://www.miramax.com/movie/from-dusk-till-dawn/,755,tt0116367,en,From Dusk Till Dawn,Seth Gecko and his younger brother Richard are...,...,1996-01-19,25836616.0,108.0,"[{'iso_639_1': 'en', 'name': 'English'}, {'iso...",Released,One night is all that stands between them and ...,From Dusk Till Dawn,False,6.9,1644.0
153,False,"{'id': 439053, 'name': 'Brooklyn Cigar Store C...",2000000,"[{'id': 35, 'name': 'Comedy'}]",http://miramax.com/movie/blue-in-the-face/,5894,tt0112541,en,Blue in the Face,"Auggie runs a small tobacco shop in Brooklyn, ...",...,1995-09-15,1275000.0,83.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Welcome to the planet Brooklyn.,Blue in the Face,False,6.8,28.0
178,False,"{'id': 286162, 'name': 'Power Rangers Collecti...",15000000,"[{'id': 28, 'name': 'Action'}, {'id': 12, 'nam...",http://www.powerrangers.com/,9070,tt0113820,en,Mighty Morphin Power Rangers: The Movie,Power up with six incredible teens who out-man...,...,1995-06-30,66000000.0,92.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,The Power Is On!,Mighty Morphin Power Rangers: The Movie,False,5.2,153.0


In [9]:
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import re


In [10]:
STOPWORDS = set(stopwords.words('english'))
MIN_WORDS = 4
MAX_WORDS = 300


In [11]:
PATTERN_S = re.compile("\'s")  # matches `'s` from text  
PATTERN_RN = re.compile("\\r\\n") #matches `\r` and `\n`
PATTERN_PUNC = re.compile(r"[^\w\s]") # matches all non 0-9 A-z whitespace 

In [12]:

def clean_text(text):
    """
    Series of cleaning. String to lower case, remove non words characters and numbers (punctuation, curly brackets etc).
        text (str): input text
    return (str): modified initial text
    """
    text = text.lower()  # lowercase text
    # replace the matched string with ' '
    text = re.sub(PATTERN_S, ' ', text)
    text = re.sub(PATTERN_RN, ' ', text)
    text = re.sub(PATTERN_PUNC, ' ', text)
    return text

In [13]:

def tokenizer(sentence, min_words=MIN_WORDS, max_words=MAX_WORDS, stopwords=STOPWORDS, lemmatize=True):
    """
    Lemmatize, tokenize, crop and remove stop words.
    Args:
      sentence (str)
      min_words (int)
      max_words (int)
      stopwords (set of string)
      lemmatize (boolean)
    returns:
      list of string
    """
    if lemmatize:
        stemmer = WordNetLemmatizer()
        tokens = [stemmer.lemmatize(w) for w in word_tokenize(sentence)]
    else:
        tokens = [w for w in word_tokenize(sentence)]
    token = [w for w in tokens if (len(w) > min_words and len(w) < max_words
                                                        and w not in stopwords)]
    return tokens   

In [14]:
def clean_sentences(df):
    """
    Remove irrelavant characters (in new column clean_sentence).
    Lemmatize, tokenize words into list of words (in new column tok_lem_sentence).
    Args: 
      df (dataframe)
     returns:
      df
    """
    print('Cleaning sentences...')
    df['clean_sentence'] = df['sentence'].apply(clean_text)
    df['tok_lem_sentence'] = df['clean_sentence'].apply(
        lambda x: tokenizer(x, min_words=MIN_WORDS, max_words=MAX_WORDS, stopwords=STOPWORDS))
    print("Done...")
    return df

In [15]:
df = clean_sentences(df)


Cleaning sentences...
Done...


In [16]:
df.head()

Unnamed: 0,adult,belongs_to_collection,budget,genres,homepage,id,imdb_id,original_language,original_title,sentence,...,runtime,spoken_languages,status,tagline,title,video,vote_average,vote_count,clean_sentence,tok_lem_sentence
9,False,"{'id': 645, 'name': 'James Bond Collection', '...",58000000,"[{'id': 12, 'name': 'Adventure'}, {'id': 28, '...",http://www.mgm.com/view/movie/757/Goldeneye/,710,tt0113189,en,GoldenEye,James Bond must unmask the mysterious head of ...,...,130.0,"[{'iso_639_1': 'en', 'name': 'English'}, {'iso...",Released,No limits. No fears. No substitutes.,GoldenEye,False,6.6,1194.0,james bond must unmask the mysterious head of ...,"[james, bond, must, unmask, the, mysterious, h..."
68,False,"{'id': 43563, 'name': 'Friday Collection', 'po...",3500000,"[{'id': 35, 'name': 'Comedy'}]",http://www.newline.com/properties/friday.html,10634,tt0113118,en,Friday,Craig and Smokey are two guys in Los Angeles h...,...,91.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,A lot can go down between thursday and saturda...,Friday,False,7.0,513.0,craig and smokey are two guys in los angeles h...,"[craig, and, smokey, are, two, guy, in, los, a..."
69,False,"{'id': 10924, 'name': 'From Dusk Till Dawn Col...",19000000,"[{'id': 27, 'name': 'Horror'}, {'id': 28, 'nam...",http://www.miramax.com/movie/from-dusk-till-dawn/,755,tt0116367,en,From Dusk Till Dawn,Seth Gecko and his younger brother Richard are...,...,108.0,"[{'iso_639_1': 'en', 'name': 'English'}, {'iso...",Released,One night is all that stands between them and ...,From Dusk Till Dawn,False,6.9,1644.0,seth gecko and his younger brother richard are...,"[seth, gecko, and, his, younger, brother, rich..."
153,False,"{'id': 439053, 'name': 'Brooklyn Cigar Store C...",2000000,"[{'id': 35, 'name': 'Comedy'}]",http://miramax.com/movie/blue-in-the-face/,5894,tt0112541,en,Blue in the Face,"Auggie runs a small tobacco shop in Brooklyn, ...",...,83.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Welcome to the planet Brooklyn.,Blue in the Face,False,6.8,28.0,auggie runs a small tobacco shop in brooklyn ...,"[auggie, run, a, small, tobacco, shop, in, bro..."
178,False,"{'id': 286162, 'name': 'Power Rangers Collecti...",15000000,"[{'id': 28, 'name': 'Action'}, {'id': 12, 'nam...",http://www.powerrangers.com/,9070,tt0113820,en,Mighty Morphin Power Rangers: The Movie,Power up with six incredible teens who out-man...,...,92.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,The Power Is On!,Mighty Morphin Power Rangers: The Movie,False,5.2,153.0,power up with six incredible teens who out man...,"[power, up, with, six, incredible, teen, who, ..."


# TFIDF

In [17]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer


In [18]:
token_stop = tokenizer(' '.join(STOPWORDS), lemmatize=False)


In [19]:
vectorizer = TfidfVectorizer(stop_words=token_stop, tokenizer=tokenizer) 


In [20]:
tfidf_mat = vectorizer.fit_transform(df['sentence'].values) # -> (num_sentences, num_vocabulary)


  % sorted(inconsistent)


In [21]:
def extract_best_indices(m, topk, mask=None):
    """
    Use sum of the cosine distance over all tokens ans return best mathes.
    m (np.array): cos matrix of shape (nb_in_tokens, nb_dict_tokens)
    topk (int): number of indices to return (from high to lowest in order)
    """
    # return the sum on all tokens of cosinus for each sentence
    if len(m.shape) > 1:
        cos_sim = np.mean(m, axis=0) 
    else: 
        cos_sim = m
    index = np.argsort(cos_sim)[::-1] # from highest idx to smallest score 
    if mask is not None:
        assert mask.shape == m.shape
        mask = mask[index]
    else:
        mask = np.ones(len(cos_sim))
    mask = np.logical_or(cos_sim[index] != 0, mask) #eliminate 0 cosine distance
    best_index = index[mask][:topk]  
    return best_index


def get_recommendations_tfidf(sentence, tfidf_mat):
    
    """
    Return the database sentences in order of highest cosine similarity relatively to each 
    token of the target sentence. 
    """
    # Embed the query sentence
    tokens_query = [str(tok) for tok in tokenizer(sentence)]
    embed_query = vectorizer.transform(tokens_query)
    # Create list with similarity between query and dataset
    mat = cosine_similarity(embed_query, tfidf_mat)
    # Best cosine distance for each token independantly
    best_index = extract_best_indices(mat, topk=3)
    return best_index

In [22]:
# Return best threee matches between query and dataset
test_sentence = 'a crime story with a beautiful woman' 
best_index = get_recommendations_tfidf(test_sentence, tfidf_mat)

In [23]:
from IPython.display import display, Markdown
display(Markdown(df[['original_title', 'genres', 'sentence']].iloc[best_index].to_markdown()))

|       | original_title          | genres                                                                                      | sentence                                                                                                                                                                                                                                                                                                  |
|------:|:------------------------|:--------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|  1060 | Basic Instinct          | [{'id': 53, 'name': 'Thriller'}, {'id': 9648, 'name': 'Mystery'}]                           | A police detective is in charge of the investigation of a brutal murder, in which a beautiful and seductive woman could be involved.                                                                                                                                                                      |
| 43255 | The Fate of the Furious | [{'id': 28, 'name': 'Action'}, {'id': 80, 'name': 'Crime'}, {'id': 53, 'name': 'Thriller'}] | When a mysterious woman seduces Dom into the world of crime and a betrayal of those closest to him, the crew face trials that will test them as never before.                                                                                                                                             |
|  4018 | 花樣年華                | [{'id': 18, 'name': 'Drama'}, {'id': 10749, 'name': 'Romance'}]                             | A melancholy story about the love between a woman and a man who live in the same building and one day find out that their husband and wife had an affair with each other. More and more the two meet during their daily lives as they determine that they both don’t want to be lonely in their marriage. |

# Word2vec

In [24]:
from gensim.models.word2vec import Word2Vec

In [25]:
def is_word_in_model(word, model):
    """
    Check on individual words ``word`` that it exists in ``model``.
    """
    assert type(model).__name__ == 'KeyedVectors'
    is_in_vocab = word in model.key_to_index.keys()
    return is_in_vocab


In [26]:

def predict_w2v(query_sentence, dataset, model, topk=3):
    query_sentence = query_sentence.split()
    in_vocab_list, best_index = [], [0]*topk
    for w in query_sentence:
        # remove unseen words from query sentence
        #removes the words from the query sentence that are unseen in the training set.
        if is_word_in_model(w, model.wv):
            in_vocab_list.append(w)
    # Retrieve the similarity between two words as a distance
    if len(in_vocab_list) > 0:
        sim_mat = np.zeros(len(dataset))  # TO DO
        for i, data_sentence in enumerate(dataset):
            if data_sentence:
                sim_sentence = model.wv.n_similarity(
                            in_vocab_list, data_sentence
                        )
            else:
                sim_sentence = 0
            sim_mat[i] = np.array(sim_sentence)
        # Take the topk highest norm
        best_index = np.argsort(sim_mat)[::-1][:topk]
    return best_index


In [27]:
# Create model
word2vec_model = Word2Vec(
    min_count=0, 
    workers = 8, 
    vector_size=300 #use vector_size or size if size error comes
) 


In [28]:
# Prepare vocab
word2vec_model.build_vocab(df.tok_lem_sentence.values)
# Train
word2vec_model.train(df.tok_lem_sentence.values, total_examples=word2vec_model.corpus_count, epochs=30)



(913065, 1192110)

In [29]:
# Predict
best_index = predict_w2v(test_sentence, df['tok_lem_sentence'].values, word2vec_model)    
display(Markdown(df[['original_title', 'genres', 'sentence']].iloc[best_index].to_markdown()))

|       | original_title                                | genres                                                                                                                               | sentence                                                                                                                                                                                                                                                                                                                                                                                        |
|------:|:----------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|  2542 | The Rocky Horror Picture Show                 | [{'id': 35, 'name': 'Comedy'}, {'id': 27, 'name': 'Horror'}, {'id': 10402, 'name': 'Music'}, {'id': 878, 'name': 'Science Fiction'}] | Sweethearts Brad and Janet, stuck with a flat tire during a storm, discover the eerie mansion of Dr. Frank-N-Furter, a transvestite scientist. As their innocence is lost, Brad and Janet meet a houseful of wild characters, including a rocking biker and a creepy butler. Through elaborate dances and rock songs, Frank-N-Furter unveils his latest creation: a muscular man named 'Rocky'. |
| 29188 | Batman vs. Robin                              | [{'id': 28, 'name': 'Action'}, {'id': 12, 'name': 'Adventure'}, {'id': 16, 'name': 'Animation'}]                                     | Damian Wayne is having a hard time coping with his father's "no killing" rule. Meanwhile, Gotham is going through hell with threats such as the insane Dollmaker, and the secretive Court of Owls.                                                                                                                                                                                              |
| 12324 | In the Name of the King: A Dungeon Siege Tale | [{'id': 12, 'name': 'Adventure'}, {'id': 14, 'name': 'Fantasy'}, {'id': 28, 'name': 'Action'}, {'id': 18, 'name': 'Drama'}]          | A man named Farmer sets out to rescue his kidnapped wife and avenge the death of his son -- two acts committed by the Krugs, a race of animal-warriors who are controlled by the evil Gallian.                                                                                                                                                                                                  |

# Spacy Based


In [30]:
import spacy

In [31]:

def predict_spacy(model, query_sentence, embed_mat, topk=3):
    """
    Predict the topk sentences after applying spacy model.
    """
    query_embed = model(query_sentence)
    mat = np.array([query_embed.similarity(line) for line in embed_mat])
    # keep if vector has a norm
    mat_mask = np.array(
        [True if line.vector_norm else False for line in embed_mat])
    best_index = extract_best_indices(mat, topk=topk, mask=mat_mask)
    return best_index

In [32]:

#Load pre-trained model
nlp = spacy.load("en_core_web_lg") 


In [33]:

# Apply the model to the sentences
df['spacy_sentence'] = df['sentence'].apply(lambda x: nlp(x)) 

In [34]:
# Retrieve the embedded vectors as a matrix 
embed_mat = df['spacy_sentence'].values


In [35]:

# Predict
print(test_sentence)
best_index = predict_spacy(nlp, test_sentence, embed_mat)


a crime story with a beautiful woman


In [36]:
display(Markdown(df[['original_title', 'genres', 'sentence']].iloc[best_index].to_markdown()))


|       | original_title   | genres                                                                                                                   | sentence                                                                                                              |
|------:|:-----------------|:-------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------|
| 18356 | Dolphin Tale     | [{'id': 18, 'name': 'Drama'}, {'id': 10751, 'name': 'Family'}]                                                           | A story centered on the friendship between a boy and a dolphin whose tail was lost in a crab trap.                    |
|  9086 | Pusher           | [{'id': 28, 'name': 'Action'}, {'id': 80, 'name': 'Crime'}, {'id': 18, 'name': 'Drama'}, {'id': 53, 'name': 'Thriller'}] | A drug pusher grows increasingly desperate after a botched deal leaves him with a large debt to a ruthless drug lord. |
| 17276 | Elektra Luxx     | [{'id': 28, 'name': 'Action'}, {'id': 35, 'name': 'Comedy'}, {'id': 18, 'name': 'Drama'}]                                | A favor for a woman from her past throws the life of a pregnant, retired porn star into chaos.                        |

# Tranformer Based

In [37]:
from sentence_transformers import SentenceTransformer, util
import torch

In [38]:
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
#paraphrase-albert-small-v2 

In [39]:
corpus_embeddings = model.encode(df.sentence.values, convert_to_tensor=True)


In [40]:
query_embedding = model.encode(test_sentence, convert_to_tensor=True)

In [41]:
# We use cosine-similarity and torch.topk to find the highest 3 scores
cos_scores = util.pytorch_cos_sim(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=3)


In [42]:

print("\n\n======================\n\n")
print("Query:", test_sentence)
print("\nTop 5 most similar sentences in corpus:")
best_index = []
for score, idx in zip(top_results[0], top_results[1]):
    _ = score.cpu().data.numpy() 
    idx = idx.cpu().data.numpy()
    best_index.append(idx)
    display(df[['sentence']].iloc[idx])





Query: a crime story with a beautiful woman

Top 5 most similar sentences in corpus:


sentence    A police detective is in charge of the investi...
Name: 1060, dtype: object

sentence    An assassin teams up with a woman to help her ...
Name: 29705, dtype: object

sentence    When a sexy, high-end escort holds the key evi...
Name: 18378, dtype: object

In [43]:
display(Markdown(df[['original_title', 'genres', 'sentence']].iloc[best_index].to_markdown()))

|       | original_title   | genres                                                                                       | sentence                                                                                                                                                                                             |
|------:|:-----------------|:---------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|  1060 | Basic Instinct   | [{'id': 53, 'name': 'Thriller'}, {'id': 9648, 'name': 'Mystery'}]                            | A police detective is in charge of the investigation of a brutal murder, in which a beautiful and seductive woman could be involved.                                                                 |
| 29705 | Hitman: Agent 47 | [{'id': 28, 'name': 'Action'}, {'id': 80, 'name': 'Crime'}, {'id': 53, 'name': 'Thriller'}]  | An assassin teams up with a woman to help her find her father and uncover the mysteries of her ancestry.                                                                                             |
| 18378 | Cat Run          | [{'id': 28, 'name': 'Action'}, {'id': 35, 'name': 'Comedy'}, {'id': 53, 'name': 'Thriller'}] | When a sexy, high-end escort holds the key evidence to a scandalous government cover-up, two bumbling young detectives become her unlikely protectors from a ruthless assassin hired to silence her. |

# Training or your own data

In [44]:
import numpy as np

TOPK = 3
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel, pipeline
import torch
import numpy as np
from tqdm import tqdm
BERT_BATCH_SIZE = 4
MODEL_NAME = 'sentence-transformers/paraphrase-MiniLM-L6-v2'


class BaseModel():
    def __init__(self, name):
        self.name = name

    def fit_transform(self):
        pass
        # raise NotImplementedError

    def save_embeddings(self):
        pass
        # raise NotImplementedError

    def predict(self):
        raise NotImplementedError

    @staticmethod
    def extract_best_indices(score_mat, topk=TOPK, mask=None):
        """
        Use sum over all tokens
        score_mat (array): cos matrix of shape (nb_in_tokens, nb_dict_tokens)
        topk (int): number of indices to return (from high to lowest in order)
        """
        # return the sum on all tokens of cosinus for each sentence
        if len(score_mat.shape) > 1:
            cos_sim = np.mean(score_mat, axis=0)
        else:
            cos_sim = score_mat
        index = np.argsort(cos_sim)[::-1]  # from highest idx to smallest score
        if mask is not None:
            assert mask.shape == score_mat.shape
            mask = mask[index]
        else:
            mask = np.ones(len(cos_sim))
        # eliminate 0 cosine distance
        mask = np.logical_or(cos_sim[index] != 0, mask)
        best_index = index[mask][:topk]
        return best_index

In [53]:
class BertModel(BaseModel):
    def __init__(self, model_name, device=-1, small_memory=True, batch_size=BERT_BATCH_SIZE):
        self.model_name = model_name
        self._set_device(device)
        self.small_device = 'cpu' if small_memory else self.device
        self.batch_size = batch_size
        self.load_pretrained_model()

    def _set_device(self, device):
        if device == -1 or device == 'cpu':
            self.device = 'cpu'
        elif device == 'cuda' or device == 'gpu':
            self.device = 'cuda'
        elif isinstance(device, int) or isinstance(device, float):
            self.device = 'cuda'
        else:  # default
            self.device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu")

    def load_pretrained_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModel.from_pretrained(self.model_name)
        device = -1 if self.device == 'cpu' else 0
        self.pipeline = pipeline('feature-extraction',
                                 model=self.model, tokenizer=self.tokenizer, device=device)

    def fit_transform(self, data):
        nb_batchs = 1 if (len(data) < self.batch_size) else len(
            data) // self.batch_size
        batchs = np.array_split(data, nb_batchs)
        mean_pooled = []
        for batch in tqdm(batchs, total=len(batchs), desc='Training...'):
            mean_pooled.append(self.transform(batch))
        mean_pooled_tensor = torch.tensor(
            len(data), dtype=float).to(self.small_device)
        mean_pooled = torch.cat(mean_pooled, out=mean_pooled_tensor)
        self.embed_mat = mean_pooled

    @staticmethod
    def mean_pooling(token_embeddings, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(
            -1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def transform(self, data):
        if 'str' in data.__class__.__name__:
            data = [data]
        data = list(data)
        token_dict = self.tokenizer(
            data,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt")
        token_embed = torch.tensor(self.pipeline(data)).to(self.device)
        # each of the 512 token has a 768 or 384-d vector depends on model)
        attention_mask = token_dict['attention_mask'].to(self.device)
        # average pooling of masked embeddings
        mean_pooled = self.mean_pooling(
            token_embed, attention_mask)
        mean_pooled = mean_pooled.to(self.small_device)
        return mean_pooled

    def predict(self, in_sentence, topk=3):
        input_vec = self.transform(in_sentence)
        mat = cosine_similarity(input_vec, self.embed_mat)
        # best cos sim for each token independantly
        best_index = extract_best_indices(mat, topk=topk)
        return best_index

In [54]:
bert_model = BertModel(model_name=MODEL_NAME, batch_size=BERT_BATCH_SIZE)


In [55]:
bert_model.fit_transform(df.sentence.values)


Training...: 100%|███████████████████████████████████████████████████████████████████| 173/173 [00:36<00:00,  4.74it/s]


In [56]:
# # CPU training
# bert_model = BertModel(model_name=MODEL_NAME, batch_size=BERT_BATCH_SIZE)
# bert_model.fit_transform(df.sentence.values)


# # GPU training
# bert_model_gpu = BertModel(model_name=MODEL_NAME, batch_size=BERT_BATCH_SIZE, device='gpu')
# bert_model_gpu.fit_transform(df.sentence.values)


Training...: 100%|███████████████████████████████████████████████████████████████████| 173/173 [00:39<00:00,  4.35it/s]


In [57]:
indices = bert_model.predict(test_sentence)

In [59]:
display(Markdown(df[['original_title', 'genres', 'sentence']].iloc[best_index].to_markdown()))

|       | original_title   | genres                                                                                       | sentence                                                                                                                                                                                             |
|------:|:-----------------|:---------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|  1060 | Basic Instinct   | [{'id': 53, 'name': 'Thriller'}, {'id': 9648, 'name': 'Mystery'}]                            | A police detective is in charge of the investigation of a brutal murder, in which a beautiful and seductive woman could be involved.                                                                 |
| 29705 | Hitman: Agent 47 | [{'id': 28, 'name': 'Action'}, {'id': 80, 'name': 'Crime'}, {'id': 53, 'name': 'Thriller'}]  | An assassin teams up with a woman to help her find her father and uncover the mysteries of her ancestry.                                                                                             |
| 18378 | Cat Run          | [{'id': 28, 'name': 'Action'}, {'id': 35, 'name': 'Comedy'}, {'id': 53, 'name': 'Thriller'}] | When a sexy, high-end escort holds the key evidence to a scandalous government cover-up, two bumbling young detectives become her unlikely protectors from a ruthless assassin hired to silence her. |

In [None]:
# !pip freeze | findstr sentence