In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import json
from torch.utils.data import DataLoader, Dataset
import random
import time
import ast
import numpy as np
from tqdm import tqdm
import pandas as pd
from IPython.display import display, clear_output

In [5]:
data_csv = pd.read_csv('movies_metadata.csv', low_memory=False)
data_csv.head()

Unnamed: 0,adult,belongs_to_collection,budget,genres,homepage,id,imdb_id,original_language,original_title,overview,...,release_date,revenue,runtime,spoken_languages,status,tagline,title,video,vote_average,vote_count
0,False,"{'id': 10194, 'name': 'Toy Story Collection', ...",30000000,"[{'id': 16, 'name': 'Animation'}, {'id': 35, '...",http://toystory.disney.com/toy-story,862,tt0114709,en,Toy Story,"Led by Woody, Andy's toys live happily in his ...",...,1995-10-30,373554033.0,81.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,,Toy Story,False,7.7,5415.0
1,False,,65000000,"[{'id': 12, 'name': 'Adventure'}, {'id': 14, '...",,8844,tt0113497,en,Jumanji,When siblings Judy and Peter discover an encha...,...,1995-12-15,262797249.0,104.0,"[{'iso_639_1': 'en', 'name': 'English'}, {'iso...",Released,Roll the dice and unleash the excitement!,Jumanji,False,6.9,2413.0
2,False,"{'id': 119050, 'name': 'Grumpy Old Men Collect...",0,"[{'id': 10749, 'name': 'Romance'}, {'id': 35, ...",,15602,tt0113228,en,Grumpier Old Men,A family wedding reignites the ancient feud be...,...,1995-12-22,0.0,101.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Still Yelling. Still Fighting. Still Ready for...,Grumpier Old Men,False,6.5,92.0
3,False,,16000000,"[{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam...",,31357,tt0114885,en,Waiting to Exhale,"Cheated on, mistreated and stepped on, the wom...",...,1995-12-22,81452156.0,127.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Friends are the people who let you be yourself...,Waiting to Exhale,False,6.1,34.0
4,False,"{'id': 96871, 'name': 'Father of the Bride Col...",0,"[{'id': 35, 'name': 'Comedy'}]",,11862,tt0113041,en,Father of the Bride Part II,Just when George Banks has recovered from his ...,...,1995-02-10,76578911.0,106.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Just When His World Is Back To Normal... He's ...,Father of the Bride Part II,False,5.7,173.0


In [6]:
data_csv = data_csv.drop('imdb_id', axis=1)
data_csv = data_csv.drop('id', axis=1)
data_csv = data_csv.drop('poster_path', axis=1)
data_csv = data_csv.drop('homepage', axis=1)

In [7]:
data_csv.keys()

Index(['adult', 'belongs_to_collection', 'budget', 'genres',
       'original_language', 'original_title', 'overview', 'popularity',
       'production_companies', 'production_countries', 'release_date',
       'revenue', 'runtime', 'spoken_languages', 'status', 'tagline', 'title',
       'video', 'vote_average', 'vote_count'],
      dtype='object')

In [41]:
class DataSetManual(Dataset):
    def __init__(self, data, vocab: dict = None):
        super().__init__()
        if vocab is not None:
            self.vocab = vocab
        else:
            self.vocab = {}

        self.adult = data['adult']
        self.belongs_to_collection = data['belongs_to_collection']
        self.budget = data['budget']
        self.genres = data['genres']
        self.original_language = data['original_language']
        self.original_title = data['original_title']
        self.overview = data['overview']
        self.popularity = data['popularity']
        self.production_companies = data['production_companies']
        self.production_countries = data['production_countries']
        self.release_date = data['release_date']
        self.revenue = data['revenue']
        self.runtime = data['runtime']
        self.spoken_languages = data['spoken_languages']
        self.status = data['status']
        self.tagline = data['tagline']
        self.tittle = data['title']
        self.video = data['video']
        self.vote_average = data['vote_average']
        self.vote_count = data['vote_count']

    def __len__(self):
        return len(self.adult)

    def translate(self, *arg: [dict, int, list]) -> str:
        ...
        simulation_list = list(self.vocab)
        s_pr = []
        arg = list(arg)

        for i in range(len(arg)):

            if isinstance(arg[i], int):
                s_pr.append(simulation_list[arg[i]])
            if isinstance(arg[i], list) or isinstance(arg[i], np.ndarray) or isinstance(arg[i], tuple):

                for v in range(len(arg[i])):
                    s_pr.append(simulation_list[int(arg[i][v])])
        s_pr = str(s_pr)
        s_pr = s_pr.replace(']', '')
        s_pr = s_pr.replace('[', '')
        s_pr = s_pr.replace(',', '')
        s_pr = s_pr.replace("'", '')
        s_pr = s_pr.replace('"', '')
        return s_pr

    def __getitem__(self, item):
        #adult
        item = int(item)
        adult_out = []
        adult = self.adult[item].lower()
        if adult not in self.vocab:
            self.vocab[adult] = len(self.vocab)
        adult_out.append(self.vocab[adult])

        # belongs_to_collection
        belongs_to_collection_out = []
        if not isinstance(self.belongs_to_collection[item], float):

            sim = str(self.belongs_to_collection[item])
            sim = ast.literal_eval(sim)
            belongs_to_collection = sim['name'].lower().split()
        else:
            belongs_to_collection = ['nan']
        for word in belongs_to_collection:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            belongs_to_collection_out.append(self.vocab[word])

        genres_out = []
        genres_names = []
        genres_str = str(self.genres[item])
        genres_json = ast.literal_eval(genres_str)
        for i in range(len(genres_json)):
            genres_names.append(genres_json[i]['name'].lower())

        for word in genres_names:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            genres_out.append(self.vocab[word])

        o_language_out = []
        o_language = self.original_language[item]
        o_language = o_language.lower().split()
        for word in o_language:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            o_language_out.append(self.vocab[word])

        original_title_out = []
        original_title = self.original_title[item]
        original_title = original_title.lower().split()
        for word in original_title:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            original_title_out.append(self.vocab[word])

        overview_out = []
        overview = self.overview[item]
        overview = overview.lower().split()
        for word in overview:
            gs = word.find('.')
            if gs is not None:
                word = word.replace('.', '')
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            overview_out.append(self.vocab[word])

        production_companies_out = []
        production_companies = str(self.production_companies[item])
        production_companies = ast.literal_eval(production_companies)

        for i in range(len(production_companies)):
            spa = production_companies[i]['name'].lower().split()
            for word in spa:
                gs = word.find('.')
                if gs is not None:
                    word = word.replace('.', '')
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                production_companies_out.append(self.vocab[word])

        production_countries_out = []
        production_countries = str(self.production_countries[item])
        production_countries = ast.literal_eval(production_countries)
        for i in range(len(production_countries)):
            spa = production_companies[i]['name'].lower().split()

            for word in spa:
                gs = word.find('.')
                if gs is not None:
                    word = word.replace('.', '')
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                production_countries_out.append(self.vocab[word])

        release_date_out = []
        release_date = self.release_date[item]
        release_date = release_date.lower().split()
        for word in release_date:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            release_date_out.append(self.vocab[word])

        revenue_out = []
        revenue = self.revenue[item]
        revenue_out.append(revenue)

        spoken_languages_out = []
        spoken_languages = str(self.spoken_languages[item])
        spoken_languages = ast.literal_eval(spoken_languages)
        for i in range(len(spoken_languages)):
            spa = spoken_languages[i]['name'].lower().split()
            for word in spa:
                gs = word.find('.')
                if gs is not None:
                    word = word.replace('.', '')
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                spoken_languages_out.append(self.vocab[word])

        status_out = []
        status = self.status[item]
        status = status.lower().split()
        for word in status:
            gs = word.find('.')
            if gs is not None:
                word = word.replace('.', '')
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            status_out.append(self.vocab[word])

        tagline_out = []

        if self.tagline[item] != 'nan':
            tagline = 0
        else:
            tagline = self.tagline[item]

        # tagline = self.tagline[item] if self.tagline[item] != 'nan' else 0

        tagline = str(tagline)
        tagline_out.append(tagline)

        title_out = []
        tittle = self.tittle[item]
        tittle = tittle.lower().split()
        for word in tittle:
            gs = word.find('.')
            if gs is not None:
                word = word.replace('.', '')
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            title_out.append(self.vocab[word])

        video_out = []
        video = str(self.video[item])
        video = video.lower().split()
        for word in video:
            gs = word.find('.')
            if gs is not None:
                word = word.replace('.', '')
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            video_out.append(self.vocab[word])

        vote_average_out = []
        vote_average = self.vote_average[item]
        vote_average_out.append(vote_average)

        vote_count_out = []
        vote_count = self.vote_count[item]
        vote_count_out.append(vote_count)
        outputs = {
            'adult': np.array(adult_out, dtype=np.float64),
            'belongs_to_collection': np.array(belongs_to_collection_out, dtype=np.float64),
            'budget': np.array(int(self.budget[item]), dtype=np.float64),
            'original_language': np.array(o_language_out, dtype=np.float64),
            'original_title': np.array(original_title_out, dtype=np.float64),
            'overview': np.array(overview_out, dtype=np.float64),
            'popularity': np.array((float(self.popularity[item])), dtype=np.float64),
            'production_companies': np.array(production_companies_out, dtype=np.float64),
            'production_countries': np.array(production_countries_out, dtype=np.float64),
            'release_date': np.array(release_date_out, dtype=np.float64),
            'revenue': np.array(revenue_out, dtype=np.float64),
            'runtime': np.array(self.runtime[item], dtype=np.float64),
            'spoken_languages': np.array(spoken_languages_out, dtype=np.float64),
            'status': np.array(status_out, dtype=np.float64),
            'tagline': np.array(int(float(tagline_out[0])), dtype=np.float64),
            'title': np.array(title_out, dtype=np.float64),
            'video': np.array(video_out, dtype=np.float64),
            'vote_average': np.array(vote_average_out, dtype=np.float64),
            'vote_count': np.array(vote_count_out, dtype=np.float64),
        }
        targets = {
            'genres': np.array(genres_out, dtype=np.float64),
        }

        return outputs, targets

In [42]:
dsm = DataSetManual(data_csv)
out, tar = dsm.__getitem__(10)
print(out)

{'adult': array([0.]), 'belongs_to_collection': array([1.]), 'budget': array([62000000.]), 'original_language': array([5.]), 'original_title': array([6., 7., 8.]), 'overview': array([ 9., 10.,  8., 11., 12., 13., 14.,  6., 15., 16., 17., 18., 19.,
       20., 21., 22., 23., 24., 25., 26., 22., 27., 16., 28., 29., 30.,
       31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,
       25., 44., 45., 46., 47.]), 'popularity': array(6.318445), 'production_companies': array([48., 49., 50., 51., 52.]), 'production_countries': array([48., 49.]), 'release_date': array([53.]), 'revenue': array([1.07879496e+08]), 'runtime': array(106.), 'spoken_languages': array([54.]), 'status': array([55.]), 'tagline': array(0.), 'title': array([6., 7., 8.]), 'video': array([0.]), 'vote_average': array([6.5]), 'vote_count': array([199.])}


In [43]:
dataLd = DataLoader(
    dsm,
    batch_size=16,
    num_workers=4,
    pin_memory=True
)

In [158]:
class Net(nn.Module):

    def __init__(self,
                 num_embedding_1: int = 50000,
                 embedding_dim_1: int = 400,
                 num_embedding_2: int = 40000,
                 embedding_dim_2: int = 400,
                 num_embedding_3: int = 50000,
                 embedding_dim_3: int = 400,
                 num_embedding_4: int = 50000,
                 embedding_dim_4: int = 400,
                 num_embedding_5: int = 50000,
                 embedding_dim_5: int = 400,
                 num_embedding_6: int = 40000,
                 embedding_dim_6: int = 400,
                 num_embedding_7: int = 50000,
                 embedding_dim_7: int = 400,
                 num_embedding_8: int = 50000,
                 embedding_dim_8: int = 400,
                 num_embedding_9: int = 50000,
                 embedding_dim_9: int = 400,
                 num_embedding_10: int = 40000,
                 embedding_dim_10: int = 400,

                 lstm_layers_1: int = 1,
                 lstm_hidden_num_1: int = 15,
                 lstm_layers_2: int = 1,
                 lstm_hidden_num_2: int = 15,
                 lstm_layers_3: int = 1,
                 lstm_hidden_num_3: int = 15,
                 lstm_layers_4: int = 1,
                 lstm_hidden_num_4: int = 15,
                 lstm_layers_5: int = 1,
                 lstm_hidden_num_5: int = 15,
                 lstm_layers_6: int = 1,
                 lstm_hidden_num_6: int = 15,
                 lstm_layers_7: int = 1,
                 lstm_hidden_num_7: int = 15,
                 lstm_layers_8: int = 1,
                 lstm_hidden_num_8: int = 15,
                 lstm_layers_9: int = 1,
                 lstm_hidden_num_9: int = 15,
                 lstm_layers_10: int = 1,
                 lstm_hidden_num_10: int = 15,

                 output_size: int = 15):
        super(Net, self).__init__()

        self.num_embedding_1 = num_embedding_1
        self.embedding_dim_1 = embedding_dim_1
        self.num_embedding_2 = num_embedding_2
        self.embedding_dim_2 = embedding_dim_2
        self.num_embedding_3 = num_embedding_3
        self.embedding_dim_3 = embedding_dim_3
        self.num_embedding_4 = num_embedding_4
        self.embedding_dim_4 = embedding_dim_4
        self.num_embedding_5 = num_embedding_5
        self.embedding_dim_5 = embedding_dim_5
        self.num_embedding_6 = num_embedding_6
        self.embedding_dim_6 = embedding_dim_6
        self.num_embedding_7 = num_embedding_7
        self.embedding_dim_7 = embedding_dim_7
        self.num_embedding_8 = num_embedding_8
        self.embedding_dim_8 = embedding_dim_8
        self.num_embedding_9 = num_embedding_9
        self.embedding_dim_9 = embedding_dim_9
        self.num_embedding_10 = num_embedding_10
        self.embedding_dim_10 = embedding_dim_10

        self.lstm_hidden_num_1 = lstm_hidden_num_1
        self.lstm_hidden_num_2 = lstm_hidden_num_2
        self.lstm_hidden_num_3 = lstm_hidden_num_3
        self.lstm_hidden_num_4 = lstm_hidden_num_4
        self.lstm_hidden_num_5 = lstm_hidden_num_5
        self.lstm_hidden_num_6 = lstm_hidden_num_6
        self.lstm_hidden_num_7 = lstm_hidden_num_7
        self.lstm_hidden_num_8 = lstm_hidden_num_8
        self.lstm_hidden_num_9 = lstm_hidden_num_9
        self.lstm_hidden_num_10 = lstm_hidden_num_10

        self.lstm_layers_1 = lstm_layers_1
        self.lstm_layers_2 = lstm_layers_2
        self.lstm_layers_3 = lstm_layers_3
        self.lstm_layers_4 = lstm_layers_4
        self.lstm_layers_5 = lstm_layers_5
        self.lstm_layers_6 = lstm_layers_6
        self.lstm_layers_7 = lstm_layers_7
        self.lstm_layers_8 = lstm_layers_8
        self.lstm_layers_9 = lstm_layers_9
        self.lstm_layers_10 = lstm_layers_10

        self.fc_adult_0 = nn.Linear(1, 2)
        self.fc_adult_1 = nn.Linear(2, 1)

        self.fc_budget_0 = nn.Linear(1, 10)
        self.fc_budget_1 = nn.Linear(10, 1)

        self.fc_popularity_0 = nn.Linear(1, 10)
        self.fc_popularity_1 = nn.Linear(10, 1)

        self.fc_revenue_0 = nn.Linear(1, 30)
        self.fc_revenue_1 = nn.Linear(30, 1)

        self.fc_runtime_0 = nn.Linear(1, 10)
        self.fc_runtime_1 = nn.Linear(10, 1)

        self.fc_status_0 = nn.Linear(1, 2)
        self.fc_status_1 = nn.Linear(2, 1)

        self.fc_video_0 = nn.Linear(1, 2)
        self.fc_video_1 = nn.Linear(2, 1)

        self.fc_vote_average_0 = nn.Linear(1, 10)
        self.fc_vote_average_1 = nn.Linear(10, 1)

        self.fc_vote_count_0 = nn.Linear(1, 10)
        self.fc_vote_count_1 = nn.Linear(10, 1)

        self.embedding_layer_1 = nn.Embedding(num_embeddings=num_embedding_1, embedding_dim=embedding_dim_1)
        self.embedding_layer_2 = nn.Embedding(num_embeddings=num_embedding_2, embedding_dim=embedding_dim_2)
        self.embedding_layer_3 = nn.Embedding(num_embeddings=num_embedding_3, embedding_dim=embedding_dim_3)
        self.embedding_layer_4 = nn.Embedding(num_embeddings=num_embedding_4, embedding_dim=embedding_dim_4)
        self.embedding_layer_5 = nn.Embedding(num_embeddings=num_embedding_5, embedding_dim=embedding_dim_5)
        self.embedding_layer_6 = nn.Embedding(num_embeddings=num_embedding_6, embedding_dim=embedding_dim_6)
        self.embedding_layer_7 = nn.Embedding(num_embeddings=num_embedding_7, embedding_dim=embedding_dim_7)
        self.embedding_layer_8 = nn.Embedding(num_embeddings=num_embedding_8, embedding_dim=embedding_dim_8)
        self.embedding_layer_9 = nn.Embedding(num_embeddings=num_embedding_9, embedding_dim=embedding_dim_9)
        self.embedding_layer_10 = nn.Embedding(num_embeddings=num_embedding_10, embedding_dim=embedding_dim_10)

        self.lstm_1 = nn.LSTM(num_layers=lstm_layers_1, hidden_size=lstm_hidden_num_1, input_size=embedding_dim_1)
        self.lstm_2 = nn.LSTM(num_layers=lstm_layers_2, hidden_size=lstm_hidden_num_2, input_size=embedding_dim_2)
        self.lstm_3 = nn.LSTM(num_layers=lstm_layers_3, hidden_size=lstm_hidden_num_3, input_size=embedding_dim_3)
        self.lstm_4 = nn.LSTM(num_layers=lstm_layers_4, hidden_size=lstm_hidden_num_4, input_size=embedding_dim_4)
        self.lstm_5 = nn.LSTM(num_layers=lstm_layers_5, hidden_size=lstm_hidden_num_5, input_size=embedding_dim_5)
        self.lstm_6 = nn.LSTM(num_layers=lstm_layers_6, hidden_size=lstm_hidden_num_6, input_size=embedding_dim_6)
        self.lstm_7 = nn.LSTM(num_layers=lstm_layers_7, hidden_size=lstm_hidden_num_7, input_size=embedding_dim_7)
        self.lstm_8 = nn.LSTM(num_layers=lstm_layers_8, hidden_size=lstm_hidden_num_8, input_size=embedding_dim_8)
        self.lstm_9 = nn.LSTM(num_layers=lstm_layers_9, hidden_size=lstm_hidden_num_9, input_size=embedding_dim_9)
        self.lstm_10 = nn.LSTM(num_layers=lstm_layers_10, hidden_size=lstm_hidden_num_10, input_size=embedding_dim_10)

        self.fc0_1 = nn.Linear(lstm_hidden_num_1, lstm_hidden_num_1 * 2)
        self.relu_0_1 = nn.ReLU()
        self.fc1_1 = nn.Linear(lstm_hidden_num_1 * 2, 4)
        self.relu_1_1 = nn.ReLU()

        self.fc0_2 = nn.Linear(lstm_hidden_num_2, lstm_hidden_num_2 * 2)
        self.relu_0_2 = nn.ReLU()
        self.fc1_2 = nn.Linear(lstm_hidden_num_2 * 2, 4)
        self.relu_1_2 = nn.ReLU()

        self.fc0_3 = nn.Linear(lstm_hidden_num_3, lstm_hidden_num_3 * 2)
        self.relu_0_3 = nn.ReLU()
        self.fc1_3 = nn.Linear(lstm_hidden_num_4 * 2, 4)
        self.relu_1_3 = nn.ReLU()

        self.fc0_4 = nn.Linear(lstm_hidden_num_4, lstm_hidden_num_4 * 2)
        self.relu_0_4 = nn.ReLU()
        self.fc1_4 = nn.Linear(lstm_hidden_num_4 * 2, 4)
        self.relu_1_4 = nn.ReLU()

        self.fc0_5 = nn.Linear(lstm_hidden_num_5, lstm_hidden_num_5 * 2)
        self.relu_0_5 = nn.ReLU()
        self.fc1_5 = nn.Linear(lstm_hidden_num_5 * 2, 4)
        self.relu_1_5 = nn.ReLU()

        self.fc0_6 = nn.Linear(lstm_hidden_num_6, lstm_hidden_num_6 * 2)
        self.relu_0_6 = nn.ReLU()
        self.fc1_6 = nn.Linear(lstm_hidden_num_6 * 2, 4)
        self.relu_1_6 = nn.ReLU()

        self.fc0_7 = nn.Linear(lstm_hidden_num_7, lstm_hidden_num_7 * 2)
        self.relu_0_7 = nn.ReLU()
        self.fc1_7 = nn.Linear(lstm_hidden_num_7 * 2, 4)
        self.relu_1_7 = nn.ReLU()

        self.fc0_8 = nn.Linear(lstm_hidden_num_8, lstm_hidden_num_8 * 2)
        self.relu_0_8 = nn.ReLU()
        self.fc1_8 = nn.Linear(lstm_hidden_num_8 * 2, 4)
        self.relu_1_8 = nn.ReLU()

        self.fc0_9 = nn.Linear(lstm_hidden_num_9, lstm_hidden_num_9 * 2)
        self.relu_0_9 = nn.ReLU()
        self.fc1_9 = nn.Linear(lstm_hidden_num_9 * 2, 4)
        self.relu_1_9 = nn.ReLU()

        self.fc0_10 = nn.Linear(lstm_hidden_num_10, lstm_hidden_num_10 * 2)
        self.relu_0_10 = nn.ReLU()
        self.fc1_10 = nn.Linear(lstm_hidden_num_10 * 2, 4)
        self.relu_1_10 = nn.ReLU()

        self.output_layer = nn.Linear(16, 16)
        self.softmax = nn.Softmax()

        print('\033[38;2;255;0;0initialization Done\033[38;2;255;255;255m')

    def forward(self,
                adult: torch.Tensor,
                belongs_to_collection: torch.Tensor,
                budget: torch.Tensor,
                original_language: torch.Tensor,
                original_title: torch.Tensor,
                overview: torch.Tensor,
                popularity: torch.Tensor,
                production_companies: torch.Tensor,
                production_countries: torch.Tensor,
                release_date: torch.Tensor,
                revenue: torch.Tensor,
                runtime: torch.Tensor,
                spoken_languages: torch.Tensor,
                status: torch.Tensor,
                tagline: torch.Tensor,
                title: torch.Tensor,
                video: torch.Tensor,
                vote_average: torch.Tensor,
                vote_count: torch.Tensor,
                genres: torch.Tensor,
                ) -> torch.Tensor:
        ...

        adult_out = f.relu(self.fc_adult_1(f.relu(self.fc_adult_0(adult))))
        budget_out = f.relu(self.fc_budget_1(f.relu(self.fc_budget_0(budget))))
        popularity_out = f.relu(self.fc_popularity_1(f.relu(self.fc_popularity_0(popularity))))
        revenue_out = f.relu(self.fc_revenue_1(f.relu(self.fc_revenue_0(revenue))))
        runtime_out = f.relu(self.fc_runtime_1(f.relu(self.fc_runtime_0(runtime))))
        status_out = f.relu(self.fc_status_1(f.relu(self.fc_status_0(status))))
        video_out = f.relu(self.fc_video_1(f.relu(self.fc_video_0(video))))
        vote_average_out = f.relu(self.fc_vote_average_1(f.relu(self.fc_vote_average_0(vote_average))))
        vote_count_out = f.relu(self.fc_vote_count_1(f.relu(self.fc_vote_count_0(vote_count))))

        non_text = torch.tensor(
            [adult_out[0], budget_out[0], popularity_out[0], revenue_out[0], runtime_out[0], status_out[0],
             video_out[0], vote_average_out[0], vote_count_out[0]])
        print('non_text', non_text.shape)

        belongs_to_collection_x_0 = self.embedding_layer_1(belongs_to_collection)  #Done

        original_language_x_0 = self.embedding_layer_2(original_language)  #Done

        original_title_x_0 = self.embedding_layer_3(original_title)  #Done

        overview_x_0 = self.embedding_layer_4(overview)  #Done

        production_companies_x_0 = self.embedding_layer_5(production_companies)  #Done

        production_countries_x_0 = self.embedding_layer_6(production_countries)  #Done

        release_date_x_0 = self.embedding_layer_7(release_date)  #Done

        spoken_languages_x_0 = self.embedding_layer_8(spoken_languages)  #Done

        tagline_x_0 = self.embedding_layer_9(tagline)  # Done

        title_x_0 = self.embedding_layer_10(title)  # Done

        h0_1 = torch.zeros(self.lstm_layers_1, belongs_to_collection_x_0.size()[0], self.lstm_hidden_num_1)
        c0_1 = torch.zeros(self.lstm_layers_1, belongs_to_collection_x_0.size()[0], self.lstm_hidden_num_1)

        h0_2 = torch.zeros(self.lstm_layers_2, original_language_x_0.size()[0], self.lstm_hidden_num_2)
        c0_2 = torch.zeros(self.lstm_layers_2, original_language_x_0.size()[0], self.lstm_hidden_num_2)

        h0_3 = torch.zeros(self.lstm_layers_3, original_title_x_0.size()[0], self.lstm_hidden_num_3)
        c0_3 = torch.zeros(self.lstm_layers_3, original_title_x_0.size()[0], self.lstm_hidden_num_3)

        h0_4 = torch.zeros(self.lstm_layers_4, overview_x_0.size()[0], self.lstm_hidden_num_4)
        c0_4 = torch.zeros(self.lstm_layers_4, overview_x_0.size()[0], self.lstm_hidden_num_4)

        h0_5 = torch.zeros(self.lstm_layers_5, production_companies_x_0.size()[0], self.lstm_hidden_num_5)
        c0_5 = torch.zeros(self.lstm_layers_5, production_companies_x_0.size()[0], self.lstm_hidden_num_5)

        h0_6 = torch.zeros(self.lstm_layers_6, production_countries_x_0.size()[0], self.lstm_hidden_num_6)
        c0_6 = torch.zeros(self.lstm_layers_6, production_countries_x_0.size()[0], self.lstm_hidden_num_6)

        h0_7 = torch.zeros(self.lstm_layers_7, release_date_x_0.size()[0], self.lstm_hidden_num_7)
        c0_7 = torch.zeros(self.lstm_layers_7, release_date_x_0.size()[0], self.lstm_hidden_num_7)

        h0_8 = torch.zeros(self.lstm_layers_8, spoken_languages_x_0.size()[0], self.lstm_hidden_num_8)
        c0_8 = torch.zeros(self.lstm_layers_8, spoken_languages_x_0.size()[0], self.lstm_hidden_num_8)

        h0_9 = torch.zeros(self.lstm_layers_9, 1, self.lstm_hidden_num_9) if len(
            tagline_x_0.shape) == 1 else torch.zeros(self.lstm_layers_9, tagline_x_0.size()[0], self.lstm_hidden_num_9)
        c0_9 = torch.zeros(self.lstm_layers_9, 1, self.lstm_hidden_num_9) if len(
            tagline_x_0.shape) == 1 else torch.zeros(self.lstm_layers_9, tagline_x_0.size()[0], self.lstm_hidden_num_9)

        h0_10 = torch.zeros(self.lstm_layers_10, title_x_0.size()[0], self.lstm_hidden_num_10)
        c0_10 = torch.zeros(self.lstm_layers_10, title_x_0.size()[0], self.lstm_hidden_num_10)

        # h0_1 = torch.zeros(self.lstm_layers_1, 1, self.lstm_hidden_num_1)
        # c0_1 = torch.zeros(self.lstm_layers_1, 1, self.lstm_hidden_num_1)
        #
        # h0_2 = torch.zeros(self.lstm_layers_2, original_language_x_0.size()[0], self.lstm_hidden_num_2)
        # c0_2 = torch.zeros(self.lstm_layers_2, original_language_x_0.size()[0], self.lstm_hidden_num_2)
        #
        # h0_3 = torch.zeros(self.lstm_layers_3, original_title_x_0.size()[0], self.lstm_hidden_num_3)
        # c0_3 = torch.zeros(self.lstm_layers_3, original_title_x_0.size()[0], self.lstm_hidden_num_3)
        #
        # h0_4 = torch.zeros(self.lstm_layers_4, overview_x_0.size()[0], self.lstm_hidden_num_4)
        # c0_4 = torch.zeros(self.lstm_layers_4, overview_x_0.size()[0], self.lstm_hidden_num_4)
        #
        # h0_5 = torch.zeros(self.lstm_layers_5, production_companies_x_0.size()[0], self.lstm_hidden_num_5)
        # c0_5 = torch.zeros(self.lstm_layers_5, production_companies_x_0.size()[0], self.lstm_hidden_num_5)
        #
        # h0_6 = torch.zeros(self.lstm_layers_6, production_countries_x_0.size()[0], self.lstm_hidden_num_6)
        # c0_6 = torch.zeros(self.lstm_layers_6, production_countries_x_0.size()[0], self.lstm_hidden_num_6)
        #
        # h0_7 = torch.zeros(self.lstm_layers_7, release_date_x_0.size()[0], self.lstm_hidden_num_7)
        # c0_7 = torch.zeros(self.lstm_layers_7, release_date_x_0.size()[0], self.lstm_hidden_num_7)
        #
        # h0_8 = torch.zeros(self.lstm_layers_8, spoken_languages_x_0.size()[0], self.lstm_hidden_num_8)
        # c0_8 = torch.zeros(self.lstm_layers_8, spoken_languages_x_0.size()[0], self.lstm_hidden_num_8)
        #
        # h0_9 = torch.zeros(self.lstm_layers_9, tagline_x_0.size()[0], self.lstm_hidden_num_9)
        # c0_9 = torch.zeros(self.lstm_layers_9, tagline_x_0.size()[0], self.lstm_hidden_num_9)
        #
        # h0_10 = torch.zeros(self.lstm_layers_10, title_x_0.size()[0], self.lstm_hidden_num_10)
        # c0_10 = torch.zeros(self.lstm_layers_10, title_x_0.size()[0], self.lstm_hidden_num_10)

        # OverflowError
        belongs_to_collection_x_0 = belongs_to_collection_x_0.view(1, belongs_to_collection_x_0.size()[0],
                                                                   belongs_to_collection_x_0.size()[1])
        belongs_to_collection_x_lstm, _ = self.lstm_1(belongs_to_collection_x_0, (h0_1, c0_1))

        belongs_to_collection_out = self.relu_1_1(self.fc1_1(self.relu_0_1(self.fc0_1(belongs_to_collection_x_lstm))))
        print('belongs_to_collection_out', belongs_to_collection_out.shape)
        # OverflowError

        original_language_x_0 = original_language_x_0.view(1, original_language_x_0.size()[0], self.embedding_dim_2)
        original_language_x_lstm, _ = self.lstm_2(original_language_x_0, (h0_2, c0_2))

        original_language_out = self.relu_1_2(self.fc1_2(self.relu_0_2(self.fc0_2(original_language_x_lstm))))
        print('original_language_out', original_language_out.shape)
        # OverflowError

        original_title_x_0 = original_title_x_0.view(1, original_title_x_0.size()[0], self.embedding_dim_3)
        original_title_x_lstm, _ = self.lstm_3(original_title_x_0, (h0_3, c0_3))

        original_title_out = self.relu_1_3(self.fc1_3(self.relu_0_3(self.fc0_3(original_title_x_lstm))))
        print('original_title_out', original_title_out.shape)
        # OverflowError
        overview_x_0 = overview_x_0.view(1, overview_x_0.size()[0], self.embedding_dim_4)
        overview_x_lstm, _ = self.lstm_4(overview_x_0, (h0_4, c0_4))

        overview_out = self.relu_1_4(self.fc1_4(self.relu_0_4(self.fc0_4(overview_x_lstm))))
        print('overview_out', overview_out.shape)
        # OverflowError
        production_companies_x_0 = production_companies_x_0.view(1, production_companies_x_0.size()[0],
                                                                 self.embedding_dim_5)
        production_companies_x_lstm, _ = self.lstm_5(production_companies_x_0, (h0_5, c0_5))

        production_companies_out = self.relu_1_5(self.fc1_5(self.relu_0_5(self.fc0_5(production_companies_x_lstm))))
        print('production_companies_out', production_companies_out.shape)
        # OverflowError
        production_countries_x_0 = production_countries_x_0.view(1, production_countries_x_0.size()[0],
                                                                 self.embedding_dim_6)
        production_countries_x_lstm, _ = self.lstm_6(production_countries_x_0, (h0_6, c0_6))

        production_countries_out = self.relu_1_6(self.fc1_6(self.relu_0_6(self.fc0_6(production_countries_x_lstm))))
        print('production_countries_out', production_countries_out.shape)
        # OverflowError
        release_date_x_0 = release_date_x_0.view(1, release_date_x_0.size()[0], self.embedding_dim_7)
        release_date_x_lstm, _ = self.lstm_7(release_date_x_0, (h0_7, c0_7))

        release_date_out = self.relu_1_7(self.fc1_7(self.relu_0_7(self.fc0_7(release_date_x_lstm))))
        print('release_date_out', release_date_out.shape)
        # OverflowError
        spoken_languages_x_0 = spoken_languages_x_0.view(1, spoken_languages_x_0.size()[0], self.embedding_dim_8)
        spoken_languages_x_lstm, _ = self.lstm_8(spoken_languages_x_0, (h0_8, c0_8))

        spoken_languages_out = self.relu_1_8(self.fc1_8(self.relu_0_8(self.fc0_8(spoken_languages_x_lstm))))
        print('spoken_languages_out', spoken_languages_out.shape)
        # OverflowError

        tagline_x_0 = tagline_x_0.view(1, 1, self.embedding_dim_9) if len(tagline_x_0.shape) == 1 else tagline_x_0.view(
            1, tagline_x_0.size()[0], self.embedding_dim_9)

        tagline_x_lstm, _ = self.lstm_9(tagline_x_0, (h0_9, c0_9))

        tagline_out = self.relu_1_9(self.fc1_9(self.relu_0_9(self.fc0_9(tagline_x_lstm))))
        print('tagline_out', tagline_out.shape)
        # OverflowError
        title_x_0 = title_x_0.view(1, title_x_0.size()[0], self.embedding_dim_10)
        title_x_lstm, _ = self.lstm_10(title_x_0, (h0_10, c0_10))

        title_out = self.relu_1_10(self.fc1_10(self.relu_0_10(self.fc0_10(title_x_lstm))))
        print('title_out', title_out.shape)
        x_1 = torch.cat((belongs_to_collection_out,
                         original_language_out,
                         original_title_out,
                         overview_out,
                         production_companies_out,
                         production_countries_out,
                         release_date_out,
                         spoken_languages_out,
                         tagline_out,
                         title_out), dim=-3)
        return title_out



In [159]:
def train(data_loader, epochs: int = 50):
    network = Net()
    print('Initials Function')
    for epoch in range(epochs):
        print(f'Running on Epoch {epoch + 1} / {epochs}', end='\r')
        for x, y in data_loader.dataset:
            adult_x = torch.from_numpy(x['adult']).type(torch.FloatTensor)
            belongs_to_collection_x = torch.from_numpy(x['belongs_to_collection']).type(torch.IntTensor)

            budget_x = torch.from_numpy(x['budget'].reshape(1, 1)).type(torch.FloatTensor)

            original_language_x = torch.from_numpy(x['original_language']).type(torch.IntTensor)

            original_title_x = torch.from_numpy(x['original_title']).type(torch.IntTensor)

            overview_x = torch.from_numpy(x['overview']).type(torch.IntTensor)

            popularity_x = torch.from_numpy(x['popularity'].reshape(1, 1)).type(torch.FloatTensor)

            production_companies_x = torch.from_numpy(x['production_companies']).type(torch.IntTensor)
            production_countries_x = torch.from_numpy(x['production_countries']).type(torch.IntTensor)
            release_date_x = torch.from_numpy(x['release_date']).type(torch.IntTensor)
            revenue_x = torch.from_numpy(x['revenue']).type(torch.FloatTensor)
            runtime_x = torch.from_numpy(x['runtime'].reshape(1, 1)).type(torch.FloatTensor)
            spoken_languages_x = torch.from_numpy(x['spoken_languages']).type(torch.IntTensor)
            status_x = torch.from_numpy(x['status']).type(torch.FloatTensor)
            tagline_x = torch.from_numpy(x['tagline']).type(torch.IntTensor)
            title_x = torch.from_numpy(x['title']).type(torch.IntTensor)
            video_x = torch.from_numpy(x['video']).type(torch.FloatTensor)
            vote_average_x = torch.from_numpy(x['vote_average']).type(torch.FloatTensor)
            vote_count_x = torch.from_numpy(x['vote_count']).type(torch.FloatTensor)
            genres_y = torch.from_numpy(y['genres']).type(torch.IntTensor)

            v = network.forward(adult=adult_x,
                                belongs_to_collection=belongs_to_collection_x,
                                budget=budget_x,
                                original_language=original_language_x,
                                original_title=original_title_x,
                                overview=overview_x,
                                popularity=popularity_x,
                                production_companies=production_companies_x,
                                production_countries=production_countries_x,
                                release_date=release_date_x,
                                runtime=runtime_x,
                                revenue=revenue_x,
                                spoken_languages=spoken_languages_x,
                                status=status_x,
                                tagline=tagline_x,
                                title=title_x,
                                video=video_x,
                                vote_average=vote_average_x,
                                vote_count=vote_count_x,
                                genres=genres_y)





In [160]:
train(data_loader=dataLd)

[38;2;255;0;0initialization Done[38;2;255;255;255m
Initials Function
non_text torch.Size([9])
belongs_to_collection_out torch.Size([1, 3, 4])
original_language_out torch.Size([1, 1, 4])
original_title_out torch.Size([1, 2, 4])
overview_out torch.Size([1, 50, 4])
production_companies_out torch.Size([1, 3, 4])
production_countries_out torch.Size([1, 3, 4])
release_date_out torch.Size([1, 1, 4])
spoken_languages_out torch.Size([1, 1, 4])
tagline_out torch.Size([1, 1, 4])
title_out torch.Size([1, 2, 4])


RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 3 but got size 1 for tensor number 1 in the list.