In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as f
import json
import os
import yaml
from torch.utils.data import DataLoader, Dataset
import ast
import numpy as np
from tqdm import tqdm
import pandas as pd

In [2]:
data_csv = pd.read_csv('movies_metadata.csv', low_memory=False)
data_csv.head()

Unnamed: 0,adult,belongs_to_collection,budget,genres,homepage,id,imdb_id,original_language,original_title,overview,...,release_date,revenue,runtime,spoken_languages,status,tagline,title,video,vote_average,vote_count
0,False,"{'id': 10194, 'name': 'Toy Story Collection', ...",30000000,"[{'id': 16, 'name': 'Animation'}, {'id': 35, '...",http://toystory.disney.com/toy-story,862,tt0114709,en,Toy Story,"Led by Woody, Andy's toys live happily in his ...",...,1995-10-30,373554033.0,81.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,,Toy Story,False,7.7,5415.0
1,False,,65000000,"[{'id': 12, 'name': 'Adventure'}, {'id': 14, '...",,8844,tt0113497,en,Jumanji,When siblings Judy and Peter discover an encha...,...,1995-12-15,262797249.0,104.0,"[{'iso_639_1': 'en', 'name': 'English'}, {'iso...",Released,Roll the dice and unleash the excitement!,Jumanji,False,6.9,2413.0
2,False,"{'id': 119050, 'name': 'Grumpy Old Men Collect...",0,"[{'id': 10749, 'name': 'Romance'}, {'id': 35, ...",,15602,tt0113228,en,Grumpier Old Men,A family wedding reignites the ancient feud be...,...,1995-12-22,0.0,101.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Still Yelling. Still Fighting. Still Ready for...,Grumpier Old Men,False,6.5,92.0
3,False,,16000000,"[{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam...",,31357,tt0114885,en,Waiting to Exhale,"Cheated on, mistreated and stepped on, the wom...",...,1995-12-22,81452156.0,127.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Friends are the people who let you be yourself...,Waiting to Exhale,False,6.1,34.0
4,False,"{'id': 96871, 'name': 'Father of the Bride Col...",0,"[{'id': 35, 'name': 'Comedy'}]",,11862,tt0113041,en,Father of the Bride Part II,Just when George Banks has recovered from his ...,...,1995-02-10,76578911.0,106.0,"[{'iso_639_1': 'en', 'name': 'English'}]",Released,Just When His World Is Back To Normal... He's ...,Father of the Bride Part II,False,5.7,173.0


In [3]:
data_csv = data_csv.drop('imdb_id', axis=1)
data_csv = data_csv.drop('id', axis=1)
data_csv = data_csv.drop('poster_path', axis=1)
data_csv = data_csv.drop('homepage', axis=1)

In [4]:
data_csv.keys()

Index(['adult', 'belongs_to_collection', 'budget', 'genres',
       'original_language', 'original_title', 'overview', 'popularity',
       'production_companies', 'production_countries', 'release_date',
       'revenue', 'runtime', 'spoken_languages', 'status', 'tagline', 'title',
       'video', 'vote_average', 'vote_count'],
      dtype='object')

In [27]:
class DataSetManual(Dataset):
    def __init__(self, data, vocab: dict = None):
        super().__init__()

        if vocab is not None:
            self.vocab = vocab
        else:
            self.vocab = {}

        self.categories = {}
        self.adult = data['adult']
        self.belongs_to_collection = data['belongs_to_collection']
        self.budget = data['budget']
        self.genres = data['genres']
        self.original_language = data['original_language']
        self.original_title = data['original_title']
        self.overview = data['overview']
        self.popularity = data['popularity']
        self.production_companies = data['production_companies']
        self.production_countries = data['production_countries']
        self.release_date = data['release_date']
        self.revenue = data['revenue']
        self.runtime = data['runtime']
        self.spoken_languages = data['spoken_languages']
        self.status = data['status']
        self.tagline = data['tagline']
        self.tittle = data['title']
        self.video = data['video']
        self.vote_average = data['vote_average']
        self.vote_count = data['vote_count']

    def __len__(self):
        return len(self.adult) - 1

    def __save_vocab__(self, name:str='vocab.yaml',path:'str'=None):
        pbar = tqdm(range(self.__len__()))

        rows = ['word' , 'ids']
        if path is not None:
            os.chdir(path)
            pbar.set_description(f'path To {path}')
        else:
            path = os.getcwd()
        for i in pbar:
            _, _ = self.__getitem__(i)

        if name.endswith('.yaml'):

            with open(f'{path}{name}', 'w') as writer:
                print(f'writing yaml file {name}')
                print(f'at {path}{name}')
                yaml.dump(self.vocab, writer)
                print(f'Done')
        if name.endswith('.json'):

            pbar = tqdm(range(len(self.vocab)))
            with open(f'{path}{name}','w') as write:
                pbar.set_description(f'writing json file {name}')
                pbar.set_description(f'at {path}{name}')
                list_vocab = list(self.vocab)
                list_filler = []
                for i in pbar:
                    key = list_vocab[i]
                    val = self.vocab[key]
                    list_filler.append({'key':key,'val':val})
                jsoned = dict(list_filler)
                json.dump(jsoned,write)
                pbar.set_description(f'Done')

    def translate(self, *arg: [dict, int, list]) -> str:
        ...
        simulation_list = list(self.vocab)
        s_pr = []
        arg = list(arg)

        for i in range(len(arg)):

            if isinstance(arg[i], int):
                s_pr.append(simulation_list[arg[i]])
            if isinstance(arg[i], list) or isinstance(arg[i], np.ndarray) or isinstance(arg[i], tuple):

                for v in range(len(arg[i])):
                    s_pr.append(simulation_list[int(arg[i][v])])
        s_pr = str(s_pr)
        s_pr = s_pr.replace(']', '')
        s_pr = s_pr.replace('[', '')
        s_pr = s_pr.replace(',', '')
        s_pr = s_pr.replace("'", '')
        s_pr = s_pr.replace('"', '')
        return s_pr

    def __getitem__(self, item):
        #adult
        item = int(item)
        if 'nan' not in self.vocab:
            self.vocab['nan'] = len(self.vocab)
        b_point = []
        if '^' not in self.vocab:
            self.vocab['^'] = len(self.vocab)
        b_point.append(self.vocab['^'])

        adult_out = []
        adult = self.adult[item].lower()
        if adult not in self.vocab:
            self.vocab[adult] = len(self.vocab)
        adult_out.append(self.vocab[adult])

        popularity_out = 0.
        if isinstance(self.popularity[item], float) or isinstance(self.popularity[item], int):
            popularity_out = float(self.popularity[item])
        # belongs_to_collection
        belongs_to_collection_out = []
        if not isinstance(self.belongs_to_collection[item], float):

            sim = str(self.belongs_to_collection[item])
            sim = ast.literal_eval(sim)
            if not isinstance(sim, float):
                belongs_to_collection = sim['name'].lower().split()
            else:
                belongs_to_collection = 'nan'
        else:
            belongs_to_collection = ['nan']

        for word in belongs_to_collection:
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            belongs_to_collection_out.append(self.vocab[word])
        belongs_to_collection_out.append(b_point[0])
        genres_out = []
        genres_names = []
        genres_str = str(self.genres[item])
        genres_json = ast.literal_eval(genres_str)
        if self.genres[item] != 'nan':
            for i in range(len(genres_json)):
                genres_names.append(genres_json[i]['name'].lower())

            for word in genres_names:
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                if word not in self.categories:
                    self.categories[word] = len(self.categories)
                genres_out.append(self.categories[word])
        else:
            genres_out = self.vocab['nan']
        o_language_out = []
        if self.original_language[item] != '[]' and not isinstance(self.original_language[item], float):
            o_language = self.original_language[item]
            o_language = o_language.lower().split()
            if self.original_language[item] != 'nan':
                for word in o_language:
                    if word not in self.vocab:
                        self.vocab[word] = len(self.vocab)
                    o_language_out.append(self.vocab[word])
        else:
            o_language_out.append(self.vocab['nan'])
        o_language_out.append(b_point[0])
        budget_out = 0
        if isinstance(self.budget[item], int) or isinstance(self.budget[item], float):
            budget_out = int(self.budget[item])
        else:
            budget_out = 0
        original_title_out = []
        original_title = self.original_title[item]
        original_title = original_title.lower().split()
        if self.original_title[item] != 'nan':
            for word in original_title:
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                original_title_out.append(self.vocab[word])
        original_title_out.append(b_point[0])
        overview_out = []
        overview = self.overview[item]
        if self.overview[item] != 'nan' and not isinstance(overview, float):

            overview = overview.lower().split()
            for word in overview:
                gs = word.find('.')
                if gs is not None:
                    word = word.replace('.', '')
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                overview_out.append(self.vocab[word])

        else:
            overview_out.append(self.vocab['nan'])


        production_companies_out = []
        if not isinstance(self.production_companies[item], float):
            if self.production_companies[item] != '[]' and len(self.production_companies[item]) != 0:

                production_companies = str(self.production_companies[item])
                production_companies = ast.literal_eval(production_companies)
                if not isinstance(production_companies, bool):
                    if len(production_companies) > 1:
                        for i in range(len(production_companies)):
                            spa = production_companies[i]['name'].lower().split()
                            for word in spa:
                                gs = word.find('.')
                                if gs is not None:
                                    word = word.replace('.', '')
                                if word not in self.vocab:
                                    self.vocab[word] = len(self.vocab)
                                production_companies_out.append(self.vocab[word])
                    else:
                        spa = production_companies[0]['name'].lower().split()
                        for word in spa:
                            gs = word.find('.')
                            if gs is not None:
                                word = word.replace('.', '')
                            if word not in self.vocab:
                                self.vocab[word] = len(self.vocab)
                            production_companies_out.append(self.vocab[word])
                else:
                    production_companies_out.append(self.vocab['nan'])
            else:
                production_companies_out.append(self.vocab['nan'])
        else:
            production_companies_out.append(self.vocab['nan'])
        production_companies_out.append(b_point[0])
        production_countries_out = []
        if self.production_countries[item] != 'nan' and not isinstance(self.production_countries[item], float):
            production_countries = str(self.production_countries[item])

            production_countries = ast.literal_eval(production_countries)
            if not isinstance(production_countries, float):
                if len(production_countries) != 0:
                    for i in range(len(production_countries)):
                        spa = production_countries[i]['name'].lower().split()

                        for word in spa:
                            gs = word.find('.')
                            if gs is not None:
                                word = word.replace('.', '')
                            if word not in self.vocab:
                                self.vocab[word] = len(self.vocab)
                            production_countries_out.append(self.vocab[word])
            else:
                production_countries_out.append(self.vocab['nan'])
        else:
            production_countries_out.append(self.vocab['nan'])
        production_countries_out.append(b_point[0])
        release_date_out = []
        if self.release_date[item] != 'nan' and not isinstance(self.release_date[item], float):
            release_date = self.release_date[item]
            release_date = release_date.lower()
            release_date = release_date.replace('-', '')
            release_date_out.append(int(release_date))
        else:
            release_date_out = 00000000

        revenue_out = []
        revenue = self.revenue[item]
        revenue_out.append(revenue)

        spoken_languages_out = []
        if self.spoken_languages[item] != 'nan' and not isinstance(self.spoken_languages[item], float):
            spoken_languages = str(self.spoken_languages[item])
            spoken_languages = ast.literal_eval(spoken_languages)
            for i in range(len(spoken_languages)):
                spa = spoken_languages[i]['name'].lower().split()
                for word in spa:
                    gs = word.find('.')
                    if gs is not None:
                        word = word.replace('.', '')
                    if word not in self.vocab:
                        self.vocab[word] = len(self.vocab)
                    spoken_languages_out.append(self.vocab[word])
        else:
            spoken_languages_out.append(self.vocab['nan'])
        spoken_languages_out.append(b_point[0])
        status_out = []
        if self.status[item] != 'nan' and not isinstance(self.status[item], float):
            status = self.status[item]
            status = status.lower().split()
            for word in status:
                gs = word.find('.')
                if gs is not None:
                    word = word.replace('.', '')
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                status_out.append(self.vocab[word])
        else:
            status_out = self.vocab['nan']
        tagline_out = []

        if self.tagline[item] != 'nan':
            tagline = 0
        else:
            tagline = self.tagline[item]

        # tagline = self.tagline[item] if self.tagline[item] != 'nan' else 0

        tagline = str(tagline)
        tagline_out.append(tagline)

        title_out = []
        tittle = self.tittle[item]
        if self.tittle[item] != 'nan' and not isinstance(self.tittle[item], float):
            tittle = tittle.lower().split()
            for word in tittle:
                gs = word.find('.')
                if gs is not None:
                    word = word.replace('.', '')
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)
                title_out.append(self.vocab[word])
        else:
            title_out.append(self.vocab['nan'])
        title_out.append(b_point[0])
        video_out = []
        video = str(self.video[item])
        video = video.lower().split()
        for word in video:
            gs = word.find('.')
            if gs is not None:
                word = word.replace('.', '')
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)
            video_out.append(self.vocab[word])

        vote_average_out = []
        vote_average = self.vote_average[item]
        vote_average_out.append(vote_average)

        vote_count_out = []
        vote_count = self.vote_count[item]
        vote_count_out.append(vote_count)
        if isinstance(status_out, list):
            status_out = status_out[0]
        elif isinstance(status_out, int):
            status_out = status_out

        nl_1 = np.concatenate((belongs_to_collection_out, title_out,original_title_out))
        nl_1 = nl_1.astype(np.int8)

        nl_2 = np.array(overview_out)
        nl_2 = nl_2.astype(np.int8)

        nl_3 = np.concatenate(( tagline_out,spoken_languages_out, o_language_out))
        nl_3 = nl_3.astype(np.int8)

        nl_4 = np.concatenate((production_companies_out,production_countries_out))
        nl_4 = nl_4.astype(np.int8)

        outputs = {
            'nl_1': nl_1,
            'nl_2': nl_2,
            'nl_3': nl_3,
            'nl_4': nl_4,
            'adult': np.array(adult_out, dtype=np.float64),
            'budget': np.array(budget_out, dtype=np.float64),
            'popularity': np.array(popularity_out, dtype=np.float64),
            'release_date': np.array(release_date_out, dtype=np.float64),
            'revenue': np.array(revenue_out, dtype=np.float64),
            'runtime': np.array(self.runtime[item], dtype=np.float64),
            'status': np.array(status_out, dtype=np.float64),
            'video': np.array(video_out, dtype=np.float64),
            'vote_average': np.array(vote_average_out, dtype=np.float64),
            'vote_count': np.array(vote_count_out, dtype=np.float64),
        }

        cfk = np.zeros(32, dtype=np.float64)

        for i in range(len(genres_out)):
            cfk[genres_out[i]] = 1

        targets = {
            'genres': cfk,
        }

        return outputs, targets

In [28]:
dsm = DataSetManual(data_csv)
# DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = 'cpu'
# DEVICE = 'cuda:0'
dsm.__save_vocab__('vocab.yaml',path='E:/Python/movie_recommender/')



path To E:/Python/movie_recommender/: 100%|██████████| 45465/45465 [00:11<00:00, 3908.03it/s]


writing yaml file vocab.yaml
at E:/Python/movie_recommender/vocab.yaml
Done


In [7]:
dataLd = DataLoader(
    dsm,
    batch_size=4,
    num_workers=1,
    pin_memory=True
)

In [35]:
class Net(nn.Module):

    def __init__(self,
                 num_embedding_1: int = 171678+10,
                 embedding_dim_1: int = 400,
                 num_embedding_2: int = 171678+10,
                 embedding_dim_2: int = 400,
                 num_embedding_3: int = 171678+10,
                 embedding_dim_3: int = 400,
                 num_embedding_4: int = 171678+10,
                 embedding_dim_4: int = 400,
                 lstm_layers_1: int = 1,
                 lstm_hidden_num_1: int = 90,
                 lstm_layers_2: int = 1,
                 lstm_hidden_num_2: int = 90,
                 lstm_layers_3: int = 1,
                 lstm_hidden_num_3: int = 90,
                 lstm_layers_4: int = 1,
                 lstm_hidden_num_4: int = 90,

                 output_size: int = 15):
        super(Net, self).__init__()

        self.num_embedding_1 = num_embedding_1
        self.embedding_dim_1 = embedding_dim_1

        self.lstm_hidden_num_1 = lstm_hidden_num_1

        self.lstm_layers_1 = lstm_layers_1

        self.fc_adult_0 = nn.Linear(1, 15)
        self.fc_adult_1 = nn.Linear(15, 10)

        self.fc_budget_0 = nn.Linear(1, 10)
        self.fc_budget_1 = nn.Linear(10, 10)

        self.fc_popularity_0 = nn.Linear(1, 15)
        self.fc_popularity_1 = nn.Linear(15, 10)

        self.fc_revenue_0 = nn.Linear(1, 30)
        self.fc_revenue_1 = nn.Linear(30, 10)

        self.fc_runtime_0 = nn.Linear(1, 15)
        self.fc_runtime_1 = nn.Linear(15, 10)

        self.fc_status_0 = nn.Linear(1, 20)
        self.fc_status_1 = nn.Linear(20, 10)

        self.fc_video_0 = nn.Linear(1, 20)
        self.fc_video_1 = nn.Linear(20, 10)

        self.fc_vote_average_0 = nn.Linear(1, 15)
        self.fc_vote_average_1 = nn.Linear(15, 10)

        self.fc_vote_count_0 = nn.Linear(1, 15)
        self.fc_vote_count_1 = nn.Linear(15, 10)

        self.fc_release_date_0 = nn.Linear(1, 15)
        self.fc_release_date_1 = nn.Linear(15, 10)

        self.embedding_layer_1 = nn.Embedding(num_embeddings=num_embedding_1, embedding_dim=embedding_dim_1)
        self.lstm_1 = nn.LSTM(num_layers=lstm_layers_1, hidden_size=lstm_hidden_num_1, input_size=embedding_dim_1)

        self.embedding_layer_2 = nn.Embedding(num_embeddings=num_embedding_2, embedding_dim=embedding_dim_2)
        self.lstm_2 = nn.LSTM(num_layers=lstm_layers_2, hidden_size=lstm_hidden_num_2, input_size=embedding_dim_2)

        self.embedding_layer_3 = nn.Embedding(num_embeddings=num_embedding_3, embedding_dim=embedding_dim_3)
        self.lstm_3 = nn.LSTM(num_layers=lstm_layers_3, hidden_size=lstm_hidden_num_3, input_size=embedding_dim_3)

        self.embedding_layer_4 = nn.Embedding(num_embeddings=num_embedding_4, embedding_dim=embedding_dim_4)
        self.lstm_4 = nn.LSTM(num_layers=lstm_layers_4, hidden_size=lstm_hidden_num_4, input_size=embedding_dim_4)

        self.fc0_1 = nn.Linear(lstm_hidden_num_1, lstm_hidden_num_1 * 2)
        self.relu_0_1 = nn.ReLU()
        self.fc1_1 = nn.Linear(lstm_hidden_num_1 * 2, 8)
        self.relu_1_1 = nn.ReLU()

        self.output_layer_1 = nn.Linear(self.lstm_hidden_num_1, self.lstm_hidden_num_1*2)
        self.output_layer_2 = nn.Linear( self.lstm_hidden_num_1*2, 64)
        self.output_layer_3 = nn.Linear(64, 32)
        self.softmax = nn.Softmax(dim=2)

    @staticmethod
    def reshape(ins,hidden_size:int=90):
        if len(ins.shape) > 1:
            h0 = torch.zeros(1, ins.size()[0], hidden_size)
            c0 = torch.zeros(1, ins.size()[0], hidden_size)
            ins = ins.view(1, ins.size()[0], ins.size()[1])

        else:
            h0 = torch.zeros(1, 1, hidden_size)
            c0 = torch.zeros(1, 1, hidden_size)
            ins = ins.view(1, 1, ins.size()[0])
        return ins, h0, c0
    @staticmethod
    def review(ins, dim_check:int=1, dims=1):
        if len(ins.shape) <= dim_check:

            ins = ins.view(dims,-1)
        return ins
    def forward(self,
                nl_1: torch.Tensor,
                nl_2: torch.Tensor,
                nl_3: torch.Tensor,
                nl_4: torch.Tensor,
                adult: torch.Tensor,
                budget: torch.Tensor,
                popularity: torch.Tensor,
                release_date: torch.Tensor,
                revenue: torch.Tensor,
                runtime: torch.Tensor,
                status: torch.Tensor,
                video: torch.Tensor,
                vote_average: torch.Tensor,

                ) -> torch.Tensor:
        ...

        if len(status.shape) <= 1:
            status = status.view(1, -1)

        release_date_out = f.leaky_relu_(self.fc_release_date_1(f.leaky_relu_(self.fc_release_date_0(release_date),0.2)),0.2)
        adult_out = f.leaky_relu_(self.fc_adult_1(f.leaky_relu_(self.fc_adult_0(adult),0.2)),0.2)
        budget_out = f.leaky_relu_(self.fc_budget_1(f.leaky_relu_(self.fc_budget_0(budget),0.2)),0.2)
        popularity_out = f.leaky_relu_(self.fc_popularity_1(f.leaky_relu_(self.fc_popularity_0(popularity),0.2)),0.2)
        revenue_out = f.leaky_relu_(self.fc_revenue_1(f.leaky_relu_(self.fc_revenue_0(revenue),0.2)),0.2)
        runtime_out = f.leaky_relu_(self.fc_runtime_1(f.leaky_relu_(self.fc_runtime_0(runtime),0.2)),0.2)
        status_out = f.leaky_relu_(self.fc_status_1(f.leaky_relu_(self.fc_status_0(status),0.2)),0.2)
        video_out = f.leaky_relu_(self.fc_video_1(f.leaky_relu_(self.fc_video_0(video),0.2)),0.2)
        vote_average_out = f.leaky_relu_(self.fc_vote_average_1(f.leaky_relu_(self.fc_vote_average_0(vote_average),0.2)),0.2)


        release_date_out = self.review(release_date_out)
        adult_out = self.review(adult_out)
        budget_out = self.review(budget_out)
        popularity_out = self.review(popularity_out)
        revenue_out = self.review(revenue_out)
        runtime_out = self.review(runtime_out)
        status_out = self.review(status_out)
        video_out = self.review(video_out)
        vote_average_out = self.review(vote_average_out)

        non_text = torch.cat(
            (release_date_out, adult_out, budget_out, popularity_out, revenue_out, runtime_out,
             status_out,
             video_out, vote_average_out),dim=1)

        embedding_two_out = self.embedding_layer_2(nl_2)
        embedding_one_out = self.embedding_layer_1(nl_1)
        embedding_three_out = self.embedding_layer_3(nl_3)
        embedding_four_out = self.embedding_layer_4(nl_4)

        reshape_one_out, h0_0, c0_0 = self.reshape(ins=embedding_one_out)
        lstm_out_1, _ = self.lstm_1(reshape_one_out, (h0_0, c0_0))

        reshape_two_out, h0_0, c0_0 = self.reshape(ins=embedding_two_out)
        lstm_out_2, _ = self.lstm_2(reshape_two_out, (h0_0, c0_0))

        reshape_three_out, h0_0, c0_0 = self.reshape(ins=embedding_three_out)
        lstm_out_3, _ = self.lstm_3(reshape_three_out, (h0_0, c0_0))

        reshape_four_out, h0_0, c0_0 = self.reshape(ins=embedding_four_out)
        lstm_out_4, _ = self.lstm_4(reshape_four_out, (h0_0, c0_0))

        non_text = non_text.view(1, 1, -1)
        x_2 = torch.cat((lstm_out_1,lstm_out_2,lstm_out_3 ,lstm_out_4,non_text), dim=1)
        prediction = self.softmax(self.output_layer_3(f.leaky_relu_(self.output_layer_2(f.leaky_relu_(self.output_layer_1(x_2),0.2)),0.2)))

        return prediction



In [36]:
network = Net().to(DEVICE)

In [39]:
def train(
        network_in,
        epochs: int = 50,
):
    scalar = torch.cuda.amp.GradScaler()
    loss_function = nn.NLLLoss().to(DEVICE)
    optimizer = optim.Adam(network.parameters(), lr=3e-4)
    pbar = enumerate(dataLd.dataset)
    pbar = tqdm(pbar)

    tqdm.write('initialization Done')
    for epoch in range(epochs):
        accurate = 0
        not_accurate = 0
        total_loss = 0
        for index, (x, y) in pbar:
            if DEVICE == 'cuda':
                optimizer.zero_grad()
                adult_x = torch.from_numpy(x['adult']).type(torch.cuda.FloatTensor)
                belongs_to_collection_x = torch.from_numpy(x['belongs_to_collection']).type(torch.cuda.IntTensor)
                budget_x = torch.from_numpy(x['budget'].reshape(1, 1)).type(torch.cuda.FloatTensor)
                original_language_x = torch.from_numpy(x['original_language']).type(torch.cuda.IntTensor)
                original_title_x = torch.from_numpy(x['original_title']).type(torch.cuda.IntTensor)
                overview_x = torch.from_numpy(x['overview']).type(torch.cuda.IntTensor)
                popularity_x = torch.from_numpy(x['popularity'].reshape(1, 1)).type(torch.cuda.FloatTensor)
                production_companies_x = torch.from_numpy(x['production_companies']).type(torch.cuda.IntTensor)
                production_countries_x = torch.from_numpy(x['production_countries']).type(torch.cuda.IntTensor)
                release_date_x = torch.from_numpy(x['release_date']).type(torch.cuda.IntTensor)
                revenue_x = torch.from_numpy(x['revenue']).type(torch.cuda.FloatTensor)
                runtime_x = torch.from_numpy(x['runtime'].reshape(1, 1)).type(torch.cuda.FloatTensor)
                spoken_languages_x = torch.from_numpy(x['spoken_languages']).type(torch.cuda.IntTensor)
                status_x = torch.from_numpy(x['status']).type(torch.cuda.FloatTensor)
                tagline_x = torch.from_numpy(x['tagline']).type(torch.cuda.IntTensor)
                title_x = torch.from_numpy(x['title']).type(torch.cuda.IntTensor)
                video_x = torch.from_numpy(x['video']).type(torch.cuda.FloatTensor)
                vote_average_x = torch.from_numpy(x['vote_average']).type(torch.cuda.FloatTensor)
                vote_count_x = torch.from_numpy(x['vote_count']).type(torch.cuda.FloatTensor)
                genres_y = torch.from_numpy(y['genres']).type(torch.cuda.IntTensor)
                with torch.cuda.amp.autocast():
                    y_hat = network_in.forward(adult=adult_x,
                                            belongs_to_collection=belongs_to_collection_x,
                                            budget=budget_x,
                                            original_language=original_language_x,
                                            original_title=original_title_x,
                                            overview=overview_x,
                                            popularity=popularity_x,
                                            production_companies=production_companies_x,
                                            production_countries=production_countries_x,
                                            release_date=release_date_x,
                                            runtime=runtime_x,
                                            revenue=revenue_x,
                                            spoken_languages=spoken_languages_x,
                                            status=status_x,
                                            tagline=tagline_x,
                                            title=title_x,
                                            video=video_x,
                                            vote_average=vote_average_x
                                            )
                    simulation_y = torch.zeros(y_hat.shape).to(DEVICE)
                    simulation_y[:, :, 0:] = genres_y
                    loss = loss_function(y_hat, simulation_y).to(DEVICE)
                del adult_x
                del belongs_to_collection_x
                del budget_x
                del original_language_x
                del original_title_x
                del overview_x
                del popularity_x
                del production_companies_x
                del production_countries_x
                del release_date_x
                del revenue_x
                del runtime_x
                del spoken_languages_x
                del status_x
                del tagline_x
                del title_x
                del video_x
                del vote_average_x
                del vote_count_x
                scalar.scale(loss).backward()
                scalar.step(optimizer)
                scalar.update()

                total_loss += loss.item()
                if y_hat[:, :, 0:] == simulation_y[:, :, 0:]:
                    accurate += 1
                else:
                    not_accurate += 1
                ac = round((accurate / index) * 100, 2)
                print(
                    f'epoch {epoch} / {epochs} , index : {index} accuracy : {ac} loss : {round(total_loss, 3)} , pass : {round((index / dsm.__len__()) * 100, 3)}',
                    end='\r')

            else:
                optimizer.zero_grad()
                nl_2 = torch.from_numpy(x['nl_2']).type(torch.IntTensor)
                nl_1 = torch.from_numpy(x['nl_1']).type(torch.IntTensor)
                nl_3 = torch.from_numpy(x['nl_3']).type(torch.IntTensor)
                nl_4 = torch.from_numpy(x['nl_4']).type(torch.IntTensor)
                adult_x = torch.from_numpy(x['adult']).type(torch.FloatTensor)
                budget_x = torch.from_numpy(x['budget'].reshape(1, 1)).type(torch.FloatTensor)
                popularity_x = torch.from_numpy(x['popularity'].reshape(1, 1)).type(torch.FloatTensor)
                release_date_x = torch.from_numpy(x['release_date']).type(torch.FloatTensor)
                revenue_x = torch.from_numpy(x['revenue']).type(torch.FloatTensor)
                runtime_x = torch.from_numpy(x['runtime'].reshape(1, 1)).type(torch.FloatTensor)
                status_x = torch.from_numpy(x['status']).type(torch.FloatTensor)
                video_x = torch.from_numpy(x['video']).type(torch.FloatTensor)
                vote_average_x = torch.from_numpy(x['vote_average']).type(torch.FloatTensor)
                genres_y = torch.from_numpy(y['genres']).type(torch.IntTensor)
                print(torch.max(nl_2))
                print(nl_2.size())
                y_hat = network_in.forward(
                    nl_1=nl_1,
                    nl_2=nl_2,
                    nl_3=nl_3,
                    nl_4=nl_4,
                    adult=adult_x,
                    budget=budget_x,
                    popularity=popularity_x,
                    release_date=release_date_x,
                    runtime=runtime_x,
                    revenue=revenue_x,
                    status=status_x,
                    video=video_x,
                    vote_average=vote_average_x
                )
                del adult_x
                del budget_x
                del release_date_x
                del revenue_x
                del runtime_x
                del status_x
                del video_x
                del vote_average_x

                simulation_y = torch.zeros(1, 32).to(DEVICE)
                simulation_y[0, 0:] = genres_y

                loss = loss_function(y_hat, simulation_y.type(torch.LongTensor)).to(DEVICE)
                loss.backward()
                optimizer.step()
                tza = []
                total_loss += loss.item()

                y_hat_fr = y_hat[:, 0]
                y_hat_numpy = y_hat_fr.detach().numpy()
                simulation_y_numpy = simulation_y.detach().numpy()
                for i in range(y_hat_numpy[0].shape[0]):
                    tza.append(0 if y_hat_numpy[0, i] < 0.5 else 1)
                simulation_y_numpy = simulation_y_numpy.astype(np.int8)
                accurate_num = 0
                for i in range(simulation_y_numpy.shape[1]):
                    ova = simulation_y_numpy[0, i]
                    if tza[i] == ova:
                        accurate_num += 1
                accurate += 1 if accurate_num == 32 else 0
                not_accurate += 0 if accurate_num == 32 else 1

                del simulation_y_numpy
                del y_hat_numpy
                del y_hat_fr
                del tza

                ac = (accurate / index) * 100 if accurate != 0 else 0

                pbar.set_description(
                    f' \r epoch {epoch} / {epochs} loss : {loss.item():.4f}  loss_total : {total_loss:.4f} , pass : % {(index / dsm.__len__()) * 100:.4f} , ac : {ac:.4f}'
                )

                pbar.refresh()
        pbar.write('')


In [40]:
train(network)

0it [00:00, ?it/s]

initialization Done
tensor(50, dtype=torch.int32)
torch.Size([50])


 epoch 0 / 50 loss : -0.0314  loss_total : -0.0314 , pass : % 0.0000 , ac : 0.0000: : 1it [00:01,  1.96s/it]

tensor(112, dtype=torch.int32)
torch.Size([67])


 epoch 0 / 50 loss : -0.0313  loss_total : -0.0627 , pass : % 0.0022 , ac : 0.0000: : 2it [00:03,  1.76s/it]

tensor(127, dtype=torch.int32)
torch.Size([56])





IndexError: index out of range in self

In [12]:
print(max(dsm.vocab))

years


In [29]:
v = list(dsm.vocab.keys())
print(v[-1])

yermoliev


In [30]:
print(dsm.vocab['yermoliev'])

171678
