In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import random
import glob
import pickle
import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from models import CNN_LSTM, SepCNN_LSTM, ConvGRU_LSTM, RandomForestBaseline, LassoModel
import joblib
from torchsummary import summary

In [2]:
seed = 42

# Set seed for NumPy
np.random.seed(seed)

# Set seed for Python's built-in random module
random.seed(seed)

# Set seed for PyTorch
torch.manual_seed(seed)

# Set seed for Torch's CUDA operations if GPU is used
# if torch.cuda.is_available():
#     torch.backends.cudnn.deterministic = True
#     torch.cuda.manual_seed(seed)

<torch._C.Generator at 0x1e9b4d2c410>

In [3]:
# # Load a sample of the data
sample_data = np.load('./data/PROCESSED_III/2018_13_155.npy')  

# # Check the shape of the sample data
print("Shape of sample data:", sample_data.shape)

Shape of sample data: (38, 1, 128, 9)


In [4]:
# Define generator function
def generator(IDs, yields, batch_size, cutoff=None):
    def load_data(ID):
        try:
            data = np.load('./data/PROCESSED_III/' + ID + '.npy')
            return data, True
        except Exception as e:
            return None, False

    batches = 0

    while True:
        batch_features = np.zeros((batch_size, 38, 1, 128, 9)) if cutoff is None else np.zeros((batch_size, cutoff, 1, 128, 9))
        batch_yields = np.zeros(batch_size)

        if batches == len(IDs) // batch_size:
            batches = 0
            yield None, None

        for i in range(batch_size):
            while True:
                index = random.choice(range(len(IDs)))
                ID = IDs[index]
                data, success = load_data(ID)
                if success:
                    break

            if data is not None:
                if cutoff is not None:
                    if not np.isnan(data).any():
                        batch_features[i, :, :, :, :] = data[:cutoff, :, :, :]
                        batch_yields[i] = yields[ID]
                else:
                    batch_features[i, :, :, :, :] = data
                    batch_yields[i] = yields[ID]
                

        batches += 1

        yield torch.tensor(batch_features, dtype=torch.float32, device='cuda'), torch.tensor(batch_yields, dtype=torch.float32, device='cuda')


In [5]:
# Datasets
yields = pickle.load(open('data/yields.p', 'rb'))

# Generators
training_generator = generator(list(yields['train'].keys()), yields['train'], 16)
validation_generator = generator(list(yields['validation'].keys()), yields['validation'], 16)

In [5]:
n_estimators_values = [50, 100, 150]  
max_depth_values = [None, 10, 20]

# Initialize variables to store best parameters and corresponding MSE
best_params = None
best_mse = float('inf')  

for n_estimators in n_estimators_values:
    for max_depth in max_depth_values:
        random_forest_model = RandomForestBaseline(n_estimators=n_estimators, max_depth=max_depth, random_state=42)

        # Fit the model to training data
        X_train, y_train = next(training_generator)
        random_forest_model.fit(X_train.cpu().reshape(X_train.shape[0], -1), y_train.cpu())

        # Make predictions on test data
        X_test, y_test = next(validation_generator)
        predictions = random_forest_model.predict(X_test.cpu().reshape(X_test.shape[0], -1))

        # Evaluate the model
        mse = random_forest_model.evaluate(X_test.cpu().reshape(X_test.cpu().shape[0], -1), y_test.cpu())

        # Print MSE for current parameter combination
        print(f"Parameters: n_estimators={n_estimators}, max_depth={max_depth} MSE: {mse}")

        # Check if current combination improves performance
        if mse < best_mse:
            print('save best model')
            joblib.dump(random_forest_model, 'random_forest_best_model.pkl')            
            best_mse = mse
            best_params = {'n_estimators': n_estimators, 'max_depth': max_depth}

# Print best parameters and corresponding MSE
print("\nBest Parameters:")
print(best_params)
print("Best Mean Squared Error:", best_mse)

Parameters: n_estimators=50, max_depth=None MSE: 1215.3222280938926
save best model
Parameters: n_estimators=50, max_depth=10 MSE: 1632.7786611852525
Parameters: n_estimators=50, max_depth=20 MSE: 1590.7839759603144
Parameters: n_estimators=100, max_depth=None MSE: 1881.0676816275231
Parameters: n_estimators=100, max_depth=10 MSE: 1404.0375377720136
Parameters: n_estimators=100, max_depth=20 MSE: 853.107621468707
save best model
Parameters: n_estimators=150, max_depth=None MSE: 1889.1073206777735
Parameters: n_estimators=150, max_depth=10 MSE: 1652.2194739090282
Parameters: n_estimators=150, max_depth=20 MSE: 2248.1081579971983

Best Parameters:
{'n_estimators': 100, 'max_depth': 20}
Best Mean Squared Error: 853.107621468707


In [6]:
alpha_values = [0.3, 0.4, 0.5]

best_alpha = None
best_mse = float('inf') 

# Iterate over alpha values
for alpha in alpha_values:
    # Initialize Lasso model with current alpha value
    lasso_model = LassoModel(alpha=alpha, random_state=42)

    # Fit the model to training data
    X_train, y_train = next(training_generator)
    lasso_model.fit(X_train.cpu().reshape(X_train.shape[0], -1), y_train.cpu())

    # Make predictions on test data
    X_test, y_test = next(validation_generator)
    predictions = lasso_model.predict(X_test.cpu().reshape(X_test.shape[0], -1))

    # Evaluate the model
    mse = lasso_model.evaluate(X_test.cpu().reshape(X_test.cpu().shape[0], -1), y_test.cpu())

    print(f"Alpha: {alpha}, MSE: {mse}")

    # Check if current alpha improves performance
    if mse < best_mse:
        print('save best model')
        joblib.dump(lasso_model, 'lasso_best_model.pkl')   
        best_mse = mse
        best_alpha = alpha

# Print best alpha and corresponding MSE
print("\nBest Alpha:", best_alpha)
print("Best Mean Squared Error:", best_mse)

Alpha: 0.3, MSE: 279.69549731154285
save best model
Alpha: 0.4, MSE: 2428.617348022587
Alpha: 0.5, MSE: 634.4768458833705

Best Alpha: 0.3
Best Mean Squared Error: 279.69549731154285


In [6]:
model_functions = {
    'CNN_LSTM': CNN_LSTM,
    'SepCNN_LSTM': SepCNN_LSTM,
    'ConvGRU_LSTM': ConvGRU_LSTM,
}


epochs = 100

for model_name, model_function in model_functions.items():
    model = model_function(dimensions=[38, 1, 128, 9])
    model.to('cuda')
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.MSELoss()

    best_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_losses = []
        for batch_data, batch_labels in tqdm.tqdm(training_generator, desc=f"Epoch {epoch+1}/{epochs}"):
            if batch_data is None:
                break

            optimizer.zero_grad()
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels.unsqueeze(1))
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        model.eval()
        val_losses = []
        
        with torch.no_grad():
            for val_data, val_labels in validation_generator:
                if val_data is None:
                    break
                val_outputs = model(val_data)
                val_loss = criterion(val_outputs, val_labels.unsqueeze(1))
                val_losses.append(val_loss.item())

        current_loss = np.mean(val_losses)
        if current_loss < best_loss:
            torch.save(model, f'{model_name}_best.pt')
            best_loss = current_loss
            
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {np.mean(train_losses):.4f}, Val Loss: {current_loss:.4f}")

  return F.conv2d(input, weight, bias, self.stride,
Epoch 1/100: 389it [00:27, 13.99it/s]


Epoch 1/100, Train Loss: 22225.5468, Val Loss: 20560.4383


Epoch 2/100: 389it [00:26, 14.88it/s]


Epoch 2/100, Train Loss: 12500.0966, Val Loss: 10456.9597


Epoch 3/100: 389it [00:26, 14.61it/s]


Epoch 3/100, Train Loss: 5517.7083, Val Loss: 5138.1607


Epoch 4/100: 389it [00:26, 14.65it/s]


Epoch 4/100, Train Loss: 3027.8418, Val Loss: 5343.1144


Epoch 5/100: 389it [00:26, 14.74it/s]


Epoch 5/100, Train Loss: 1639.3942, Val Loss: 2312.8053


Epoch 6/100: 389it [00:27, 14.40it/s]


Epoch 6/100, Train Loss: 1136.7613, Val Loss: 3508.0707


Epoch 7/100: 389it [00:27, 14.30it/s]


Epoch 7/100, Train Loss: 1029.0102, Val Loss: 2983.4327


Epoch 8/100: 389it [00:27, 14.30it/s]


Epoch 8/100, Train Loss: 978.1141, Val Loss: 4488.2710


Epoch 9/100: 389it [00:27, 14.38it/s]


Epoch 9/100, Train Loss: 954.1938, Val Loss: 2141.1601


Epoch 10/100: 389it [00:26, 14.61it/s]


Epoch 10/100, Train Loss: 900.1827, Val Loss: 2842.2853


Epoch 11/100: 389it [00:26, 14.59it/s]


Epoch 11/100, Train Loss: 877.3522, Val Loss: 3039.4682


Epoch 12/100: 389it [00:26, 14.55it/s]


Epoch 12/100, Train Loss: 809.2621, Val Loss: 4346.4946


Epoch 13/100: 389it [00:26, 14.49it/s]


Epoch 13/100, Train Loss: 791.0919, Val Loss: 3874.4511


Epoch 14/100: 389it [00:26, 14.54it/s]


Epoch 14/100, Train Loss: 794.2349, Val Loss: 4870.6922


Epoch 15/100: 389it [00:27, 14.27it/s]


Epoch 15/100, Train Loss: 755.9405, Val Loss: 3970.6995


Epoch 16/100: 389it [00:26, 14.52it/s]


Epoch 16/100, Train Loss: 693.7396, Val Loss: 1478.6347


Epoch 17/100: 389it [00:26, 14.53it/s]


Epoch 17/100, Train Loss: 687.2075, Val Loss: 2940.2179


Epoch 18/100: 389it [00:26, 14.52it/s]


Epoch 18/100, Train Loss: 680.3952, Val Loss: 2645.4345


Epoch 19/100: 389it [00:26, 14.49it/s]


Epoch 19/100, Train Loss: 678.0415, Val Loss: 2849.4052


Epoch 20/100: 389it [00:26, 14.46it/s]


Epoch 20/100, Train Loss: 661.7623, Val Loss: 3466.7622


Epoch 21/100: 389it [00:27, 14.31it/s]


Epoch 21/100, Train Loss: 681.5184, Val Loss: 3142.4172


Epoch 22/100: 389it [00:27, 14.31it/s]


Epoch 22/100, Train Loss: 656.7110, Val Loss: 3355.2934


Epoch 23/100: 389it [00:27, 14.33it/s]


Epoch 23/100, Train Loss: 666.8497, Val Loss: 3766.6565


Epoch 24/100: 389it [00:27, 14.37it/s]


Epoch 24/100, Train Loss: 650.8566, Val Loss: 2137.5411


Epoch 25/100: 389it [00:27, 14.38it/s]


Epoch 25/100, Train Loss: 660.6723, Val Loss: 2649.1621


Epoch 26/100: 389it [00:27, 14.36it/s]


Epoch 26/100, Train Loss: 648.5920, Val Loss: 1582.3899


Epoch 27/100: 389it [00:27, 14.30it/s]


Epoch 27/100, Train Loss: 633.9657, Val Loss: 1248.8244


Epoch 28/100: 389it [00:27, 14.35it/s]


Epoch 28/100, Train Loss: 610.0658, Val Loss: 1216.8524


Epoch 29/100: 389it [00:27, 14.11it/s]


Epoch 29/100, Train Loss: 629.0018, Val Loss: 1253.3367


Epoch 30/100: 389it [00:30, 12.70it/s]


Epoch 30/100, Train Loss: 606.8664, Val Loss: 4244.8942


Epoch 31/100: 389it [00:29, 13.39it/s]


Epoch 31/100, Train Loss: 613.1194, Val Loss: 86498.4103


Epoch 32/100: 389it [00:28, 13.82it/s]


Epoch 32/100, Train Loss: 622.8692, Val Loss: 34959.0724


Epoch 33/100: 389it [00:27, 14.31it/s]


Epoch 33/100, Train Loss: 618.0082, Val Loss: 1785.3786


Epoch 34/100: 389it [00:27, 14.37it/s]


Epoch 34/100, Train Loss: 633.9312, Val Loss: 3317.6462


Epoch 35/100: 389it [00:27, 14.36it/s]


Epoch 35/100, Train Loss: 657.4793, Val Loss: 3885.3874


Epoch 36/100: 389it [00:27, 14.06it/s]


Epoch 36/100, Train Loss: 627.0751, Val Loss: 1288.7617


Epoch 37/100: 389it [00:27, 14.04it/s]


Epoch 37/100, Train Loss: 618.3053, Val Loss: 1290.9045


Epoch 38/100: 389it [00:27, 14.14it/s]


Epoch 38/100, Train Loss: 606.8914, Val Loss: 2475.4406


Epoch 39/100: 389it [00:26, 14.43it/s]


Epoch 39/100, Train Loss: 591.4615, Val Loss: 1477.5438


Epoch 40/100: 389it [00:27, 14.17it/s]


Epoch 40/100, Train Loss: 595.9649, Val Loss: 1261.5203


Epoch 41/100: 389it [00:27, 14.36it/s]


Epoch 41/100, Train Loss: 604.1081, Val Loss: 2152.1472


Epoch 42/100: 389it [00:27, 14.05it/s]


Epoch 42/100, Train Loss: 605.1445, Val Loss: 1572.7913


Epoch 43/100: 389it [00:27, 14.40it/s]


Epoch 43/100, Train Loss: 601.9341, Val Loss: 2292.0155


Epoch 44/100: 389it [00:26, 14.47it/s]


Epoch 44/100, Train Loss: 615.8721, Val Loss: 5358.6032


Epoch 45/100: 389it [00:27, 14.36it/s]


Epoch 45/100, Train Loss: 589.8831, Val Loss: 1541.6962


Epoch 46/100: 389it [00:27, 14.37it/s]


Epoch 46/100, Train Loss: 600.2572, Val Loss: 1147.7255


Epoch 47/100: 389it [00:27, 14.33it/s]


Epoch 47/100, Train Loss: 588.6833, Val Loss: 1490.5684


Epoch 48/100: 389it [00:27, 14.34it/s]


Epoch 48/100, Train Loss: 575.7857, Val Loss: 1686.2801


Epoch 49/100: 389it [00:26, 14.44it/s]


Epoch 49/100, Train Loss: 597.1863, Val Loss: 1348.6681


Epoch 50/100: 389it [00:26, 14.50it/s]


Epoch 50/100, Train Loss: 605.1022, Val Loss: 1448.3142


Epoch 51/100: 389it [00:26, 14.50it/s]


Epoch 51/100, Train Loss: 585.9429, Val Loss: 2085.9542


Epoch 52/100: 389it [00:27, 14.34it/s]


Epoch 52/100, Train Loss: 604.7660, Val Loss: 1474.3298


Epoch 53/100: 389it [00:26, 14.56it/s]


Epoch 53/100, Train Loss: 586.4891, Val Loss: 1941.5677


Epoch 54/100: 389it [00:26, 14.55it/s]


Epoch 54/100, Train Loss: 579.8770, Val Loss: 2018.7620


Epoch 55/100: 389it [00:26, 14.56it/s]


Epoch 55/100, Train Loss: 557.7706, Val Loss: 1738.1370


Epoch 56/100: 389it [00:26, 14.42it/s]


Epoch 56/100, Train Loss: 577.1523, Val Loss: 1222.5580


Epoch 57/100: 389it [00:27, 14.36it/s]


Epoch 57/100, Train Loss: 574.9994, Val Loss: 1438.2894


Epoch 58/100: 389it [00:26, 14.41it/s]


Epoch 58/100, Train Loss: 575.7817, Val Loss: 1129.5837


Epoch 59/100: 389it [00:26, 14.52it/s]


Epoch 59/100, Train Loss: 585.7219, Val Loss: 1818.4430


Epoch 60/100: 389it [00:26, 14.53it/s]


Epoch 60/100, Train Loss: 593.9073, Val Loss: 1711.4903


Epoch 61/100: 389it [00:27, 14.33it/s]


Epoch 61/100, Train Loss: 577.3638, Val Loss: 1662.4783


Epoch 62/100: 389it [00:26, 14.54it/s]


Epoch 62/100, Train Loss: 571.9239, Val Loss: 1853.6292


Epoch 63/100: 389it [00:27, 14.32it/s]


Epoch 63/100, Train Loss: 578.6362, Val Loss: 1902.7388


Epoch 64/100: 389it [00:26, 14.44it/s]


Epoch 64/100, Train Loss: 560.7561, Val Loss: 1528.9698


Epoch 65/100: 389it [00:27, 14.40it/s]


Epoch 65/100, Train Loss: 586.5293, Val Loss: 1506.0918


Epoch 66/100: 389it [00:26, 14.49it/s]


Epoch 66/100, Train Loss: 563.8260, Val Loss: 1920.9778


Epoch 67/100: 389it [00:26, 14.51it/s]


Epoch 67/100, Train Loss: 574.8922, Val Loss: 1926.3301


Epoch 68/100: 389it [00:26, 14.55it/s]


Epoch 68/100, Train Loss: 571.0931, Val Loss: 1808.2005


Epoch 69/100: 389it [00:26, 14.53it/s]


Epoch 69/100, Train Loss: 579.1359, Val Loss: 1901.5874


Epoch 70/100: 389it [00:26, 14.49it/s]


Epoch 70/100, Train Loss: 549.5092, Val Loss: 1550.5828


Epoch 71/100: 389it [00:26, 14.52it/s]


Epoch 71/100, Train Loss: 544.0588, Val Loss: 1442.8448


Epoch 72/100: 389it [00:26, 14.54it/s]


Epoch 72/100, Train Loss: 561.8024, Val Loss: 1587.0143


Epoch 73/100: 389it [00:26, 14.53it/s]


Epoch 73/100, Train Loss: 566.5095, Val Loss: 2148.6393


Epoch 74/100: 389it [00:26, 14.55it/s]


Epoch 74/100, Train Loss: 544.5480, Val Loss: 1658.9050


Epoch 75/100: 389it [00:26, 14.58it/s]


Epoch 75/100, Train Loss: 524.5881, Val Loss: 1620.4348


Epoch 76/100: 389it [00:26, 14.55it/s]


Epoch 76/100, Train Loss: 567.0125, Val Loss: 1539.6374


Epoch 77/100: 389it [00:26, 14.48it/s]


Epoch 77/100, Train Loss: 563.1555, Val Loss: 1920.9307


Epoch 78/100: 389it [00:26, 14.56it/s]


Epoch 78/100, Train Loss: 581.4803, Val Loss: 2009.8213


Epoch 79/100: 389it [00:26, 14.51it/s]


Epoch 79/100, Train Loss: 570.4732, Val Loss: 1748.1295


Epoch 80/100: 389it [00:26, 14.42it/s]


Epoch 80/100, Train Loss: 558.7463, Val Loss: 1457.5356


Epoch 81/100: 389it [00:27, 14.40it/s]


Epoch 81/100, Train Loss: 582.6484, Val Loss: 1544.5340


Epoch 82/100: 389it [00:26, 14.56it/s]


Epoch 82/100, Train Loss: 571.2943, Val Loss: 1415.8226


Epoch 83/100: 389it [00:26, 14.57it/s]


Epoch 83/100, Train Loss: 550.6289, Val Loss: 1612.6089


Epoch 84/100: 389it [00:26, 14.57it/s]


Epoch 84/100, Train Loss: 566.0508, Val Loss: 1782.4913


Epoch 85/100: 389it [00:26, 14.52it/s]


Epoch 85/100, Train Loss: 547.2406, Val Loss: 1593.8333


Epoch 86/100: 389it [00:26, 14.53it/s]


Epoch 86/100, Train Loss: 543.2029, Val Loss: 1658.2790


Epoch 87/100: 389it [00:26, 14.54it/s]


Epoch 87/100, Train Loss: 533.4554, Val Loss: 7634.0481


Epoch 88/100: 389it [00:26, 14.55it/s]


Epoch 88/100, Train Loss: 540.3594, Val Loss: 1476.3161


Epoch 89/100: 389it [00:26, 14.58it/s]


Epoch 89/100, Train Loss: 528.1698, Val Loss: 1384.3335


Epoch 90/100: 389it [00:26, 14.56it/s]


Epoch 90/100, Train Loss: 548.5618, Val Loss: 1684.7858


Epoch 91/100: 389it [00:26, 14.56it/s]


Epoch 91/100, Train Loss: 530.2726, Val Loss: 1533.6253


Epoch 92/100: 389it [00:26, 14.58it/s]


Epoch 92/100, Train Loss: 534.8269, Val Loss: 1489.8889


Epoch 93/100: 389it [00:26, 14.56it/s]


Epoch 93/100, Train Loss: 526.8164, Val Loss: 1753.8163


Epoch 94/100: 389it [00:26, 14.60it/s]


Epoch 94/100, Train Loss: 544.8527, Val Loss: 1824.0242


Epoch 95/100: 389it [00:26, 14.57it/s]


Epoch 95/100, Train Loss: 548.3154, Val Loss: 1351.6792


Epoch 96/100: 389it [00:26, 14.57it/s]


Epoch 96/100, Train Loss: 543.2210, Val Loss: 1383.1312


Epoch 97/100: 389it [00:26, 14.62it/s]


Epoch 97/100, Train Loss: 526.7904, Val Loss: 1081.1803


Epoch 98/100: 389it [00:26, 14.63it/s]


Epoch 98/100, Train Loss: 533.0911, Val Loss: 1563.7162


Epoch 99/100: 389it [00:26, 14.60it/s]


Epoch 99/100, Train Loss: 530.5924, Val Loss: 1562.1012


Epoch 100/100: 389it [00:26, 14.58it/s]


Epoch 100/100, Train Loss: 536.9096, Val Loss: 1355.8591


Epoch 1/100: 389it [00:27, 13.94it/s]


Epoch 1/100, Train Loss: 22214.0497, Val Loss: 19170.4443


Epoch 2/100: 389it [00:27, 13.98it/s]


Epoch 2/100, Train Loss: 12595.6488, Val Loss: 11369.6488


Epoch 3/100: 389it [00:27, 13.95it/s]


Epoch 3/100, Train Loss: 6574.2207, Val Loss: 6057.8102


Epoch 4/100: 389it [00:27, 13.96it/s]


Epoch 4/100, Train Loss: 3545.0216, Val Loss: 2624.3091


Epoch 5/100: 389it [00:27, 13.95it/s]


Epoch 5/100, Train Loss: 1375.6017, Val Loss: 2052.4980


Epoch 6/100: 389it [00:27, 13.94it/s]


Epoch 6/100, Train Loss: 1131.4497, Val Loss: 2503.2546


Epoch 7/100: 389it [00:27, 13.95it/s]


Epoch 7/100, Train Loss: 1051.7371, Val Loss: 1724.4507


Epoch 8/100: 389it [00:27, 13.97it/s]


Epoch 8/100, Train Loss: 960.2707, Val Loss: 1700.4659


Epoch 9/100: 389it [00:27, 13.93it/s]


Epoch 9/100, Train Loss: 900.2959, Val Loss: 1608.4951


Epoch 10/100: 389it [00:27, 13.92it/s]


Epoch 10/100, Train Loss: 867.3553, Val Loss: 1639.4283


Epoch 11/100: 389it [00:27, 13.93it/s]


Epoch 11/100, Train Loss: 864.1911, Val Loss: 1871.1856


Epoch 12/100: 389it [00:28, 13.88it/s]


Epoch 12/100, Train Loss: 784.2373, Val Loss: 2127.0939


Epoch 13/100: 389it [00:27, 13.94it/s]


Epoch 13/100, Train Loss: 774.8922, Val Loss: 1340.2130


Epoch 14/100: 389it [00:27, 13.96it/s]


Epoch 14/100, Train Loss: 754.5343, Val Loss: 1268.8251


Epoch 15/100: 389it [00:27, 13.96it/s]


Epoch 15/100, Train Loss: 719.2764, Val Loss: 1792.8247


Epoch 16/100: 389it [00:27, 13.94it/s]


Epoch 16/100, Train Loss: 725.8767, Val Loss: 1987.7495


Epoch 17/100: 389it [00:27, 13.93it/s]


Epoch 17/100, Train Loss: 705.8440, Val Loss: 2168.1707


Epoch 18/100: 389it [00:27, 13.94it/s]


Epoch 18/100, Train Loss: 663.5421, Val Loss: 1804.4611


Epoch 19/100: 389it [00:27, 13.91it/s]


Epoch 19/100, Train Loss: 640.8710, Val Loss: 1970.8673


Epoch 20/100: 389it [00:27, 13.89it/s]


Epoch 20/100, Train Loss: 672.9450, Val Loss: 1810.0839


Epoch 21/100: 389it [00:27, 13.95it/s]


Epoch 21/100, Train Loss: 629.5055, Val Loss: 2144.8963


Epoch 22/100: 389it [00:27, 13.92it/s]


Epoch 22/100, Train Loss: 639.0085, Val Loss: 1561.5239


Epoch 23/100: 389it [00:27, 13.94it/s]


Epoch 23/100, Train Loss: 645.7930, Val Loss: 1683.1398


Epoch 24/100: 389it [00:28, 13.87it/s]


Epoch 24/100, Train Loss: 628.3503, Val Loss: 2108.0304


Epoch 25/100: 389it [00:27, 13.94it/s]


Epoch 25/100, Train Loss: 643.8085, Val Loss: 1463.7727


Epoch 26/100: 389it [00:27, 13.93it/s]


Epoch 26/100, Train Loss: 638.6863, Val Loss: 1663.6987


Epoch 27/100: 389it [00:27, 13.94it/s]


Epoch 27/100, Train Loss: 613.8342, Val Loss: 1747.1584


Epoch 28/100: 389it [00:27, 13.95it/s]


Epoch 28/100, Train Loss: 603.1616, Val Loss: 2186.1745


Epoch 29/100: 389it [00:27, 13.93it/s]


Epoch 29/100, Train Loss: 604.5143, Val Loss: 1577.4444


Epoch 30/100: 389it [00:27, 13.94it/s]


Epoch 30/100, Train Loss: 599.2079, Val Loss: 1670.8506


Epoch 31/100: 389it [00:28, 13.89it/s]


Epoch 31/100, Train Loss: 603.1022, Val Loss: 1851.0494


Epoch 32/100: 389it [00:27, 13.96it/s]


Epoch 32/100, Train Loss: 599.3508, Val Loss: 1897.5375


Epoch 33/100: 389it [00:27, 13.94it/s]


Epoch 33/100, Train Loss: 602.6750, Val Loss: 1358.6437


Epoch 34/100: 389it [00:27, 13.90it/s]


Epoch 34/100, Train Loss: 608.9040, Val Loss: 1375.5692


Epoch 35/100: 389it [00:27, 13.92it/s]


Epoch 35/100, Train Loss: 582.4771, Val Loss: 1916.7324


Epoch 36/100: 389it [00:27, 13.99it/s]


Epoch 36/100, Train Loss: 573.2647, Val Loss: 1639.6528


Epoch 37/100: 389it [00:27, 13.95it/s]


Epoch 37/100, Train Loss: 587.8350, Val Loss: 1685.7342


Epoch 38/100: 389it [00:27, 13.96it/s]


Epoch 38/100, Train Loss: 583.3537, Val Loss: 1589.8139


Epoch 39/100: 389it [00:27, 13.97it/s]


Epoch 39/100, Train Loss: 591.3018, Val Loss: 1657.1629


Epoch 40/100: 389it [00:27, 13.97it/s]


Epoch 40/100, Train Loss: 609.6783, Val Loss: 1793.2241


Epoch 41/100: 389it [00:27, 13.90it/s]


Epoch 41/100, Train Loss: 585.3062, Val Loss: 1933.6890


Epoch 42/100: 389it [00:27, 13.92it/s]


Epoch 42/100, Train Loss: 585.8045, Val Loss: 1733.2995


Epoch 43/100: 389it [00:27, 13.98it/s]


Epoch 43/100, Train Loss: 571.5180, Val Loss: 1638.1542


Epoch 44/100: 389it [00:27, 13.97it/s]


Epoch 44/100, Train Loss: 569.9690, Val Loss: 1795.2130


Epoch 45/100: 389it [00:27, 13.96it/s]


Epoch 45/100, Train Loss: 577.4781, Val Loss: 1616.6449


Epoch 46/100: 389it [00:27, 13.97it/s]


Epoch 46/100, Train Loss: 573.8686, Val Loss: 1428.8230


Epoch 47/100: 389it [00:27, 13.96it/s]


Epoch 47/100, Train Loss: 562.1129, Val Loss: 1594.4784


Epoch 48/100: 389it [00:27, 13.97it/s]


Epoch 48/100, Train Loss: 578.2182, Val Loss: 1740.8188


Epoch 49/100: 389it [00:27, 13.96it/s]


Epoch 49/100, Train Loss: 590.6144, Val Loss: 1415.9107


Epoch 50/100: 389it [00:27, 13.96it/s]


Epoch 50/100, Train Loss: 578.5525, Val Loss: 1950.4772


Epoch 51/100: 389it [00:27, 13.96it/s]


Epoch 51/100, Train Loss: 582.6192, Val Loss: 1548.8680


Epoch 52/100: 389it [00:27, 13.94it/s]


Epoch 52/100, Train Loss: 555.7143, Val Loss: 2412.0464


Epoch 53/100: 389it [00:27, 13.97it/s]


Epoch 53/100, Train Loss: 558.3905, Val Loss: 1561.8998


Epoch 54/100: 389it [00:29, 13.18it/s]


Epoch 54/100, Train Loss: 550.3657, Val Loss: 1791.7873


Epoch 55/100: 389it [00:28, 13.51it/s]


Epoch 55/100, Train Loss: 565.6361, Val Loss: 1471.0964


Epoch 56/100: 389it [00:28, 13.86it/s]


Epoch 56/100, Train Loss: 570.5646, Val Loss: 2019.2723


Epoch 57/100: 389it [00:28, 13.58it/s]


Epoch 57/100, Train Loss: 557.8283, Val Loss: 1634.0014


Epoch 58/100: 389it [00:28, 13.58it/s]


Epoch 58/100, Train Loss: 551.1637, Val Loss: 1685.8753


Epoch 59/100: 389it [00:28, 13.75it/s]


Epoch 59/100, Train Loss: 557.2277, Val Loss: 1534.4593


Epoch 60/100: 389it [00:28, 13.73it/s]


Epoch 60/100, Train Loss: 569.6729, Val Loss: 1670.3507


Epoch 61/100: 389it [00:28, 13.63it/s]


Epoch 61/100, Train Loss: 566.0147, Val Loss: 1864.1036


Epoch 62/100: 389it [00:28, 13.77it/s]


Epoch 62/100, Train Loss: 561.6587, Val Loss: 1932.0916


Epoch 63/100: 389it [00:28, 13.78it/s]


Epoch 63/100, Train Loss: 528.1449, Val Loss: 1889.3208


Epoch 64/100: 389it [00:29, 13.06it/s]


Epoch 64/100, Train Loss: 558.9976, Val Loss: 1672.8765


Epoch 65/100: 389it [00:29, 13.19it/s]


Epoch 65/100, Train Loss: 563.7995, Val Loss: 1926.9619


Epoch 66/100: 389it [00:28, 13.63it/s]


Epoch 66/100, Train Loss: 549.7179, Val Loss: 1648.2764


Epoch 67/100: 389it [00:28, 13.60it/s]


Epoch 67/100, Train Loss: 563.5740, Val Loss: 1735.6748


Epoch 68/100: 389it [00:28, 13.75it/s]


Epoch 68/100, Train Loss: 544.5132, Val Loss: 2509.0883


Epoch 69/100: 389it [00:28, 13.72it/s]


Epoch 69/100, Train Loss: 548.4745, Val Loss: 1780.3579


Epoch 70/100: 389it [00:28, 13.79it/s]


Epoch 70/100, Train Loss: 543.9964, Val Loss: 1856.1322


Epoch 71/100: 389it [00:28, 13.74it/s]


Epoch 71/100, Train Loss: 544.2484, Val Loss: 1746.5534


Epoch 72/100: 389it [00:28, 13.68it/s]


Epoch 72/100, Train Loss: 530.3894, Val Loss: 1957.5132


Epoch 73/100: 389it [00:28, 13.65it/s]


Epoch 73/100, Train Loss: 532.5720, Val Loss: 1955.6845


Epoch 74/100: 389it [00:28, 13.70it/s]


Epoch 74/100, Train Loss: 554.5449, Val Loss: 1642.9692


Epoch 75/100: 389it [00:28, 13.72it/s]


Epoch 75/100, Train Loss: 565.4843, Val Loss: 1689.6494


Epoch 76/100: 389it [00:28, 13.69it/s]


Epoch 76/100, Train Loss: 556.1559, Val Loss: 1573.4775


Epoch 77/100: 389it [00:28, 13.78it/s]


Epoch 77/100, Train Loss: 549.5782, Val Loss: 2241.9902


Epoch 78/100: 389it [00:28, 13.72it/s]


Epoch 78/100, Train Loss: 537.7115, Val Loss: 1480.7205


Epoch 79/100: 389it [00:28, 13.70it/s]


Epoch 79/100, Train Loss: 534.1136, Val Loss: 1628.1453


Epoch 80/100: 389it [00:28, 13.71it/s]


Epoch 80/100, Train Loss: 548.7600, Val Loss: 1754.2058


Epoch 81/100: 389it [00:28, 13.69it/s]


Epoch 81/100, Train Loss: 531.7751, Val Loss: 2021.6150


Epoch 82/100: 389it [00:28, 13.67it/s]


Epoch 82/100, Train Loss: 540.0659, Val Loss: 1608.7710


Epoch 83/100: 389it [00:29, 13.40it/s]


Epoch 83/100, Train Loss: 550.0294, Val Loss: 2007.0086


Epoch 84/100: 389it [00:29, 13.23it/s]


Epoch 84/100, Train Loss: 545.0355, Val Loss: 1952.5428


Epoch 85/100: 389it [00:28, 13.64it/s]


Epoch 85/100, Train Loss: 560.9363, Val Loss: 2145.9477


Epoch 86/100: 389it [00:28, 13.70it/s]


Epoch 86/100, Train Loss: 563.9988, Val Loss: 1854.1805


Epoch 87/100: 389it [00:28, 13.69it/s]


Epoch 87/100, Train Loss: 543.9817, Val Loss: 1625.5910


Epoch 88/100: 389it [00:28, 13.74it/s]


Epoch 88/100, Train Loss: 531.6940, Val Loss: 1814.6107


Epoch 89/100: 389it [00:28, 13.72it/s]


Epoch 89/100, Train Loss: 544.5047, Val Loss: 1730.3378


Epoch 90/100: 389it [00:28, 13.75it/s]


Epoch 90/100, Train Loss: 532.6824, Val Loss: 1820.3743


Epoch 91/100: 389it [00:29, 13.32it/s]


Epoch 91/100, Train Loss: 550.1690, Val Loss: 1857.4788


Epoch 92/100: 389it [00:28, 13.70it/s]


Epoch 92/100, Train Loss: 552.3021, Val Loss: 2764.0387


Epoch 93/100: 389it [00:28, 13.66it/s]


Epoch 93/100, Train Loss: 527.2795, Val Loss: 3433.3760


Epoch 94/100: 389it [00:28, 13.70it/s]


Epoch 94/100, Train Loss: 550.1475, Val Loss: 7741.7419


Epoch 95/100: 389it [00:28, 13.67it/s]


Epoch 95/100, Train Loss: 522.8235, Val Loss: 1702.0680


Epoch 96/100: 389it [00:28, 13.72it/s]


Epoch 96/100, Train Loss: 536.7910, Val Loss: 1617.9761


Epoch 97/100: 389it [00:28, 13.74it/s]


Epoch 97/100, Train Loss: 529.2581, Val Loss: 1865.1708


Epoch 98/100: 389it [00:28, 13.71it/s]


Epoch 98/100, Train Loss: 543.1876, Val Loss: 1792.0986


Epoch 99/100: 389it [00:28, 13.68it/s]


Epoch 99/100, Train Loss: 534.4144, Val Loss: 1858.6473


Epoch 100/100: 389it [00:28, 13.67it/s]


Epoch 100/100, Train Loss: 536.9317, Val Loss: 2303.1295


Epoch 1/100: 389it [01:13,  5.28it/s]


Epoch 1/100, Train Loss: 22436.3165, Val Loss: 20905.4539


Epoch 2/100: 389it [01:13,  5.30it/s]


Epoch 2/100, Train Loss: 12326.1328, Val Loss: 12317.9995


Epoch 3/100: 389it [01:13,  5.29it/s]


Epoch 3/100, Train Loss: 5628.4702, Val Loss: 9209.0871


Epoch 4/100: 389it [01:13,  5.31it/s]


Epoch 4/100, Train Loss: 3404.4338, Val Loss: 4667.5888


Epoch 5/100: 389it [01:13,  5.31it/s]


Epoch 5/100, Train Loss: 2018.1125, Val Loss: 2470.1429


Epoch 6/100: 389it [01:13,  5.30it/s]


Epoch 6/100, Train Loss: 1152.9957, Val Loss: 1580.7331


Epoch 7/100: 389it [01:13,  5.30it/s]


Epoch 7/100, Train Loss: 1074.9693, Val Loss: 1463.3334


Epoch 8/100: 389it [01:13,  5.30it/s]


Epoch 8/100, Train Loss: 1018.3887, Val Loss: 1631.5903


Epoch 9/100: 389it [01:13,  5.32it/s]


Epoch 9/100, Train Loss: 952.2512, Val Loss: 1299.9728


Epoch 10/100: 389it [01:13,  5.32it/s]


Epoch 10/100, Train Loss: 864.9178, Val Loss: 1359.7520


Epoch 11/100: 389it [01:13,  5.32it/s]


Epoch 11/100, Train Loss: 809.7806, Val Loss: 1407.9068


Epoch 12/100: 389it [01:13,  5.32it/s]


Epoch 12/100, Train Loss: 798.0794, Val Loss: 1515.1842


Epoch 13/100: 389it [01:13,  5.32it/s]


Epoch 13/100, Train Loss: 781.2624, Val Loss: 1280.1426


Epoch 14/100: 389it [01:13,  5.31it/s]


Epoch 14/100, Train Loss: 769.0568, Val Loss: 1536.2271


Epoch 15/100: 389it [01:13,  5.31it/s]


Epoch 15/100, Train Loss: 729.1313, Val Loss: 1267.5003


Epoch 16/100: 389it [01:13,  5.31it/s]


Epoch 16/100, Train Loss: 704.5300, Val Loss: 1393.0404


Epoch 17/100: 389it [01:13,  5.31it/s]


Epoch 17/100, Train Loss: 673.3749, Val Loss: 1444.9634


Epoch 18/100: 389it [01:13,  5.31it/s]


Epoch 18/100, Train Loss: 696.9049, Val Loss: 1885.6904


Epoch 19/100: 389it [01:13,  5.31it/s]


Epoch 19/100, Train Loss: 665.4681, Val Loss: 1223.0499


Epoch 20/100: 389it [01:13,  5.31it/s]


Epoch 20/100, Train Loss: 652.4031, Val Loss: 1397.9669


Epoch 21/100: 389it [01:13,  5.31it/s]


Epoch 21/100, Train Loss: 660.1125, Val Loss: 1553.4220


Epoch 22/100: 389it [01:13,  5.31it/s]


Epoch 22/100, Train Loss: 636.6872, Val Loss: 1407.3722


Epoch 23/100: 389it [01:13,  5.27it/s]


Epoch 23/100, Train Loss: 668.1152, Val Loss: 1466.0079


Epoch 24/100: 389it [01:13,  5.32it/s]


Epoch 24/100, Train Loss: 661.2993, Val Loss: 1112.7969


Epoch 25/100: 389it [01:13,  5.31it/s]


Epoch 25/100, Train Loss: 657.4036, Val Loss: 1460.0005


Epoch 26/100: 389it [01:13,  5.32it/s]


Epoch 26/100, Train Loss: 656.6203, Val Loss: 984.2092


Epoch 27/100: 389it [01:13,  5.31it/s]


Epoch 27/100, Train Loss: 587.8544, Val Loss: 1437.2619


Epoch 28/100: 389it [01:13,  5.31it/s]


Epoch 28/100, Train Loss: 612.0836, Val Loss: 1027.8147


Epoch 29/100: 389it [01:13,  5.32it/s]


Epoch 29/100, Train Loss: 677.0674, Val Loss: 1240.0698


Epoch 30/100: 389it [01:13,  5.30it/s]


Epoch 30/100, Train Loss: 638.9311, Val Loss: 1099.6343


Epoch 31/100: 389it [01:13,  5.32it/s]


Epoch 31/100, Train Loss: 598.9884, Val Loss: 1581.8877


Epoch 32/100: 389it [01:13,  5.31it/s]


Epoch 32/100, Train Loss: 621.8000, Val Loss: 1056.6242


Epoch 33/100: 389it [01:13,  5.32it/s]


Epoch 33/100, Train Loss: 622.3464, Val Loss: 1312.7717


Epoch 34/100: 389it [01:13,  5.30it/s]


Epoch 34/100, Train Loss: 600.2293, Val Loss: 1114.4103


Epoch 35/100: 389it [01:13,  5.31it/s]


Epoch 35/100, Train Loss: 592.1024, Val Loss: 1840.0789


Epoch 36/100: 389it [01:13,  5.31it/s]


Epoch 36/100, Train Loss: 598.5839, Val Loss: 1277.0843


Epoch 37/100: 389it [01:13,  5.31it/s]


Epoch 37/100, Train Loss: 603.7528, Val Loss: 1290.4061


Epoch 38/100: 389it [01:13,  5.28it/s]


Epoch 38/100, Train Loss: 578.6723, Val Loss: 1463.8909


Epoch 39/100: 389it [01:13,  5.28it/s]


Epoch 39/100, Train Loss: 607.1392, Val Loss: 1231.9008


Epoch 40/100: 389it [01:14,  5.24it/s]


Epoch 40/100, Train Loss: 588.3203, Val Loss: 1473.7765


Epoch 41/100: 389it [01:13,  5.30it/s]


Epoch 41/100, Train Loss: 609.9049, Val Loss: 1362.4782


Epoch 42/100: 389it [01:13,  5.31it/s]


Epoch 42/100, Train Loss: 584.9002, Val Loss: 1277.5223


Epoch 43/100: 389it [01:13,  5.31it/s]


Epoch 43/100, Train Loss: 584.4267, Val Loss: 958.2344


Epoch 44/100: 389it [01:13,  5.31it/s]


Epoch 44/100, Train Loss: 552.4987, Val Loss: 2885.6936


Epoch 45/100: 389it [01:13,  5.31it/s]


Epoch 45/100, Train Loss: 585.5047, Val Loss: 1142.5097


Epoch 46/100: 389it [01:13,  5.28it/s]


Epoch 46/100, Train Loss: 556.4188, Val Loss: 1024.0857


Epoch 47/100: 389it [01:12,  5.34it/s]


Epoch 47/100, Train Loss: 563.0245, Val Loss: 964.1149


Epoch 48/100: 389it [01:13,  5.30it/s]


Epoch 48/100, Train Loss: 583.4912, Val Loss: 1173.8324


Epoch 49/100: 389it [01:12,  5.33it/s]


Epoch 49/100, Train Loss: 566.8544, Val Loss: 1217.2736


Epoch 50/100: 389it [01:12,  5.33it/s]


Epoch 50/100, Train Loss: 579.2775, Val Loss: 1052.3499


Epoch 51/100: 389it [01:12,  5.34it/s]


Epoch 51/100, Train Loss: 589.7445, Val Loss: 1142.8862


Epoch 52/100: 389it [01:12,  5.33it/s]


Epoch 52/100, Train Loss: 575.9584, Val Loss: 1064.0996


Epoch 53/100: 389it [01:12,  5.34it/s]


Epoch 53/100, Train Loss: 583.3946, Val Loss: 1231.2820


Epoch 54/100: 389it [01:12,  5.33it/s]


Epoch 54/100, Train Loss: 563.9050, Val Loss: 995.7006


Epoch 55/100: 389it [01:12,  5.34it/s]


Epoch 55/100, Train Loss: 564.3217, Val Loss: 1118.8287


Epoch 56/100: 389it [01:12,  5.34it/s]


Epoch 56/100, Train Loss: 553.4954, Val Loss: 1446.7388


Epoch 57/100: 389it [01:13,  5.30it/s]


Epoch 57/100, Train Loss: 587.0100, Val Loss: 927.3592


Epoch 58/100: 389it [01:13,  5.32it/s]


Epoch 58/100, Train Loss: 576.1794, Val Loss: 1311.7293


Epoch 59/100: 389it [01:13,  5.29it/s]


Epoch 59/100, Train Loss: 586.9733, Val Loss: 990.4093


Epoch 60/100: 389it [01:13,  5.32it/s]


Epoch 60/100, Train Loss: 591.8131, Val Loss: 1131.4340


Epoch 61/100: 389it [01:12,  5.33it/s]


Epoch 61/100, Train Loss: 571.8516, Val Loss: 1285.4714


Epoch 62/100: 389it [01:13,  5.32it/s]


Epoch 62/100, Train Loss: 579.7934, Val Loss: 1039.5831


Epoch 63/100: 389it [01:12,  5.33it/s]


Epoch 63/100, Train Loss: 565.9421, Val Loss: 962.5488


Epoch 64/100: 389it [01:13,  5.30it/s]


Epoch 64/100, Train Loss: 594.3033, Val Loss: 1069.6588


Epoch 65/100: 389it [01:12,  5.33it/s]


Epoch 65/100, Train Loss: 562.0073, Val Loss: 1162.2570


Epoch 66/100: 389it [01:13,  5.29it/s]


Epoch 66/100, Train Loss: 561.1277, Val Loss: 942.7045


Epoch 67/100: 389it [01:13,  5.32it/s]


Epoch 67/100, Train Loss: 548.3965, Val Loss: 1048.1463


Epoch 68/100: 389it [01:13,  5.31it/s]


Epoch 68/100, Train Loss: 542.6478, Val Loss: 1220.6378


Epoch 69/100: 389it [01:13,  5.33it/s]


Epoch 69/100, Train Loss: 558.1077, Val Loss: 1225.7781


Epoch 70/100: 389it [01:13,  5.32it/s]


Epoch 70/100, Train Loss: 531.0888, Val Loss: 973.4465


Epoch 71/100: 389it [01:13,  5.31it/s]


Epoch 71/100, Train Loss: 564.9857, Val Loss: 1142.8694


Epoch 72/100: 389it [01:13,  5.33it/s]


Epoch 72/100, Train Loss: 546.6414, Val Loss: 977.4989


Epoch 73/100: 389it [01:13,  5.30it/s]


Epoch 73/100, Train Loss: 560.7982, Val Loss: 1245.9258


Epoch 74/100: 389it [01:13,  5.32it/s]


Epoch 74/100, Train Loss: 553.9214, Val Loss: 1021.0859


Epoch 75/100: 389it [01:13,  5.30it/s]


Epoch 75/100, Train Loss: 530.3640, Val Loss: 1379.9379


Epoch 76/100: 389it [01:12,  5.33it/s]


Epoch 76/100, Train Loss: 532.7898, Val Loss: 1163.0775


Epoch 77/100: 389it [01:12,  5.33it/s]


Epoch 77/100, Train Loss: 534.6583, Val Loss: 1336.8908


Epoch 78/100: 389it [01:13,  5.32it/s]


Epoch 78/100, Train Loss: 543.2972, Val Loss: 1096.4052


Epoch 79/100: 389it [01:13,  5.32it/s]


Epoch 79/100, Train Loss: 528.9580, Val Loss: 1140.1887


Epoch 80/100: 389it [01:13,  5.32it/s]


Epoch 80/100, Train Loss: 572.9239, Val Loss: 1097.7813


Epoch 81/100: 389it [01:13,  5.32it/s]


Epoch 81/100, Train Loss: 543.5936, Val Loss: 1195.9882


Epoch 82/100: 389it [01:13,  5.32it/s]


Epoch 82/100, Train Loss: 525.1268, Val Loss: 1251.4709


Epoch 83/100: 389it [01:13,  5.33it/s]


Epoch 83/100, Train Loss: 545.6777, Val Loss: 1239.0767


Epoch 84/100: 389it [01:13,  5.33it/s]


Epoch 84/100, Train Loss: 537.1776, Val Loss: 1296.2116


Epoch 85/100: 389it [01:13,  5.33it/s]


Epoch 85/100, Train Loss: 518.7512, Val Loss: 1139.1562


Epoch 86/100: 389it [01:13,  5.33it/s]


Epoch 86/100, Train Loss: 543.3416, Val Loss: 1123.4627


Epoch 87/100: 389it [01:13,  5.32it/s]


Epoch 87/100, Train Loss: 534.4307, Val Loss: 1455.2214


Epoch 88/100: 389it [01:12,  5.34it/s]


Epoch 88/100, Train Loss: 526.8702, Val Loss: 1021.9362


Epoch 89/100: 389it [01:13,  5.30it/s]


Epoch 89/100, Train Loss: 526.0123, Val Loss: 1170.8334


Epoch 90/100: 389it [01:12,  5.34it/s]


Epoch 90/100, Train Loss: 526.8944, Val Loss: 1271.8185


Epoch 91/100: 389it [01:13,  5.27it/s]


Epoch 91/100, Train Loss: 526.3404, Val Loss: 1044.9947


Epoch 92/100: 389it [01:13,  5.31it/s]


Epoch 92/100, Train Loss: 537.6421, Val Loss: 1095.0832


Epoch 93/100: 389it [01:13,  5.26it/s]


Epoch 93/100, Train Loss: 515.0545, Val Loss: 1428.6600


Epoch 94/100: 389it [01:13,  5.32it/s]


Epoch 94/100, Train Loss: 553.7015, Val Loss: 1127.6092


Epoch 95/100: 389it [01:12,  5.33it/s]


Epoch 95/100, Train Loss: 528.7651, Val Loss: 1268.9354


Epoch 96/100: 389it [01:12,  5.35it/s]


Epoch 96/100, Train Loss: 544.5153, Val Loss: 1122.5687


Epoch 97/100: 389it [01:12,  5.34it/s]


Epoch 97/100, Train Loss: 536.2306, Val Loss: 980.4212


Epoch 98/100: 389it [01:13,  5.31it/s]


Epoch 98/100, Train Loss: 528.3994, Val Loss: 1050.8370


Epoch 99/100: 389it [01:13,  5.32it/s]


Epoch 99/100, Train Loss: 550.5661, Val Loss: 1107.5980


Epoch 100/100: 389it [01:12,  5.33it/s]


Epoch 100/100, Train Loss: 531.9821, Val Loss: 1009.0650


In [None]:
trained_model = torch.load('./models/CNN_LSTM_best.pt')

test_gen = generator(list(yields['validation'].keys()), yields['validation'], len(yields['validation']))
X_test, y_test = next(test_gen)

In [9]:
lst = list(yields['train'].keys())
data_years = {
    '2018': [],
    '2019': [],
    '2020': [],
    '2021': [],
    '2022': [],
    '2023': [],
}

for data in lst:
    year = data.split('_')[0]
    data_years[year].append(data)

lst = list(yields['validation'].keys())

for data in lst:
    year = data.split('_')[0]
    data_years[year].append(data)

In [16]:
models = ['CNN_LSTM_Adam_lr_0.001', 'CNN_LSTM_Adam_lr_0.005', 'ConvGRU_LSTM_Adam_lr_0.001', 'ConvGRU_LSTM_Adam_lr_0.005']

all_data = {}

criterion = nn.MSELoss()

for model_name in models:
    if 'CNN_LSTM' in model_name:
        model = CNN_LSTM(dimensions=[38, 1, 128, 9])
    else:
        model = ConvGRU_LSTM(dimensions=[38, 1, 128, 9])
    model.load_state_dict(torch.load(f'./models/{model_name}_best.pt'))
    model.to('cuda')
    model.eval()

    for year in ['2018', '2019', '2020', '2021']:
        losses = []
        outputs = []

        if year == '2021':
            year_gen = generator(data_years[year], yields['validation'], 16)
        else:
            year_gen = generator(data_years[year], yields['train'], 16)

        with torch.no_grad():
            for data, yield_data in year_gen:
                if data is None:
                    break
                output = model(data)
                loss = criterion(output, yield_data.unsqueeze(1))
                outputs.append(output.cpu().numpy())
                losses.append(loss.item())

        all_data[f'{model_name}_{year}'] = (np.mean(losses), np.sum(outputs))

all_data

{'CNN_LSTM_Adam_lr_0.001_2018': (295.3537059474636, 184468.12),
 'CNN_LSTM_Adam_lr_0.001_2019': (243.38821160303402, 173044.56),
 'CNN_LSTM_Adam_lr_0.001_2020': (715.7331414694314, 220489.1),
 'CNN_LSTM_Adam_lr_0.001_2021': (1304.5928936004639, 187200.78),
 'CNN_LSTM_Adam_lr_0.005_2018': (557.2617604023701, 188240.3),
 'CNN_LSTM_Adam_lr_0.005_2019': (734.1392391675139, 178598.97),
 'CNN_LSTM_Adam_lr_0.005_2020': (452.00859891451324, 228300.23),
 'CNN_LSTM_Adam_lr_0.005_2021': (1417.686783027649, 198523.88),
 'ConvGRU_LSTM_Adam_lr_0.001_2018': (397.7488078555545, 191843.06),
 'ConvGRU_LSTM_Adam_lr_0.001_2019': (397.0032612003692, 182157.72),
 'ConvGRU_LSTM_Adam_lr_0.001_2020': (351.8298935104202, 230499.19),
 'ConvGRU_LSTM_Adam_lr_0.001_2021': (1193.4569580078125, 195139.4),
 'ConvGRU_LSTM_Adam_lr_0.005_2018': (530.0325443164722, 187800.28),
 'ConvGRU_LSTM_Adam_lr_0.005_2019': (577.3304587586285, 175580.12),
 'ConvGRU_LSTM_Adam_lr_0.005_2020': (410.55387308309366, 225286.42),
 'ConvGRU_