# Imports

In [12]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset
from transformers import BertTokenizer

# Loading prepared train and test data

In [13]:
data = pd.read_csv('prepared_data/country_wise_train_test.csv')
print(data.shape)
data.head()

(1819, 1582)


Unnamed: 0,id,overview,budget_unknown,budget_100M,revenue_Argentina_M,revenue_Australia_M,revenue_Austria_M,revenue_Belgium_M,revenue_Domestic_M,revenue_France_M,...,crew_1417400,crew_1421720,crew_1425513,crew_1440737,crew_1455461,crew_1463785,crew_1548698,crew_1552521,crew_1552549,crew_1733142
0,86835,Rick is a screenwriter living in Los Angeles. ...,1,0.0,0.474532,0.615837,0.445917,0.30037,,2.166664,...,0,0,0,1,0,0,0,0,0,0
1,147441,The defiant leader Moses rises up against the ...,0,1.4,,,,,,,...,0,0,0,0,0,0,0,0,0,0
2,173327,From Bedrooms to Billions is a 2014 documentar...,1,0.0,,,,,,0.277779,...,0,0,0,0,0,0,0,0,0,0
3,173165,"Starting as a passion project, this movie laun...",0,0.00175,,,,,,,...,0,0,0,0,0,0,0,0,0,0
4,155084,"A bright but meek salesman, drowning in debt a...",0,0.05,,,,,,,...,0,0,0,0,0,0,0,0,0,0


In [14]:
print(data.columns.to_list())

['id', 'overview', 'budget_unknown', 'budget_100M', 'revenue_Argentina_M', 'revenue_Australia_M', 'revenue_Austria_M', 'revenue_Belgium_M', 'revenue_Domestic_M', 'revenue_France_M', 'revenue_Germany_M', 'revenue_Italy_M', 'revenue_Mexico_M', 'revenue_Netherlands_M', 'revenue_New Zealand_M', 'revenue_Portugal_M', 'revenue_Russia/CIS_M', 'revenue_South Korea_M', 'revenue_Spain_M', 'revenue_Taiwan_M', 'revenue_United Kingdom_M', 'original_language_bn', 'original_language_bs', 'original_language_ca', 'original_language_cn', 'original_language_cs', 'original_language_cy', 'original_language_da', 'original_language_de', 'original_language_el', 'original_language_en', 'original_language_es', 'original_language_et', 'original_language_fa', 'original_language_fi', 'original_language_fr', 'original_language_hi', 'original_language_hu', 'original_language_id', 'original_language_it', 'original_language_ja', 'original_language_kk', 'original_language_ko', 'original_language_nl', 'original_language

In [15]:
# revenue_Argentina           540
# revenue_Australia           820
# revenue_Austria             549
# revenue_Belgium             505
# revenue_Domestic            693
# revenue_France              778
# revenue_Germany             695
# revenue_Italy               676
# revenue_Mexico              659
# revenue_Netherlands         516
# revenue_New Zealand         734
# revenue_Portugal            550
# revenue_Russia/CIS          570
# revenue_South Korea         533
# revenue_Spain               779
# revenue_Taiwan              510
# revenue_United Kingdom
COUNTRIES = ['Argentina', 'Australia', 'Austria', 'Belgium', 'Domestic', 'France', 'Germany', 'Italy', 'Mexico', 'Netherlands', 'New Zealand', 'Portugal', 'Russia/CIS', 'South Korea', 'Spain', 'Taiwan', 'United Kingdom']
print(len(COUNTRIES))

17


In [16]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [18]:
DATA = data.copy()

In [19]:
config = {}
for COUNTRY in COUNTRIES:
    print("Processing", COUNTRY)
    data = DATA.copy()
    REVENUE_COL = 'revenue_' + COUNTRY + '_M'
    data.dropna(subset=[REVENUE_COL], inplace=True)
    data = data[data[REVENUE_COL] > 0]
    print(data.shape)
    data['overview'].fillna('', inplace=True)
    # Drop all columns where all values are 0 (cast, crew, etc missing so we can save model training time)
    data = data.loc[:, (data != 0).any(axis=0)]
    print(data.shape)
    REVENUE_SCALE = 2 * data[REVENUE_COL].mean()
    data[REVENUE_COL] = data[REVENUE_COL] / REVENUE_SCALE
    
    class RevenueDataset(Dataset):
        def __init__(self, tokenizer, data, device, max_length=256):
            self.tokenizer = tokenizer
            self.max_length = max_length
            self.data = data
            self.original_language_cols = [x for x in data.columns if x.startswith('original_language_')]
            self.genre_cols = [x for x in data.columns if x.startswith('genre_')]
            self.cast_cols = [x for x in data.columns if x.startswith('cast_')]
            self.crew_cols = [x for x in data.columns if x.startswith('crew_')]
            self.device = device

        def __getitem__(self, idx):
            row = self.data.iloc[idx]
            inputs = self.tokenizer.encode_plus(row['overview'], add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt').to(self.device)

            original_language = torch.tensor(row[self.original_language_cols].values.astype(float), dtype=torch.float, device=self.device)
            genres = torch.tensor(row[self.genre_cols].values.astype(float), dtype=torch.float, device=self.device)
            cast = torch.tensor(row[self.cast_cols].values.astype(float), dtype=torch.float, device=self.device)
            crew = torch.tensor(row[self.crew_cols].values.astype(float), dtype=torch.float, device=self.device)
            budget = torch.tensor(row['budget_100M'], dtype=torch.float, device=self.device)
            budget_unknown = torch.tensor(row['budget_unknown'], dtype=torch.float, device=self.device)
            revenue = torch.tensor(row[REVENUE_COL], dtype=torch.float, device=self.device)

            x = torch.cat((
                inputs["input_ids"].squeeze(),
                inputs["attention_mask"].squeeze(),
                original_language,
                genres,
                cast,
                crew,
                budget.unsqueeze(0),
                budget_unknown.unsqueeze(0)
            ))

            return x, revenue

        def __len__(self):
            return len(self.data)
    DATASET = RevenueDataset(tokenizer, data, DEVICE)
    # Save all info in the config
    config[COUNTRY] = {
        'original_language_cols': [x.replace('original_language_', '') for x in DATASET.original_language_cols],
        'genre_cols': [x.replace('genre_', '') for x in DATASET.genre_cols],
        'cast_cols': [x.replace('cast_', '') for x in DATASET.cast_cols],
        'crew_cols': [x.replace('crew_', '') for x in DATASET.crew_cols],
        'revenue_scale': REVENUE_SCALE
    }

Processing Argentina
(540, 1582)
(540, 1032)
Processing Australia
(822, 1582)
(822, 1366)
Processing Austria
(549, 1582)
(549, 1069)
Processing Belgium
(505, 1582)
(505, 915)
Processing Domestic
(695, 1582)
(695, 1434)
Processing France
(780, 1582)
(780, 1329)
Processing Germany
(695, 1582)
(695, 1230)
Processing Italy
(676, 1582)
(676, 1155)
Processing Mexico
(659, 1582)
(659, 1172)
Processing Netherlands
(516, 1582)
(516, 1017)
Processing New Zealand
(734, 1582)
(734, 1330)
Processing Portugal
(550, 1582)
(550, 1084)
Processing Russia/CIS
(570, 1582)
(570, 1033)
Processing South Korea
(533, 1582)
(533, 1078)
Processing Spain
(779, 1582)
(779, 1253)
Processing Taiwan
(510, 1582)
(510, 1004)
Processing United Kingdom
(913, 1582)
(913, 1396)


In [20]:
print(config)

{'Argentina': {'original_language_cols': ['bn', 'bs', 'cn', 'cs', 'cy', 'de', 'en', 'es', 'fa', 'fi', 'fr', 'hi', 'id', 'it', 'ja', 'ko', 'nl', 'pl', 'pt', 'ru', 'sr', 'sv', 'ta', 'te', 'th', 'xx', 'zh', 'zu'], 'genre_cols': ['Action', 'Adventure', 'Animation', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Family', 'Fantasy', 'Foreign', 'History', 'Horror', 'Music', 'Mystery', 'Romance', 'Science Fiction', 'TV Movie', 'Thriller', 'War', 'Western'], 'cast_cols': ['3', '31', '49', '50', '53', '64', '65', '85', '99', '110', '112', '113', '116', '134', '147', '193', '207', '227', '230', '287', '326', '335', '380', '385', '400', '418', '501', '518', '522', '524', '526', '532', '539', '540', '585', '588', '591', '649', '658', '689', '694', '738', '741', '742', '776', '785', '821', '824', '854', '857', '883', '884', '923', '936', '937', '973', '1004', '1009', '1032', '1037', '1062', '1064', '1065', '1100', '1118', '1121', '1137', '1146', '1160', '1204', '1205', '1211', '1229', '1230', '1231', '

In [21]:
# ["model_argentina_hidden_128_bert_192_cast32_crew_16_batch_16.pth", "model_aus_hidden_128_bert_192_cast32_crew_16_batch_16.pth", "model_austria_hidden_256_bert_256_cast32_crew_16_batch_16_mse_0.49.pth", "model_belgium_hidden_256_bert_256_cast32_crew_16_batch_16.pth", "model_domestic_hidden_128_bert_192_cast32_crew_16_mse_0.55.pth", "model_skorea_hidden_256_bert_256_cast32_crew_16_batch_16.pth", "model_spain_hidden_128_bert_192_cast32_crew_16_batch_16.pth", "model_taiwan_hidden_256_bert_256_cast32_crew_16_batch_16.pth", "model_uk_hidden_256_bert_256_cast32_crew_16_batch_16.pth"] (10) = $6
config['Argentina']['model_path'] = 'model_argentina_hidden_128_bert_192_cast32_crew_16_batch_16.pth'
config['Argentina']['model_params'] = {
    'hidden_size': 128,
    'bert_hidden_size': 192,
    'cast_size': 32,
    'crew_size': 16
}
config['Australia']['model_path'] = 'model_aus_hidden_128_bert_192_cast32_crew_16_batch_16.pth'
config['Australia']['model_params'] = {
    'hidden_size': 128,
    'bert_hidden_size': 192,
    'cast_size': 32,
    'crew_size': 16
}
config['Austria']['model_path'] = 'model_austria_hidden_256_bert_256_cast32_crew_16_batch_16_mse_0.49.pth'
config['Austria']['model_params'] = {
    'hidden_size': 256,
    'bert_hidden_size': 256,
    'cast_size': 32,
    'crew_size': 16
}
config['Belgium']['model_path'] = 'model_belgium_hidden_256_bert_256_cast32_crew_16_batch_16.pth'
config['Belgium']['model_params'] = {
    'hidden_size': 256,
    'bert_hidden_size': 256,
    'cast_size': 32,
    'crew_size': 16
}
config['Domestic']['model_path'] = 'model_domestic_hidden_128_bert_192_cast32_crew_16_mse_0.55.pth'
config['Domestic']['model_params'] = {
    'hidden_size': 128,
    'bert_hidden_size': 192,
    'cast_size': 32,
    'crew_size': 16
}
config['South Korea']['model_path'] = 'model_skorea_hidden_256_bert_256_cast32_crew_16_batch_16.pth'
config['South Korea']['model_params'] = {
    'hidden_size': 256,
    'bert_hidden_size': 256,
    'cast_size': 32,
    'crew_size': 16
}
config['Spain']['model_path'] = 'model_spain_hidden_128_bert_192_cast32_crew_16_batch_16.pth'
config['Spain']['model_params'] = {
    'hidden_size': 128,
    'bert_hidden_size': 192,
    'cast_size': 32,
    'crew_size': 16
}
config['Taiwan']['model_path'] = 'model_taiwan_hidden_256_bert_256_cast32_crew_16_batch_16.pth'
config['Taiwan']['model_params'] = {
    'hidden_size': 256,
    'bert_hidden_size': 256,
    'cast_size': 32,
    'crew_size': 16
}
config['United Kingdom']['model_path'] = 'model_uk_hidden_256_bert_256_cast32_crew_16_batch_16.pth'
config['United Kingdom']['model_params'] = {
    'hidden_size': 256,
    'bert_hidden_size': 256,
    'cast_size': 32,
    'crew_size': 16
}

countries_completed = ['Argentina', 'Australia', 'Austria', 'Belgium', 'Domestic', 'South Korea', 'Spain', 'Taiwan', 'United Kingdom']
print(len(countries_completed))

9


In [22]:
# Save the config to a file
import json
with open('prepared_data/country_wise_config.json', 'w') as f:
    json.dump(config, f)