In [11]:
import math
import numpy as np
import wandb

import torch
import torch_geometric
from torch_geometric.data import Data

import sys
import os
from tqdm import tqdm

# Add the 'scripts' directory to the Python path
scripts_path = os.path.abspath(os.path.join('..'))
if scripts_path not in sys.path:
    sys.path.append(scripts_path)

# Now you can import the gnn_io module
import gnn_io as gio

import gnn_architectures as garch

## 1. Define model and parameters

In [12]:
wandb.login()

# Define parameters 
num_epochs = 1000
project_name = 'gnn_target_normalized_features_car_vol_baseline_capacity_reduction_and_highway'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

early_stopping = gio.EarlyStopping(patience=10, verbose=True)
torch.set_printoptions(precision=4, sci_mode=False)



## 2. Load data

In [13]:
# Load the list of dictionaries
data_dict_list = torch.load('../../data/train_data/dataset_1pm_0-2600_16_07.pt')

# Reconstruct the Data objects
datalist = [Data(x=d['x'], edge_index=d['edge_index'], pos=d['pos'], y=d['y']) for d in data_dict_list]

In [14]:
dataset_only_relevant_dimensions = gio.cut_dimensions(dataset=datalist, indices_of_dimensions_to_keep=[0, 1, 2, 3])
dataset_normalized = gio.normalize_dataset(dataset_only_relevant_dimensions)

In [15]:
baseline_error = gio.compute_baseline_of_no_policies(dataset_normalized)
print(f'Baseline error: {baseline_error}')

baseline_error = gio.compute_baseline_of_mean_target(dataset_normalized)
print(f'Baseline error: {baseline_error}')

Baseline error: 0.32162734866142273
mean_y_normalized: 
0.56424004
median_y_normalized: 
0.56032723
Baseline error: 0.0032576548401266336


## 4. Train the model

We first find a good model for one batch. 

In [16]:
def train(model, config=None, loss_fct=None, optimizer=None, train_dl=None, valid_dl=None, device=None, early_stopping=None):
    for epoch in range(config.epochs):
        model.train()
        for idx, data in tqdm(enumerate(train_dl)):
            input_node_features, targets = data.x.to(device), data.y.to(device)
            optimizer.zero_grad()

            # Forward pass
            predicted = model(data)
            train_loss = loss_fct(predicted, targets)
            
            # Backward pass
            train_loss.backward()
            optimizer.step()
            
            wandb.log({"train_loss": train_loss.item(), "epoch": epoch, "step": idx})
            # print(f"epoch: {epoch}, step: {idx}, loss: {train_loss.item()}")
        
        val_loss = garch.validate_model_pos_features(model, valid_dl, loss_fct, device)
        print(f"epoch: {epoch}, validation loss: {val_loss}")
        wandb.log({"loss": val_loss, "epoch": epoch})
            
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered. Stopping training.")
            break
    
    print("Best validation loss: ", val_loss)
    wandb.summary["val_loss"] = val_loss
    wandb.finish()
    return val_loss, epoch

In [17]:
batch_size = 16
train_dl, valid_dl, test_dl = gio.create_dataloaders(batch_size = batch_size, dataset=dataset_normalized, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)

Total dataset length: 2580
Training subset length: 1805
Validation subset length: 387
Test subset length: 388


In [18]:
gio.save_dataloader(test_dl, 'test_dl_new_2.pt')
gio.save_dataloader_params(test_dl, 'test_loader_params_new_2.json')

In [19]:
# Currently best architecture ! 
output_layer_parameter = 'gat'
hidden_size_parameter = 16
gat_layer_parameter = 0
gcn_layer_parameter = 0
lr = 0.0001
in_channels = 6 # dimensions of the x vector + 2 (pos)

wandb.login()

wandb.init(
    project=project_name,
    config={
        "epochs": num_epochs,
        "batch_size": batch_size,
        "lr": lr,
        "early_stopping_patience": 10,
        "hidden_layer_size": hidden_size_parameter,
        "gat_layers": gat_layer_parameter,
        "gcn_layers": gcn_layer_parameter,
        "output_layer": output_layer_parameter,
        # "dropout": 0.15,
    }
)
config = wandb.config

print("output_layer: ", output_layer_parameter)
print("hidden_size: ", hidden_size_parameter)
print("gat_layers: ", gat_layer_parameter)
print("gcn_layers: ", gcn_layer_parameter)

gnn_instance = garch.MyGnn(in_channels=in_channels, out_channels=1, hidden_size=hidden_size_parameter, gat_layers=gat_layer_parameter, gcn_layers=gcn_layer_parameter, output_layer=output_layer_parameter)
model = gnn_instance.to(device)

best_val_loss, best_epoch = train(model, config=config, 
                                loss_fct=torch.nn.MSELoss(), 
                                optimizer=torch.optim.Adam(model.parameters(), lr=lr),
                                train_dl=train_dl, valid_dl=valid_dl,
                                device=device, early_stopping=early_stopping)



0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss,██▇▆▆▆▅▅▅▄▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▂▄▆█▂▄▆█▃▅▇▁▃▅▇▂▄▅▇▂▅▇▂▄▅▇▂▄▆█▂▄▆▁▃▅▇▁▃▆
train_loss,██▇▆▆▆▅▅▅▄▄▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,92.0
loss,0.00289
step,9.0
train_loss,0.00289


output_layer:  gat
hidden_size:  16
gat_layers:  0
gcn_layers:  0
Model initialized
MyGnn(
  (pointLayer): PointNetConv(local_nn=Sequential(
    (0): Linear(in_features=6, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=16, bias=True)
  ), global_nn=Sequential(
    (0): Linear(in_features=16, out_features=8, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=16, bias=True)
  ))
  (output_layer): GATConv(16, 1, heads=1)
)


113it [12:53,  6.84s/it]


epoch: 0, validation loss: 0.1439522510766983


113it [04:45,  2.53s/it]


epoch: 1, validation loss: 0.018852412700653076


113it [01:38,  1.15it/s]


epoch: 2, validation loss: 0.0035537397488951683


113it [01:29,  1.27it/s]


epoch: 3, validation loss: 0.0032891016360372307


113it [01:31,  1.23it/s]


epoch: 4, validation loss: 0.003287394745275378


113it [01:42,  1.10it/s]


epoch: 5, validation loss: 0.0032861712668091057


113it [01:43,  1.09it/s]


epoch: 6, validation loss: 0.003284875079989433


113it [01:29,  1.26it/s]


epoch: 7, validation loss: 0.0032834978681057694


113it [01:44,  1.08it/s]


epoch: 8, validation loss: 0.0032820724789053204


113it [01:26,  1.30it/s]


epoch: 9, validation loss: 0.0032806184608489275


113it [01:27,  1.29it/s]


epoch: 10, validation loss: 0.0032791201770305633


113it [01:29,  1.27it/s]


epoch: 11, validation loss: 0.003277597492560744


113it [01:26,  1.30it/s]


epoch: 12, validation loss: 0.003276061499491334


113it [01:26,  1.31it/s]


epoch: 13, validation loss: 0.0032745175901800396


113it [01:39,  1.14it/s]


epoch: 14, validation loss: 0.003272998360916972


113it [01:32,  1.22it/s]


epoch: 15, validation loss: 0.003271497106179595


113it [01:27,  1.30it/s]


epoch: 16, validation loss: 0.0032700220961123707


113it [01:26,  1.31it/s]


epoch: 17, validation loss: 0.0032685974240303037


113it [01:42,  1.10it/s]


epoch: 18, validation loss: 0.0032672420982271435


113it [01:42,  1.10it/s]


epoch: 19, validation loss: 0.00326587843708694


113it [01:32,  1.23it/s]


epoch: 20, validation loss: 0.0032645126339048147


113it [01:38,  1.15it/s]


epoch: 21, validation loss: 0.0032631126046180725


113it [01:32,  1.22it/s]


epoch: 22, validation loss: 0.0032613575644791126


113it [01:37,  1.16it/s]


epoch: 23, validation loss: 0.003258615490049124


113it [01:36,  1.17it/s]


epoch: 24, validation loss: 0.00325469090603292


113it [01:32,  1.22it/s]


epoch: 25, validation loss: 0.0032512296456843615


113it [01:27,  1.29it/s]


epoch: 26, validation loss: 0.0032485008798539636


113it [01:32,  1.22it/s]


epoch: 27, validation loss: 0.0032464817725121973


113it [01:41,  1.12it/s]


epoch: 28, validation loss: 0.003244969071820378


113it [01:32,  1.22it/s]


epoch: 29, validation loss: 0.003243718519806862


113it [01:26,  1.30it/s]


epoch: 30, validation loss: 0.003242727145552635


113it [01:35,  1.19it/s]


epoch: 31, validation loss: 0.0032418628502637148


113it [01:26,  1.30it/s]


epoch: 32, validation loss: 0.003241070993244648


113it [01:26,  1.30it/s]


epoch: 33, validation loss: 0.0032403557747602463


113it [01:26,  1.31it/s]


epoch: 34, validation loss: 0.003239692421630025


113it [01:28,  1.28it/s]


epoch: 35, validation loss: 0.0032390798442065716


113it [01:28,  1.28it/s]


epoch: 36, validation loss: 0.0032385028898715973


113it [01:32,  1.22it/s]


epoch: 37, validation loss: 0.0032379212696105244


113it [01:49,  1.03it/s]


epoch: 38, validation loss: 0.0032373205665498974


113it [01:31,  1.24it/s]


epoch: 39, validation loss: 0.0032367021776735783


113it [01:28,  1.28it/s]


epoch: 40, validation loss: 0.003236042084172368


113it [01:30,  1.24it/s]


epoch: 41, validation loss: 0.0032353522442281246


113it [02:01,  1.08s/it]


epoch: 42, validation loss: 0.0032346029579639435


113it [01:49,  1.03it/s]


epoch: 43, validation loss: 0.0032336234394460916


113it [01:47,  1.05it/s]


epoch: 44, validation loss: 0.00323272192850709


113it [01:40,  1.13it/s]


epoch: 45, validation loss: 0.003230144260451198


113it [01:42,  1.11it/s]


epoch: 46, validation loss: 0.0032265200093388557


113it [01:34,  1.20it/s]


epoch: 47, validation loss: 0.0032244518119841814


113it [01:35,  1.18it/s]


epoch: 48, validation loss: 0.0032230846118181944


113it [01:56,  1.03s/it]


epoch: 49, validation loss: 0.0032218617852777243


113it [02:08,  1.14s/it]


epoch: 50, validation loss: 0.0032206054218113424


113it [02:07,  1.13s/it]


epoch: 51, validation loss: 0.0032193877268582582


113it [02:05,  1.11s/it]


epoch: 52, validation loss: 0.0032181748934090137


113it [02:07,  1.13s/it]


epoch: 53, validation loss: 0.0032170030754059553


113it [02:12,  1.17s/it]


epoch: 54, validation loss: 0.00321585058234632


113it [02:11,  1.17s/it]


epoch: 55, validation loss: 0.003214677823707461


113it [02:37,  1.39s/it]


epoch: 56, validation loss: 0.0032135303970426323


113it [02:07,  1.13s/it]


epoch: 57, validation loss: 0.0032123690471053123


113it [02:01,  1.08s/it]


epoch: 58, validation loss: 0.0032111783511936665


113it [01:58,  1.05s/it]


epoch: 59, validation loss: 0.003209963208064437


113it [01:58,  1.05s/it]


epoch: 60, validation loss: 0.0032086947467178106


113it [02:20,  1.24s/it]


epoch: 61, validation loss: 0.0032073510996997357


113it [02:09,  1.14s/it]


epoch: 62, validation loss: 0.003205878408625722


113it [02:00,  1.06s/it]


epoch: 63, validation loss: 0.003204396227374673


113it [01:56,  1.03s/it]


epoch: 64, validation loss: 0.003202882818877697


113it [01:55,  1.02s/it]


epoch: 65, validation loss: 0.003201358951628208


113it [01:58,  1.05s/it]


epoch: 66, validation loss: 0.00319983646273613


113it [02:07,  1.13s/it]


epoch: 67, validation loss: 0.003198328884318471


113it [02:27,  1.31s/it]


epoch: 68, validation loss: 0.0031967694126069546


113it [02:08,  1.14s/it]


epoch: 69, validation loss: 0.0031951775588095186


113it [02:12,  1.17s/it]


epoch: 70, validation loss: 0.0031935116183012726


113it [02:11,  1.16s/it]


epoch: 71, validation loss: 0.003191764000803232


113it [02:08,  1.14s/it]


epoch: 72, validation loss: 0.003190029663965106


113it [02:30,  1.33s/it]


epoch: 73, validation loss: 0.0031882235780358315


113it [01:58,  1.05s/it]


epoch: 74, validation loss: 0.003186423797160387


113it [02:04,  1.10s/it]


epoch: 75, validation loss: 0.0031845986377447844


113it [02:03,  1.10s/it]


epoch: 76, validation loss: 0.0031828447245061398


113it [02:11,  1.16s/it]


epoch: 77, validation loss: 0.003181114327162504


113it [02:14,  1.19s/it]


epoch: 78, validation loss: 0.003179429564625025


113it [02:07,  1.13s/it]


epoch: 79, validation loss: 0.003177748993039131


113it [02:00,  1.07s/it]


epoch: 80, validation loss: 0.0031759517919272185


113it [01:51,  1.01it/s]


epoch: 81, validation loss: 0.0031740448903292418


113it [01:46,  1.06it/s]


epoch: 82, validation loss: 0.0031720013357698917


113it [01:25,  1.32it/s]


epoch: 83, validation loss: 0.00316994683817029


113it [01:26,  1.31it/s]


epoch: 84, validation loss: 0.0031678210850805046


113it [01:24,  1.34it/s]


epoch: 85, validation loss: 0.003165637841448188


113it [01:25,  1.33it/s]


epoch: 86, validation loss: 0.0031628576200455425


113it [01:26,  1.30it/s]


epoch: 87, validation loss: 0.0031606324575841425


113it [01:28,  1.28it/s]


epoch: 88, validation loss: 0.0031584987789392473


3it [00:02,  1.10it/s]wandb: Network error (ReadTimeout), entering retry loop.
113it [01:22,  1.36it/s]


epoch: 89, validation loss: 0.0031564403511583805


113it [01:24,  1.34it/s]


epoch: 90, validation loss: 0.0031542226020246744


113it [01:39,  1.14it/s]


epoch: 91, validation loss: 0.0031516647431999445


113it [01:28,  1.28it/s]


epoch: 92, validation loss: 0.0031490395870059727


113it [01:33,  1.21it/s]


epoch: 93, validation loss: 0.0031464821565896275


113it [01:30,  1.25it/s]


epoch: 94, validation loss: 0.0031437149643898012


113it [01:25,  1.32it/s]


epoch: 95, validation loss: 0.003140071639791131


113it [01:19,  1.42it/s]


epoch: 96, validation loss: 0.003136054379865527


113it [01:24,  1.33it/s]


epoch: 97, validation loss: 0.0031326527427881956


113it [01:41,  1.11it/s]


epoch: 98, validation loss: 0.003129404978826642


113it [01:39,  1.13it/s]


epoch: 99, validation loss: 0.0031255918834358454


113it [01:39,  1.14it/s]


epoch: 100, validation loss: 0.003120762277394533


113it [01:43,  1.09it/s]


epoch: 101, validation loss: 0.003115933407098055


113it [01:53,  1.01s/it]


epoch: 102, validation loss: 0.003111801575869322


113it [02:22,  1.26s/it]


epoch: 103, validation loss: 0.00310829421505332


113it [02:02,  1.08s/it]


epoch: 104, validation loss: 0.003105047158896923


113it [02:02,  1.08s/it]


epoch: 105, validation loss: 0.0031022403854876757


113it [02:00,  1.06s/it]


epoch: 106, validation loss: 0.0030991625878959896


113it [02:04,  1.10s/it]


epoch: 107, validation loss: 0.003096359334886074


113it [02:04,  1.10s/it]


epoch: 108, validation loss: 0.003093830766156316


113it [01:58,  1.05s/it]


epoch: 109, validation loss: 0.0030917085241526365


113it [02:22,  1.26s/it]


epoch: 110, validation loss: 0.0030896914936602115


113it [02:00,  1.06s/it]


epoch: 111, validation loss: 0.0030878137331455946


113it [02:01,  1.08s/it]


epoch: 112, validation loss: 0.0030859338771551848


113it [01:58,  1.05s/it]


epoch: 113, validation loss: 0.0030841226596385243


113it [01:58,  1.05s/it]


epoch: 114, validation loss: 0.0030823277682065966


113it [01:56,  1.03s/it]


epoch: 115, validation loss: 0.0030806977301836014


113it [01:56,  1.03s/it]


epoch: 116, validation loss: 0.0030789969116449354


113it [01:56,  1.03s/it]


epoch: 117, validation loss: 0.0030773999355733395


113it [01:56,  1.03s/it]


epoch: 118, validation loss: 0.0030757658928632736


113it [01:42,  1.10it/s]


epoch: 119, validation loss: 0.0030743177235126496


113it [01:26,  1.30it/s]


epoch: 120, validation loss: 0.003072929084300995


113it [01:27,  1.30it/s]


epoch: 121, validation loss: 0.0030717479530721903


113it [01:29,  1.26it/s]


epoch: 122, validation loss: 0.0030702548194676636


113it [01:28,  1.28it/s]


epoch: 123, validation loss: 0.0030689577013254168


113it [01:26,  1.31it/s]


epoch: 124, validation loss: 0.003067689007148147


113it [01:24,  1.34it/s]


epoch: 125, validation loss: 0.0030663716793060303


113it [01:37,  1.16it/s]


epoch: 126, validation loss: 0.0030651241540908813


113it [01:37,  1.16it/s]


epoch: 127, validation loss: 0.0030638747848570347


113it [01:35,  1.18it/s]


epoch: 128, validation loss: 0.0030627290438860655


113it [01:38,  1.15it/s]


epoch: 129, validation loss: 0.0030612433329224586


113it [01:43,  1.09it/s]


epoch: 130, validation loss: 0.003059941101819277


113it [01:40,  1.13it/s]


epoch: 131, validation loss: 0.0030586335342377424


113it [01:29,  1.27it/s]


epoch: 132, validation loss: 0.003057536669075489


113it [01:27,  1.28it/s]


epoch: 133, validation loss: 0.0030556742660701273


113it [01:45,  1.07it/s]


epoch: 134, validation loss: 0.003053687987849116


113it [01:37,  1.16it/s]


epoch: 135, validation loss: 0.003051946870982647


113it [01:34,  1.20it/s]


epoch: 136, validation loss: 0.0030504106543958187


113it [01:46,  1.06it/s]


epoch: 137, validation loss: 0.0030492495372891424


113it [02:07,  1.13s/it]


epoch: 138, validation loss: 0.0030479903798550367


113it [02:09,  1.14s/it]


epoch: 139, validation loss: 0.003046578504145145


113it [02:03,  1.09s/it]


epoch: 140, validation loss: 0.0030452981777489186


113it [02:03,  1.09s/it]


epoch: 141, validation loss: 0.0030440322775393725


113it [02:01,  1.08s/it]


epoch: 142, validation loss: 0.0030426326859742405


113it [02:05,  1.11s/it]


epoch: 143, validation loss: 0.003041681367903948


113it [01:59,  1.06s/it]


epoch: 144, validation loss: 0.003040432697162032


113it [02:02,  1.08s/it]


epoch: 145, validation loss: 0.003039586590602994


113it [02:13,  1.18s/it]


epoch: 146, validation loss: 0.0030392361618578434


113it [02:07,  1.13s/it]


epoch: 147, validation loss: 0.0030367867834866047


113it [02:06,  1.12s/it]


epoch: 148, validation loss: 0.003035455234348774


113it [02:04,  1.10s/it]


epoch: 149, validation loss: 0.003034058976918459


113it [02:02,  1.09s/it]


epoch: 150, validation loss: 0.0030328526813536883


113it [02:25,  1.29s/it]


epoch: 151, validation loss: 0.0030318065732717514


113it [02:10,  1.15s/it]


epoch: 152, validation loss: 0.003030549017712474


113it [01:56,  1.03s/it]


epoch: 153, validation loss: 0.003029288267716765


113it [02:25,  1.29s/it]


epoch: 154, validation loss: 0.0030281264241784813


113it [02:05,  1.11s/it]


epoch: 155, validation loss: 0.0030269492324441673


113it [02:02,  1.09s/it]


epoch: 156, validation loss: 0.0030256088357418776


113it [02:04,  1.10s/it]


epoch: 157, validation loss: 0.00302442979067564


113it [02:05,  1.12s/it]


epoch: 158, validation loss: 0.0030232477001845837


113it [01:59,  1.05s/it]


epoch: 159, validation loss: 0.0030220944806933405


113it [01:54,  1.01s/it]


epoch: 160, validation loss: 0.0030208511743694544


113it [01:54,  1.01s/it]


epoch: 161, validation loss: 0.0030196972470730543


113it [03:56,  2.10s/it]


epoch: 162, validation loss: 0.003018675157800317


113it [01:25,  1.32it/s]


epoch: 163, validation loss: 0.0030177133250981567


113it [01:26,  1.31it/s]


epoch: 164, validation loss: 0.003016815045848489


113it [01:47,  1.05it/s]


epoch: 165, validation loss: 0.003015788970515132


113it [01:49,  1.03it/s]


epoch: 166, validation loss: 0.003014781726524234


113it [01:45,  1.07it/s]


epoch: 167, validation loss: 0.003013833900913596


113it [01:44,  1.08it/s]


epoch: 168, validation loss: 0.003013010835275054


113it [01:47,  1.06it/s]


epoch: 169, validation loss: 0.003012308394536376


113it [01:50,  1.03it/s]


epoch: 170, validation loss: 0.0030115544702857735


113it [01:49,  1.03it/s]


epoch: 171, validation loss: 0.003010730491951108


113it [01:50,  1.03it/s]


epoch: 172, validation loss: 0.00300993911921978


113it [01:52,  1.01it/s]


epoch: 173, validation loss: 0.0030093092937022448


113it [01:50,  1.02it/s]


epoch: 174, validation loss: 0.0030087667796760797


113it [01:50,  1.02it/s]


epoch: 175, validation loss: 0.0030079565662890673


113it [01:49,  1.03it/s]


epoch: 176, validation loss: 0.003007364002987742


113it [01:51,  1.02it/s]


epoch: 177, validation loss: 0.003006680402904749


113it [01:47,  1.06it/s]


epoch: 178, validation loss: 0.0030061220563948154


113it [01:50,  1.02it/s]


epoch: 179, validation loss: 0.0030055013485252857


113it [01:50,  1.03it/s]


epoch: 180, validation loss: 0.003004775382578373


113it [01:49,  1.03it/s]


epoch: 181, validation loss: 0.0030040186922997234


113it [01:53,  1.00s/it]


epoch: 182, validation loss: 0.0030032440554350615


113it [01:40,  1.12it/s]


epoch: 183, validation loss: 0.0030023914109915495


113it [01:37,  1.15it/s]


epoch: 184, validation loss: 0.00300156744197011


113it [01:36,  1.17it/s]


epoch: 185, validation loss: 0.003000810286030173


113it [01:39,  1.13it/s]


epoch: 186, validation loss: 0.0030001104064285755


113it [01:36,  1.17it/s]


epoch: 187, validation loss: 0.0029995152354240417


113it [01:35,  1.18it/s]


epoch: 188, validation loss: 0.002998862164095044


113it [01:35,  1.18it/s]


epoch: 189, validation loss: 0.0029982403013855217


113it [01:22,  1.38it/s]


epoch: 190, validation loss: 0.002997774165123701


113it [01:26,  1.30it/s]


epoch: 191, validation loss: 0.002996933413669467


113it [01:18,  1.43it/s]


epoch: 192, validation loss: 0.0029957685619592667


113it [01:21,  1.39it/s]


epoch: 193, validation loss: 0.0029946530517190695


113it [01:19,  1.41it/s]


epoch: 194, validation loss: 0.0029936970770359037


113it [01:25,  1.32it/s]


epoch: 195, validation loss: 0.0029928069561719894


113it [01:15,  1.50it/s]


epoch: 196, validation loss: 0.002991856522858143


113it [01:23,  1.35it/s]


epoch: 197, validation loss: 0.0029908793698996306


113it [01:17,  1.45it/s]


epoch: 198, validation loss: 0.002990090744569898


113it [01:15,  1.49it/s]


epoch: 199, validation loss: 0.002989160129800439


113it [01:16,  1.49it/s]


epoch: 200, validation loss: 0.002988303080201149


113it [01:17,  1.46it/s]


epoch: 201, validation loss: 0.002987488880753517


113it [01:17,  1.47it/s]


epoch: 202, validation loss: 0.0029867168143391607


113it [01:17,  1.45it/s]


epoch: 203, validation loss: 0.002986014122143388


113it [01:18,  1.45it/s]


epoch: 204, validation loss: 0.0029853083938360214


113it [01:17,  1.46it/s]


epoch: 205, validation loss: 0.0029847507737576964


113it [01:18,  1.44it/s]


epoch: 206, validation loss: 0.002984139369800687


113it [01:14,  1.52it/s]


epoch: 207, validation loss: 0.0029836951196193697


113it [01:17,  1.46it/s]


epoch: 208, validation loss: 0.002983152158558369


113it [01:20,  1.41it/s]


epoch: 209, validation loss: 0.002982643898576498


113it [01:17,  1.45it/s]


epoch: 210, validation loss: 0.002981977090239525


113it [01:14,  1.51it/s]


epoch: 211, validation loss: 0.0029815260693430903


113it [01:15,  1.51it/s]


epoch: 212, validation loss: 0.0029809460788965225


113it [01:18,  1.45it/s]


epoch: 213, validation loss: 0.0029804906714707615


113it [01:19,  1.43it/s]


epoch: 214, validation loss: 0.0029800936952233317


113it [01:16,  1.47it/s]


epoch: 215, validation loss: 0.002979670651257038


113it [01:14,  1.52it/s]


epoch: 216, validation loss: 0.002979378467425704


113it [01:17,  1.46it/s]


epoch: 217, validation loss: 0.0029788604006171227


113it [01:16,  1.47it/s]


epoch: 218, validation loss: 0.002978420117869973


113it [01:16,  1.48it/s]


epoch: 219, validation loss: 0.002977896500378847


113it [01:15,  1.50it/s]


epoch: 220, validation loss: 0.0029772571474313736


113it [01:18,  1.45it/s]


epoch: 221, validation loss: 0.002976825013756752


113it [01:16,  1.47it/s]


epoch: 222, validation loss: 0.0029762538615614176


113it [01:15,  1.50it/s]


epoch: 223, validation loss: 0.0029758440796285868


113it [01:16,  1.48it/s]


epoch: 224, validation loss: 0.0029755211621522904


113it [01:48,  1.04it/s]


epoch: 225, validation loss: 0.0029751779418438675


113it [01:37,  1.16it/s]


epoch: 226, validation loss: 0.002974881567060947


113it [01:35,  1.18it/s]


epoch: 227, validation loss: 0.0029746554605662823


113it [01:36,  1.17it/s]


epoch: 228, validation loss: 0.0029742908850312233


113it [01:36,  1.17it/s]


epoch: 229, validation loss: 0.0029739462677389385


113it [01:35,  1.19it/s]


epoch: 230, validation loss: 0.002973705753684044


113it [01:34,  1.19it/s]


epoch: 231, validation loss: 0.002973250327631831


113it [01:23,  1.35it/s]


epoch: 232, validation loss: 0.002972904611378908


113it [01:20,  1.40it/s]


epoch: 233, validation loss: 0.002972702980041504


113it [01:13,  1.53it/s]


epoch: 234, validation loss: 0.0029723350517451763


113it [01:12,  1.55it/s]


epoch: 235, validation loss: 0.002972278743982315


113it [01:16,  1.49it/s]


epoch: 236, validation loss: 0.002972006779164076


113it [01:20,  1.41it/s]


epoch: 237, validation loss: 0.0029717215802520514


113it [01:15,  1.50it/s]


epoch: 238, validation loss: 0.0029713183082640173


113it [01:12,  1.55it/s]


epoch: 239, validation loss: 0.0029708670917898417


113it [01:14,  1.51it/s]


epoch: 240, validation loss: 0.002970208656042814


113it [01:18,  1.44it/s]


epoch: 241, validation loss: 0.002969742752611637


113it [01:17,  1.45it/s]


epoch: 242, validation loss: 0.0029689606931060553


113it [01:13,  1.53it/s]


epoch: 243, validation loss: 0.002968498272821307


113it [01:14,  1.52it/s]


epoch: 244, validation loss: 0.0029680710285902023


113it [01:17,  1.46it/s]


epoch: 245, validation loss: 0.0029674884863197803


113it [01:18,  1.44it/s]


epoch: 246, validation loss: 0.0029669210594147444


113it [01:15,  1.50it/s]


epoch: 247, validation loss: 0.0029664144292473795


113it [01:17,  1.45it/s]


epoch: 248, validation loss: 0.002966016540303826


113it [01:20,  1.40it/s]


epoch: 249, validation loss: 0.002965632826089859


113it [01:14,  1.52it/s]


epoch: 250, validation loss: 0.002965210471302271


113it [01:13,  1.54it/s]


epoch: 251, validation loss: 0.0029647587798535824


113it [01:15,  1.51it/s]


epoch: 252, validation loss: 0.0029644013848155737


113it [01:18,  1.44it/s]


epoch: 253, validation loss: 0.002964177625253797


113it [01:14,  1.51it/s]


epoch: 254, validation loss: 0.0029637638945132494


113it [01:13,  1.55it/s]


epoch: 255, validation loss: 0.0029632086120545864


113it [01:13,  1.54it/s]


epoch: 256, validation loss: 0.0029628782253712416


113it [01:17,  1.45it/s]


epoch: 257, validation loss: 0.0029625687841325996


113it [01:16,  1.48it/s]


epoch: 258, validation loss: 0.0029622486419975756


113it [01:13,  1.54it/s]


epoch: 259, validation loss: 0.0029619585163891318


113it [01:14,  1.52it/s]


epoch: 260, validation loss: 0.0029616686515510082


113it [01:17,  1.46it/s]


epoch: 261, validation loss: 0.0029615836776793


113it [01:18,  1.45it/s]


epoch: 262, validation loss: 0.002961311489343643


113it [01:14,  1.52it/s]


epoch: 263, validation loss: 0.0029609731771051886


113it [01:13,  1.53it/s]


epoch: 264, validation loss: 0.0029607675969600677


113it [01:19,  1.42it/s]


epoch: 265, validation loss: 0.0029605883173644543


113it [01:18,  1.44it/s]


epoch: 266, validation loss: 0.0029604008793830874


113it [01:13,  1.53it/s]


epoch: 267, validation loss: 0.002960043726488948


113it [01:14,  1.51it/s]


epoch: 268, validation loss: 0.002959877485409379


113it [01:18,  1.44it/s]


epoch: 269, validation loss: 0.0029596220701932907


113it [01:20,  1.40it/s]


epoch: 270, validation loss: 0.0029594388511031866


113it [01:15,  1.49it/s]


epoch: 271, validation loss: 0.0029592767916619776


113it [01:12,  1.56it/s]


epoch: 272, validation loss: 0.0029591154307127


113it [01:16,  1.48it/s]


epoch: 273, validation loss: 0.0029590199515223503


113it [01:19,  1.42it/s]


epoch: 274, validation loss: 0.0029587580170482397


113it [01:14,  1.52it/s]


epoch: 275, validation loss: 0.002958603920415044


113it [01:14,  1.51it/s]


epoch: 276, validation loss: 0.0029584802500903606


113it [01:19,  1.43it/s]


epoch: 277, validation loss: 0.0029583622235804796


113it [01:21,  1.39it/s]


epoch: 278, validation loss: 0.002958156866952777


113it [01:14,  1.51it/s]


epoch: 279, validation loss: 0.0029580025002360344


113it [01:13,  1.54it/s]


epoch: 280, validation loss: 0.0029580169450491665
EarlyStopping counter: 1 out of 10


113it [01:15,  1.50it/s]


epoch: 281, validation loss: 0.0029577060882002115


113it [01:19,  1.42it/s]


epoch: 282, validation loss: 0.0029575468506664038


113it [01:14,  1.51it/s]


epoch: 283, validation loss: 0.002957558026537299
EarlyStopping counter: 1 out of 10


113it [01:13,  1.54it/s]


epoch: 284, validation loss: 0.0029573037941008806


113it [01:15,  1.51it/s]


epoch: 285, validation loss: 0.0029570535011589527


113it [01:17,  1.45it/s]


epoch: 286, validation loss: 0.0029568644240498543


113it [01:18,  1.44it/s]


epoch: 287, validation loss: 0.002956785261631012


113it [01:14,  1.51it/s]


epoch: 288, validation loss: 0.0029566990956664085


113it [01:14,  1.52it/s]


epoch: 289, validation loss: 0.0029563789814710615


113it [01:18,  1.44it/s]


epoch: 290, validation loss: 0.0029562423098832367


113it [01:18,  1.44it/s]


epoch: 291, validation loss: 0.002956144977360964


113it [01:15,  1.50it/s]


epoch: 292, validation loss: 0.0029559289291501045


113it [01:14,  1.52it/s]


epoch: 293, validation loss: 0.002956011071801186
EarlyStopping counter: 1 out of 10


113it [01:18,  1.43it/s]


epoch: 294, validation loss: 0.0029557359032332895


113it [01:18,  1.43it/s]


epoch: 295, validation loss: 0.0029555270448327065


113it [01:14,  1.52it/s]


epoch: 296, validation loss: 0.0029553426429629326


113it [01:13,  1.53it/s]


epoch: 297, validation loss: 0.0029551868699491026


113it [01:15,  1.49it/s]


epoch: 298, validation loss: 0.0029550096951425076


113it [01:18,  1.44it/s]


epoch: 299, validation loss: 0.0029548434540629387


113it [01:22,  1.37it/s]


epoch: 300, validation loss: 0.0029547440353780985


113it [01:16,  1.47it/s]


epoch: 301, validation loss: 0.002954439278692007


113it [01:15,  1.51it/s]


epoch: 302, validation loss: 0.002954274658113718


113it [01:15,  1.49it/s]


epoch: 303, validation loss: 0.002954097930341959


113it [01:15,  1.49it/s]


epoch: 304, validation loss: 0.0029539726581424477


113it [01:15,  1.49it/s]


epoch: 305, validation loss: 0.0029537370428442955


113it [01:16,  1.48it/s]


epoch: 306, validation loss: 0.002953593842685223


113it [01:19,  1.43it/s]


epoch: 307, validation loss: 0.0029535770881921053


113it [01:18,  1.43it/s]


epoch: 308, validation loss: 0.002953272070735693


113it [01:13,  1.53it/s]


epoch: 309, validation loss: 0.002953083710744977


113it [01:14,  1.53it/s]


epoch: 310, validation loss: 0.002953096050769091
EarlyStopping counter: 1 out of 10


113it [01:17,  1.45it/s]


epoch: 311, validation loss: 0.0029527433123439552


113it [01:15,  1.51it/s]


epoch: 312, validation loss: 0.002952596861869097


113it [01:15,  1.51it/s]


epoch: 313, validation loss: 0.002952325390651822


113it [01:14,  1.51it/s]


epoch: 314, validation loss: 0.0029521235171705486


113it [01:17,  1.46it/s]


epoch: 315, validation loss: 0.0029518094565719366


113it [01:14,  1.51it/s]


epoch: 316, validation loss: 0.0029516119975596666


113it [01:14,  1.52it/s]


epoch: 317, validation loss: 0.002951447619125247


113it [01:14,  1.52it/s]


epoch: 318, validation loss: 0.0029513458721339703


113it [01:17,  1.46it/s]


epoch: 319, validation loss: 0.0029512636829167604


113it [01:21,  1.39it/s]


epoch: 320, validation loss: 0.0029510636813938618


113it [01:15,  1.50it/s]


epoch: 321, validation loss: 0.00295084317214787


113it [01:13,  1.54it/s]


epoch: 322, validation loss: 0.002950760768726468


113it [01:16,  1.47it/s]


epoch: 323, validation loss: 0.002950512571260333


113it [01:18,  1.45it/s]


epoch: 324, validation loss: 0.002950440375134349


113it [01:14,  1.52it/s]


epoch: 325, validation loss: 0.002950232243165374


113it [01:13,  1.54it/s]


epoch: 326, validation loss: 0.0029501530807465315


113it [01:18,  1.44it/s]


epoch: 327, validation loss: 0.0029499011766165495


113it [01:16,  1.48it/s]


epoch: 328, validation loss: 0.002949703736230731


113it [01:14,  1.52it/s]


epoch: 329, validation loss: 0.0029496438801288605


113it [01:12,  1.56it/s]


epoch: 330, validation loss: 0.0029495123494416475


113it [01:16,  1.47it/s]


epoch: 331, validation loss: 0.0029492941685020924


113it [01:16,  1.48it/s]


epoch: 332, validation loss: 0.0029492464382201433


113it [01:14,  1.52it/s]


epoch: 333, validation loss: 0.0029490890353918078


113it [01:14,  1.52it/s]


epoch: 334, validation loss: 0.002949203373864293
EarlyStopping counter: 1 out of 10


113it [01:17,  1.46it/s]


epoch: 335, validation loss: 0.002948806853964925


113it [01:17,  1.46it/s]


epoch: 336, validation loss: 0.00294872023165226


113it [01:13,  1.53it/s]


epoch: 337, validation loss: 0.0029484836943447588


113it [01:14,  1.51it/s]


epoch: 338, validation loss: 0.0029485532827675342
EarlyStopping counter: 1 out of 10


113it [01:16,  1.48it/s]


epoch: 339, validation loss: 0.002948297904804349


113it [01:17,  1.45it/s]


epoch: 340, validation loss: 0.0029481647070497274


113it [01:15,  1.49it/s]


epoch: 341, validation loss: 0.00294798843562603


113it [01:14,  1.52it/s]


epoch: 342, validation loss: 0.002947989162057638
EarlyStopping counter: 1 out of 10


113it [01:17,  1.45it/s]


epoch: 343, validation loss: 0.0029477265011519193


113it [01:15,  1.51it/s]


epoch: 344, validation loss: 0.0029475900810211897


113it [01:14,  1.52it/s]


epoch: 345, validation loss: 0.002947468291968107


113it [01:16,  1.49it/s]


epoch: 346, validation loss: 0.0029475293308496475
EarlyStopping counter: 1 out of 10


113it [01:17,  1.45it/s]


epoch: 347, validation loss: 0.0029472173191607


113it [01:17,  1.46it/s]


epoch: 348, validation loss: 0.0029471234884113073


113it [01:14,  1.51it/s]


epoch: 349, validation loss: 0.0029470487497746944


113it [01:15,  1.50it/s]


epoch: 350, validation loss: 0.0029467784333974123


113it [01:18,  1.45it/s]


epoch: 351, validation loss: 0.0029466985817998646


113it [01:16,  1.48it/s]


epoch: 352, validation loss: 0.002946671340614557


113it [01:13,  1.53it/s]


epoch: 353, validation loss: 0.0029464082326740026


113it [01:14,  1.52it/s]


epoch: 354, validation loss: 0.0029462583176791666


113it [01:18,  1.45it/s]


epoch: 355, validation loss: 0.0029460629634559155


113it [01:17,  1.46it/s]


epoch: 356, validation loss: 0.002946158880367875
EarlyStopping counter: 1 out of 10


113it [01:16,  1.48it/s]


epoch: 357, validation loss: 0.0029457851871848104


113it [01:14,  1.51it/s]


epoch: 358, validation loss: 0.0029457006603479385


113it [01:14,  1.51it/s]


epoch: 359, validation loss: 0.002945480402559042


113it [01:15,  1.50it/s]


epoch: 360, validation loss: 0.0029453446529805662


113it [01:16,  1.47it/s]


epoch: 361, validation loss: 0.0029453507065773012
EarlyStopping counter: 1 out of 10


113it [01:18,  1.44it/s]


epoch: 362, validation loss: 0.002945063645020127


113it [01:20,  1.41it/s]


epoch: 363, validation loss: 0.002944965371862054


113it [01:15,  1.50it/s]


epoch: 364, validation loss: 0.0029448599088937044


113it [01:18,  1.43it/s]


epoch: 365, validation loss: 0.0029446284752339125


113it [01:14,  1.52it/s]


epoch: 366, validation loss: 0.0029444803949445486


113it [01:17,  1.45it/s]


epoch: 367, validation loss: 0.002944335797801614


113it [01:14,  1.51it/s]


epoch: 368, validation loss: 0.002944212853908539


113it [01:15,  1.49it/s]


epoch: 369, validation loss: 0.0029441153164952993


113it [01:16,  1.48it/s]


epoch: 370, validation loss: 0.002943935338407755


113it [01:19,  1.43it/s]


epoch: 371, validation loss: 0.0029439630638808012
EarlyStopping counter: 1 out of 10


113it [01:15,  1.49it/s]


epoch: 372, validation loss: 0.0029436771105974912


113it [01:12,  1.57it/s]


epoch: 373, validation loss: 0.002943564224988222


113it [01:14,  1.52it/s]


epoch: 374, validation loss: 0.0029434820357710123


113it [01:18,  1.44it/s]


epoch: 375, validation loss: 0.002943267123773694


113it [01:15,  1.50it/s]


epoch: 376, validation loss: 0.002943047545850277


113it [01:14,  1.51it/s]


epoch: 377, validation loss: 0.0029429560527205467


113it [01:17,  1.46it/s]


epoch: 378, validation loss: 0.002942891577258706


113it [01:18,  1.44it/s]


epoch: 379, validation loss: 0.0029424668569117784


113it [01:15,  1.49it/s]


epoch: 380, validation loss: 0.0029423683881759644


113it [01:13,  1.55it/s]


epoch: 381, validation loss: 0.0029420887678861616


113it [01:13,  1.53it/s]


epoch: 382, validation loss: 0.002941854763776064


113it [01:17,  1.46it/s]


epoch: 383, validation loss: 0.002941767917945981


113it [01:18,  1.44it/s]


epoch: 384, validation loss: 0.0029417828377336264
EarlyStopping counter: 1 out of 10


113it [01:14,  1.51it/s]


epoch: 385, validation loss: 0.002941577211022377


113it [01:17,  1.46it/s]


epoch: 386, validation loss: 0.0029412531293928623


113it [01:17,  1.45it/s]


epoch: 387, validation loss: 0.0029413839895278213
EarlyStopping counter: 1 out of 10


113it [01:28,  1.28it/s]


epoch: 388, validation loss: 0.0029409888759255407


113it [01:15,  1.50it/s]


epoch: 389, validation loss: 0.0029407658241689203


113it [01:13,  1.54it/s]


epoch: 390, validation loss: 0.002940651038661599


113it [01:17,  1.46it/s]


epoch: 391, validation loss: 0.002940565561875701


113it [01:19,  1.42it/s]


epoch: 392, validation loss: 0.002940306458622217


113it [01:14,  1.51it/s]


epoch: 393, validation loss: 0.0029401972610503435


113it [01:13,  1.53it/s]


epoch: 394, validation loss: 0.002940056398510933


113it [01:15,  1.49it/s]


epoch: 395, validation loss: 0.002939936686307192


113it [01:20,  1.41it/s]


epoch: 396, validation loss: 0.0029396549705415966


113it [01:14,  1.52it/s]


epoch: 397, validation loss: 0.0029394803568720818


113it [01:12,  1.56it/s]


epoch: 398, validation loss: 0.0029393883887678385


113it [01:16,  1.48it/s]


epoch: 399, validation loss: 0.0029392291326075792


113it [01:18,  1.43it/s]


epoch: 400, validation loss: 0.0029391832649707794


113it [01:15,  1.49it/s]


epoch: 401, validation loss: 0.002939052637666464


113it [01:12,  1.57it/s]


epoch: 402, validation loss: 0.002938961610198021


113it [01:15,  1.49it/s]


epoch: 403, validation loss: 0.002938926229253411


113it [01:17,  1.45it/s]


epoch: 404, validation loss: 0.002938853558152914


113it [01:17,  1.45it/s]


epoch: 405, validation loss: 0.0029385699704289436


113it [01:14,  1.52it/s]


epoch: 406, validation loss: 0.0029385057277977467


113it [01:16,  1.47it/s]


epoch: 407, validation loss: 0.002938486421480775


113it [01:17,  1.45it/s]


epoch: 408, validation loss: 0.0029384950175881386
EarlyStopping counter: 1 out of 10


113it [01:17,  1.45it/s]


epoch: 409, validation loss: 0.0029381460044533014


113it [01:13,  1.55it/s]


epoch: 410, validation loss: 0.0029380817525088785


113it [01:13,  1.54it/s]


epoch: 411, validation loss: 0.0029380500968545675


113it [01:18,  1.44it/s]


epoch: 412, validation loss: 0.002937905490398407


113it [01:25,  1.32it/s]


epoch: 413, validation loss: 0.002937770700082183


113it [01:14,  1.51it/s]


epoch: 414, validation loss: 0.0029378367867320776
EarlyStopping counter: 1 out of 10


113it [01:15,  1.51it/s]


epoch: 415, validation loss: 0.0029375569336116316


113it [01:15,  1.49it/s]


epoch: 416, validation loss: 0.0029375264421105385


113it [01:16,  1.48it/s]


epoch: 417, validation loss: 0.0029373385664075613


113it [01:14,  1.51it/s]


epoch: 418, validation loss: 0.0029374072421342132
EarlyStopping counter: 1 out of 10


113it [01:16,  1.48it/s]


epoch: 419, validation loss: 0.0029374402947723866
EarlyStopping counter: 2 out of 10


113it [01:17,  1.45it/s]


epoch: 420, validation loss: 0.0029370158445090055


113it [01:18,  1.44it/s]


epoch: 421, validation loss: 0.002937089866027236
EarlyStopping counter: 1 out of 10


113it [01:14,  1.52it/s]


epoch: 422, validation loss: 0.0029367695190012453


113it [01:16,  1.48it/s]


epoch: 423, validation loss: 0.0029366421606391666


113it [01:17,  1.45it/s]


epoch: 424, validation loss: 0.002936657750979066
EarlyStopping counter: 1 out of 10


113it [01:15,  1.50it/s]


epoch: 425, validation loss: 0.0029364451859146355


113it [01:13,  1.55it/s]


epoch: 426, validation loss: 0.002936655879020691
EarlyStopping counter: 1 out of 10


113it [01:13,  1.54it/s]


epoch: 427, validation loss: 0.0029362393449991944


113it [01:17,  1.46it/s]


epoch: 428, validation loss: 0.002936137607321143


113it [01:15,  1.50it/s]


epoch: 429, validation loss: 0.0029360777791589496


113it [01:15,  1.50it/s]


epoch: 430, validation loss: 0.002935958094894886


113it [01:13,  1.53it/s]


epoch: 431, validation loss: 0.0029359115380793808


113it [01:17,  1.45it/s]


epoch: 432, validation loss: 0.002935773925855756


113it [01:15,  1.50it/s]


epoch: 433, validation loss: 0.0029361143149435522
EarlyStopping counter: 1 out of 10


113it [01:14,  1.52it/s]


epoch: 434, validation loss: 0.0029355941619724035


113it [01:14,  1.51it/s]


epoch: 435, validation loss: 0.0029357420373708008
EarlyStopping counter: 1 out of 10


113it [01:18,  1.45it/s]


epoch: 436, validation loss: 0.0029355008248239754


113it [01:15,  1.50it/s]


epoch: 437, validation loss: 0.00293536183424294


113it [01:14,  1.53it/s]


epoch: 438, validation loss: 0.0029352793749421835


113it [01:16,  1.47it/s]


epoch: 439, validation loss: 0.002935055186972022


113it [01:17,  1.46it/s]


epoch: 440, validation loss: 0.002935028159990907


113it [01:15,  1.50it/s]


epoch: 441, validation loss: 0.002934888219460845


113it [01:14,  1.52it/s]


epoch: 442, validation loss: 0.0029348612669855357


113it [01:13,  1.54it/s]


epoch: 443, validation loss: 0.0029349688161164522
EarlyStopping counter: 1 out of 10


113it [01:18,  1.45it/s]


epoch: 444, validation loss: 0.0029346118681132793


113it [01:15,  1.50it/s]


epoch: 445, validation loss: 0.0029346775449812412
EarlyStopping counter: 1 out of 10


113it [01:14,  1.51it/s]


epoch: 446, validation loss: 0.0029343524761497974


113it [01:13,  1.53it/s]


epoch: 447, validation loss: 0.002934274738654494


113it [01:17,  1.45it/s]


epoch: 448, validation loss: 0.0029341918416321278


113it [01:16,  1.48it/s]


epoch: 449, validation loss: 0.0029342344496399164
EarlyStopping counter: 1 out of 10


113it [01:14,  1.52it/s]


epoch: 450, validation loss: 0.002933955974876881


113it [01:15,  1.50it/s]


epoch: 451, validation loss: 0.0029339718259871004
EarlyStopping counter: 1 out of 10


113it [01:18,  1.45it/s]


epoch: 452, validation loss: 0.0029338426049798725


113it [01:19,  1.42it/s]


epoch: 453, validation loss: 0.0029336978029459713


113it [01:14,  1.52it/s]


epoch: 454, validation loss: 0.0029335513431578875


113it [01:14,  1.51it/s]


epoch: 455, validation loss: 0.0029334810096770525


113it [01:17,  1.46it/s]


epoch: 456, validation loss: 0.002933592554181814
EarlyStopping counter: 1 out of 10


113it [01:15,  1.49it/s]


epoch: 457, validation loss: 0.002933374149724841


113it [01:15,  1.51it/s]


epoch: 458, validation loss: 0.0029332416504621508


113it [01:16,  1.47it/s]


epoch: 459, validation loss: 0.0029331562109291553


113it [01:18,  1.45it/s]


epoch: 460, validation loss: 0.0029330206941813233


113it [01:18,  1.45it/s]


epoch: 461, validation loss: 0.002932921992614865


113it [01:14,  1.52it/s]


epoch: 462, validation loss: 0.0029328356124460695


113it [01:13,  1.55it/s]


epoch: 463, validation loss: 0.002932921750470996
EarlyStopping counter: 1 out of 10


113it [01:55,  1.03s/it]


epoch: 464, validation loss: 0.0029326747078448536


11it [00:10,  1.02it/s]


KeyboardInterrupt: 

In [None]:
# torch.save(model, '../../data/trained_models/model_16_07_new.pth')

In [None]:
model_path = '../../data/trained_models/model_16_07_lr_0.0001.pth'

# Save the model state dictionary and configuration
torch.save({
    'state_dict': model.state_dict(),
    'config': {
        'in_channels': model.in_channels,
        'out_channels': 1,
        'hidden_size': model.hidden_size,
        'gat_layers': model.gat_layers,
        'gcn_layers': model.gcn_layers,
        'output_layer': model.output_layer
    }
}, model_path)