<a href="https://colab.research.google.com/github/frtrigg5/A-new-signature-model/blob/main/ModelConstruction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FBM Experiment

In [1]:
import sys
import os
sys.path.append('/rds/general/user/ll1917/home/esig/gp-esig-classifier') # to add when running on remote Jupyter server
os.chdir('/rds/general/user/ll1917/home/esig/gp-esig-classifier') # to add when running on remote Jupyter server

import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
begin, end, number, division, dim = 0, 1, 100, 1, 1
'''
begin = first time steps
end = last known time steps
division = 1 means we are taking the middle points as new time instants -> L2 = L1 - 1
number = L1
dim = 1 - one dimensional
'''

# generating the time steps
known_times = torch.linspace(begin, end, number)
new_times = torch.zeros(division*(number-1))
for i in range(0,(number-1)):
  new_times[(division*i):(division*(i+1))] = torch.linspace(known_times[i], known_times[i+1], (division+2))[1:(1 + division)]

# Length of known values and new values
L1 = known_times.shape[0]
L2 = new_times.shape[0]

timesteps = torch.cat((known_times, new_times),axis=0)
timesteps_sorted, order = torch.sort(timesteps)

extended_order = torch.zeros(dim*order.size(0))
for i in range(order.size(0)):
  extended_order[(i*dim):((i+1)*dim)] = torch.arange(order[i]*dim, (order[i] + 1)*dim)

In [16]:
from lib.data.synthetic import BM_sample, FBM_sample

# construction of the dataset
trShape, vlShape, testShape = 1000, 400, 600
H = 0.26 # Hurst exponent

known_times = np.linspace(begin, end, number) # istanti temporali noti
div = 1 # quanti nuovi istanti temporali prendere tra due istanti noti
new_times = np.zeros(div*(number - 1))
for i in range((number-1)):
  new_times[(div*i):(div*(i+1))] = np.linspace(known_times[i], known_times[i+1], (div+2))[1:(1+div)]

L1 = known_times.size
L2 = new_times.size

timestamps = np.concatenate((known_times, new_times), axis=0)
# time series train
dataset_value = np.zeros(shape=[trShape, number])
seed = 0
for i in range(trShape//2):
  dataset_value[i] = BM_sample(begin, end, number, seed=seed+i)[0]
  dataset_value[i+trShape//2] = FBM_sample(begin, end, number, H, seed=seed+i)[0]

# time series validation
dataset_value2 = np.zeros(shape=[vlShape, number])
seed = trShape//2
for i in range(vlShape//2):
  dataset_value2[i] = BM_sample(begin, end, number, seed=seed+i)[0]
  dataset_value2[i+vlShape//2] = FBM_sample(begin, end, number, H, seed=seed+i)[0]

# time series test
dataset_value3 = np.zeros(shape=[testShape, number])
seed = trShape//2 + vlShape//2
for i in range(testShape//2):
  dataset_value3[i] = BM_sample(begin, end, number, seed=seed+i)[0]
  dataset_value3[i+testShape//2] = FBM_sample(begin, end, number, H, seed=seed+i)[0]
 
# adding known and unknown time stamps
time_data = np.zeros((trShape, L1 + L2))
for i in range(trShape):
  time_data[i] = timestamps

time_data2 = np.zeros((vlShape, L1 + L2))
for i in range(vlShape):
  time_data2[i] = timestamps

time_data3 = np.zeros((testShape, L1 + L2))
for i in range(testShape):
  time_data3[i] = timestamps  

# full dataset train
dataset = np.concatenate((time_data, dataset_value), axis=-1) # full dataset
dataset = dataset.astype('float32')

# full dataset validation
dataset2 = np.concatenate((time_data2,dataset_value2), axis=-1) # full dataset
dataset2 = dataset2.astype('float32')

# full dataset test
dataset3 = np.concatenate((time_data3, dataset_value3), axis=-1) # full dataset
dataset3 = dataset3.astype('float32')

# label construction
y = np.zeros(trShape, dtype='uint8') # label 0 for Brownian Motion, 1 for FBM
y[trShape//2:] = 1

#label di validation
y2 = np.zeros(vlShape, dtype='uint8')
y2[vlShape//2:] = 1

#label di test
y3 = np.zeros(testShape, dtype='uint8') 
y3[testShape//2:] = 1

In [17]:
from torch.utils.data import DataLoader, TensorDataset

batch = 60
training_data = TensorDataset(torch.from_numpy(dataset), torch.from_numpy(y).long())
train_loader = DataLoader(training_data, batch_size=batch, shuffle=True)

val_data = TensorDataset(torch.from_numpy(dataset2), torch.from_numpy(y2).long())
val_loader = DataLoader(val_data, batch_size=vlShape, shuffle=False)

test_data = TensorDataset(torch.from_numpy(dataset3), torch.from_numpy(y3).long())
test_loader = DataLoader(test_data, batch_size=testShape, shuffle=False)

In [5]:
# Evaluation function to calculate accuracy
def evaluate_accuracy(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # No gradient calculation during evaluation
        for inputs, labels in data_loader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)  # Get the predicted class
            correct += (predicted == labels).sum().item()  # Count correct predictions
            total += labels.size(0)  # Total samples
    
    accuracy = correct / total * 100  # Calculate accuracy as a percentage
    return accuracy

In [7]:
from lib.model import MyModel

training_dir = os.path.join('checkpoints', 'run_003')
os.makedirs(training_dir, exist_ok=True)

alpha = L2 # 1, 5, L2
level = 3
number_classes = 2
C = 5e3
a = 1
K = 30

# Initialize model, loss function, and optimizer
model = MyModel(
    L1=L1, 
    L2=L2, 
    dim=dim, 
    order=order, 
    extended_order=extended_order, 
    alpha=alpha, 
    level=level, 
    number_classes=number_classes, 
    C=C, 
    a=a, 
    K=K
)

In [10]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os

training_dir = os.path.join('checkpoints', 'run_003')
os.makedirs(training_dir, exist_ok=True)

# Initialize loss function, and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Training parameters
num_epochs = 300
patience = 10

# Define the path to save the checkpoint
checkpoint_path = os.path.join(training_dir, 'checkpoint.pth')

# Initialize best_loss and early_stopping_counter
best_loss = float('inf')
early_stopping_counter = 0

# Load from checkpoint if it exists
if os.path.isfile(checkpoint_path):
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    best_loss = checkpoint['best_loss']
    early_stopping_counter = checkpoint['early_stopping_counter']
else:
    start_epoch = 0  # Start from the beginning if no checkpoint exists

# Training loop
for epoch in range(start_epoch, num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    
    for inputs, labels in tqdm(train_loader):
        optimizer.zero_grad()  # Zero the parameter gradients
        outputs = model(inputs)  # Forward pass
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Optimize weights
        
        running_loss += loss.item()  # Accumulate loss
    
    # Average loss for the training epoch
    train_loss = running_loss / len(train_loader)
    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}')

    # Validation step
    model.eval()  # Set the model to evaluation mode
    val_running_loss = 0.0
    with torch.no_grad():  # No gradient calculation for validation
        for inputs, labels in val_loader:
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Compute validation loss
            val_running_loss += loss.item()  # Accumulate validation loss

    # Average loss for the validation epoch
    val_loss = val_running_loss / len(val_loader)
    print(f'Validation Loss: {val_loss:.4f}')

    # Calculate and print accuracy on validation set
    accuracy = evaluate_accuracy(model, val_loader)
    print(f'Validation Accuracy: {accuracy:.2f}%')
    
    # Early stopping
    if val_loss < best_loss:
        best_loss = val_loss
        early_stopping_counter = 0  # Reset counter if loss improves
        print("Improved! Saving model...")
        torch.save(model.state_dict(), os.path.join(training_dir, 'best_model.pth'))  # Save the best model
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= patience:
            print("Early stopping triggered.")
            break  # Stop training if patience is exceeded

    # Save checkpoint after each epoch
    torch.save({
        'epoch': epoch + 1,  # Save next epoch
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_loss': best_loss,
        'early_stopping_counter': early_stopping_counter,
    }, checkpoint_path)

print("Training completed.")

100%|██████████| 17/17 [00:08<00:00,  2.10it/s]


Epoch 1/300, Training Loss: 0.7256
Validation Loss: 0.6901
Validation Accuracy: 51.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.10it/s]


Epoch 2/300, Training Loss: 0.6952
Validation Loss: 0.6760
Validation Accuracy: 55.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.10it/s]


Epoch 3/300, Training Loss: 0.6862
Validation Loss: 0.6710
Validation Accuracy: 50.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.10it/s]


Epoch 4/300, Training Loss: 0.6786
Validation Loss: 0.6714
Validation Accuracy: 57.25%


100%|██████████| 17/17 [00:08<00:00,  2.10it/s]


Epoch 5/300, Training Loss: 0.6735
Validation Loss: 0.6581
Validation Accuracy: 64.25%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.12it/s]


Epoch 6/300, Training Loss: 0.6740
Validation Loss: 0.6527
Validation Accuracy: 66.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.12it/s]


Epoch 7/300, Training Loss: 0.6756
Validation Loss: 0.6452
Validation Accuracy: 70.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.12it/s]


Epoch 8/300, Training Loss: 0.6678
Validation Loss: 0.6400
Validation Accuracy: 70.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.12it/s]


Epoch 9/300, Training Loss: 0.6651
Validation Loss: 0.6386
Validation Accuracy: 71.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.12it/s]


Epoch 10/300, Training Loss: 0.6601
Validation Loss: 0.6386
Validation Accuracy: 71.50%


100%|██████████| 17/17 [00:08<00:00,  2.12it/s]


Epoch 11/300, Training Loss: 0.6630
Validation Loss: 0.6283
Validation Accuracy: 71.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.13it/s]


Epoch 12/300, Training Loss: 0.6550
Validation Loss: 0.6284
Validation Accuracy: 68.50%


100%|██████████| 17/17 [00:07<00:00,  2.13it/s]


Epoch 13/300, Training Loss: 0.6561
Validation Loss: 0.6203
Validation Accuracy: 72.25%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.13it/s]


Epoch 14/300, Training Loss: 0.6434
Validation Loss: 0.6169
Validation Accuracy: 73.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.12it/s]


Epoch 15/300, Training Loss: 0.6428
Validation Loss: 0.6102
Validation Accuracy: 72.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 16/300, Training Loss: 0.6374
Validation Loss: 0.6115
Validation Accuracy: 67.50%


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 17/300, Training Loss: 0.6342
Validation Loss: 0.5983
Validation Accuracy: 73.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 18/300, Training Loss: 0.6294
Validation Loss: 0.6024
Validation Accuracy: 73.75%


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 19/300, Training Loss: 0.6233
Validation Loss: 0.5943
Validation Accuracy: 74.75%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.16it/s]


Epoch 20/300, Training Loss: 0.6162
Validation Loss: 0.5872
Validation Accuracy: 76.75%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 21/300, Training Loss: 0.6021
Validation Loss: 0.5823
Validation Accuracy: 73.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 22/300, Training Loss: 0.6011
Validation Loss: 0.5759
Validation Accuracy: 75.25%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 23/300, Training Loss: 0.5919
Validation Loss: 0.5700
Validation Accuracy: 75.75%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 24/300, Training Loss: 0.5906
Validation Loss: 0.5590
Validation Accuracy: 76.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 25/300, Training Loss: 0.5857
Validation Loss: 0.5614
Validation Accuracy: 78.25%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 26/300, Training Loss: 0.5745
Validation Loss: 0.5495
Validation Accuracy: 77.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 27/300, Training Loss: 0.5715
Validation Loss: 0.5424
Validation Accuracy: 78.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 28/300, Training Loss: 0.5613
Validation Loss: 0.5453
Validation Accuracy: 72.75%


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 29/300, Training Loss: 0.5640
Validation Loss: 0.5361
Validation Accuracy: 77.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 30/300, Training Loss: 0.5487
Validation Loss: 0.5322
Validation Accuracy: 73.75%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 31/300, Training Loss: 0.5439
Validation Loss: 0.5254
Validation Accuracy: 78.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.12it/s]


Epoch 32/300, Training Loss: 0.5327
Validation Loss: 0.5261
Validation Accuracy: 74.75%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 33/300, Training Loss: 0.5268
Validation Loss: 0.5123
Validation Accuracy: 79.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.13it/s]


Epoch 34/300, Training Loss: 0.5211
Validation Loss: 0.5084
Validation Accuracy: 79.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.13it/s]


Epoch 35/300, Training Loss: 0.5161
Validation Loss: 0.4997
Validation Accuracy: 79.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.20it/s]


Epoch 36/300, Training Loss: 0.5109
Validation Loss: 0.5010
Validation Accuracy: 80.50%


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 37/300, Training Loss: 0.5053
Validation Loss: 0.4960
Validation Accuracy: 79.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 38/300, Training Loss: 0.4965
Validation Loss: 0.4858
Validation Accuracy: 81.75%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 39/300, Training Loss: 0.4912
Validation Loss: 0.4820
Validation Accuracy: 80.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.98it/s]


Epoch 40/300, Training Loss: 0.4858
Validation Loss: 0.4774
Validation Accuracy: 79.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.98it/s]


Epoch 41/300, Training Loss: 0.4803
Validation Loss: 0.4738
Validation Accuracy: 80.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.02it/s]


Epoch 42/300, Training Loss: 0.4700
Validation Loss: 0.4698
Validation Accuracy: 79.25%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.02it/s]


Epoch 43/300, Training Loss: 0.4681
Validation Loss: 0.4674
Validation Accuracy: 80.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.05it/s]


Epoch 44/300, Training Loss: 0.4636
Validation Loss: 0.4656
Validation Accuracy: 79.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 45/300, Training Loss: 0.4566
Validation Loss: 0.4576
Validation Accuracy: 82.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.05it/s]


Epoch 46/300, Training Loss: 0.4444
Validation Loss: 0.4546
Validation Accuracy: 80.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 47/300, Training Loss: 0.4449
Validation Loss: 0.4516
Validation Accuracy: 82.25%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 48/300, Training Loss: 0.4400
Validation Loss: 0.4514
Validation Accuracy: 81.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 49/300, Training Loss: 0.4269
Validation Loss: 0.4444
Validation Accuracy: 83.25%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.98it/s]


Epoch 50/300, Training Loss: 0.4272
Validation Loss: 0.4413
Validation Accuracy: 82.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 51/300, Training Loss: 0.4233
Validation Loss: 0.4332
Validation Accuracy: 80.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 52/300, Training Loss: 0.4164
Validation Loss: 0.4324
Validation Accuracy: 82.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 53/300, Training Loss: 0.4146
Validation Loss: 0.4312
Validation Accuracy: 82.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.03it/s]


Epoch 54/300, Training Loss: 0.4088
Validation Loss: 0.4283
Validation Accuracy: 80.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 55/300, Training Loss: 0.3979
Validation Loss: 0.4263
Validation Accuracy: 84.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 56/300, Training Loss: 0.4019
Validation Loss: 0.4189
Validation Accuracy: 84.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 57/300, Training Loss: 0.3914
Validation Loss: 0.4171
Validation Accuracy: 82.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 58/300, Training Loss: 0.3866
Validation Loss: 0.4179
Validation Accuracy: 83.25%


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 59/300, Training Loss: 0.3895
Validation Loss: 0.4147
Validation Accuracy: 82.25%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 60/300, Training Loss: 0.3780
Validation Loss: 0.4140
Validation Accuracy: 81.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 61/300, Training Loss: 0.3740
Validation Loss: 0.4111
Validation Accuracy: 82.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.02it/s]


Epoch 62/300, Training Loss: 0.3680
Validation Loss: 0.4020
Validation Accuracy: 83.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 63/300, Training Loss: 0.3654
Validation Loss: 0.4016
Validation Accuracy: 84.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 64/300, Training Loss: 0.3672
Validation Loss: 0.4005
Validation Accuracy: 83.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 65/300, Training Loss: 0.3629
Validation Loss: 0.4000
Validation Accuracy: 83.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 66/300, Training Loss: 0.3548
Validation Loss: 0.3935
Validation Accuracy: 81.25%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 67/300, Training Loss: 0.3506
Validation Loss: 0.3867
Validation Accuracy: 83.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.02it/s]


Epoch 68/300, Training Loss: 0.3467
Validation Loss: 0.3876
Validation Accuracy: 82.00%


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 69/300, Training Loss: 0.3417
Validation Loss: 0.3843
Validation Accuracy: 85.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 70/300, Training Loss: 0.3398
Validation Loss: 0.3842
Validation Accuracy: 83.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 71/300, Training Loss: 0.3377
Validation Loss: 0.3738
Validation Accuracy: 85.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 72/300, Training Loss: 0.3322
Validation Loss: 0.3740
Validation Accuracy: 85.25%


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 73/300, Training Loss: 0.3286
Validation Loss: 0.3795
Validation Accuracy: 82.00%


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 74/300, Training Loss: 0.3273
Validation Loss: 0.3736
Validation Accuracy: 85.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 75/300, Training Loss: 0.3232
Validation Loss: 0.3671
Validation Accuracy: 83.25%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 76/300, Training Loss: 0.3164
Validation Loss: 0.3621
Validation Accuracy: 86.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 77/300, Training Loss: 0.3156
Validation Loss: 0.3649
Validation Accuracy: 84.75%


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 78/300, Training Loss: 0.3158
Validation Loss: 0.3650
Validation Accuracy: 82.75%


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 79/300, Training Loss: 0.3104
Validation Loss: 0.3551
Validation Accuracy: 84.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 80/300, Training Loss: 0.3100
Validation Loss: 0.3593
Validation Accuracy: 84.00%


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 81/300, Training Loss: 0.3070
Validation Loss: 0.3582
Validation Accuracy: 85.00%


100%|██████████| 17/17 [00:08<00:00,  1.99it/s]


Epoch 82/300, Training Loss: 0.2976
Validation Loss: 0.3669
Validation Accuracy: 82.75%


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 83/300, Training Loss: 0.2988
Validation Loss: 0.3515
Validation Accuracy: 85.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.02it/s]


Epoch 84/300, Training Loss: 0.2926
Validation Loss: 0.3545
Validation Accuracy: 83.25%


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 85/300, Training Loss: 0.2907
Validation Loss: 0.3507
Validation Accuracy: 85.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 86/300, Training Loss: 0.2924
Validation Loss: 0.3448
Validation Accuracy: 85.25%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 87/300, Training Loss: 0.2899
Validation Loss: 0.3381
Validation Accuracy: 87.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.02it/s]


Epoch 88/300, Training Loss: 0.2822
Validation Loss: 0.3367
Validation Accuracy: 85.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 89/300, Training Loss: 0.2805
Validation Loss: 0.3484
Validation Accuracy: 85.75%


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 90/300, Training Loss: 0.2765
Validation Loss: 0.3357
Validation Accuracy: 87.00%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 91/300, Training Loss: 0.2756
Validation Loss: 0.3296
Validation Accuracy: 86.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 92/300, Training Loss: 0.2774
Validation Loss: 0.3293
Validation Accuracy: 86.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 93/300, Training Loss: 0.2705
Validation Loss: 0.3397
Validation Accuracy: 85.25%


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 94/300, Training Loss: 0.2669
Validation Loss: 0.3306
Validation Accuracy: 85.25%


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 95/300, Training Loss: 0.2664
Validation Loss: 0.3238
Validation Accuracy: 86.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.02it/s]


Epoch 96/300, Training Loss: 0.2625
Validation Loss: 0.3272
Validation Accuracy: 88.25%


100%|██████████| 17/17 [00:08<00:00,  2.00it/s]


Epoch 97/300, Training Loss: 0.2599
Validation Loss: 0.3250
Validation Accuracy: 85.25%


100%|██████████| 17/17 [00:08<00:00,  2.08it/s]


Epoch 98/300, Training Loss: 0.2573
Validation Loss: 0.3236
Validation Accuracy: 86.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.01it/s]


Epoch 99/300, Training Loss: 0.2604
Validation Loss: 0.3219
Validation Accuracy: 86.50%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.08it/s]


Epoch 100/300, Training Loss: 0.2549
Validation Loss: 0.3228
Validation Accuracy: 85.50%


100%|██████████| 17/17 [00:08<00:00,  2.10it/s]


Epoch 101/300, Training Loss: 0.2524
Validation Loss: 0.3166
Validation Accuracy: 85.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.07it/s]


Epoch 102/300, Training Loss: 0.2486
Validation Loss: 0.3208
Validation Accuracy: 88.25%


100%|██████████| 17/17 [00:08<00:00,  2.03it/s]


Epoch 103/300, Training Loss: 0.2477
Validation Loss: 0.3130
Validation Accuracy: 87.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.09it/s]


Epoch 104/300, Training Loss: 0.2434
Validation Loss: 0.3072
Validation Accuracy: 86.75%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.10it/s]


Epoch 105/300, Training Loss: 0.2381
Validation Loss: 0.3119
Validation Accuracy: 86.50%


100%|██████████| 17/17 [00:08<00:00,  2.11it/s]


Epoch 106/300, Training Loss: 0.2352
Validation Loss: 0.3049
Validation Accuracy: 89.25%
Improved! Saving model...


100%|██████████| 17/17 [00:08<00:00,  2.10it/s]


Epoch 107/300, Training Loss: 0.2363
Validation Loss: 0.3151
Validation Accuracy: 86.50%


100%|██████████| 17/17 [00:07<00:00,  2.18it/s]


Epoch 108/300, Training Loss: 0.2391
Validation Loss: 0.3032
Validation Accuracy: 88.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.18it/s]


Epoch 109/300, Training Loss: 0.2327
Validation Loss: 0.3062
Validation Accuracy: 86.25%


100%|██████████| 17/17 [00:08<00:00,  2.12it/s]


Epoch 110/300, Training Loss: 0.2311
Validation Loss: 0.2992
Validation Accuracy: 88.25%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.13it/s]


Epoch 111/300, Training Loss: 0.2304
Validation Loss: 0.2982
Validation Accuracy: 89.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 112/300, Training Loss: 0.2294
Validation Loss: 0.3163
Validation Accuracy: 86.25%


100%|██████████| 17/17 [00:08<00:00,  2.04it/s]


Epoch 113/300, Training Loss: 0.2297
Validation Loss: 0.2950
Validation Accuracy: 88.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.17it/s]


Epoch 114/300, Training Loss: 0.2227
Validation Loss: 0.2953
Validation Accuracy: 88.25%


100%|██████████| 17/17 [00:07<00:00,  2.17it/s]


Epoch 115/300, Training Loss: 0.2219
Validation Loss: 0.2955
Validation Accuracy: 86.75%


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 116/300, Training Loss: 0.2176
Validation Loss: 0.2933
Validation Accuracy: 88.25%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 117/300, Training Loss: 0.2177
Validation Loss: 0.2937
Validation Accuracy: 87.75%


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 118/300, Training Loss: 0.2159
Validation Loss: 0.2947
Validation Accuracy: 87.25%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 119/300, Training Loss: 0.2120
Validation Loss: 0.2969
Validation Accuracy: 87.25%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 120/300, Training Loss: 0.2124
Validation Loss: 0.2930
Validation Accuracy: 87.75%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 121/300, Training Loss: 0.2090
Validation Loss: 0.2899
Validation Accuracy: 88.25%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 122/300, Training Loss: 0.2085
Validation Loss: 0.2910
Validation Accuracy: 89.50%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 123/300, Training Loss: 0.2070
Validation Loss: 0.2851
Validation Accuracy: 88.50%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 124/300, Training Loss: 0.2055
Validation Loss: 0.2938
Validation Accuracy: 87.25%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 125/300, Training Loss: 0.2040
Validation Loss: 0.2952
Validation Accuracy: 85.50%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 126/300, Training Loss: 0.2008
Validation Loss: 0.2795
Validation Accuracy: 88.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 127/300, Training Loss: 0.1996
Validation Loss: 0.2868
Validation Accuracy: 88.25%


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 128/300, Training Loss: 0.1994
Validation Loss: 0.2744
Validation Accuracy: 89.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 129/300, Training Loss: 0.1981
Validation Loss: 0.2738
Validation Accuracy: 87.75%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 130/300, Training Loss: 0.1969
Validation Loss: 0.2850
Validation Accuracy: 88.00%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 131/300, Training Loss: 0.1915
Validation Loss: 0.2713
Validation Accuracy: 89.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 132/300, Training Loss: 0.1913
Validation Loss: 0.2766
Validation Accuracy: 88.75%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 133/300, Training Loss: 0.1955
Validation Loss: 0.2818
Validation Accuracy: 88.50%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 134/300, Training Loss: 0.1914
Validation Loss: 0.2705
Validation Accuracy: 88.25%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 135/300, Training Loss: 0.1883
Validation Loss: 0.2716
Validation Accuracy: 88.75%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 136/300, Training Loss: 0.1868
Validation Loss: 0.2617
Validation Accuracy: 89.00%
Improved! Saving model...


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 137/300, Training Loss: 0.1865
Validation Loss: 0.2663
Validation Accuracy: 90.75%


100%|██████████| 17/17 [00:07<00:00,  2.15it/s]


Epoch 138/300, Training Loss: 0.1830
Validation Loss: 0.2642
Validation Accuracy: 90.25%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 139/300, Training Loss: 0.1823
Validation Loss: 0.2664
Validation Accuracy: 90.00%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 140/300, Training Loss: 0.1772
Validation Loss: 0.2668
Validation Accuracy: 89.50%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 141/300, Training Loss: 0.1764
Validation Loss: 0.2634
Validation Accuracy: 89.50%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 142/300, Training Loss: 0.1767
Validation Loss: 0.2764
Validation Accuracy: 89.50%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 143/300, Training Loss: 0.1741
Validation Loss: 0.2667
Validation Accuracy: 89.25%


100%|██████████| 17/17 [00:07<00:00,  2.13it/s]


Epoch 144/300, Training Loss: 0.1709
Validation Loss: 0.2638
Validation Accuracy: 89.25%


100%|██████████| 17/17 [00:07<00:00,  2.13it/s]


Epoch 145/300, Training Loss: 0.1705
Validation Loss: 0.2635
Validation Accuracy: 90.25%


100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Epoch 146/300, Training Loss: 0.1709
Validation Loss: 0.2647
Validation Accuracy: 88.00%
Early stopping triggered.
Training completed.


In [23]:
training_dir = os.path.join('checkpoints', 'run_003')

# Load the best model weights
model.load_state_dict(torch.load(os.path.join(training_dir, 'best_model.pth')))

# Calculate and print accuracy on validation set
accuracy = evaluate_accuracy(model, train_loader)
print(f'Test Accuracy: {accuracy:.2f}%')

Test Accuracy: 91.40%
