In [9]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import sys
print(sys.executable)
sys.path.insert(1, '../src/')
from config import raw_data_path, univariate_data_path, processed_data_path, models_path
from preprocessing_modules import create_time_windows_with_labels, create_time_windows_with_metadata
from skopt import gp_minimize
from skopt.space import Real, Integer, Categorical
from skopt.utils import use_named_args
from skopt import gp_minimize
from skopt.space import Real, Integer
from skopt.utils import use_named_args
import matplotlib.pyplot as plt
from FCMAE_model import FCMAE
from fastai.callback.tracker import SaveModelCallback, EarlyStoppingCallback


/home/nwertheim/miniconda3/bin/python


In [10]:
# Define FCMAE Hyperparameters:
batch_size = 16
num_blocks = 5
kernel_size = 7
base_dim = 32
learning_rate = 0.00016938000495408888
input_dimension = 512

num_epochs = 10

In [11]:

import torch
import numpy as np
import os
from torch.utils.data import TensorDataset, DataLoader, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight
import optuna
from fastai.learner import Learner
from fastai.data.core import DataLoaders
from fastai.metrics import accuracy
from tsai.models.ResNet import ResNet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from fastai.optimizer import SGD
# Load training data
train_dir = os.path.join(univariate_data_path, "target_univariate_no_PCA_train_2_80.npy")
train_data = np.load(train_dir, allow_pickle=True)
train_windows, train_labels, train_meta = create_time_windows_with_metadata(train_data)

# Convert to tensors
train_windows_tensor = torch.tensor(train_windows, dtype=torch.float32).unsqueeze(1)  # shape: (N, 1, 12000)
train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)

# Compute class weights
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

# Prepare dataset
full_dataset = TensorDataset(train_windows_tensor, train_labels_tensor)

In [12]:



# Load the pre-trained encoder
path = os.path.join(models_path, 'FCMAE_encoder_no_PCA_gpu_normalized_correct.pth')

# Load the full FCMAE model
fcmae = FCMAE(in_channels=1, base_dim=base_dim, num_blocks=num_blocks, kernel_size=kernel_size)


# Load the saved weights into the full model
state_dict = torch.load(path)
fcmae.load_state_dict(state_dict)

# Extract only the encoder from the FCMAE model
encoder = fcmae.encoder
# Set the encoder to evaluation mode and freeze the encoder parameters
encoder.eval()
for param in encoder.parameters():
    param.requires_grad = False  # Freeze the encoder

print(encoder)

Sequential(
  (0): Conv1d(1, 32, kernel_size=(7,), stride=(2,), padding=(3,))
  (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Conv1d(32, 64, kernel_size=(7,), stride=(2,), padding=(3,))
  (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
  (6): Conv1d(64, 128, kernel_size=(7,), stride=(2,), padding=(3,))
  (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ReLU()
  (9): Conv1d(128, 256, kernel_size=(7,), stride=(2,), padding=(3,))
  (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ReLU()
  (12): Conv1d(256, 512, kernel_size=(7,), stride=(2,), padding=(3,))
  (13): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (14): ReLU()
)


In [13]:
'''tuneable optimizer, lower LR-range, early stopping'''
from fastai.learner import Metric
from sklearn.metrics import average_precision_score
from fastai.optimizer import SGD, Adam


# Custom Average Precision Metric
class AveragePrecision(Metric):
    def __init__(self):
        self.pred = []
        self.target = []

    def reset(self):
        self.pred, self.target = [], []

    def accumulate(self, learn):
        preds = learn.pred.detach().softmax(dim=-1)[:, 1]
        targs = learn.y.detach()
        self.pred.append(preds.cpu())
        self.target.append(targs.cpu())

    @property
    def value(self):
        preds = torch.cat(self.pred).numpy()
        targs = torch.cat(self.target).numpy()
        return average_precision_score(targs, preds)

    @property
    def name(self): return "avg_precision"

def objective(trial):
    lr = trial.suggest_loguniform('lr', 1e-7, 1e-6)
    optimizer_name = trial.suggest_categorical('optimizer', ['SGD', 'Adam'])

    # Map optimizer name to fastai optimizer function
    opt_func_map = {
    'SGD': SGD,
    'Adam': Adam,
    }

    opt_func = opt_func_map[optimizer_name]

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    avg_precisions = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(train_windows_tensor, train_labels_tensor)):
        train_subset = Subset(full_dataset, train_idx)
        val_subset = Subset(full_dataset, val_idx)

        train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=64, shuffle=False)
        dls = DataLoaders(train_loader, val_loader)

        class FCMAEClassifier(nn.Module):
            def __init__(self, encoder):
                super().__init__()
                self.encoder = encoder
                self.classifier = ResNet(input_dimension, 2)

            def forward(self, x):
                with torch.no_grad():
                    x = self.encoder(x)
                return self.classifier(x)

        model = FCMAEClassifier(encoder).to(device)
        loss_func = torch.nn.CrossEntropyLoss(weight=class_weights_tensor)

        learn = learn = Learner(
            dls,
            model,
            loss_func=loss_func,
            opt_func=opt_func,
            metrics=accuracy,
            cbs=[
                EarlyStoppingCallback(monitor='valid_loss', patience=3),
            ]
        )

        learn.fit_one_cycle(num_epochs, lr)
        ap = learn.validate()[1]  # [loss, avg_precision]
        avg_precisions.append(ap)

    return np.mean(avg_precisions)  # maximize this

# Run Optuna study
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=25)

print("Best trial:")
print(study.best_trial)


# Save the study
import joblib
joblib.dump(study, "FCMAE_ResNet_tune_80.pkl")


[I 2025-05-15 21:33:28,046] A new study created in memory with name: no-name-04ec4a0d-19a9-4fc8-8271-265743179d1e


epoch,train_loss,valid_loss,accuracy,time
0,0.706805,0.698517,0.463492,00:00
1,0.705398,0.704388,0.389206,00:00
2,0.70796,0.705494,0.447619,00:00
3,0.705986,0.700187,0.497143,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.700093,0.690621,0.554638,00:00
1,0.699178,0.694171,0.554638,00:00
2,0.69685,0.690925,0.554003,00:00
3,0.698964,0.690684,0.555273,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.704644,0.717994,0.33418,00:00
1,0.704575,0.719159,0.333545,00:00
2,0.704604,0.718267,0.33418,00:00
3,0.702945,0.717967,0.33418,00:00
4,0.704314,0.717479,0.33418,00:00
5,0.703004,0.71844,0.33418,00:00
6,0.704835,0.718344,0.33418,00:00
7,0.707306,0.717568,0.33418,00:00


No improvement since epoch 4: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.696441,0.697226,0.451715,00:00
1,0.695506,0.697115,0.446633,00:00
2,0.696307,0.696584,0.450445,00:00
3,0.695256,0.696407,0.459339,00:00
4,0.696296,0.69719,0.447903,00:00
5,0.695817,0.696528,0.440915,00:00
6,0.696259,0.696329,0.45108,00:00
7,0.696422,0.696975,0.447268,00:00
8,0.696076,0.69616,0.463151,00:00
9,0.696669,0.696485,0.43075,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.740536,0.725989,0.666455,00:00
1,0.738487,0.719119,0.666455,00:00
2,0.73858,0.712155,0.666455,00:00
3,0.737857,0.71987,0.666455,00:00
4,0.737819,0.714874,0.666455,00:00
5,0.740004,0.716164,0.666455,00:00


No improvement since epoch 2: early stopping


[I 2025-05-15 21:33:50,859] Trial 0 finished with value: 0.4967602133750916 and parameters: {'lr': 4.167985891932696e-07, 'optimizer': 'SGD'}. Best is trial 0 with value: 0.4967602133750916.


epoch,train_loss,valid_loss,accuracy,time
0,0.715151,0.708239,0.666667,00:00
1,0.706941,0.695505,0.668571,00:00
2,0.700059,0.687985,0.659683,00:00
3,0.690864,0.683895,0.652063,00:00
4,0.684313,0.678033,0.633651,00:00
5,0.681895,0.674404,0.624762,00:00
6,0.680112,0.674775,0.604444,00:00
7,0.678867,0.678376,0.570159,00:00
8,0.67824,0.67575,0.612063,00:00


No improvement since epoch 5: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.739632,0.729453,0.33291,00:00
1,0.731671,0.721826,0.33291,00:00
2,0.725033,0.717517,0.33291,00:00
3,0.715382,0.708939,0.33291,00:00
4,0.707946,0.704591,0.33291,00:00
5,0.704942,0.704691,0.33291,00:00
6,0.701648,0.698742,0.33291,00:00
7,0.70084,0.707181,0.33291,00:00
8,0.700452,0.697917,0.340534,00:00
9,0.699151,0.695628,0.339263,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.69564,0.679891,0.609276,00:00
1,0.692361,0.678321,0.622618,00:00
2,0.686865,0.673856,0.615629,00:00
3,0.684213,0.669524,0.592757,00:00
4,0.680697,0.667699,0.596569,00:00
5,0.678477,0.667219,0.586404,00:00
6,0.677379,0.669023,0.579416,00:00
7,0.674554,0.665531,0.584498,00:00
8,0.675812,0.665556,0.583227,00:00
9,0.679099,0.67065,0.576239,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.715962,0.706736,0.655654,00:00
1,0.709832,0.691777,0.665184,00:00
2,0.701722,0.687892,0.663914,00:00
3,0.692476,0.681868,0.650572,00:00
4,0.685469,0.679521,0.617535,00:00
5,0.683626,0.676063,0.613723,00:00
6,0.678525,0.676973,0.599746,00:00
7,0.67756,0.67505,0.571792,00:00
8,0.678236,0.675461,0.574968,00:00
9,0.677788,0.674215,0.590851,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.703882,0.718744,0.331639,00:00
1,0.702527,0.710992,0.333545,00:00
2,0.697551,0.700895,0.35324,00:00
3,0.693007,0.69401,0.456163,00:00
4,0.689487,0.690441,0.444727,00:00
5,0.68658,0.688777,0.448539,00:00
6,0.686284,0.6916,0.41169,00:00
7,0.684301,0.682915,0.499365,00:00
8,0.683756,0.68422,0.472681,00:00
9,0.682476,0.688221,0.42249,00:00


[I 2025-05-15 21:34:34,610] Trial 1 finished with value: 0.5081814289093017 and parameters: {'lr': 6.326305372949287e-07, 'optimizer': 'Adam'}. Best is trial 1 with value: 0.5081814289093017.


epoch,train_loss,valid_loss,accuracy,time
0,0.697766,0.695228,0.347302,00:00
1,0.691753,0.689388,0.433016,00:00
2,0.686227,0.683486,0.496508,00:00
3,0.68138,0.679371,0.506667,00:00
4,0.678378,0.674067,0.512381,00:00
5,0.674341,0.672237,0.512381,00:00
6,0.673203,0.672827,0.510476,00:00
7,0.672773,0.674362,0.510476,00:00
8,0.671958,0.673678,0.520635,00:00


No improvement since epoch 5: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.708381,0.712141,0.672173,00:00
1,0.706324,0.696686,0.669632,00:00
2,0.695331,0.68232,0.671537,00:00
3,0.687788,0.675139,0.658196,00:00
4,0.685128,0.673071,0.643583,00:00
5,0.681641,0.669774,0.645489,00:00
6,0.681561,0.664222,0.635324,00:00
7,0.678803,0.672433,0.608005,00:00
8,0.679649,0.666607,0.621982,00:00
9,0.679637,0.671684,0.628335,00:00


No improvement since epoch 6: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.705472,0.710482,0.33291,00:00
1,0.699451,0.705677,0.33291,00:00
2,0.696,0.698361,0.337357,00:00
3,0.690995,0.697008,0.404066,00:00
4,0.685546,0.686586,0.522236,00:00
5,0.683779,0.684266,0.522872,00:00
6,0.68014,0.683768,0.515883,00:00
7,0.680735,0.682798,0.522236,00:00
8,0.679394,0.700532,0.398983,00:00
9,0.678384,0.68021,0.52859,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.693609,0.700825,0.385006,00:00
1,0.690744,0.69656,0.455527,00:00
2,0.686275,0.690583,0.487294,00:00
3,0.681162,0.688255,0.466328,00:00
4,0.677749,0.684422,0.487294,00:00
5,0.676827,0.678984,0.5,00:00
6,0.672165,0.677923,0.505718,00:00
7,0.67354,0.681214,0.499365,00:00
8,0.67178,0.675241,0.506989,00:00
9,0.670494,0.678726,0.5,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.728686,0.745098,0.333545,00:00
1,0.723982,0.734017,0.333545,00:00
2,0.71551,0.716391,0.333545,00:00
3,0.706905,0.708287,0.333545,00:00
4,0.698199,0.701247,0.333545,00:00
5,0.691964,0.697771,0.385642,00:00
6,0.690227,0.692385,0.423126,00:00
7,0.68827,0.690884,0.437103,00:00
8,0.689167,0.692292,0.425667,00:00
9,0.686723,0.690288,0.435832,00:00


[I 2025-05-15 21:35:18,371] Trial 2 finished with value: 0.5226784586906433 and parameters: {'lr': 9.86612434312205e-07, 'optimizer': 'Adam'}. Best is trial 2 with value: 0.5226784586906433.


epoch,train_loss,valid_loss,accuracy,time
0,0.696669,0.700826,0.504762,00:00
1,0.693449,0.695757,0.530794,00:00
2,0.685422,0.688787,0.499048,00:00
3,0.683686,0.686783,0.485714,00:00
4,0.679171,0.679212,0.506032,00:00
5,0.677339,0.686002,0.470476,00:00
6,0.675394,0.676559,0.503492,00:00
7,0.675777,0.67872,0.498413,00:00
8,0.673902,0.682896,0.486349,00:00
9,0.673267,0.681545,0.489524,00:00


No improvement since epoch 6: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.803414,0.758305,0.33291,00:00
1,0.791044,0.758936,0.33291,00:00
2,0.771433,0.72499,0.33291,00:00
3,0.758767,0.732069,0.33291,00:00
4,0.740332,0.722827,0.33291,00:00
5,0.735374,0.722856,0.33291,00:00
6,0.728612,0.710603,0.33291,00:00
7,0.722071,0.7224,0.33291,00:00
8,0.718065,0.708655,0.33291,00:00
9,0.718352,0.713993,0.33291,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.735777,0.731739,0.33291,00:00
1,0.725082,0.722691,0.33291,00:00
2,0.715866,0.719131,0.33291,00:00
3,0.708961,0.705904,0.33418,00:00
4,0.705392,0.704366,0.334816,00:00
5,0.694752,0.699334,0.395807,00:00
6,0.692074,0.695278,0.423126,00:00
7,0.692066,0.69326,0.423126,00:00
8,0.692315,0.696066,0.407243,00:00
9,0.690154,0.699038,0.401525,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.69978,0.708234,0.355146,00:00
1,0.695432,0.699283,0.397078,00:00
2,0.688884,0.693741,0.42122,00:00
3,0.683196,0.686247,0.503812,00:00
4,0.681341,0.683654,0.505718,00:00
5,0.677533,0.680827,0.505083,00:00
6,0.67552,0.680069,0.508259,00:00
7,0.673606,0.678858,0.505083,00:00
8,0.673983,0.682509,0.498729,00:00
9,0.672634,0.677542,0.508259,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.717294,0.722986,0.461881,00:00
1,0.708005,0.711818,0.464422,00:00
2,0.698683,0.697053,0.430114,00:00
3,0.693335,0.687284,0.599746,00:00
4,0.686277,0.683035,0.568615,00:00
5,0.682487,0.682482,0.544473,00:00
6,0.682663,0.673197,0.584498,00:00
7,0.680526,0.679519,0.544473,00:00
8,0.680219,0.674607,0.592757,00:00
9,0.6793,0.689142,0.452986,00:00


No improvement since epoch 6: early stopping


[I 2025-05-15 21:36:03,004] Trial 3 finished with value: 0.4370407283306122 and parameters: {'lr': 7.61441250532136e-07, 'optimizer': 'Adam'}. Best is trial 2 with value: 0.5226784586906433.


epoch,train_loss,valid_loss,accuracy,time
0,0.69704,0.695419,0.513651,00:00
1,0.698737,0.700269,0.516825,00:00
2,0.698352,0.695324,0.512381,00:00
3,0.696058,0.696294,0.516825,00:00
4,0.696144,0.694998,0.512381,00:00
5,0.697096,0.692094,0.51619,00:00
6,0.695884,0.695911,0.513016,00:00
7,0.696399,0.693474,0.514286,00:00
8,0.696243,0.695549,0.513651,00:00


No improvement since epoch 5: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.716823,0.701479,0.347522,00:00
1,0.713208,0.699436,0.357052,00:00
2,0.711794,0.701717,0.34371,00:00
3,0.710292,0.700737,0.346887,00:00
4,0.711097,0.706986,0.341169,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.695304,0.695807,0.49047,00:00
1,0.695713,0.693209,0.50953,00:00
2,0.694921,0.692849,0.504447,00:00
3,0.695448,0.694328,0.506989,00:00
4,0.696204,0.696251,0.488564,00:00
5,0.695278,0.693388,0.506989,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.723924,0.710747,0.550826,00:00
1,0.726394,0.707258,0.552097,00:00
2,0.727872,0.716868,0.519695,00:00
3,0.72273,0.710139,0.567344,00:00
4,0.721402,0.710935,0.531131,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.725037,0.726879,0.333545,00:00
1,0.726483,0.722693,0.333545,00:00
2,0.726509,0.726053,0.333545,00:00
3,0.728619,0.727135,0.333545,00:00
4,0.727099,0.71066,0.333545,00:00
5,0.727468,0.726987,0.333545,00:00
6,0.726106,0.724705,0.333545,00:00
7,0.727548,0.728972,0.333545,00:00


No improvement since epoch 4: early stopping


[I 2025-05-15 21:36:26,468] Trial 4 finished with value: 0.4452968657016754 and parameters: {'lr': 1.5897310496482612e-07, 'optimizer': 'SGD'}. Best is trial 2 with value: 0.5226784586906433.


epoch,train_loss,valid_loss,accuracy,time
0,0.733135,0.734999,0.524444,00:00
1,0.731434,0.733089,0.526349,00:00
2,0.732976,0.723394,0.535873,00:00
3,0.730158,0.723843,0.526984,00:00
4,0.731107,0.728985,0.52127,00:00
5,0.73345,0.741143,0.523175,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.691421,0.684059,0.545108,00:00
1,0.693949,0.684053,0.54892,00:00
2,0.691897,0.684073,0.54892,00:00
3,0.69279,0.683629,0.555273,00:00
4,0.694186,0.683802,0.555273,00:00
5,0.692792,0.684356,0.553367,00:00
6,0.691599,0.6847,0.538755,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.712663,0.707332,0.425667,00:00
1,0.712958,0.710795,0.43075,00:00
2,0.713675,0.707559,0.432656,00:00
3,0.712811,0.701386,0.507624,00:00
4,0.710292,0.70375,0.425667,00:00
5,0.712167,0.711466,0.429479,00:00
6,0.712328,0.707247,0.427573,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.696204,0.697239,0.632147,00:00
1,0.696398,0.697505,0.627065,00:00
2,0.696001,0.696929,0.635959,00:00
3,0.697563,0.698401,0.609276,00:00
4,0.696925,0.698033,0.628335,00:00
5,0.697089,0.697986,0.604193,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.776381,0.752127,0.666455,00:00
1,0.775588,0.765352,0.666455,00:00
2,0.778943,0.73679,0.666455,00:00
3,0.781978,0.750676,0.666455,00:00
4,0.781659,0.763628,0.666455,00:00
5,0.77817,0.743732,0.666455,00:00


No improvement since epoch 2: early stopping


[I 2025-05-15 21:36:49,248] Trial 5 finished with value: 0.5520300924777984 and parameters: {'lr': 1.0126272833972606e-07, 'optimizer': 'SGD'}. Best is trial 5 with value: 0.5520300924777984.


epoch,train_loss,valid_loss,accuracy,time
0,0.703264,0.705301,0.389206,00:00
1,0.700829,0.695698,0.373333,00:00
2,0.692588,0.693245,0.388571,00:00
3,0.686376,0.687797,0.507937,00:00
4,0.682174,0.683605,0.512381,00:00
5,0.678353,0.691166,0.44,00:00
6,0.679944,0.680777,0.507302,00:00
7,0.676964,0.684249,0.502857,00:00
8,0.677194,0.682906,0.502222,00:00
9,0.677804,0.680322,0.509841,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.688516,0.688276,0.52033,00:00
1,0.686718,0.682498,0.55845,00:00
2,0.684924,0.67603,0.564168,00:00
3,0.68297,0.676863,0.555909,00:00
4,0.681802,0.680268,0.522236,00:00
5,0.680514,0.679347,0.526684,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.74668,0.749233,0.33291,00:00
1,0.737655,0.741897,0.33291,00:00
2,0.731635,0.740051,0.33291,00:00
3,0.724967,0.730607,0.33291,00:00
4,0.717647,0.725822,0.33291,00:00
5,0.712911,0.720551,0.33291,00:00
6,0.709063,0.715019,0.33291,00:00
7,0.706743,0.729077,0.333545,00:00
8,0.706295,0.71337,0.333545,00:00
9,0.710005,0.716159,0.33291,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.696577,0.694919,0.576239,00:00
1,0.694472,0.690007,0.526684,00:00
2,0.690283,0.686597,0.505718,00:00
3,0.685001,0.683703,0.503812,00:00
4,0.682753,0.680903,0.506353,00:00
5,0.68164,0.680664,0.518424,00:00
6,0.678231,0.678364,0.510165,00:00
7,0.678696,0.678318,0.512706,00:00
8,0.675703,0.677279,0.508895,00:00
9,0.676468,0.678196,0.51906,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.72858,0.728991,0.333545,00:00
1,0.720334,0.7187,0.333545,00:00
2,0.715366,0.715375,0.333545,00:00
3,0.708797,0.706817,0.333545,00:00
4,0.702366,0.701802,0.335451,00:00
5,0.698033,0.695613,0.407243,00:00
6,0.696492,0.69661,0.414867,00:00
7,0.696361,0.692758,0.461245,00:00
8,0.692811,0.697652,0.412961,00:00
9,0.69226,0.693598,0.433926,00:00


[I 2025-05-15 21:37:30,319] Trial 6 finished with value: 0.4644841432571411 and parameters: {'lr': 6.126120654796146e-07, 'optimizer': 'Adam'}. Best is trial 5 with value: 0.5520300924777984.


epoch,train_loss,valid_loss,accuracy,time
0,0.709844,0.698516,0.650159,00:00
1,0.705771,0.694583,0.652698,00:00
2,0.703922,0.695223,0.648254,00:00
3,0.703923,0.693285,0.643175,00:00
4,0.703446,0.693507,0.643175,00:00
5,0.700782,0.692,0.643175,00:00
6,0.700758,0.6926,0.645079,00:00
7,0.701212,0.690012,0.63619,00:00
8,0.699652,0.692682,0.64254,00:00
9,0.700412,0.690492,0.64254,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.704281,0.709314,0.365947,00:00
1,0.705767,0.709738,0.33418,00:00
2,0.704242,0.706023,0.370394,00:00
3,0.704652,0.709856,0.333545,00:00
4,0.699763,0.709287,0.337357,00:00
5,0.70115,0.709476,0.339898,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.710328,0.711638,0.484117,00:00
1,0.708576,0.719386,0.484752,00:00
2,0.707122,0.714633,0.482846,00:00
3,0.706675,0.710067,0.470775,00:00
4,0.702463,0.721782,0.477128,00:00
5,0.702127,0.713681,0.465057,00:00
6,0.70164,0.71302,0.468234,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.738459,0.746641,0.332274,00:00
1,0.738602,0.730111,0.333545,00:00
2,0.736456,0.728282,0.333545,00:00
3,0.734609,0.733643,0.333545,00:00
4,0.73218,0.75431,0.333545,00:00
5,0.730175,0.73253,0.333545,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.707158,0.72006,0.33291,00:00
1,0.704813,0.716605,0.333545,00:00
2,0.706009,0.716154,0.344981,00:00
3,0.705693,0.716551,0.344981,00:00
4,0.70424,0.715238,0.374206,00:00
5,0.702773,0.713516,0.374841,00:00
6,0.702669,0.713814,0.34244,00:00
7,0.702875,0.710687,0.365311,00:00
8,0.702289,0.712369,0.385006,00:00
9,0.702543,0.710746,0.357687,00:00


[I 2025-05-15 21:38:05,225] Trial 7 finished with value: 0.4283808708190918 and parameters: {'lr': 1.0882781771456574e-07, 'optimizer': 'Adam'}. Best is trial 5 with value: 0.5520300924777984.


epoch,train_loss,valid_loss,accuracy,time
0,0.695955,0.688923,0.516825,00:00
1,0.693213,0.686997,0.518095,00:00
2,0.690078,0.686911,0.523175,00:00
3,0.689753,0.686685,0.528889,00:00
4,0.68783,0.683863,0.513651,00:00
5,0.689077,0.684461,0.514921,00:00
6,0.688295,0.683369,0.516825,00:00
7,0.686961,0.683727,0.509841,00:00
8,0.685818,0.683777,0.519365,00:00
9,0.685422,0.684224,0.515556,00:00


No improvement since epoch 6: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.714741,0.701503,0.361499,00:00
1,0.715174,0.700675,0.333545,00:00
2,0.711212,0.700567,0.333545,00:00
3,0.708292,0.699801,0.333545,00:00
4,0.709269,0.697426,0.33418,00:00
5,0.70705,0.694366,0.334816,00:00
6,0.70403,0.695183,0.334816,00:00
7,0.702063,0.69529,0.336722,00:00
8,0.702741,0.694338,0.335451,00:00
9,0.703012,0.694252,0.335451,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.713683,0.730511,0.33291,00:00
1,0.713604,0.731134,0.33291,00:00
2,0.708487,0.727432,0.33291,00:00
3,0.706993,0.723657,0.33291,00:00
4,0.704837,0.718206,0.33291,00:00
5,0.702752,0.719589,0.33291,00:00
6,0.700617,0.717799,0.33291,00:00
7,0.701,0.711588,0.33291,00:00
8,0.698857,0.711236,0.33291,00:00
9,0.701509,0.719046,0.33291,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.743832,0.723954,0.501906,00:00
1,0.74129,0.725891,0.506989,00:00
2,0.732658,0.717472,0.506989,00:00
3,0.72836,0.710603,0.505718,00:00
4,0.720459,0.704365,0.522872,00:00
5,0.717991,0.703277,0.551461,00:00
6,0.714451,0.698324,0.656925,00:00
7,0.710478,0.696004,0.644854,00:00
8,0.710518,0.695121,0.576874,00:00
9,0.709792,0.697636,0.666455,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.688593,0.686163,0.55972,00:00
1,0.688391,0.683663,0.550191,00:00
2,0.687426,0.681209,0.552732,00:00
3,0.685248,0.680752,0.541296,00:00
4,0.683953,0.681762,0.554638,00:00
5,0.682106,0.676872,0.55845,00:00
6,0.681758,0.677327,0.554003,00:00
7,0.68086,0.678442,0.537484,00:00
8,0.680725,0.682541,0.534307,00:00


No improvement since epoch 5: early stopping


[I 2025-05-15 21:38:48,935] Trial 8 finished with value: 0.4769357621669769 and parameters: {'lr': 2.3857341351332107e-07, 'optimizer': 'Adam'}. Best is trial 5 with value: 0.5520300924777984.


epoch,train_loss,valid_loss,accuracy,time
0,0.691601,0.684868,0.524444,00:00
1,0.689064,0.683879,0.525714,00:00
2,0.687788,0.682707,0.525079,00:00
3,0.686508,0.681211,0.526349,00:00
4,0.684886,0.68112,0.538413,00:00
5,0.683841,0.679966,0.533968,00:00
6,0.682317,0.679913,0.532698,00:00
7,0.681715,0.679778,0.530159,00:00
8,0.681855,0.679312,0.535238,00:00
9,0.682248,0.680229,0.532063,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.704488,0.680153,0.559085,00:00
1,0.702702,0.680202,0.556544,00:00
2,0.703217,0.677674,0.556544,00:00
3,0.69797,0.678285,0.551461,00:00
4,0.692901,0.676877,0.55972,00:00
5,0.689919,0.675756,0.555273,00:00
6,0.691127,0.675616,0.55845,00:00
7,0.687088,0.674485,0.557179,00:00
8,0.685508,0.674048,0.552732,00:00
9,0.687209,0.674682,0.553367,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.70001,0.701357,0.473316,00:00
1,0.700657,0.694696,0.522872,00:00
2,0.693983,0.692811,0.526048,00:00
3,0.690759,0.690348,0.536213,00:00
4,0.690253,0.688798,0.540025,00:00
5,0.687746,0.68719,0.540661,00:00
6,0.687439,0.686207,0.527319,00:00
7,0.687112,0.686344,0.535578,00:00
8,0.687027,0.685765,0.527954,00:00
9,0.685865,0.685482,0.527954,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.706097,0.709921,0.338628,00:00
1,0.70608,0.709466,0.340534,00:00
2,0.701246,0.703081,0.334816,00:00
3,0.698494,0.70436,0.335451,00:00
4,0.695787,0.70151,0.335451,00:00
5,0.694972,0.698811,0.335451,00:00
6,0.693142,0.697862,0.336086,00:00
7,0.69308,0.695382,0.348793,00:00
8,0.693771,0.695717,0.368488,00:00
9,0.694192,0.696402,0.399619,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.755623,0.72925,0.614358,00:00
1,0.749517,0.718245,0.662008,00:00
2,0.743112,0.718151,0.666455,00:00
3,0.738359,0.719844,0.666455,00:00
4,0.729431,0.715628,0.667726,00:00
5,0.72704,0.698608,0.666455,00:00
6,0.720587,0.701725,0.66709,00:00
7,0.718487,0.695171,0.667726,00:00
8,0.715001,0.702294,0.667726,00:00
9,0.716537,0.699737,0.667726,00:00


[I 2025-05-15 21:39:33,611] Trial 9 finished with value: 0.5361458659172058 and parameters: {'lr': 2.9718062311176007e-07, 'optimizer': 'Adam'}. Best is trial 5 with value: 0.5520300924777984.


epoch,train_loss,valid_loss,accuracy,time
0,0.706847,0.702851,0.64127,00:00
1,0.703303,0.701739,0.645714,00:00
2,0.70507,0.700189,0.657778,00:00
3,0.705159,0.701885,0.651429,00:00
4,0.706063,0.702781,0.64127,00:00
5,0.705156,0.703611,0.641905,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.698069,0.709262,0.371029,00:00
1,0.699291,0.707199,0.370394,00:00
2,0.698804,0.705999,0.367853,00:00
3,0.698668,0.706762,0.371029,00:00
4,0.699086,0.70879,0.376747,00:00
5,0.700022,0.705649,0.374206,00:00
6,0.699453,0.708843,0.371029,00:00
7,0.698263,0.7083,0.367217,00:00
8,0.699437,0.709522,0.372935,00:00


No improvement since epoch 5: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.703214,0.704665,0.431385,00:00
1,0.700736,0.708246,0.415502,00:00
2,0.700149,0.708206,0.418043,00:00
3,0.700025,0.709498,0.415502,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.705472,0.692033,0.664549,00:00
1,0.705487,0.691772,0.673443,00:00
2,0.70382,0.692095,0.674079,00:00
3,0.703917,0.69158,0.673443,00:00
4,0.706998,0.690923,0.675985,00:00
5,0.704156,0.69115,0.675985,00:00
6,0.704836,0.691818,0.675985,00:00
7,0.703728,0.691854,0.670267,00:00


No improvement since epoch 4: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.691897,0.688637,0.655019,00:00
1,0.693091,0.687934,0.655019,00:00
2,0.692857,0.687773,0.658831,00:00
3,0.692554,0.689083,0.651207,00:00
4,0.693193,0.689358,0.655019,00:00
5,0.69279,0.688013,0.65756,00:00


No improvement since epoch 2: early stopping


[I 2025-05-15 21:39:57,031] Trial 10 finished with value: 0.5516338050365448 and parameters: {'lr': 1.1945992484694896e-07, 'optimizer': 'SGD'}. Best is trial 5 with value: 0.5520300924777984.


epoch,train_loss,valid_loss,accuracy,time
0,0.690217,0.695151,0.452064,00:00
1,0.690108,0.691352,0.528889,00:00
2,0.689514,0.691515,0.540317,00:00
3,0.690791,0.694044,0.468571,00:00
4,0.690445,0.692745,0.507937,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.693884,0.691067,0.614994,00:00
1,0.695089,0.692023,0.623888,00:00
2,0.695068,0.691448,0.620076,00:00
3,0.695689,0.690943,0.6277,00:00
4,0.696117,0.692129,0.6169,00:00
5,0.694249,0.691436,0.639136,00:00
6,0.695099,0.6921,0.637865,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.733972,0.70594,0.66709,00:00
1,0.734993,0.711075,0.66709,00:00
2,0.731632,0.716215,0.66709,00:00
3,0.7342,0.713454,0.66709,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.695206,0.692258,0.663278,00:00
1,0.694423,0.69216,0.662643,00:00
2,0.695504,0.69119,0.665184,00:00
3,0.696319,0.691671,0.662008,00:00
4,0.696419,0.691627,0.662643,00:00
5,0.693896,0.692603,0.663278,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.701456,0.703079,0.357052,00:00
1,0.698067,0.7026,0.358323,00:00
2,0.701349,0.702983,0.357687,00:00
3,0.700765,0.700005,0.351334,00:00
4,0.700832,0.704508,0.359593,00:00
5,0.699926,0.705388,0.360229,00:00
6,0.700618,0.70118,0.355146,00:00


No improvement since epoch 3: early stopping


[I 2025-05-15 21:40:17,688] Trial 11 finished with value: 0.5662632942199707 and parameters: {'lr': 1.1568422321572289e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.697531,0.695207,0.382857,00:00
1,0.694747,0.6951,0.379048,00:00
2,0.695943,0.69717,0.354921,00:00
3,0.696201,0.693495,0.398095,00:00
4,0.69348,0.696524,0.377143,00:00
5,0.697076,0.693283,0.419683,00:00
6,0.695599,0.696625,0.365714,00:00
7,0.69619,0.696692,0.358095,00:00
8,0.696382,0.696441,0.363175,00:00


No improvement since epoch 5: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.758781,0.755745,0.33291,00:00
1,0.763649,0.743665,0.33291,00:00
2,0.759501,0.742455,0.33291,00:00
3,0.76316,0.742301,0.33291,00:00
4,0.763236,0.753314,0.33291,00:00
5,0.75892,0.754575,0.33291,00:00
6,0.764479,0.740746,0.33291,00:00
7,0.761797,0.743795,0.33291,00:00
8,0.761809,0.748139,0.33291,00:00
9,0.761926,0.739746,0.33291,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.704972,0.691954,0.675349,00:00
1,0.70433,0.691391,0.674079,00:00
2,0.703732,0.691772,0.673443,00:00
3,0.703619,0.691871,0.675349,00:00
4,0.703274,0.691586,0.674079,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.701107,0.707485,0.494282,00:00
1,0.702336,0.709859,0.494917,00:00
2,0.702269,0.710463,0.493647,00:00
3,0.704377,0.709457,0.493647,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.727687,0.731137,0.333545,00:00
1,0.727827,0.736551,0.333545,00:00
2,0.734353,0.722049,0.332274,00:00
3,0.733304,0.724045,0.332274,00:00
4,0.729519,0.724794,0.33291,00:00
5,0.733547,0.73886,0.333545,00:00


No improvement since epoch 2: early stopping


[I 2025-05-15 21:40:41,838] Trial 12 finished with value: 0.4394710123538971 and parameters: {'lr': 1.7199645256293796e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.71289,0.708757,0.373333,00:00
1,0.713011,0.710607,0.445079,00:00
2,0.713021,0.703105,0.464127,00:00
3,0.709963,0.71064,0.420952,00:00
4,0.71222,0.70622,0.431746,00:00
5,0.712884,0.710562,0.36381,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.714459,0.699729,0.53939,00:00
1,0.709388,0.696236,0.532402,00:00
2,0.714761,0.692139,0.527319,00:00
3,0.710998,0.695185,0.53939,00:00
4,0.709719,0.695333,0.522872,00:00
5,0.714723,0.695034,0.530496,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.715603,0.712839,0.33291,00:00
1,0.714209,0.722204,0.33291,00:00
2,0.71806,0.726753,0.33291,00:00
3,0.715341,0.716491,0.33291,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.710087,0.717592,0.35324,00:00
1,0.707689,0.718431,0.34371,00:00
2,0.709018,0.719008,0.351334,00:00
3,0.70873,0.717815,0.351334,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.720515,0.726771,0.333545,00:00
1,0.720206,0.726414,0.333545,00:00
2,0.717993,0.729981,0.333545,00:00
3,0.719624,0.729031,0.333545,00:00
4,0.723575,0.736462,0.333545,00:00


No improvement since epoch 1: early stopping


[I 2025-05-15 21:40:59,787] Trial 13 finished with value: 0.38241882920265197 and parameters: {'lr': 1.0080218600394519e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.799319,0.788905,0.666667,00:00
1,0.799766,0.780985,0.666667,00:00
2,0.799696,0.787205,0.666667,00:00
3,0.802893,0.788707,0.666667,00:00
4,0.798267,0.785027,0.666667,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.71815,0.714983,0.653748,00:00
1,0.717141,0.706889,0.637865,00:00
2,0.721347,0.719834,0.642948,00:00
3,0.721,0.707083,0.641042,00:00
4,0.719653,0.704801,0.63723,00:00
5,0.718846,0.713399,0.647395,00:00
6,0.721785,0.712206,0.645489,00:00
7,0.718315,0.710886,0.649936,00:00


No improvement since epoch 4: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.710356,0.696072,0.658196,00:00
1,0.70721,0.698144,0.635324,00:00
2,0.708537,0.696635,0.660737,00:00
3,0.708827,0.693924,0.663278,00:00
4,0.709056,0.698146,0.649936,00:00
5,0.7093,0.696875,0.651842,00:00
6,0.708169,0.69379,0.664549,00:00
7,0.707454,0.699062,0.635324,00:00
8,0.707012,0.696881,0.658831,00:00
9,0.709887,0.695836,0.662643,00:00


No improvement since epoch 6: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.718916,0.727316,0.333545,00:00
1,0.717294,0.723593,0.333545,00:00
2,0.719995,0.732742,0.333545,00:00
3,0.71936,0.72,0.333545,00:00
4,0.71805,0.728488,0.333545,00:00
5,0.719681,0.731409,0.333545,00:00
6,0.718882,0.726282,0.333545,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.703906,0.703334,0.456798,00:00
1,0.704654,0.704794,0.461245,00:00
2,0.70485,0.705804,0.462516,00:00
3,0.703352,0.704453,0.46061,00:00


No improvement since epoch 0: early stopping


[I 2025-05-15 21:41:23,932] Trial 14 finished with value: 0.5546802341938019 and parameters: {'lr': 1.6516872304522795e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.690158,0.688834,0.445079,00:00
1,0.690131,0.688634,0.457778,00:00
2,0.689264,0.68825,0.469841,00:00
3,0.689599,0.688442,0.460952,00:00
4,0.689992,0.68822,0.469841,00:00
5,0.690468,0.688717,0.433651,00:00
6,0.691416,0.688587,0.450159,00:00
7,0.690817,0.689151,0.427936,00:00


No improvement since epoch 4: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.701625,0.705053,0.338628,00:00
1,0.700992,0.705591,0.337992,00:00
2,0.702393,0.702126,0.330368,00:00
3,0.702217,0.702695,0.33291,00:00
4,0.701409,0.703667,0.341804,00:00
5,0.703095,0.706874,0.339898,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.703309,0.712129,0.333545,00:00
1,0.704363,0.711724,0.335451,00:00
2,0.705007,0.712534,0.335451,00:00
3,0.702994,0.710882,0.334816,00:00
4,0.704519,0.710092,0.334816,00:00
5,0.704731,0.712008,0.336086,00:00
6,0.703641,0.711854,0.335451,00:00
7,0.703216,0.711411,0.334816,00:00


No improvement since epoch 4: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.771245,0.75434,0.495553,00:00
1,0.765218,0.774556,0.491741,00:00
2,0.772626,0.775274,0.488564,00:00
3,0.768065,0.787466,0.487294,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.697115,0.69104,0.585769,00:00
1,0.697703,0.692669,0.58831,00:00
2,0.700158,0.69113,0.587039,00:00
3,0.7009,0.693302,0.606734,00:00


No improvement since epoch 0: early stopping


[I 2025-05-15 21:41:45,274] Trial 15 finished with value: 0.43933571577072145 and parameters: {'lr': 1.738488144474391e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.713386,0.709769,0.666667,00:00
1,0.712297,0.712399,0.666667,00:00
2,0.713051,0.711017,0.666667,00:00
3,0.711936,0.712081,0.666667,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.712851,0.704878,0.33291,00:00
1,0.710281,0.703565,0.33291,00:00
2,0.708318,0.70444,0.33291,00:00
3,0.709841,0.702374,0.33291,00:00
4,0.707632,0.704843,0.33291,00:00
5,0.710478,0.704021,0.33291,00:00
6,0.707323,0.707513,0.33291,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.70804,0.70007,0.478399,00:00
1,0.706978,0.695581,0.608005,00:00
2,0.706336,0.69926,0.540025,00:00
3,0.707988,0.698566,0.599746,00:00
4,0.705557,0.699724,0.506989,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.700352,0.697385,0.658196,00:00
1,0.702416,0.69582,0.65756,00:00
2,0.700579,0.696194,0.645489,00:00
3,0.702479,0.696345,0.656925,00:00
4,0.702325,0.694551,0.655019,00:00
5,0.701976,0.696175,0.655654,00:00
6,0.701608,0.69325,0.648031,00:00
7,0.704349,0.69616,0.658196,00:00
8,0.701583,0.695716,0.651842,00:00
9,0.701773,0.695028,0.653113,00:00


No improvement since epoch 6: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.733715,0.724292,0.345616,00:00
1,0.734926,0.734277,0.339898,00:00
2,0.734842,0.740004,0.338628,00:00
3,0.733603,0.724424,0.349428,00:00


No improvement since epoch 0: early stopping


[I 2025-05-15 21:42:06,589] Trial 16 finished with value: 0.5018212676048279 and parameters: {'lr': 2.2587437048358376e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.786822,0.7558,0.666667,00:00
1,0.786284,0.771762,0.666667,00:00
2,0.784918,0.760178,0.666667,00:00
3,0.790958,0.761734,0.666667,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.733009,0.720268,0.66709,00:00
1,0.734623,0.732829,0.658196,00:00
2,0.73208,0.735191,0.656925,00:00
3,0.731802,0.728546,0.651207,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.71053,0.716964,0.414231,00:00
1,0.709484,0.723118,0.421855,00:00
2,0.705883,0.718544,0.416137,00:00
3,0.704852,0.714072,0.403431,00:00
4,0.70873,0.726514,0.435832,00:00
5,0.709235,0.717822,0.386912,00:00
6,0.70795,0.73191,0.401525,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.701835,0.70146,0.335451,00:00
1,0.702446,0.706622,0.339263,00:00
2,0.702212,0.708004,0.339898,00:00
3,0.701997,0.706615,0.338628,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.724345,0.733832,0.444091,00:00
1,0.728035,0.71704,0.442821,00:00
2,0.722573,0.730363,0.391995,00:00
3,0.719808,0.729171,0.340534,00:00
4,0.725904,0.728609,0.336086,00:00


No improvement since epoch 1: early stopping


[I 2025-05-15 21:42:23,829] Trial 17 finished with value: 0.4788225293159485 and parameters: {'lr': 1.426352365641869e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.700736,0.706514,0.334603,00:00
1,0.701706,0.706088,0.334603,00:00
2,0.700798,0.70935,0.333968,00:00
3,0.702008,0.707805,0.335873,00:00
4,0.700816,0.71156,0.335238,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.697965,0.702173,0.382465,00:00
1,0.697574,0.700681,0.360864,00:00
2,0.697256,0.700077,0.355781,00:00
3,0.696927,0.699773,0.356417,00:00
4,0.697394,0.699946,0.379288,00:00
5,0.698031,0.702451,0.372935,00:00
6,0.698301,0.701417,0.3723,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.745233,0.712746,0.66582,00:00
1,0.746439,0.703208,0.651207,00:00
2,0.750711,0.721179,0.667726,00:00
3,0.743345,0.713354,0.670902,00:00
4,0.744942,0.712031,0.660102,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.699247,0.693009,0.665184,00:00
1,0.700185,0.693911,0.665184,00:00
2,0.702425,0.693141,0.66582,00:00
3,0.700492,0.693651,0.665184,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.697483,0.701284,0.415502,00:00
1,0.69853,0.699872,0.41169,00:00
2,0.69767,0.701941,0.414231,00:00
3,0.699695,0.70502,0.411055,00:00
4,0.698333,0.703159,0.414867,00:00


No improvement since epoch 1: early stopping


[I 2025-05-15 21:42:42,396] Trial 18 finished with value: 0.4895380973815918 and parameters: {'lr': 3.6305064143608645e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.706775,0.710724,0.48127,00:00
1,0.707914,0.708931,0.483175,00:00
2,0.707245,0.709291,0.467302,00:00
3,0.70735,0.70969,0.459048,00:00
4,0.70572,0.710139,0.486984,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.773713,0.739487,0.339263,00:00
1,0.771946,0.74066,0.336086,00:00
2,0.773963,0.73931,0.33418,00:00
3,0.770256,0.743037,0.333545,00:00
4,0.772612,0.744787,0.341169,00:00
5,0.773143,0.722307,0.33291,00:00
6,0.771922,0.722207,0.36277,00:00
7,0.777187,0.757633,0.348793,00:00
8,0.774846,0.747779,0.350699,00:00
9,0.762996,0.746685,0.335451,00:00


No improvement since epoch 6: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.711961,0.725377,0.33291,00:00
1,0.713427,0.71851,0.33291,00:00
2,0.712718,0.719997,0.33291,00:00
3,0.713631,0.720288,0.33291,00:00
4,0.713338,0.715866,0.33291,00:00
5,0.714227,0.715637,0.33291,00:00
6,0.717365,0.717067,0.33291,00:00
7,0.715857,0.712884,0.33291,00:00
8,0.715048,0.723113,0.33291,00:00
9,0.716539,0.722594,0.33291,00:00


epoch,train_loss,valid_loss,accuracy,time
0,0.727458,0.725528,0.493011,00:00
1,0.725405,0.734027,0.489835,00:00
2,0.723505,0.735244,0.489835,00:00
3,0.722608,0.748043,0.489835,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.706922,0.714197,0.333545,00:00
1,0.707049,0.709448,0.333545,00:00
2,0.706753,0.713762,0.333545,00:00
3,0.709619,0.709645,0.333545,00:00
4,0.706334,0.708154,0.333545,00:00
5,0.70651,0.710478,0.333545,00:00
6,0.709933,0.70835,0.333545,00:00
7,0.707043,0.71258,0.333545,00:00


No improvement since epoch 4: early stopping


[I 2025-05-15 21:43:08,586] Trial 19 finished with value: 0.39574498534202573 and parameters: {'lr': 2.1723858351198724e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.694891,0.690884,0.490159,00:00
1,0.694767,0.688226,0.540952,00:00
2,0.69398,0.689821,0.504127,00:00
3,0.693875,0.693084,0.457778,00:00
4,0.694466,0.692789,0.458413,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.721354,0.717625,0.33291,00:00
1,0.718731,0.715574,0.33291,00:00
2,0.722192,0.720428,0.33291,00:00
3,0.719577,0.718494,0.332274,00:00
4,0.723904,0.716034,0.33291,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.69441,0.690533,0.585133,00:00
1,0.693959,0.689905,0.581321,00:00
2,0.693917,0.690505,0.601652,00:00
3,0.694115,0.691005,0.590851,00:00
4,0.693696,0.690154,0.590851,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.698899,0.697244,0.486658,00:00
1,0.698215,0.694007,0.473316,00:00
2,0.698746,0.696431,0.476493,00:00
3,0.699551,0.694121,0.465692,00:00
4,0.69862,0.695335,0.474587,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.715178,0.713756,0.33291,00:00
1,0.715157,0.716235,0.332274,00:00
2,0.714214,0.720442,0.33418,00:00
3,0.712812,0.716565,0.332274,00:00


No improvement since epoch 0: early stopping


[I 2025-05-15 21:43:25,763] Trial 20 finished with value: 0.4378070652484894 and parameters: {'lr': 1.3980169728719182e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.811022,0.77675,0.333333,00:00
1,0.815054,0.807416,0.333333,00:00
2,0.81204,0.792651,0.333333,00:00
3,0.818381,0.819571,0.333333,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.698111,0.684474,0.547649,00:00
1,0.698902,0.684226,0.546379,00:00
2,0.697904,0.687321,0.532402,00:00
3,0.700069,0.685746,0.546379,00:00
4,0.70063,0.684977,0.545743,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.694262,0.696333,0.630241,00:00
1,0.693277,0.691244,0.63723,00:00
2,0.69257,0.693041,0.650572,00:00
3,0.694295,0.692392,0.639771,00:00
4,0.693319,0.690675,0.654384,00:00
5,0.693939,0.691463,0.658831,00:00
6,0.693718,0.693542,0.648666,00:00
7,0.694158,0.694027,0.6277,00:00


No improvement since epoch 4: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.690527,0.692755,0.533037,00:00
1,0.690007,0.692266,0.518424,00:00
2,0.69177,0.692636,0.51906,00:00
3,0.693572,0.69208,0.524778,00:00
4,0.691764,0.692757,0.526048,00:00
5,0.6923,0.692426,0.523507,00:00
6,0.693273,0.69253,0.530496,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.809701,0.76978,0.666455,00:00
1,0.80545,0.75176,0.666455,00:00
2,0.800452,0.757743,0.666455,00:00
3,0.800212,0.76495,0.666455,00:00
4,0.805255,0.768464,0.666455,00:00


No improvement since epoch 1: early stopping


[I 2025-05-15 21:43:46,444] Trial 21 finished with value: 0.5407454550266266 and parameters: {'lr': 1.2744728015430476e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.693289,0.69105,0.450794,00:00
1,0.69426,0.692163,0.446349,00:00
2,0.692442,0.692101,0.460317,00:00
3,0.693668,0.691275,0.501587,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.755468,0.751115,0.33291,00:00
1,0.760851,0.751096,0.33291,00:00
2,0.758884,0.752377,0.33291,00:00
3,0.764308,0.75834,0.33291,00:00
4,0.769898,0.751816,0.33291,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.700784,0.696718,0.649301,00:00
1,0.700354,0.697893,0.632783,00:00
2,0.700719,0.692735,0.651842,00:00
3,0.700433,0.696849,0.640407,00:00
4,0.699574,0.697997,0.636595,00:00
5,0.699908,0.697727,0.637865,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.712046,0.719469,0.338628,00:00
1,0.712231,0.717982,0.355146,00:00
2,0.710217,0.719331,0.337992,00:00
3,0.71084,0.717429,0.355781,00:00
4,0.710376,0.720218,0.34244,00:00
5,0.710985,0.716833,0.341804,00:00
6,0.712958,0.721626,0.337357,00:00
7,0.712188,0.718392,0.349428,00:00
8,0.711801,0.718058,0.34244,00:00


No improvement since epoch 5: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.741741,0.738871,0.456798,00:00
1,0.743216,0.742278,0.457433,00:00
2,0.747991,0.737365,0.456798,00:00
3,0.748339,0.750733,0.462516,00:00
4,0.745738,0.738356,0.46061,00:00
5,0.747977,0.740991,0.459975,00:00


No improvement since epoch 2: early stopping


[I 2025-05-15 21:44:07,772] Trial 22 finished with value: 0.45495533347129824 and parameters: {'lr': 1.0037319820095982e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.747235,0.74486,0.48254,00:00
1,0.745104,0.745241,0.48254,00:00
2,0.745381,0.743017,0.48127,00:00
3,0.743804,0.753489,0.479365,00:00
4,0.74862,0.747581,0.481905,00:00
5,0.747571,0.748287,0.48254,00:00


No improvement since epoch 2: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.694019,0.682171,0.557814,00:00
1,0.695177,0.682117,0.559085,00:00
2,0.694749,0.682434,0.554638,00:00
3,0.69392,0.682064,0.557814,00:00
4,0.692775,0.682178,0.554638,00:00
5,0.694212,0.681943,0.549555,00:00
6,0.695386,0.681844,0.557814,00:00
7,0.695187,0.683258,0.548285,00:00
8,0.693188,0.682233,0.555273,00:00
9,0.694955,0.681986,0.55845,00:00


No improvement since epoch 6: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.725404,0.712684,0.512706,00:00
1,0.726784,0.720981,0.467598,00:00
2,0.726484,0.71517,0.47967,00:00
3,0.726609,0.711046,0.475222,00:00
4,0.727747,0.719397,0.469504,00:00
5,0.725178,0.710043,0.503812,00:00
6,0.725182,0.706562,0.549555,00:00
7,0.726413,0.714703,0.468869,00:00
8,0.726994,0.710149,0.478399,00:00
9,0.725847,0.706563,0.475858,00:00


No improvement since epoch 6: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.701501,0.70037,0.479034,00:00
1,0.700844,0.699869,0.483482,00:00
2,0.700954,0.699924,0.48094,00:00
3,0.701213,0.702849,0.44155,00:00
4,0.70041,0.70247,0.466328,00:00


No improvement since epoch 1: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.727355,0.714594,0.666455,00:00
1,0.725297,0.711825,0.666455,00:00
2,0.724853,0.714823,0.666455,00:00
3,0.7229,0.705946,0.666455,00:00
4,0.725464,0.714656,0.666455,00:00
5,0.726681,0.703392,0.66709,00:00
6,0.726056,0.711269,0.66709,00:00
7,0.725957,0.712271,0.66709,00:00
8,0.727896,0.714061,0.666455,00:00


No improvement since epoch 5: early stopping


[I 2025-05-15 21:44:36,058] Trial 23 finished with value: 0.5299259781837463 and parameters: {'lr': 1.315142419769887e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


epoch,train_loss,valid_loss,accuracy,time
0,0.716455,0.730433,0.368254,00:00
1,0.717653,0.731894,0.366984,00:00
2,0.719254,0.730185,0.36127,00:00
3,0.717353,0.733738,0.360635,00:00
4,0.716531,0.730508,0.372698,00:00
5,0.715683,0.728698,0.373968,00:00
6,0.715212,0.739579,0.36,00:00
7,0.719443,0.730356,0.369524,00:00
8,0.717363,0.731462,0.371429,00:00


No improvement since epoch 5: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.729757,0.731941,0.66709,00:00
1,0.732411,0.732517,0.666455,00:00
2,0.735328,0.74003,0.666455,00:00
3,0.732044,0.732828,0.666455,00:00


No improvement since epoch 0: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.712044,0.715394,0.33291,00:00
1,0.710794,0.721535,0.33291,00:00
2,0.713092,0.717749,0.33291,00:00
3,0.710901,0.714312,0.33291,00:00
4,0.710348,0.715053,0.33291,00:00
5,0.712153,0.720161,0.33291,00:00
6,0.70962,0.716671,0.33291,00:00


No improvement since epoch 3: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.700904,0.713958,0.35324,00:00
1,0.699179,0.709688,0.362135,00:00
2,0.699813,0.708598,0.360229,00:00
3,0.700738,0.708012,0.373571,00:00
4,0.701825,0.707092,0.346252,00:00
5,0.705487,0.717731,0.340534,00:00
6,0.704806,0.711184,0.343075,00:00
7,0.703734,0.709127,0.365311,00:00


No improvement since epoch 4: early stopping


epoch,train_loss,valid_loss,accuracy,time
0,0.69585,0.689627,0.648031,00:00
1,0.696402,0.690497,0.648666,00:00
2,0.696829,0.689287,0.627065,00:00
3,0.697022,0.690021,0.625159,00:00
4,0.69601,0.691039,0.638501,00:00
5,0.696133,0.690585,0.647395,00:00


No improvement since epoch 2: early stopping


[I 2025-05-15 21:45:00,353] Trial 24 finished with value: 0.4766999542713165 and parameters: {'lr': 1.9441393182774296e-07, 'optimizer': 'SGD'}. Best is trial 11 with value: 0.5662632942199707.


Best trial:
FrozenTrial(number=11, state=1, values=[0.5662632942199707], datetime_start=datetime.datetime(2025, 5, 15, 21, 39, 57, 31771), datetime_complete=datetime.datetime(2025, 5, 15, 21, 40, 17, 688659), params={'lr': 1.1568422321572289e-07, 'optimizer': 'SGD'}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'lr': FloatDistribution(high=1e-06, log=True, low=1e-07, step=None), 'optimizer': CategoricalDistribution(choices=('SGD', 'Adam'))}, trial_id=11, value=None)


['FCMAE_ResNet_tune_80.pkl']

In [14]:
print("Best trial:")
print(study.best_trial)


Best trial:
FrozenTrial(number=11, state=1, values=[0.5662632942199707], datetime_start=datetime.datetime(2025, 5, 15, 21, 39, 57, 31771), datetime_complete=datetime.datetime(2025, 5, 15, 21, 40, 17, 688659), params={'lr': 1.1568422321572289e-07, 'optimizer': 'SGD'}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'lr': FloatDistribution(high=1e-06, log=True, low=1e-07, step=None), 'optimizer': CategoricalDistribution(choices=('SGD', 'Adam'))}, trial_id=11, value=None)


In [15]:
# '''MY DATA'''
# train_dir = os.path.join(univariate_data_path, "target_univariate_no_PCA_train2.npy")
# train_data = np.load(train_dir, allow_pickle=True)
# # train_data = pd.DataFrame(train_data)
# test_dir = os.path.join(univariate_data_path, "target_univariate_no_PCA_test2.npy")
# test_data = np.load(test_dir, allow_pickle=True)
# # test_data = pd.DataFrame(test_data)
# print(len(train_data))
# print(len(test_data))
# print(train_data[0])
# from torch.utils.data import DataLoader, TensorDataset, random_split
# from sklearn.utils.class_weight import compute_class_weight

# from collections import Counter

# # Example dataset with windows and labels
# train_windows, train_labels = create_time_windows_with_labels(train_data)
# test_windows, test_labels = create_time_windows_with_labels(test_data)

# # Count label distribution
# train_label_counts = Counter(train_labels)
# test_label_counts = Counter(test_labels)

# # Print the counts
# print("Train label distribution:")
# print(f"  Term (0): {train_label_counts[0]}")
# print(f"  Preterm (1): {train_label_counts[1]}")

# print("\nTest label distribution:")
# print(f"  Term (0): {test_label_counts[0]}")
# print(f"  Preterm (1): {test_label_counts[1]}")

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# # Convert to tensors
# train_windows_tensor = torch.tensor(train_windows, dtype=torch.float32)
# train_labels_tensor = torch.tensor(train_labels, dtype=torch.float32)
# test_windows_tensor = torch.tensor(test_windows, dtype=torch.float32)
# test_labels_tensor = torch.tensor(test_labels, dtype=torch.float32)

# train_windows_tensor = train_windows_tensor.unsqueeze(1)  # (1071, 1, 12000)
# test_windows_tensor = test_windows_tensor.unsqueeze(1)    # (899, 1, 12000)

# train_labels_tensor = train_labels_tensor.long()
# test_labels_tensor = test_labels_tensor.long()


# print(train_labels_tensor.shape)

# train_dataset = TensorDataset(train_windows_tensor, train_labels_tensor)
# test_dataset = TensorDataset(test_windows_tensor, test_labels_tensor)

# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



In [16]:
# '''CLASS WEIGHTS'''
# from tsai.models.ResNet import ResNet
# from tsai.models import ResNet
# from tsai.models.ResNet import ResNet
# from fastai.metrics import accuracy
# from fastai.data.core import DataLoaders
# from fastai.learner import Learner
# import torch

# class FCMAEClassifier(nn.Module):
#     def __init__(self, encoder):
#         super().__init__()
#         self.encoder = encoder
#         self.classifier = ResNet(2048, 2)  # 2048 channels from encoder, 2 output classes

#     def forward(self, x):
#         with torch.no_grad():  # freeze encoder
#             x = self.encoder(x)
#         return self.classifier(x)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Wrap your DataLoaders and set device
# dls = DataLoaders.from_dsets(
#     train_dataset,
#     test_dataset,
#     bs=32,
#     shuffle=True,
#     num_workers=0 
# )

# model = FCMAEClassifier(encoder)
# model.to(device)

# # Calculate class weights
# class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(train_labels), y=train_labels)
# class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

# # Use weighted loss
# loss_func = torch.nn.CrossEntropyLoss(weight=class_weights_tensor)

# # Replace your loss function in the learner
# learn = Learner(dls, model, loss_func=loss_func, metrics=accuracy)
# learn.fit_one_cycle(20, 1e-6)

# learn.recorder.plot_loss()  # Plot loss curves (training and validation loss)
# import matplotlib.pyplot as plt
# import seaborn as sns
# from sklearn.metrics import confusion_matrix
# import torch

# # Get predictions and true labels
# preds, targs = learn.get_preds(dl=learn.dls.valid)

# # Convert predictions to class labels (argmax for multi-class classification)
# pred_labels = preds.argmax(dim=1)

# # Compute confusion matrix
# cm = confusion_matrix(targs, pred_labels)

# # Plot confusion matrix
# plt.figure(figsize=(8, 6))
# sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False, 
#             xticklabels=['Class 0', 'Class 1'], yticklabels=['Class 0', 'Class 1'])
# plt.title('Confusion Matrix')
# plt.xlabel('Predicted Label')
# plt.ylabel('True Label')
# plt.show()

# from sklearn.metrics import (
#     accuracy_score,
#     precision_score,
#     recall_score,
#     f1_score,
#     average_precision_score,
#     roc_auc_score
# )

# # Convert to numpy arrays if needed
# true_labels = targs.cpu().numpy()
# pred_labels = pred_labels.cpu().numpy()
# pred_probs = preds[:, 1].cpu().numpy()  # Probability of class 1

# # Compute metrics
# accuracy = accuracy_score(true_labels, pred_labels)
# precision = precision_score(true_labels, pred_labels, zero_division=0)
# recall = recall_score(true_labels, pred_labels, zero_division=0)
# f1 = f1_score(true_labels, pred_labels, zero_division=0)
# ap = average_precision_score(true_labels, pred_probs)
# auc = roc_auc_score(true_labels, pred_probs)

# # Print results
# print(f"Accuracy: {accuracy:.4f}")
# print(f"Precision: {precision:.4f}")
# print(f"Recall: {recall:.4f}")
# print(f"F1 Score: {f1:.4f}")
# print(f"Average Precision (AP): {ap:.4f}")
# print(f"Area Under ROC Curve (AUC): {auc:.4f}")

