In [48]:
!pip install transformers
!pip install wget


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [61]:
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
import json
import os
from collections import defaultdict
import ast
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [62]:
movies_data = pd.read_csv("../dataset/downloaded/movies_metadata.csv", usecols=['id', 'overview', 'production_countries', 'original_language', 'revenue', 'budget'])

box_office_data = pd.read_csv("../dataset/created/box_office_collections.csv").drop(columns=['Movie Name', 'imdbId'])

# Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

  box_office_data = pd.read_csv("../dataset/created/box_office_collections.csv").drop(columns=['Movie Name', 'imdbId'])


In [63]:
# Number of rows with missing id
print(movies_data['id'].isna().sum())
print(box_office_data['id'].isna().sum())

0
0


In [64]:
# Clean movies_data: convert to following schema: overview: str, production_countries: str, original_language: str, revenue: float, budget: float
# If unable to convert to float, print the row and drop it
print("Shape before ", movies_data.shape)
movies_data['budget'] = pd.to_numeric(movies_data['budget'], errors='coerce', downcast='float')
movies_data = movies_data[movies_data['id'].str.isnumeric()]
movies_data['id'] = movies_data['id'].astype(int)
# If overview is null, convert it to empty string
movies_data['overview'] = movies_data['overview'].fillna('')

movies_data = movies_data.dropna(subset=['revenue', 'budget'], how='any')
print("Shape after ", movies_data.shape)
print(movies_data.dtypes)
print(movies_data.head())

# 1. Convert to following schema:
# skip Movie Name, all other columns are either empty or have a USD symbol followed by a number surrounded by double quotes. Convert these columns to float.
print("Shape before ", box_office_data.shape)
box_office_data[box_office_data.columns[1:]] = box_office_data[box_office_data.columns[1:]].replace('[\$,]', '', regex=True).astype(float)
# Append revenue_ to all column names except id
new_cols = [(col, 'revenue_'+col) for col in box_office_data.columns[1:]]
box_office_data.rename(columns=dict(new_cols), inplace=True)

print("Shape after ", box_office_data.shape)
print(box_office_data.dtypes)
print(box_office_data.head())


Shape before  (45466, 6)
Shape after  (45460, 6)
budget                  float64
id                        int64
original_language        object
overview                 object
production_countries     object
revenue                 float64
dtype: object
       budget     id original_language  \
0  30000000.0    862                en   
1  65000000.0   8844                en   
2         0.0  15602                en   
3  16000000.0  31357                en   
4         0.0  11862                en   

                                            overview  \
0  Led by Woody, Andy's toys live happily in his ...   
1  When siblings Judy and Peter discover an encha...   
2  A family wedding reignites the ancient feud be...   
3  Cheated on, mistreated and stepped on, the wom...   
4  Just when George Banks has recovered from his ...   

                                production_countries      revenue  
0  [{'iso_3166_1': 'US', 'name': 'United States o...  373554033.0  
1  [{'iso_3166_1': 

In [28]:
merged_data = pd.merge(movies_data, box_office_data, how='left', left_on='imdb_id', right_on='imdbId')
print("Shape before ", merged_data.shape)
merged_data.head(20)

Shape of merged data  (45460, 135)


Unnamed: 0,budget,id,original_language,overview,production_countries,revenue,revenue_Argentina,revenue_Aruba,revenue_Australia,revenue_Austria,...,revenue_Guatemala,revenue_Netherlands Antilles,revenue_North Macedonia,revenue_South Africa/Nigeria,revenue_Switzerland (French/Italian),revenue_E/W Africa,revenue_Laos,revenue_Bosnia,revenue_Soviet Union,revenue_Malta
0,30000000.0,862,en,"Led by Woody, Andy's toys live happily in his ...","[{'iso_3166_1': 'US', 'name': 'United States o...",373554033.0,,,,,...,,,,,,,,,,
1,65000000.0,8844,en,When siblings Judy and Peter discover an encha...,"[{'iso_3166_1': 'US', 'name': 'United States o...",262797249.0,,,,,...,,,,,,,,,,
2,0.0,15602,en,A family wedding reignites the ancient feud be...,"[{'iso_3166_1': 'US', 'name': 'United States o...",0.0,,,,,...,,,,,,,,,,
3,16000000.0,31357,en,"Cheated on, mistreated and stepped on, the wom...","[{'iso_3166_1': 'US', 'name': 'United States o...",81452156.0,,,,,...,,,,,,,,,,
4,0.0,11862,en,Just when George Banks has recovered from his ...,"[{'iso_3166_1': 'US', 'name': 'United States o...",76578911.0,,,,,...,,,,,,,,,,
5,60000000.0,949,en,"Obsessive master thief, Neil McCauley leads a ...","[{'iso_3166_1': 'US', 'name': 'United States o...",187436818.0,,,,,...,,,,,,,,,,
6,58000000.0,11860,en,An ugly duckling having undergone a remarkable...,"[{'iso_3166_1': 'DE', 'name': 'Germany'}, {'is...",0.0,,,,,...,,,,,,,,,,
7,0.0,45325,en,"A mischievous young boy, Tom Sawyer, witnesses...","[{'iso_3166_1': 'US', 'name': 'United States o...",0.0,,,,,...,,,,,,,,,,
8,35000000.0,9091,en,International action superstar Jean Claude Van...,"[{'iso_3166_1': 'US', 'name': 'United States o...",64350171.0,,,,,...,,,,,,,,,,
9,58000000.0,710,en,James Bond must unmask the mysterious head of ...,"[{'iso_3166_1': 'GB', 'name': 'United Kingdom'...",352194034.0,,,,,...,,,,,,,,,,


In [59]:
# For each column, count the number of rows with NaN values. Also calculate the number of rows with NaN values in any and all columns.

print("Rows with any Na values: ", merged_data.isna().any(axis=1).sum())
print("Rows with all Na values: ", merged_data.isna().all(axis=1).sum())

x = merged_data.isna().sum()

# Print cols with only a few non-null values
print("Fewer than 10 non-null values ", len(x[x > merged_data.shape[0] - 10]))
print("Fewer than 50 non-null values ", len(x[x > merged_data.shape[0] - 50]))
print("Fewer than 100 non-null values ", len(x[x > merged_data.shape[0] - 100]))
print("Fewer than 1000 non-null values ", len(x[x > merged_data.shape[0] - 1000]))

print("Count of Na values in each column:")
pd.set_option("display.max_rows", 200)
print(x)
pd.reset_option("display.max_rows")

Rows with any Na values:  45460
Rows with all Na values:  0
Fewer than 10 non-null values  37
Fewer than 50 non-null values  54
Fewer than 100 non-null values  67
Fewer than 1000 non-null values  129
Count of Na values in each column:
budget                       0
id                           0
overview                     0
revenue                      0
revenue_Argentina        44920
                         ...  
original_language_wo         0
original_language_xx         0
original_language_zh         0
original_language_zu         0
original_language_nan        0
Length: 384, dtype: int64


In [30]:
print("All production_countries values: ", merged_data['production_countries'].unique())

# Read the array inside each production_countries cell as a list, and convert it into a list of country_ids, where country_id is the index in dictionary built from all unique countries encountered in the list in each cell of production_countries column
def get_country_isos(production_country):
    country_isos = []
    for country in ast.literal_eval(production_country):
        country_isos.append(country['iso_3166_1'])
    return country_isos

merged_data['production_countries_isos'] = merged_data['production_countries'].apply(get_country_isos)

# multi-hot encode the production_countries column
mlb = MultiLabelBinarizer()
mlb.fit(merged_data['production_countries_isos'])
print("Total number of classes: ", len(mlb.classes_))
print("Classes: ", mlb.classes_)

multi_hot_encoded_countries = mlb.transform(merged_data['production_countries_isos'])
# Create a dataframe with the multi-hot encoded columns, where column names are 'production_country_' + mlb.classes_
multi_hot_encoded_countries_df = pd.DataFrame(multi_hot_encoded_countries, columns=['production_country_' + country for country in mlb.classes_])

merged_data = pd.concat([merged_data, multi_hot_encoded_countries_df], axis=1)
merged_data.drop(columns=['production_countries', 'production_countries_isos'], inplace=True)
merged_data.head()

All production_countries values:  ["[{'iso_3166_1': 'US', 'name': 'United States of America'}]"
 "[{'iso_3166_1': 'DE', 'name': 'Germany'}, {'iso_3166_1': 'US', 'name': 'United States of America'}]"
 "[{'iso_3166_1': 'GB', 'name': 'United Kingdom'}, {'iso_3166_1': 'US', 'name': 'United States of America'}]"
 ...
 "[{'iso_3166_1': 'PL', 'name': 'Poland'}, {'iso_3166_1': 'CZ', 'name': 'Czech Republic'}, {'iso_3166_1': 'SK', 'name': 'Slovakia'}]"
 "[{'iso_3166_1': 'CU', 'name': 'Cuba'}, {'iso_3166_1': 'DE', 'name': 'Germany'}, {'iso_3166_1': 'ES', 'name': 'Spain'}]"
 "[{'iso_3166_1': 'EG', 'name': 'Egypt'}, {'iso_3166_1': 'IT', 'name': 'Italy'}, {'iso_3166_1': 'US', 'name': 'United States of America'}]"]
Total number of classes:  161
Classes:  ['AE' 'AF' 'AL' 'AM' 'AN' 'AO' 'AQ' 'AR' 'AT' 'AU' 'AW' 'AZ' 'BA' 'BB'
 'BD' 'BE' 'BF' 'BG' 'BM' 'BN' 'BO' 'BR' 'BS' 'BT' 'BW' 'BY' 'CA' 'CD'
 'CG' 'CH' 'CI' 'CL' 'CM' 'CN' 'CO' 'CR' 'CS' 'CU' 'CY' 'CZ' 'DE' 'DK'
 'DO' 'DZ' 'EC' 'EE' 'EG' 'ES' 'ET' 

Unnamed: 0,budget,id,original_language,overview,revenue,revenue_Argentina,revenue_Aruba,revenue_Australia,revenue_Austria,revenue_Bahrain,...,production_country_UY,production_country_UZ,production_country_VE,production_country_VN,production_country_WS,production_country_XC,production_country_XG,production_country_YU,production_country_ZA,production_country_ZW
0,30000000.0,862,en,"Led by Woody, Andy's toys live happily in his ...",373554033.0,,,,,,...,0,0,0,0,0,0,0,0,0,0
1,65000000.0,8844,en,When siblings Judy and Peter discover an encha...,262797249.0,,,,,,...,0,0,0,0,0,0,0,0,0,0
2,0.0,15602,en,A family wedding reignites the ancient feud be...,0.0,,,,,,...,0,0,0,0,0,0,0,0,0,0
3,16000000.0,31357,en,"Cheated on, mistreated and stepped on, the wom...",81452156.0,,,,,,...,0,0,0,0,0,0,0,0,0,0
4,0.0,11862,en,Just when George Banks has recovered from his ...,76578911.0,,,,,,...,0,0,0,0,0,0,0,0,0,0


In [31]:
print("All original_language values: ", merged_data['original_language'].unique())

# Convert original_language to one-hot encoding, including NaN values
merged_data = pd.get_dummies(merged_data, columns=['original_language'], dummy_na=True)
merged_data.head()

All original_language values:  ['en' 'fr' 'zh' 'it' 'fa' 'nl' 'de' 'cn' 'ar' 'es' 'ru' 'sv' 'ja' 'ko'
 'sr' 'bn' 'he' 'pt' 'wo' 'ro' 'hu' 'cy' 'vi' 'cs' 'da' 'no' 'nb' 'pl'
 'el' 'sh' 'xx' 'mk' 'bo' 'ca' 'fi' 'th' 'sk' 'bs' 'hi' 'tr' 'is' 'ps'
 'ab' 'eo' 'ka' 'mn' 'bm' 'zu' 'uk' 'af' 'la' 'et' 'ku' 'fy' 'lv' 'ta'
 'sl' 'tl' 'ur' 'rw' 'id' 'bg' 'mr' 'lt' 'kk' 'ms' 'sq' nan 'qu' 'te' 'am'
 'jv' 'tg' 'ml' 'hr' 'lo' 'ay' 'kn' 'eu' 'ne' 'pa' 'ky' 'gl' 'uz' 'sm'
 'mt' 'hy' 'iu' 'lb' 'si']


Unnamed: 0,budget,id,overview,revenue,revenue_Argentina,revenue_Aruba,revenue_Australia,revenue_Austria,revenue_Bahrain,revenue_Belgium,...,original_language_tr,original_language_uk,original_language_ur,original_language_uz,original_language_vi,original_language_wo,original_language_xx,original_language_zh,original_language_zu,original_language_nan
0,30000000.0,862,"Led by Woody, Andy's toys live happily in his ...",373554033.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False
1,65000000.0,8844,When siblings Judy and Peter discover an encha...,262797249.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False
2,0.0,15602,A family wedding reignites the ancient feud be...,0.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False
3,16000000.0,31357,"Cheated on, mistreated and stepped on, the wom...",81452156.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False
4,0.0,11862,Just when George Banks has recovered from his ...,76578911.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False


In [32]:
# Pick countries with more than 100 movies
countries = x[x > merged_data.shape[0] - 100].index.tolist()
data = merged_data.drop(columns=countries)
data.head()

Unnamed: 0,budget,id,overview,revenue,revenue_Argentina,revenue_Australia,revenue_Austria,revenue_Belgium,revenue_Bolivia,revenue_Brazil,...,original_language_tr,original_language_uk,original_language_ur,original_language_uz,original_language_vi,original_language_wo,original_language_xx,original_language_zh,original_language_zu,original_language_nan
0,30000000.0,862,"Led by Woody, Andy's toys live happily in his ...",373554033.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False
1,65000000.0,8844,When siblings Judy and Peter discover an encha...,262797249.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False
2,0.0,15602,A family wedding reignites the ancient feud be...,0.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False
3,16000000.0,31357,"Cheated on, mistreated and stepped on, the wom...",81452156.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False
4,0.0,11862,Just when George Banks has recovered from his ...,76578911.0,,,,,,,...,False,False,False,False,False,False,False,False,False,False


In [33]:
# Add columns budget_unknown
data['budget_unknown'] = data['budget'].apply(lambda x: 1 if x == 0.0 else 0)

In [34]:
# Rows with any NaN values in columns other than country wise revenue columns starting with revenue
non_revenue_cols = [col for col in data.columns if not col.startswith('revenue_')]
print("Rows with any Na values except in revenue cols: ", data[non_revenue_cols].isna().any(axis=1).sum())

Rows with any Na values except in revenue cols:  0


In [35]:
# Final data
data.head()

Unnamed: 0,budget,id,overview,revenue,revenue_Argentina,revenue_Australia,revenue_Austria,revenue_Belgium,revenue_Bolivia,revenue_Brazil,...,original_language_uk,original_language_ur,original_language_uz,original_language_vi,original_language_wo,original_language_xx,original_language_zh,original_language_zu,original_language_nan,budget_unknown
0,30000000.0,862,"Led by Woody, Andy's toys live happily in his ...",373554033.0,,,,,,,...,False,False,False,False,False,False,False,False,False,0
1,65000000.0,8844,When siblings Judy and Peter discover an encha...,262797249.0,,,,,,,...,False,False,False,False,False,False,False,False,False,0
2,0.0,15602,A family wedding reignites the ancient feud be...,0.0,,,,,,,...,False,False,False,False,False,False,False,False,False,1
3,16000000.0,31357,"Cheated on, mistreated and stepped on, the wom...",81452156.0,,,,,,,...,False,False,False,False,False,False,False,False,False,0
4,0.0,11862,Just when George Banks has recovered from his ...,76578911.0,,,,,,,...,False,False,False,False,False,False,False,False,False,1


In [36]:
# Split data into train and test sets
train_data, test_data = train_test_split(data, test_size=0.1, random_state=42)

# Scale budget using StandardScaler for better performance
scaler = StandardScaler()
train_data['budget'] = scaler.fit_transform(train_data[['budget']])
test_data['budget'] = scaler.transform(test_data[['budget']])

In [37]:
print(train_data.head())
test_data.head()

         budget      id                                           overview  \
30247 -0.242568  320006  A yellow cab is driving through the vibrant an...   
98    -0.242568   11062  The accidental shooting of a boy in New York l...   
26012 -0.242568   39227  Teenage sisters Charli and Lola are on the ver...   
19893 -0.242568    4311  Six people travel by train overnight from Mars...   
22613 -0.242568   72465  An aspiring tennis player is taken under the w...   

       revenue  revenue_Argentina  revenue_Australia  revenue_Austria  \
30247      0.0                NaN                NaN              NaN   
98         0.0                NaN                NaN              NaN   
26012      0.0                NaN                NaN              NaN   
19893      0.0                NaN                NaN              NaN   
22613      0.0                NaN                NaN              NaN   

       revenue_Belgium  revenue_Bolivia  revenue_Brazil  ...  \
30247              NaN      

Unnamed: 0,budget,id,overview,revenue,revenue_Argentina,revenue_Australia,revenue_Austria,revenue_Belgium,revenue_Bolivia,revenue_Brazil,...,original_language_uk,original_language_ur,original_language_uz,original_language_vi,original_language_wo,original_language_xx,original_language_zh,original_language_zu,original_language_nan,budget_unknown
9571,-0.242568,47876,Sorrowful Jones is a cheap bookie in 1930's. W...,6321392.0,,,,,,,...,False,False,False,False,False,False,False,False,False,1
35679,-0.242568,286987,Mae and Gabby are two friends who go everywher...,0.0,,,,,,,...,False,False,False,False,False,False,False,False,False,1
39145,-0.242568,339739,Henry Gamble's Birthday Party takes place over...,0.0,,,,,,,...,False,False,False,False,False,False,False,False,False,1
33166,-0.242568,19637,A debt-ridden young man Jeetu (Shahid Kapoor) ...,0.0,,,,,,,...,False,False,False,False,False,False,False,False,False,1
44569,-0.242155,457307,,0.0,,,,,,,...,False,False,False,False,False,False,False,False,False,0


In [38]:
# 2. Dataset and Dataloader
class RevenueDataset(Dataset):
    def __init__(self, tokenizer, data, device, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = data
        self.production_country_cols = [x for x in data.columns if x.startswith('production_country_')]
        self.original_language_cols = [x for x in data.columns if x.startswith('original_language_')]
        self.revenue_cols = [x for x in data.columns if x.startswith('revenue_')]
        self.device = device

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        inputs = self.tokenizer.encode_plus(row['overview'], add_special_tokens=True, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt').to(self.device)
        
        production_countries = torch.tensor(row[self.production_country_cols].values.astype(float), dtype=torch.float, device=self.device)
        original_language = torch.tensor(row[self.original_language_cols].values.astype(float), dtype=torch.float, device=self.device)
        country_wise_revenues = torch.tensor(row[self.revenue_cols].values.astype(float), dtype=torch.float, device=self.device)
        budget = torch.tensor(row['budget'], dtype=torch.float, device=self.device)
        budget_unknown = torch.tensor(row['budget_unknown'], dtype=torch.float, device=self.device)

        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "revenue": torch.tensor(row['revenue'], dtype=torch.float, device=self.device),
            "budget": budget,
            "budget_unknown": budget_unknown,
            "production_countries": production_countries,
            "original_language": original_language,
            "country_wise_revenues": country_wise_revenues
        }

    def __len__(self):
      return len(self.data)

In [39]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [40]:
train_dataset = RevenueDataset(tokenizer, train_data, DEVICE)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataset = RevenueDataset(tokenizer, test_data, DEVICE)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [41]:
# 3. Model
class RevenuePredictor(nn.Module):
    def __init__(self, num_country_classes, num_language_classes, bert_embedding_size = 256, production_country_embedding_size = 64, original_language_embedding_size = 64, hidden_size = 256):
        super(RevenuePredictor, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")

        # Linear layer for textual embeddings
        self.linear_overview = nn.Linear(self.bert.config.hidden_size, bert_embedding_size)

        # Linear layer for production country embeddings
        self.linear_production_country = nn.Linear(num_country_classes, production_country_embedding_size)
        
        # Linear layer for original language embeddings
        self.linear_original_language = nn.Linear(num_language_classes, original_language_embedding_size)
        
        # Budget and budget_unknown
        self.other_features_size = 2
        
        self.output_layer = nn.Sequential(
            nn.Linear(bert_embedding_size + production_country_embedding_size + original_language_embedding_size + self.other_features_size, hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, input):
        bert_out = self.bert(input_ids=input['input_ids'], attention_mask=input['attention_mask'])
        overview_embedding = self.linear_overview(bert_out['pooler_output'])
        
        production_country_embedding = self.linear_production_country(input['production_countries'])
        original_language_embedding = self.linear_original_language(input['original_language'])
        
        return self.output_layer(torch.cat((
            overview_embedding,
            production_country_embedding,
            original_language_embedding,
            input['budget'].unsqueeze(1),
            input['budget_unknown'].unsqueeze(1)
        ), dim=1))

In [42]:
model = RevenuePredictor(len(train_dataset.production_country_cols), len(train_dataset.original_language_cols))

In [43]:
# 4. Optimizer & Loss
optimizer = AdamW(model.parameters(), lr=5e-5)
loss_fn = nn.MSELoss()
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(train_data) * 3)  # 3 epochs



In [44]:
model.to(DEVICE)

RevenuePredictor(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [47]:
# 5. Training Loop
for epoch in range(3):
    model.train()
    total_loss = 0
    loop = tqdm(train_dataloader)
    for batch in loop:
        optimizer.zero_grad()

        predictions = model(batch)
        loss = loss_fn(predictions, batch['revenue'].unsqueeze(1))

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        loop.set_description(f"Epoch {epoch + 1}")
        loop.set_postfix(loss=loss.item())
        total_loss += loss.item()
        

    print(f"Epoch {epoch + 1} | Loss: {total_loss / len(train_dataloader)}")

Epoch 1:   3%|▎         | 152/5115 [10:44<5:50:43,  4.24s/it, loss=47.4]    


KeyboardInterrupt: 

In [None]:
model.eval()

def denormalize_revenue(normalized_value):
    max_revenue = data['revenue'].max()
    min_revenue = data['revenue'].min()
    return normalized_value * (max_revenue - min_revenue) + min_revenue
# 3. Load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# 4. Test Function
def predict_revenue(prompt, other_features_input):  # Added other_features_input parameter
    # Tokenize the input prompt
    inputs = tokenizer.encode_plus(prompt, add_special_tokens=True, max_length=256, padding='max_length', truncation=True, return_tensors='pt')
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # Convert other_features_input to a tensor and make sure it's of size (1, number_of_features)
    # other_features_tensor = torch.tensor([other_features_input]).to(device)
    other_features_tensor = torch.tensor([other_features_input], dtype=torch.float).to(device)

    # Predict
    with torch.no_grad():
        prediction = model(input_ids, attention_mask, other_features_tensor)
    return prediction.item()

# 5. Test the function
# prompt = "A romantic story about two star-crossed lovers set in a historical backdrop."
prompt = "When siblings Judy and Peter discover an enchanted board game that opens the door to a magical world, they unwittingly invite Alan -- an adult who's been trapped inside the game for 26 years -- into their living room. Alan's only hope for freedom is to finish the game, which proves risky as all three find themselves running from giant rhinoceroses, evil monkeys, and other terrifying creatures."

other_features_for_prompt = [1, 1, 40.9]  # Fill this list with the appropriate feature values for your prompt
# other_features_for_prompt = torch.tensor([other_features_input], dtype=torch.float).to(device)

predicted_revenue_normalized = predict_revenue(prompt, other_features_for_prompt)
predicted_revenue_actual = denormalize_revenue(predicted_revenue_normalized)

print(f"Predicted revenue for the movie: ${predicted_revenue_actual}")


In [None]:
max_budget = data['budget'].max()
min_budget = data['budget'].min()

normalized_budget = (65000000 - min_budget)/max_budget
print(normalized_budget)