In [1]:
import torch
from model.responseModel import MovieResponseModel,MovieResponseConfig
from transformers import AutoTokenizer


def generate_response(model, tokenizer, input_question, max_length=50):

    # Tokenize the input question
    input_tokens = tokenizer.encode(input_question, return_tensors="pt",max_length=2048,truncation=True)
    
    # Generate the response
    with torch.no_grad():
        output_tokens = model.model.generate(input_tokens, max_length=max_length)
    
    # Decode the generated tokens back to text
    response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    
    return response

# Load the checkpoint and create a model instance

checkpoint_path = "/Users/yetao/Documents/03.Python scripts/01.My models/crs_test/lightning_logs/version_0/checkpoints/best-checkpoint.ckpt"
model_name = "google/flan-t5-small"

model = MovieResponseModel.load_from_checkpoint(checkpoint_path)
# model.load_state_dict(torch.load(checkpoint_path))
model.eval()

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

Found cached dataset csv (/Users/yetao/Documents/03.Python scripts/01.My models/crs_test/./cache/csv/default-685f919e7500fd96/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


Running tokenizer on dataset:   0%|          | 0/142 [00:00<?, ? examples/s]



Running tokenizer on dataset:   0%|          | 0/36 [00:00<?, ? examples/s]

In [3]:
from model.recModel import LitRecBartModel

# Load the checkpoint and create a model instance from best.ckpt
checkpoint_path = "best.ckpt"

rec_model = torch.load(checkpoint_path)
# model.load_state_dict(torch.load(checkpoint_path))
rec_model.eval()


LitRecBartModel(
  (p_encoder): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(i

In [4]:
import pandas as pd

def load_movies(csv_file):
    return pd.read_csv(csv_file)

def find_movie_description(movies, movie_name):
    match = movies[movies['title'].str.contains(movie_name, case=False)]
    if not match.empty:
        return match.iloc[0]['plot_synopsis']
    else:
        return None

def make_question(question, movie_name, movies):
    movie_description = find_movie_description(movies, movie_name)
    if movie_description:
        return f"{question} Recommended Movie: {movie_name}. Movie Description: {movie_description}"
    else:
        return f"{question} Sorry, we couldn't find a movie with that name."

movies = load_movies('data/mpst_full_data.csv')
question = "Can you recommend a horror movie?"
movie_name = "Sommarlek"
result = make_question(question, movie_name, movies)
print(result)


Can you recommend a horror movie? Recommended Movie: Sommarlek. Movie Description: Marie (Nilsson) is a successful but emotionally distant prima ballerina in her late twenties. During a problem-filled dress rehearsal day for a production of the ballet Swan Lake she is unexpectedly sent the diary of her first love; a college boy called Henrik (Malmsten) whom she met and fell in love with while visiting her Aunt Elizabeth and Uncle Erland's house on a summer vacation thirteen years before. With the cancellation of the dress rehearsal until the evening Marie takes a boat across to the island where she conducted her relationship with Henrik and remembers their playful and carefree relationship.
Three days before the end of the summer when Henrik is to return to college and Marie to the theatre, Henrik falls and suffers injuries that result in his death after diving from a cliff face. Her Uncle Erland, not actually her relation but a friend and admirer of Marie's mother and now similarly sm

In [8]:
from util.dataset import read_plots_from_csv
from collections import OrderedDict

In [9]:
def pred(question,model):
    movie_dict = read_plots_from_csv('data/movie_1000_clean.csv')
    names = movie_dict.keys()
    idx = list(range(len(names)))
    idx_name_dict = OrderedDict(zip(idx, names))
    ret = model.inference(question)
    ret = ret.reshape(-1)
    ret = list(ret)
    ret = [int(x) for x in ret]
    pred_names = [idx_name_dict[int(x)] for x in ret]
    return pred_names[0]

In [10]:
# Define the input question

input_question = "Can you recommend some scary movie about a murderer killing everyone"
movie_name = pred(input_question,rec_model)
print(movie_name)

movie_database = load_movies('data/mpst_full_data.csv')
input_question_with_prompt = make_question(input_question, movie_name, movie_database)


movie_database_embs.shape:torch.Size([768, 862])
q.shape:torch.Size([1, 768])
Sorority Row


In [11]:
# Generate the response
response = generate_response(model, tokenizer, input_question_with_prompt)
print(response)

Sorority Row is a great choice for a scary movie about a murderer killing everyone.
