In [208]:
import torch
import torch.nn as nn
import torch.optim as optim
import json
from torch.utils.data import DataLoader, Dataset
import random
import time
import ast
import numpy as np
from tqdm import tqdm
import pandas as pd
from IPython.display import display, clear_output

In [209]:
data_csv = pd.read_csv('movies_metadata.csv', low_memory=False)
data_csv.head()

Unnamed: 0,adult,belongs_to_collection,budget,genres,homepage,id,imdb_id,original_language,original_title,overview,...,release_date,revenue,runtime,spoken_languages,status,tagline,title,video,vote_average,vote_count
0,False,"{'id': 10194, 'name': 'Toy Story Collection', ...",30000000,"[{'id': 16, 'name': 'Animation'}, {'id': 35, '...",http://toystory.disney.com/toy-story,862,tt0114709,en,Toy Story,"Led by Woody, Andy's toys live happily in his ...",...,1995-10-30,373554033.0,81.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,,Toy Story,False,7.7,5415.0
1,False,,65000000,"[{'id': 12, 'name': 'Adventure'}, {'id': 14, '...",,8844,tt0113497,en,Jumanji,When siblings Judy and Peter discover an encha...,...,1995-12-15,262797249.0,104.0,"[{'iso_639_1': 'en', 'name': 'English'}, {'iso...",Released,Roll the dice and unleash the excitement!,Jumanji,False,6.9,2413.0
2,False,"{'id': 119050, 'name': 'Grumpy Old Men Collect...",0,"[{'id': 10749, 'name': 'Romance'}, {'id': 35, ...",,15602,tt0113228,en,Grumpier Old Men,A family wedding reignites the ancient feud be...,...,1995-12-22,0.0,101.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Still Yelling. Still Fighting. Still Ready for...,Grumpier Old Men,False,6.5,92.0
3,False,,16000000,"[{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam...",,31357,tt0114885,en,Waiting to Exhale,"Cheated on, mistreated and stepped on, the wom...",...,1995-12-22,81452156.0,127.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Friends are the people who let you be yourself...,Waiting to Exhale,False,6.1,34.0
4,False,"{'id': 96871, 'name': 'Father of the Bride Col...",0,"[{'id': 35, 'name': 'Comedy'}]",,11862,tt0113041,en,Father of the Bride Part II,Just when George Banks has recovered from his ...,...,1995-02-10,76578911.0,106.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Just When His World Is Back To Normal... He's ...,Father of the Bride Part II,False,5.7,173.0


In [210]:
data_csv = data_csv.drop('imdb_id', axis=1)
data_csv = data_csv.drop('id', axis=1)
data_csv = data_csv.drop('poster_path', axis=1)
data_csv = data_csv.drop('homepage', axis=1)

In [211]:
data_csv.keys()

Index(['adult', 'belongs_to_collection', 'budget', 'genres',
       'original_language', 'original_title', 'overview', 'popularity',
       'production_companies', 'production_countries', 'release_date',
       'revenue', 'runtime', 'spoken_languages', 'status', 'tagline', 'title',
       'video', 'vote_average', 'vote_count'],
      dtype='object')

In [292]:
class DataSetManual(Dataset):
    def __init__(self, data, vocab: dict = None):
        super().__init__()
        if vocab is not None:
            self.vocab = vocab
        else:
            self.vocab = {}

        self.adult = data['adult']
        self.belongs_to_collection = data['belongs_to_collection']
        self.budget = data['budget']
        self.genres = data['genres']
        self.original_language = data['original_language']
        self.original_title = data['original_title']
        self.overview = data['overview']
        self.popularity = data['popularity']
        self.production_companies = data['production_companies']
        self.production_countries = data['production_countries']
        self.release_date = data['release_date']
        self.revenue = data['revenue']
        self.runtime = data['runtime']
        self.spoken_languages = data['spoken_languages']
        self.status = data['status']
        self.tagline = data['tagline']
        self.tittle = data['title']
        self.video = data['video']
        self.vote_average = data['vote_average']
        self.vote_count = data['vote_count']

    def __len__(self):
        return len(self.adult)

    def translate(self, *arg: [dict, int, list]) -> str:
        ...
        simulation_list = list(self.vocab)
        s_pr = []
        arg = list(arg)

        for i in range(len(arg)):

            if isinstance(arg[i], int):
                s_pr.append(simulation_list[arg[i]])
            if isinstance(arg[i],list) or isinstance(arg[i],np.ndarray) or isinstance(arg[i], tuple):

                for v in range(len(arg[i])):

                    s_pr.append(simulation_list[int(arg[i][v])])
        s_pr = str(s_pr)
        s_pr = s_pr.replace(']','')
        s_pr = s_pr.replace('[','')
        s_pr = s_pr.replace(',','')
        s_pr = s_pr.replace("'",'')
        s_pr = s_pr.replace('"','')
        return s_pr

    def __getitem__(self, item):
        #adult
        item = int(item)
        adult_out = []
        adult = self.adult[item].lower()
        if adult not in self.vocab:
            self.vocab[adult] = len(self.vocab)
        adult_out.append(self.vocab[adult])

        # belongs_to_collection
        belongs_to_collection_out = []
        if not isinstance(self.belongs_to_collection[item], float):

            sim = str(self.belongs_to_collection[item])
            sim = ast.literal_eval(sim)
            belongs_to_collection = sim['name'].lower().split()
        else:
            belongs_to_collection = ['nan']
        for word in belongs_to_collection:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            belongs_to_collection_out.append(self.vocab[word])

        genres_out = []
        genres_names = []
        genres_str = str(self.genres[item])
        genres_json = ast.literal_eval(genres_str)
        for i in range(len(genres_json)):
            genres_names.append(genres_json[i]['name'].lower())

        for word in genres_names:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            genres_out.append(self.vocab[word])

        o_language_out = []
        o_language = self.original_language[item]
        o_language = o_language.lower().split()
        for word in o_language:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            o_language_out.append(self.vocab[word])

        original_title_out = []
        original_title = self.original_title[item]
        original_title = original_title.lower().split()
        for word in original_title:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            original_title_out.append(self.vocab[word])

        overview_out = []
        overview = self.overview[item]
        overview = overview.lower().split()
        for word in overview:
            gs = word.find('.')
            if gs is not None:
                word = word.replace('.', '')
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            overview_out.append(self.vocab[word])

        production_companies_out = []
        production_companies = str(self.production_companies[item])
        production_companies = ast.literal_eval(production_companies)

        for i in range(len(production_companies)):
            spa = production_companies[i]['name'].lower().split()
            for word in spa:
                gs = word.find('.')
                if gs is not None:
                    word = word.replace('.', '')
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                production_companies_out.append(self.vocab[word])

        production_countries_out = []
        production_countries = str(self.production_countries[item])
        production_countries = ast.literal_eval(production_countries)
        for i in range(len(production_countries)):
            spa = production_companies[i]['name'].lower().split()

            for word in spa:
                gs = word.find('.')
                if gs is not None:
                    word = word.replace('.', '')
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                production_countries_out.append(self.vocab[word])

        release_date_out = []
        release_date = self.release_date[item]
        release_date = release_date.lower().split()
        for word in release_date:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            release_date_out.append(self.vocab[word])

        revenue_out = []
        revenue = self.revenue[item]
        revenue_out.append(revenue)

        spoken_languages_out = []
        spoken_languages = str(self.spoken_languages[item])
        spoken_languages = ast.literal_eval(spoken_languages)
        for i in range(len(spoken_languages)):
            spa = spoken_languages[i]['name'].lower().split()
            for word in spa:
                gs = word.find('.')
                if gs is not None:
                    word = word.replace('.', '')
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                spoken_languages_out.append(self.vocab[word])

        status_out = []
        status = self.status[item]
        status = status.lower().split()
        for word in status:
            gs = word.find('.')
            if gs is not None:
                word = word.replace('.', '')
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            status_out.append(self.vocab[word])

        tagline_out = []

        if self.tagline[item] != 'nan':
            tagline = 0
        else:
            tagline = self.tagline[item]

        # tagline = self.tagline[item] if self.tagline[item] != 'nan' else 0

        tagline = str(tagline)
        tagline_out.append(tagline)

        title_out = []
        tittle = self.tittle[item]
        tittle = tittle.lower().split()
        for word in tittle:
            gs = word.find('.')
            if gs is not None:
                word = word.replace('.', '')
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            title_out.append(self.vocab[word])

        video_out = []
        video = str(self.video[item])
        video = video.lower().split()
        for word in video:
            gs = word.find('.')
            if gs is not None:
                word = word.replace('.', '')
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            video_out.append(self.vocab[word])

        vote_average_out = []
        vote_average = self.vote_average[item]
        vote_average_out.append(vote_average)

        vote_count_out = []
        vote_count = self.vote_count[item]
        vote_count_out.append(vote_count)
        print(tagline_out)
        outputs = {
            'adult': np.array(adult_out,dtype=np.float64),
            'belongs_to_collection': np.array(belongs_to_collection_out,dtype=np.float64),
            'budget': np.array(int(self.budget[item]),dtype=np.float64),
            'original_language': np.array(o_language_out,dtype=np.float64),
            'original_title': np.array(original_title_out,dtype=np.float64),
            'overview': np.array(overview_out,dtype=np.float64),
            'popularity': np.array((float(self.popularity[item])),dtype=np.float64),
            'production_companies': np.array(production_companies_out,dtype=np.float64),
            'production_countries': np.array(production_countries_out,dtype=np.float64),
            'release_date': np.array(release_date_out,dtype=np.float64),
            'revenue': np.array(revenue_out,dtype=np.float64),
            'runtime': np.array(self.runtime[item],dtype=np.float64),
            'spoken_languages': np.array(spoken_languages_out,dtype=np.float64),
            'status': np.array(status_out,dtype=np.float64),
            'tagline': np.array(int(float(tagline_out[0])),dtype=np.float64),
            'title': np.array(title_out,dtype=np.float64),
            'video': np.array(video_out,dtype=np.float64),
            'vote_average': np.array(vote_average_out,dtype=np.float64),
            'vote_count': np.array(vote_count_out,dtype=np.float64),
        }
        targets = {
            'genres': np.array(genres_out,dtype=np.float64),
        }

        return outputs, targets

In [295]:
dsm = DataSetManual(data_csv)
out, tar = dsm.__getitem__(0)
print(out)

['0']
{'adult': array([0.]), 'belongs_to_collection': array([1., 2., 3.]), 'budget': array(30000000.), 'original_language': array([7.]), 'original_title': array([1., 2.]), 'overview': array([ 8.,  9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 11., 19.,
       20., 21., 22., 23., 24., 25., 26., 27., 28., 16., 29., 15., 11.,
       30., 31., 32., 33., 21., 34., 35., 36., 37., 21., 38., 31., 39.,
       40., 41., 24., 42., 43., 44., 45., 46., 47., 40., 48.]), 'popularity': array(21.946943), 'production_companies': array([49.,  4., 50.]), 'production_countries': array([49.,  4., 50.]), 'release_date': array([51.]), 'revenue': array([3.73554033e+08]), 'runtime': array(81.), 'spoken_languages': array([52.]), 'status': array([53.]), 'tagline': array(0.), 'title': array([1., 2.]), 'video': array([0.]), 'vote_average': array([7.7]), 'vote_count': array([5415.])}


In [296]:
print(tar)
print(dsm.translate(out['overview']))

{'genres': array([4., 5., 6.])}
led by woody andys toys live happily in his room until andys birthday brings buzz lightyear onto the scene afraid of losing his place in andys heart woody plots against buzz but when circumstances separate buzz and woody from their owner the duo eventually learns to put aside their differences


In [297]:
dataLd = DataLoader(
    dsm,
    batch_size=16,
    num_workers=4,
    pin_memory=True
)

In [301]:
class Net(nn.Module):
    def __init__(self,
                 num_embedding_1:int=50000,
                 embedding_dim_1:int=400,
                 num_embedding_2:int=40000,
                 embedding_dim_2:int=400,
                 num_embedding_3:int=50000,
                 embedding_dim_3:int=300,
                 num_embedding_4:int=50000,
                 embedding_dim_4:int=100,
                 lstm_layers_1:int=1,
                 lstm_hidden_num_1:int=15,
                 lstm_layers_2:int=1,
                 lstm_hidden_num_2:int=15,
                 lstm_layers_3:int=1,
                 lstm_hidden_num_3:int=15,
                 lstm_layers_4:int=1,
                 lstm_hidden_num_4:int=15,

                 output_size:int=15):
        super(Net, self).__init__()
        self.embedding_layer_1 = nn.Embedding(num_embeddings=num_embedding_1,embedding_dim=embedding_dim_1)
        self.embedding_layer_2 = nn.Embedding(num_embeddings=num_embedding_2,embedding_dim=embedding_dim_2)
        self.embedding_layer_3 = nn.Embedding(num_embeddings=num_embedding_3,embedding_dim=embedding_dim_3)
        self.embedding_layer_4 = nn.Embedding(num_embeddings=num_embedding_4,embedding_dim=embedding_dim_4)
        self.lstm_1 = nn.LSTM(num_layers=lstm_layers_1,hidden_size=lstm_hidden_num_1,input_size=embedding_dim_1)
        self.lstm_1 = nn.LSTM(num_layers=lstm_layers_2,hidden_size=lstm_hidden_num_2,input_size=embedding_dim_2)
        self.lstm_1 = nn.LSTM(num_layers=lstm_layers_3,hidden_size=lstm_hidden_num_3,input_size=embedding_dim_3)
        self.lstm_1 = nn.LSTM(num_layers=lstm_layers_4,hidden_size=lstm_hidden_num_4,input_size=embedding_dim_4)

        self.fc0_1 = nn.Linear(lstm_hidden_num_1,lstm_hidden_num_1*2)
        self.relu_0_1 = nn.ReLU()
        self.fc1_1 = nn.Linear(lstm_hidden_num_1*2,4)
        self.relu_1_1 = nn.ReLU()

        self.fc0_2 = nn.Linear(lstm_hidden_num_2,lstm_hidden_num_2*2)
        self.relu_0_2 = nn.ReLU()
        self.fc1_2 = nn.Linear(lstm_hidden_num_2*2,4)
        self.relu_1_2 = nn.ReLU()

        self.fc0_3 = nn.Linear(lstm_hidden_num_3,lstm_hidden_num_3*2)
        self.relu_0_3 = nn.ReLU()
        self.fc1_3 = nn.Linear(lstm_hidden_num_4*2,4)
        self.relu_1_3 = nn.ReLU()

        self.fc0_4 = nn.Linear(lstm_hidden_num_4,lstm_hidden_num_4*2)
        self.relu_0_4 = nn.ReLU()
        self.fc1_4 = nn.Linear(lstm_hidden_num_4*2,4)
        self.relu_1_4 = nn.ReLU()

        self.output_layer = nn.Linear(16,16)
        self.softmax = nn.Softmax()

In [None]:
def train(data_loader,epochs:int=50):

    for epoch in range(epochs):
        for x, y in data_loader:
            adult_x = torch.from_numpy(x['adult'])
            belongs_to_collection_x = torch.from_numpy(x['belongs_to_collection'])
            budget_x = torch.from_numpy(x['budget'])
            original_language_x = torch.from_numpy(x['original_language'])
            original_title_x = torch.from_numpy(x['original_title'])
            overview_x = torch.from_numpy(x['overview'])
            popularity_x = torch.from_numpy(x['popularity'])
            production_companies_x = torch.from_numpy(x['production_companies'])
            production_countries_x = torch.from_numpy(x['production_countries'])
            release_date_x = torch.from_numpy(x['release_date'])
            revenue_x = torch.from_numpy(x['revenue'])
            runtime_x = torch.from_numpy(x['runtime'])
            spoken_languages_x = torch.from_numpy(x['spoken_languages'])
            status_x = torch.from_numpy(x['status'])
            tagline_x = torch.from_numpy(x['tagline'])
            title_x = torch.from_numpy(x['title'])
            video_x = torch.from_numpy(x['video'])
            vote_average_x = torch.from_numpy(x['vote_average'])
            vote_count_x = torch.from_numpy(x['vote_count'])
            genres_y = torch.from_numpy(y['genres'])
