# Imports

In [1]:
import sys
sys.path.insert(0, '/Users/mvilenko/Library/CloudStorage/OneDrive-PayPal/CPI_HRNN - version 2.0/mayas_project/hgru_model/model/')

In [2]:
import pandas as pd
import numpy as np
import pickle
import itertools
import random
import torch
import optuna
from model.utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from pipeline_config import *

# Seeds for Comparisons:

In [4]:
torch.manual_seed(1)
np.random.seed(2)
random.seed(3)

In [10]:
Year

2019

# Read Data

In [5]:
with open(son_parent_path, 'rb') as f:
    son_parent_dict = pickle.load(f)

with open(train_dataset_dict_path, 'rb') as f:
    train_dataset_dict = pickle.load(f)

with open(test_dataset_dict_path, 'rb') as f:
    test_dataset_dict = pickle.load(f)

with open(category_id_to_category_name_path, 'rb') as f:
    category_id_to_name_dict = pickle.load(f)
    
with open(categories_per_indent_path, 'rb') as f:
    categories_per_indent_dict = pickle.load(f)

# Hierarchical GRU

In [6]:
def objective(trial):
    with open(son_parent_path, 'rb') as f:
        son_parent_dict = pickle.load(f)

    with open(train_dataset_dict_path, 'rb') as f:
        train_dataset_dict = pickle.load(f)

    with open(test_dataset_dict_path, 'rb') as f:
        test_dataset_dict = pickle.load(f)

    with open(category_id_to_category_name_path, 'rb') as f:
        category_id_to_name_dict = pickle.load(f)
        
    with open(categories_per_indent_path, 'rb') as f:
        categories_per_indent_dict = pickle.load(f)

    weights_path = weightspath
    
    #--------------------------------------------------------------------------------------------------------------------------------------#

    loss_coef = trial.suggest_float('loss_coef_1',  1e-10, 1e-1, log=True)
    Lr = trial.suggest_float('Lr', 1e-5, 1e-1, log=True)

    hgru_models = {}
    num_categories = 0

    for indent in sorted(list(categories_per_indent_dict.keys())):
        for category in categories_per_indent_dict[indent]:
            num_categories +=1
            print(f'num categories: {num_categories}')
            category_name = category_id_to_name_dict[category]
            print(f'category id|name: {category}|{category_name}')

            if int(indent) == 0 or son_parent_dict[category] not in categories_per_indent_dict[indent-1]:
                loss_coef=0
                parent_weights=0
            else:
                son = category
                parent = son_parent_dict[son]
                parent_name = category_id_to_name_dict[parent]
                parent_model = GRUModel(input_dim=Features, hidden_dim=HiddenSize, layer_dim=LayersDim, output_dim=OutputDim, dropout_prob=DropoutProb)
                parent_optimizer = torch.optim.AdamW(parent_model.parameters(), lr=Lr)
                parent_model, optimizer, checkpoint, valid_loss_min = load_checkpoint(weights_path+parent_name+'.pt', parent_model, parent_optimizer)
                parent_weights = unify_model_weights(parent_model)

            train_dataloader, test_dataloader = create_dataloader(train_dataset_dict[category_name], test_dataset_dict[category_name])
            model = GRUModel(input_dim=Features, hidden_dim=HiddenSize, layer_dim=LayersDim, output_dim=OutputDim, dropout_prob=DropoutProb)
            optimizer = torch.optim.AdamW(model.parameters(), lr=Lr)
            model.to(Device)
            saving_param_path = weights_path+category_name+'.pt'
            min_error = training_and_evaluation(trial, model, train_dataloader, test_dataloader, optimizer, category_name, parent_weights, loss_coef, path=saving_param_path)
            hgru_models[category] = min_error

    average_error = sum(list(hgru_models.values()))/len(list(hgru_models.values()))
    return average_error


In [7]:
EPOCHS = 30
    
study = optuna.create_study(direction="minimize", sampler=optuna.samplers.TPESampler())
study.optimize(objective, n_trials=30)

[32m[I 2023-05-26 01:00:14,037][0m A new study created in memory with name: no-name-3e1a4c02-92a8-4011-96a4-697d6a6c1d54[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 01:47:28,848][0m Trial 0 finished with value: 0.2044547529353318 and parameters: {'loss_coef_1': 0.011591786104529891, 'Lr': 0.08370957582295632}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 02:32:01,990][0m Trial 1 finished with value: 0.21019435832893593 and parameters: {'loss_coef_1': 1.2164521225915575e-06, 'Lr': 0.043764864327572074}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 03:21:05,436][0m Trial 2 finished with value: 0.29601309839305506 and parameters: {'loss_coef_1': 4.70419933560253e-05, 'Lr': 1.5799479587299387e-05}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 04:05:38,026][0m Trial 3 finished with value: 0.22639573410407682 and parameters: {'loss_coef_1': 0.02800839374904708, 'Lr': 0.019418709871067596}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 04:52:38,300][0m Trial 4 finished with value: 0.24273911764482908 and parameters: {'loss_coef_1': 6.3308154095700285e-06, 'Lr': 0.00011154532840377565}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 05:37:09,529][0m Trial 5 finished with value: 0.2123589538813348 and parameters: {'loss_coef_1': 0.0020486115446326602, 'Lr': 0.048587691008116884}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 06:21:28,554][0m Trial 6 finished with value: 0.2302102607166961 and parameters: {'loss_coef_1': 2.22095617171102e-07, 'Lr': 0.009437599840553751}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 07:06:10,828][0m Trial 7 finished with value: 0.23521193273594368 and parameters: {'loss_coef_1': 0.0032604759447738777, 'Lr': 0.00845121575456454}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 07:52:37,948][0m Trial 8 finished with value: 0.24344377439291098 and parameters: {'loss_coef_1': 0.0021087787622657474, 'Lr': 0.00030629213602619855}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 08:38:48,842][0m Trial 9 finished with value: 0.24291920555914437 and parameters: {'loss_coef_1': 1.376476019358157e-09, 'Lr': 0.0003901932811173495}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 09:23:36,279][0m Trial 10 finished with value: 0.2387361904222252 and parameters: {'loss_coef_1': 0.00020722818545940717, 'Lr': 0.002781801959608239}. Best is trial 0 with value: 0.2044547529353318.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 10:08:01,661][0m Trial 11 finished with value: 0.1951226260319809 and parameters: {'loss_coef_1': 4.0852794322516647e-07, 'Lr': 0.09138694552832571}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 10:52:28,899][0m Trial 12 finished with value: 0.20308971823151806 and parameters: {'loss_coef_1': 8.551127495289472e-08, 'Lr': 0.09097030406274807}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 11:36:40,033][0m Trial 13 finished with value: 0.2036954027387622 and parameters: {'loss_coef_1': 4.519042051139565e-08, 'Lr': 0.0885932674282467}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 12:21:10,362][0m Trial 14 finished with value: 0.23711150360392236 and parameters: {'loss_coef_1': 1.2898940802264808e-08, 'Lr': 0.0036558432738057777}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 13:06:09,770][0m Trial 15 finished with value: 0.21893175259207082 and parameters: {'loss_coef_1': 1.5233701184463094e-10, 'Lr': 0.021812184890174575}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 13:51:49,795][0m Trial 16 finished with value: 0.20103677597799602 and parameters: {'loss_coef_1': 5.037825021669738e-07, 'Lr': 0.09283402907712522}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 14:37:30,944][0m Trial 17 finished with value: 0.22060262697841018 and parameters: {'loss_coef_1': 9.833123681013438e-07, 'Lr': 0.02587271575819774}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 15:23:10,175][0m Trial 18 finished with value: 0.23246673834252674 and parameters: {'loss_coef_1': 5.3953236751474e-06, 'Lr': 0.009552861309549531}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 16:06:32,884][0m Trial 19 finished with value: 0.23942472282199515 and parameters: {'loss_coef_1': 6.078689320875498e-09, 'Lr': 0.002469437929930294}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[32m[I 2023-05-26 16:49:55,407][0m Trial 20 finished with value: 0.2159112780373354 and parameters: {'loss_coef_1': 3.0412342240150206e-07, 'Lr': 0.037249750167414046}. Best is trial 11 with value: 0.1951226260319809.[0m


num categories: 1
category id|name: 2|All-items
num categories: 2
category id|name: 256|Alcoholic beverages, tobacco products and recreational cannabis




num categories: 3
category id|name: 290|All-items excluding alcoholic beverages, tobacco products and smokers' supplies and recreational cannabis
num categories: 4
category id|name: 287|All-items excluding energy
num categories: 5
category id|name: 284|All-items excluding food
num categories: 6
category id|name: 285|All-items excluding food and energy
num categories: 7
category id|name: 302|All-items excluding gasoline
num categories: 8
category id|name: 289|All-items excluding mortgage interest cost
num categories: 9
category id|name: 293|All-items excluding shelter
num categories: 10
category id|name: 139|Clothing and footwear
num categories: 11
category id|name: 288|Energy
num categories: 12
category id|name: 3|Food
num categories: 13
category id|name: 286|Food and energy
num categories: 14
category id|name: 201|Health and personal care
num categories: 15
category id|name: 96|Household operations, furnishings and equipment
num categories: 16
category id|name: 219|Recreation, educati

[33m[W 2023-05-26 17:18:44,958][0m Trial 21 failed with parameters: {'loss_coef_1': 9.874951458519302e-08, 'Lr': 0.09423857369400652} because of the following error: RuntimeError('File /Users/mvilenko/Library/CloudStorage/OneDrive-PayPal/CPI_HRNN - version 2.0/mayas_project/hgru_model_canada/models_weights/All other passenger vehicle operating expenses.pt cannot be opened.').[0m
Traceback (most recent call last):
  File "/Users/mvilenko/Library/Python/3.11/lib/python/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/var/folders/88/mkn0vj0s1kl1179x22s6j2h80000gq/T/ipykernel_71469/4198739395.py", line 51, in objective
    min_error = training_and_evaluation(trial, model, train_dataloader, test_dataloader, optimizer, category_name, parent_weights, loss_coef, path=saving_param_path)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

RuntimeError: File /Users/mvilenko/Library/CloudStorage/OneDrive-PayPal/CPI_HRNN - version 2.0/mayas_project/hgru_model_canada/models_weights/All other passenger vehicle operating expenses.pt cannot be opened.

In [None]:
best_trial = study.best_trial
best_trial

In [None]:
def get_results_on_test_set(weights_path, train_dataset_dict, test_dataset_dict, categories = None):
    predictions_dict = {}
    if categories is None:
        categories = list(test_dataset_dict.keys())
        
    for category in categories:
        print(category)
        train_dataloader, test_dataloader = create_dataloader(train_dataset_dict[category], test_dataset_dict[category])
        basic_model = Model
        basic_optimizer = Optimizer
        ckp_path = weights_path+category+'.pt'
        model, optimizer, checkpoint, valid_loss_min = load_checkpoint(ckp_path, basic_model, basic_optimizer)
        predictions = get_predictions_on_test_set(model, test_dataloader)
        predictions_dict[category] = predictions
    return predictions_dict

In [None]:
categories_lists = list(categories_per_indent_dict.values())
categories_id = list(itertools.chain.from_iterable(categories_lists))
categories = []
for category_id in categories_id:
    categories.append(category_id_to_name_dict[category_id])

In [None]:
predictions_dict = get_results_on_test_set(weightspath, train_dataset_dict, test_dataset_dict, categories = categories)

In [None]:
with open(test_predictions_path, 'wb') as handle:
    pickle.dump(predictions_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Get Best Model Weights:

In [None]:
#def get_weights_per_category(category_list, dir_path):
#    weights = {}
#    for category in category_list:
#        model = GRUModel(input_dim = Features, hidden_dim = HiddenSize, layer_dim = LayersDim, output_dim = OutputDim, dropout_prob = DropoutProb)
#        model.to(device)
#        optimizer = torch.optim.AdamW(model.parameters(), lr=Lr)
        
#        best_checkpoint_path = dir_path+category + '.pt'

#        category_model, optimizer, checkpoint, valid_loss_min = load_checkpoint(best_checkpoint_path, model, optimizer)
#        category_model_weights = unify_model_weights(category_model)

#        weights[category] = category_model_weights
    
#    return weights

In [None]:
def get_weights_per_category(category_id_list, dir_path):
    basic_model = GRUModel(input_dim = Features, hidden_dim = HiddenSize, layer_dim = LayersDim, output_dim = OutputDim, dropout_prob = DropoutProb)
    basic_optimizer = torch.optim.AdamW(basic_model.parameters(), lr=Lr)
    #basic_model.to(device)

    best_models_weights_dict = {}

    for category_id in category_id_list:
        category_name = category_id_to_name_dict[category_id]
        ckp_path = dir_path+category_name+'.pt'
        model, optimizer, checkpoint, valid_loss_min = load_checkpoint(ckp_path, basic_model, basic_optimizer)
        category_model_weights = unify_model_weights(model)
        best_models_weights_dict[category_id] = category_model_weights
        
    return best_models_weights_dict

In [None]:
dir_path = '/Users/mvilenko/Library/CloudStorage/OneDrive-PayPal/CPI_HRNN - version 2.0/mayas_project/hgru_model/models_weights/'

In [None]:
category_id_list = []
 
# list out keys and values separately
key_list = list(category_id_to_name_dict.keys())
val_list = list(category_id_to_name_dict.values())

for cat_name in categories:
    position = val_list.index(cat_name)
    category_id_list.append(key_list[position])


In [None]:
len(categories)

In [None]:
len(category_id_list)

In [None]:
weights_dict = get_weights_per_category(category_id_list, dir_path)

In [None]:
with open('/Users/mvilenko/Library/CloudStorage/OneDrive-PayPal/CPI_HRNN - version 2.0/pickle files/hgru_model_weights.pickle', 'wb') as handle:
    pickle.dump(weights_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)