In [1]:
import pandas as pd
from torch.utils.data import DataLoader,Dataset, Subset
import numpy as np
import tft_model
from data_formatters import ts_dataset  
import data_formatters.base
import expt_settings.configs
import importlib
from data_formatters import utils
import torch.optim as optim
import torch
from tqdm import tqdm
import pickle
import time

In [2]:
pd.set_option('max_columns', 1000)

In [3]:
importlib.reload(utils)

<module 'data_formatters.utils' from '/Users/ardakeskiner/Desktop/TUM/Courses/ws19_20/thesis/submodules/Temporal_Fusion_Transform/data_formatters/utils.py'>

In [4]:
ExperimentConfig = expt_settings.configs.ExperimentConfig

config = ExperimentConfig('electricity', 'outputs')

data_formatter = config.make_data_formatter()


print("*** Training from defined parameters for {} ***".format('m4'))
data_csv_path = '/Users/ardakeskiner/Desktop/TUM/Courses/ws19_20/tft/tft_outputs/data/electricity/hourly_electricity.csv'
print("Loading & splitting data...")
raw_data = pd.read_csv(data_csv_path, index_col=0)
print(raw_data.shape)
start = time.time()
train, valid, test = data_formatter.split_data(raw_data)
train_samples, valid_samples = data_formatter.get_num_samples_for_calibration()
print(time.time()-start)



*** Training from defined parameters for m4 ***
Loading & splitting data...


  mask |= (ar1 == a)


(2198072, 13)
Formatting train-valid-test splits.
Setting scalers with training data...
8.583341121673584


In [5]:
# Sets up default params
fixed_params = data_formatter.get_experiment_params()
params = data_formatter.get_default_model_params()

fixed_params.update(params)
fixed_params['batch_first'] = True
fixed_params['name'] = 'test'
fixed_params['device'] = 'cpu' # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# fixed_params['minibatch_size'] = 64
# fixed_params['category_count'] = [6]
device = fixed_params['device']
fixed_params['quantiles'] = [0.5]

# with open('data_formatter_m4.pkl', 'wb') as output:  # Overwrites any existing file.
#     pickle.dump(data_formatter, output, pickle.HIGHEST_PROTOCOL)



In [6]:
max_samples = 64 #* 200 * 2
elect = ts_dataset.TSDataset(fixed_params, max_samples, train)

# with open('ts_dataset_m4.pkl', 'wb') as output:  # Overwrites any existing file.
#     pickle.dump(elect, output, pickle.HIGHEST_PROTOCOL)

# with open('ts_dataset_m4.pkl', 'rb') as input:
#     elect = pickle.load(input)

Getting valid sampling locations.
# available segments=1853057
Extracting 64 samples...


In [7]:
loader = DataLoader(
            elect,
            batch_size=fixed_params['minibatch_size'],
            num_workers=2,
            shuffle=False
        )

In [8]:
importlib.reload(tft_model)
model = tft_model.TFT(fixed_params).to(device)

{'total_time_steps': 192, 'num_encoder_steps': 168, 'num_epochs': 100, 'early_stopping_patience': 5, 'multiprocessing_workers': 5, 'column_definition': [('id', <DataTypes.REAL_VALUED: 0>, <InputTypes.ID: 4>), ('hours_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.TIME: 5>), ('power_usage', <DataTypes.REAL_VALUED: 0>, <InputTypes.TARGET: 0>), ('hour', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('day_of_week', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('hours_from_start', <DataTypes.REAL_VALUED: 0>, <InputTypes.KNOWN_INPUT: 2>), ('categorical_id', <DataTypes.CATEGORICAL: 1>, <InputTypes.STATIC_INPUT: 3>)], 'input_size': 5, 'output_size': 1, 'category_counts': [369], 'input_obs_loc': [0], 'static_input_loc': [4], 'known_regular_inputs': [1, 2, 3], 'known_categorical_inputs': [0], 'dropout_rate': 0.1, 'hidden_layer_size': 160, 'learning_rate': 0.001, 'minibatch_size': 64, 'max_gradient_norm': 0.01, 'num_heads': 4, 'stack_size': 1, 'batch_first': True, '

In [9]:
from losses.quantile_loss import QuantileLoss
from losses.smape_loss import SMAPELoss
from losses.rmsse_loss import RMSSELoss
from losses.pinball_loss import PinballLoss


q_loss_func = RMSSELoss(fixed_params['device'])
# q_loss_func = SMAPELoss(fixed_params['device'])
# q_loss_func = QuantileLoss(fixed_params['quantiles'])
# import sys
# sys.path.append('/home/arda/Desktop/thesis/')
# from loss_modules import PinballLoss
q_loss_func = PinballLoss(0.45, device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [10]:
model.train()
epochs=100
losses = []
for i in range(epochs):
    epoch_loss = [] 
    progress_bar = tqdm(enumerate(loader), total=len(loader))
    for batch_num, batch in progress_bar:
        optimizer.zero_grad()
        output, all_inputs, attention_components = model(batch['inputs'])
#         loss= q_loss_func(output[:,:,:].view(-1,1), batch['outputs'][:,:,0].flatten().float().to(device))
        loss = q_loss_func(output.squeeze(2), batch['outputs'][:,:,0].float().to(device))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), fixed_params['max_gradient_norm'])
        optimizer.step()
        epoch_loss.append(loss.item())
    
    losses.append(np.mean(epoch_loss))
    if loss.item() <= min(losses):
        torch.save(model.state_dict(), 'electricity_best_model_smape_loss.pth')
        
    print(np.mean(epoch_loss))
    

  mmask = (-1e+9) * (1. - torch.tensor(mask, dtype=torch.float)) # setting to infinity
100%|██████████| 1/1 [00:02<00:00,  2.82s/it]

1.6356086730957031



100%|██████████| 1/1 [00:03<00:00,  3.05s/it]

0.8042025566101074



100%|██████████| 1/1 [00:04<00:00,  4.21s/it]

0.7193329930305481



100%|██████████| 1/1 [00:03<00:00,  3.71s/it]

0.6266722083091736



100%|██████████| 1/1 [00:06<00:00,  6.65s/it]

0.6233759522438049



100%|██████████| 1/1 [00:05<00:00,  5.71s/it]

0.6012177467346191



100%|██████████| 1/1 [00:03<00:00,  3.23s/it]

0.5798676609992981



100%|██████████| 1/1 [00:02<00:00,  2.90s/it]

0.5878437161445618



100%|██████████| 1/1 [00:03<00:00,  3.09s/it]

0.5719974637031555



100%|██████████| 1/1 [00:04<00:00,  4.30s/it]

0.5404468178749084



100%|██████████| 1/1 [00:04<00:00,  4.64s/it]

0.51779705286026



100%|██████████| 1/1 [00:04<00:00,  4.80s/it]

0.5076661705970764



100%|██████████| 1/1 [00:04<00:00,  4.40s/it]

0.4772258698940277



100%|██████████| 1/1 [00:04<00:00,  4.71s/it]

0.48332369327545166



100%|██████████| 1/1 [00:04<00:00,  4.40s/it]

0.4895535707473755



100%|██████████| 1/1 [00:03<00:00,  3.91s/it]

0.47538694739341736



100%|██████████| 1/1 [00:03<00:00,  3.68s/it]

0.4581122398376465



100%|██████████| 1/1 [00:03<00:00,  3.29s/it]

0.44741272926330566



100%|██████████| 1/1 [00:03<00:00,  3.16s/it]

0.43638166785240173



100%|██████████| 1/1 [00:05<00:00,  5.01s/it]

0.41501662135124207



100%|██████████| 1/1 [00:04<00:00,  4.42s/it]

0.4170372486114502



100%|██████████| 1/1 [00:05<00:00,  5.10s/it]

0.417370080947876



100%|██████████| 1/1 [00:05<00:00,  5.61s/it]

0.401816725730896



100%|██████████| 1/1 [00:05<00:00,  5.97s/it]

0.4050987958908081



100%|██████████| 1/1 [00:05<00:00,  5.20s/it]

0.421008437871933



100%|██████████| 1/1 [00:03<00:00,  3.08s/it]

0.4223504364490509



100%|██████████| 1/1 [00:05<00:00,  5.50s/it]

0.4094175100326538



100%|██████████| 1/1 [00:04<00:00,  4.93s/it]

0.399186372756958



100%|██████████| 1/1 [00:04<00:00,  4.50s/it]

0.388192743062973



100%|██████████| 1/1 [00:03<00:00,  3.17s/it]

0.3724614381790161



100%|██████████| 1/1 [00:03<00:00,  3.84s/it]

0.37388601899147034



100%|██████████| 1/1 [00:03<00:00,  3.62s/it]

0.3636728525161743



100%|██████████| 1/1 [00:04<00:00,  4.78s/it]

0.3530537188053131



100%|██████████| 1/1 [00:04<00:00,  4.26s/it]

0.348838210105896



100%|██████████| 1/1 [00:05<00:00,  5.05s/it]

0.3568054437637329



100%|██████████| 1/1 [00:04<00:00,  4.97s/it]

0.35134243965148926



100%|██████████| 1/1 [00:04<00:00,  4.96s/it]

0.33356860280036926



100%|██████████| 1/1 [00:03<00:00,  3.81s/it]

0.3407537043094635



100%|██████████| 1/1 [00:04<00:00,  4.78s/it]

0.3636913597583771



100%|██████████| 1/1 [00:04<00:00,  4.86s/it]

0.3684002161026001



100%|██████████| 1/1 [00:03<00:00,  3.63s/it]

0.35991379618644714



100%|██████████| 1/1 [00:04<00:00,  4.05s/it]

0.3481706380844116



100%|██████████| 1/1 [00:03<00:00,  3.21s/it]

0.32863304018974304



100%|██████████| 1/1 [00:04<00:00,  4.30s/it]

0.3209492564201355



100%|██████████| 1/1 [00:03<00:00,  3.62s/it]

0.3298642933368683



100%|██████████| 1/1 [00:03<00:00,  3.04s/it]

0.31871700286865234



100%|██████████| 1/1 [00:02<00:00,  2.96s/it]

0.32247480750083923



100%|██████████| 1/1 [00:03<00:00,  3.21s/it]

0.3250703811645508



100%|██████████| 1/1 [00:02<00:00,  2.94s/it]

0.305036336183548



100%|██████████| 1/1 [00:02<00:00,  2.94s/it]

0.3158656060695648



100%|██████████| 1/1 [00:03<00:00,  3.17s/it]

0.3301292061805725



100%|██████████| 1/1 [00:03<00:00,  3.05s/it]

0.3179815113544464



100%|██████████| 1/1 [00:03<00:00,  3.06s/it]

0.29705455899238586



100%|██████████| 1/1 [00:03<00:00,  3.03s/it]

0.2951955497264862



100%|██████████| 1/1 [00:02<00:00,  2.81s/it]

0.293441504240036



100%|██████████| 1/1 [00:02<00:00,  2.91s/it]

0.29207825660705566



100%|██████████| 1/1 [00:03<00:00,  3.04s/it]

0.29592394828796387



100%|██████████| 1/1 [00:02<00:00,  2.76s/it]

0.28327593207359314



100%|██████████| 1/1 [00:02<00:00,  2.86s/it]

0.2969607412815094



100%|██████████| 1/1 [00:02<00:00,  2.85s/it]

0.30979540944099426



100%|██████████| 1/1 [00:03<00:00,  3.09s/it]

0.3015197217464447



100%|██████████| 1/1 [00:03<00:00,  3.72s/it]

0.3016020953655243



100%|██████████| 1/1 [00:04<00:00,  4.34s/it]

0.29347550868988037



100%|██████████| 1/1 [00:03<00:00,  3.52s/it]

0.28540298342704773



100%|██████████| 1/1 [00:03<00:00,  3.44s/it]

0.2712963819503784



100%|██████████| 1/1 [00:04<00:00,  4.50s/it]

0.2884867489337921



100%|██████████| 1/1 [00:03<00:00,  3.32s/it]

0.30326858162879944



100%|██████████| 1/1 [00:03<00:00,  3.43s/it]

0.2907520830631256



100%|██████████| 1/1 [00:03<00:00,  3.44s/it]

0.2827986776828766



100%|██████████| 1/1 [00:02<00:00,  2.82s/it]

0.2777986228466034



100%|██████████| 1/1 [00:03<00:00,  3.18s/it]

0.2790696322917938



100%|██████████| 1/1 [00:03<00:00,  3.08s/it]

0.262139230966568



100%|██████████| 1/1 [00:02<00:00,  2.96s/it]

0.2816202640533447



100%|██████████| 1/1 [00:03<00:00,  3.56s/it]

0.2935580611228943



100%|██████████| 1/1 [00:02<00:00,  2.75s/it]

0.2756616175174713



100%|██████████| 1/1 [00:03<00:00,  3.13s/it]

0.2723993957042694



100%|██████████| 1/1 [00:03<00:00,  3.15s/it]

0.26830950379371643



100%|██████████| 1/1 [00:02<00:00,  2.96s/it]

0.2694215774536133



100%|██████████| 1/1 [00:02<00:00,  2.86s/it]

0.26834502816200256



100%|██████████| 1/1 [00:03<00:00,  3.23s/it]

0.2500728666782379



100%|██████████| 1/1 [00:02<00:00,  2.92s/it]

0.28545448184013367



100%|██████████| 1/1 [00:02<00:00,  2.78s/it]

0.3206288814544678



100%|██████████| 1/1 [00:02<00:00,  2.73s/it]

0.29269203543663025



100%|██████████| 1/1 [00:02<00:00,  2.79s/it]

0.2588500678539276



100%|██████████| 1/1 [00:03<00:00,  3.09s/it]

0.2705850303173065



100%|██████████| 1/1 [00:02<00:00,  2.77s/it]

0.27861878275871277



100%|██████████| 1/1 [00:02<00:00,  2.99s/it]

0.2646419405937195



100%|██████████| 1/1 [00:02<00:00,  2.88s/it]

0.26820141077041626



100%|██████████| 1/1 [00:02<00:00,  2.89s/it]

0.26707905530929565



100%|██████████| 1/1 [00:04<00:00,  4.96s/it]

0.25682708621025085



100%|██████████| 1/1 [00:03<00:00,  3.70s/it]

0.2535465657711029



100%|██████████| 1/1 [00:03<00:00,  3.93s/it]

0.24052534997463226



100%|██████████| 1/1 [00:03<00:00,  3.30s/it]

0.2521652579307556



100%|██████████| 1/1 [00:02<00:00,  2.97s/it]

0.24832798540592194



100%|██████████| 1/1 [00:03<00:00,  3.26s/it]

0.22999410331249237



100%|██████████| 1/1 [00:03<00:00,  3.08s/it]

0.23291654884815216



100%|██████████| 1/1 [00:03<00:00,  3.22s/it]

0.23447293043136597



100%|██████████| 1/1 [00:03<00:00,  3.52s/it]

0.22348058223724365



100%|██████████| 1/1 [00:02<00:00,  2.89s/it]

0.23345069587230682



100%|██████████| 1/1 [00:03<00:00,  3.41s/it]

0.23355813324451447





In [None]:
import matplotlib.pyplot as plt
import numpy as np

ind = np.random.choice(64)
print(ind)
plt.plot(output[ind,:,0].detach().cpu().numpy(), label='pred_1')
# plt.plot(output[ind,:,1].detach().cpu().numpy(), label='pred_5')
# plt.plot(output[ind,:,2].detach().cpu().numpy(), label='pred_9')

plt.plot(batch['outputs'][ind,:,0], label='true')
plt.legend()

In [None]:
def symmetric_mean_absolute_percentage_error(forecast, actual):
    # Symmetric Mean Absolute Percentage Error (SMAPE)
    sequence_length = forecast.shape[1]
    sumf = np.sum(np.abs(forecast - actual) / (np.abs(actual) + np.abs(forecast)), axis=1)
    return np.mean((2 * sumf) / sequence_length)

In [None]:
symmetric_mean_absolute_percentage_error(output[:,:,0].detach().cpu().numpy(),
                                        batch['outputs'][:,:,0].detach().cpu().numpy())

