In [26]:
import numpy as np
import os
from dotenv import load_dotenv
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import TensorDataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

from tqdm.auto import tqdm

import wandb

import logging
import sys
from torchinfo import summary

In [27]:
def setup_logger(name=__name__):
    """
    Sets up a logger that outputs to the console (stdout).
    """
    logger = logging.getLogger(name)
    if not logger.handlers:
        logger.setLevel(logging.INFO)
        handler = logging.StreamHandler(sys.stdout)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger

logger = setup_logger()

In [28]:
# Load the dataset
data_folder = "../data"
preped_folder = os.path.join(data_folder, "_preped")

train_data = pd.read_csv(os.path.join(data_folder, 'train_data.csv')).values.tolist()
test_data = pd.read_csv(os.path.join(data_folder, 'test_data.csv')).values.tolist()

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to consistent size
    transforms.ToTensor(),           # Convert to tensor [0, 1]
    transforms.Normalize(mean=[0.5], std=[0.5])
])

x_train = []
y_train = []

for img_name, label in train_data:
    img_path = os.path.join(preped_folder, img_name)
    try:
        img = Image.open(img_path).convert('L') # Convert to grayscale
        img_tensor = transform(img)
        x_train.append(img_tensor)
        y_train.append(label)
    except Exception as e:
        logger.info(f"Error loading {img_name}: {e}")

# Stack into tensors
x_train_tensor = torch.stack(x_train)
logger.info(f"Training images shape: {x_train_tensor.shape}")

# Encode labels to integers
label_to_idx = {label: idx for idx, label in enumerate(np.unique(y_train))}
y_train_encoded = [label_to_idx[label] for label in y_train]
y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long)

logger.info(f"Training labels shape: {y_train_tensor.shape}")
logger.info(f"Label mapping: {label_to_idx}")

2025-12-12 15:23:37,752 - INFO - Training images shape: torch.Size([241, 1, 224, 224])
2025-12-12 15:23:37,753 - INFO - Training labels shape: torch.Size([241])
2025-12-12 15:23:37,754 - INFO - Label mapping: {np.str_('1_Pronacio'): 0, np.str_('2_Neutralis'): 1, np.str_('3_Szupinacio'): 2}


In [29]:
x_test = []
y_test = []

for img_name, label in test_data:
    img_path = os.path.join(preped_folder, img_name)
    try:
        img = Image.open(img_path).convert('L') # Convert to grayscale
        img_tensor = transform(img)
        x_test.append(img_tensor)
        y_test.append(label)
    except Exception as e:
        logger.info(f"Error loading {img_name}: {e}")

x_test_tensor = torch.stack(x_test)
logger.info(f"Test images shape: {x_test_tensor.shape}")
y_test_encoded = [label_to_idx[label] for label in y_test]
y_test_tensor = torch.tensor(y_test_encoded, dtype=torch.long)

logger.info(f"Test labels shape: {y_test_tensor.shape}")

2025-12-12 15:23:44,175 - INFO - Test images shape: torch.Size([49, 1, 224, 224])
2025-12-12 15:23:44,175 - INFO - Test labels shape: torch.Size([49])


In [30]:
if torch.cuda.is_available():
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        logger.info(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        logger.info(f"  Memory: {props.total_memory / 1024**3:.2f} GB")
        logger.info(f"  Compute Capability: {props.major}.{props.minor}")
else:
    logger.info("CUDA not available")

2025-12-12 15:23:44,185 - INFO - CUDA available: True
2025-12-12 15:23:44,186 - INFO - Number of GPUs: 1
2025-12-12 15:23:44,187 - INFO - 
GPU 0: NVIDIA GeForce RTX 4060
2025-12-12 15:23:44,188 - INFO -   Memory: 8.00 GB
2025-12-12 15:23:44,189 - INFO -   Compute Capability: 8.9


In [31]:
batch_size = 16
num_epochs = 70
device = 'cuda' 

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [32]:
# wandb login an init
# Login to wandb with API key
load_dotenv()
wandb.login(key=os.getenv("wandbKey"))

def init_wandb():
    # Initialize wandb project
    wandb.init(
        project="ankle-align-inc-model",
        config={
            "batch_size": batch_size,
            "num_epochs": num_epochs,
      
            "architecture": "Custom CNN",
            "dataset": "AnkleAlign",
            "optimizer": "Adam"
        }
    )

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Win 10\_netrc


In [33]:
net0 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),   # 224x224 -> 112x112
    torch.nn.ReLU(),
    torch.nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),  # 112x112 -> 56x56
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 56x56 -> 28x28
    torch.nn.ReLU(),
    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)


def init_weights(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

net0.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(net0.parameters(), lr=0.1, momentum=0.9)
optimizer = torch.optim.Adam(net0.parameters(), lr=0.01)

summary(net0, input_size=(batch_size, 1, 224, 224))


Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [16, 3]                   --
├─Conv2d: 1-1                            [16, 8, 112, 112]         80
├─ReLU: 1-2                              [16, 8, 112, 112]         --
├─Conv2d: 1-3                            [16, 16, 56, 56]          1,168
├─ReLU: 1-4                              [16, 16, 56, 56]          --
├─Conv2d: 1-5                            [16, 32, 28, 28]          4,640
├─ReLU: 1-6                              [16, 32, 28, 28]          --
├─AdaptiveAvgPool2d: 1-7                 [16, 32, 1, 1]            --
├─Flatten: 1-8                           [16, 32]                  --
├─Linear: 1-9                            [16, 128]                 4,224
├─ReLU: 1-10                             [16, 128]                 --
├─Linear: 1-11                           [16, 64]                  8,256
├─ReLU: 1-12                             [16, 64]                  --
├─L

In [34]:
# Trying to overfit one batch
init_wandb()
one_batch = next(iter(train_loader))
images, labels = one_batch

images = images.to(device)
labels = labels.to(device)

loss_values = []
net0.train()
for epoch in tqdm(range(num_epochs), desc='Training model'):
        pred_logits = net0(images)
        loss = loss_fn(pred_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_values.append(loss.item())
        wandb.log({
                "epoch": epoch + 1,
                "train_loss": loss.item()
            })
        
wandb.finish()
print(loss_values)

Training model: 100%|██████████| 70/70 [00:00<00:00, 125.56it/s]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
train_loss,██▇▇▇▆▆▆▆▆▅▅▄▅▄▄▄▃▃▄▂▂▃▂▂▂▁▁▁▁▁▁▂▅▂▂▄▅▁▂

0,1
epoch,70.0
train_loss,0.10161


[1.1201071739196777, 1.133108377456665, 1.0402389764785767, 0.9731132984161377, 0.9565080404281616, 0.9460327625274658, 0.903011679649353, 0.8771029710769653, 0.865113377571106, 0.8425642251968384, 0.8185439109802246, 0.7937295436859131, 0.7641617059707642, 0.7389482855796814, 0.7124192714691162, 0.6814556121826172, 0.6469080448150635, 0.6081164479255676, 0.6201109290122986, 0.6805386543273926, 0.5522283911705017, 0.5590860843658447, 0.5181375741958618, 0.4898614287376404, 0.47567176818847656, 0.3925952911376953, 0.41096431016921997, 0.3562488257884979, 0.364826500415802, 0.3263346552848816, 0.5982373952865601, 0.4799767732620239, 0.34488189220428467, 0.3888360261917114, 0.31425362825393677, 0.2902117669582367, 0.3504578471183777, 0.21387368440628052, 0.27393537759780884, 0.2102767378091812, 0.23218084871768951, 0.19255468249320984, 0.16975906491279602, 0.18739774823188782, 0.1539849489927292, 0.15938788652420044, 0.1312681883573532, 0.11715013533830643, 0.10791566967964172, 0.09306634

In [77]:
def train_model(network, optimizer, loss_fn, enable_early_stopping=False, patience=5):
    torch.cuda.empty_cache()

    loss_values = []

    if enable_early_stopping:
        early_stopping = EarlyStopping(patience=patience, verbose=True)

    network.train()
    for epoch in tqdm(range(num_epochs), desc='Training model'):
        network.train()
        epoch_loss = 0.0
        num_batches = 0
        for images, target_labels in train_loader:
            images = images.to(device)
            target_labels = target_labels.to(device)

            pred_logits = network(images)
            loss = loss_fn(pred_logits, target_labels)
            epoch_loss += loss.item()
            num_batches += 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        avg_train_loss = epoch_loss / num_batches

        if enable_early_stopping:
            network.eval()
            val_loss = 0.0
            val_batches = 0
            correct = 0
            total = 0
            with torch.no_grad():
                for images, target_labels in val_loader:
                    images = images.to(device)
                    target_labels = target_labels.to(device)
                    
                    pred_logits = network(images)
                    loss = loss_fn(pred_logits, target_labels)
                    val_loss += loss.item()
                    val_batches += 1
                    
                    _, predicted = torch.max(pred_logits, 1)
                    total += target_labels.size(0)
                    correct += (predicted == target_labels).sum().item()
            
            avg_val_loss = val_loss / val_batches
            val_accuracy = correct / total

        # Log metrics
        if enable_early_stopping:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss,
                "val_accuracy": val_accuracy
            })
        else:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_train_loss
            })
        loss_values.append(avg_train_loss)
        
        if enable_early_stopping:
            logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        else:
            logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}")

        # Early stopping check
        if enable_early_stopping:
            early_stopping(avg_val_loss, network)
            if early_stopping.early_stop:
                logger.info("Early stopping triggered")
                network.load_state_dict(early_stopping.best_model)
                break
    
    # Load best model
    if enable_early_stopping and early_stopping.best_model is not None:
        network.load_state_dict(early_stopping.best_model)
        logger.info("Loaded best model weights")

    logger.info(loss_values)

In [36]:
def evaluate_model(network):
    # Training score
    true_labels = y_test_encoded
    predicted_labels = []
    network.eval()
    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            outputs = network(images)
            _, predicted = torch.max(outputs, 1)
            predicted_labels.extend(predicted.cpu().numpy())

    accuracy = np.mean([true == pred for true, pred in zip(true_labels, predicted_labels)])
    precision = precision_score(true_labels, predicted_labels, average='weighted')
    recall = recall_score(true_labels, predicted_labels, average='weighted')
    f1 = f1_score(true_labels, predicted_labels, average='weighted')

    logger.info(f"network accuracy: {accuracy * 100:.2f}%")
    logger.info(f"network precision: {precision * 100:.2f}%")
    logger.info(f"network recall: {recall * 100:.2f}%")
    logger.info(f"network F1 score: {f1 * 100:.2f}%")

    logger.info(f"Detailed Classification Report: \n{classification_report(true_labels, predicted_labels)}")

    # Log test metrics
    wandb.log({
        "test_accuracy": accuracy,
        "test_precision": precision,
        "test_recall": recall,
        "test_f1": f1
    })

    wandb.finish()

In [37]:
init_wandb()
train_model(net0, optimizer, loss_fn, enable_early_stopping=False)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-12 15:23:50,170 - INFO - Epoch 1/70, Train Loss: 2.1479


Training model:   1%|▏         | 1/70 [00:00<00:10,  6.64it/s]

2025-12-12 15:23:50,323 - INFO - Epoch 2/70, Train Loss: 1.0423


Training model:   3%|▎         | 2/70 [00:00<00:10,  6.59it/s]

2025-12-12 15:23:50,445 - INFO - Epoch 3/70, Train Loss: 0.9985


Training model:   4%|▍         | 3/70 [00:00<00:09,  7.23it/s]

2025-12-12 15:23:50,542 - INFO - Epoch 4/70, Train Loss: 1.0615
2025-12-12 15:23:50,631 - INFO - Epoch 5/70, Train Loss: 1.0286


Training model:   7%|▋         | 5/70 [00:00<00:07,  8.88it/s]

2025-12-12 15:23:50,721 - INFO - Epoch 6/70, Train Loss: 1.0168
2025-12-12 15:23:50,805 - INFO - Epoch 7/70, Train Loss: 0.9731


Training model:  10%|█         | 7/70 [00:00<00:06,  9.88it/s]

2025-12-12 15:23:50,884 - INFO - Epoch 8/70, Train Loss: 0.9736
2025-12-12 15:23:50,965 - INFO - Epoch 9/70, Train Loss: 1.0406


Training model:  13%|█▎        | 9/70 [00:00<00:05, 10.76it/s]

2025-12-12 15:23:51,049 - INFO - Epoch 10/70, Train Loss: 1.0123
2025-12-12 15:23:51,134 - INFO - Epoch 11/70, Train Loss: 0.9852


Training model:  16%|█▌        | 11/70 [00:01<00:05, 11.13it/s]

2025-12-12 15:23:51,237 - INFO - Epoch 12/70, Train Loss: 0.9852
2025-12-12 15:23:51,346 - INFO - Epoch 13/70, Train Loss: 1.0045


Training model:  19%|█▊        | 13/70 [00:01<00:05, 10.48it/s]

2025-12-12 15:23:51,431 - INFO - Epoch 14/70, Train Loss: 0.9885
2025-12-12 15:23:51,516 - INFO - Epoch 15/70, Train Loss: 0.9790


Training model:  21%|██▏       | 15/70 [00:01<00:05, 10.87it/s]

2025-12-12 15:23:51,599 - INFO - Epoch 16/70, Train Loss: 0.9676
2025-12-12 15:23:51,676 - INFO - Epoch 17/70, Train Loss: 0.9933


Training model:  24%|██▍       | 17/70 [00:01<00:04, 11.35it/s]

2025-12-12 15:23:51,761 - INFO - Epoch 18/70, Train Loss: 1.0163
2025-12-12 15:23:51,848 - INFO - Epoch 19/70, Train Loss: 0.9779


Training model:  27%|██▋       | 19/70 [00:01<00:04, 11.42it/s]

2025-12-12 15:23:51,935 - INFO - Epoch 20/70, Train Loss: 1.0521
2025-12-12 15:23:52,015 - INFO - Epoch 21/70, Train Loss: 0.9724


Training model:  30%|███       | 21/70 [00:01<00:04, 11.60it/s]

2025-12-12 15:23:52,093 - INFO - Epoch 22/70, Train Loss: 1.0017
2025-12-12 15:23:52,171 - INFO - Epoch 23/70, Train Loss: 0.9597


Training model:  33%|███▎      | 23/70 [00:02<00:03, 11.95it/s]

2025-12-12 15:23:52,248 - INFO - Epoch 24/70, Train Loss: 0.9888
2025-12-12 15:23:52,325 - INFO - Epoch 25/70, Train Loss: 0.9704


Training model:  36%|███▌      | 25/70 [00:02<00:03, 12.23it/s]

2025-12-12 15:23:52,405 - INFO - Epoch 26/70, Train Loss: 0.9868
2025-12-12 15:23:52,483 - INFO - Epoch 27/70, Train Loss: 0.9608


Training model:  39%|███▊      | 27/70 [00:02<00:03, 12.36it/s]

2025-12-12 15:23:52,567 - INFO - Epoch 28/70, Train Loss: 0.9653
2025-12-12 15:23:52,643 - INFO - Epoch 29/70, Train Loss: 0.9623


Training model:  41%|████▏     | 29/70 [00:02<00:03, 12.40it/s]

2025-12-12 15:23:52,736 - INFO - Epoch 30/70, Train Loss: 0.9592
2025-12-12 15:23:52,819 - INFO - Epoch 31/70, Train Loss: 1.0202


Training model:  44%|████▍     | 31/70 [00:02<00:03, 12.07it/s]

2025-12-12 15:23:52,897 - INFO - Epoch 32/70, Train Loss: 0.9623
2025-12-12 15:23:52,987 - INFO - Epoch 33/70, Train Loss: 0.9675


Training model:  47%|████▋     | 33/70 [00:02<00:03, 12.03it/s]

2025-12-12 15:23:53,071 - INFO - Epoch 34/70, Train Loss: 0.9955
2025-12-12 15:23:53,152 - INFO - Epoch 35/70, Train Loss: 1.0891


Training model:  50%|█████     | 35/70 [00:03<00:02, 12.05it/s]

2025-12-12 15:23:53,233 - INFO - Epoch 36/70, Train Loss: 0.9954
2025-12-12 15:23:53,313 - INFO - Epoch 37/70, Train Loss: 0.9975


Training model:  53%|█████▎    | 37/70 [00:03<00:02, 12.16it/s]

2025-12-12 15:23:53,393 - INFO - Epoch 38/70, Train Loss: 0.9728
2025-12-12 15:23:53,477 - INFO - Epoch 39/70, Train Loss: 0.9406


Training model:  56%|█████▌    | 39/70 [00:03<00:02, 12.16it/s]

2025-12-12 15:23:53,567 - INFO - Epoch 40/70, Train Loss: 0.9466
2025-12-12 15:23:53,646 - INFO - Epoch 41/70, Train Loss: 0.9268


Training model:  59%|█████▊    | 41/70 [00:03<00:02, 12.08it/s]

2025-12-12 15:23:53,727 - INFO - Epoch 42/70, Train Loss: 0.9297
2025-12-12 15:23:53,808 - INFO - Epoch 43/70, Train Loss: 0.9230


Training model:  61%|██████▏   | 43/70 [00:03<00:02, 12.15it/s]

2025-12-12 15:23:53,888 - INFO - Epoch 44/70, Train Loss: 0.9026
2025-12-12 15:23:53,965 - INFO - Epoch 45/70, Train Loss: 0.8643


Training model:  64%|██████▍   | 45/70 [00:03<00:02, 12.31it/s]

2025-12-12 15:23:54,044 - INFO - Epoch 46/70, Train Loss: 0.8784
2025-12-12 15:23:54,126 - INFO - Epoch 47/70, Train Loss: 0.9626


Training model:  67%|██████▋   | 47/70 [00:04<00:01, 12.35it/s]

2025-12-12 15:23:54,208 - INFO - Epoch 48/70, Train Loss: 0.9104
2025-12-12 15:23:54,291 - INFO - Epoch 49/70, Train Loss: 0.8463


Training model:  70%|███████   | 49/70 [00:04<00:01, 12.29it/s]

2025-12-12 15:23:54,379 - INFO - Epoch 50/70, Train Loss: 0.8831
2025-12-12 15:23:54,464 - INFO - Epoch 51/70, Train Loss: 0.8730


Training model:  73%|███████▎  | 51/70 [00:04<00:01, 12.06it/s]

2025-12-12 15:23:54,542 - INFO - Epoch 52/70, Train Loss: 0.8286
2025-12-12 15:23:54,620 - INFO - Epoch 53/70, Train Loss: 0.8179


Training model:  76%|███████▌  | 53/70 [00:04<00:01, 12.27it/s]

2025-12-12 15:23:54,703 - INFO - Epoch 54/70, Train Loss: 0.9207
2025-12-12 15:23:54,779 - INFO - Epoch 55/70, Train Loss: 0.8255


Training model:  79%|███████▊  | 55/70 [00:04<00:01, 12.36it/s]

2025-12-12 15:23:54,866 - INFO - Epoch 56/70, Train Loss: 0.8622
2025-12-12 15:23:54,948 - INFO - Epoch 57/70, Train Loss: 0.8982


Training model:  81%|████████▏ | 57/70 [00:04<00:01, 12.20it/s]

2025-12-12 15:23:55,030 - INFO - Epoch 58/70, Train Loss: 0.8893
2025-12-12 15:23:55,112 - INFO - Epoch 59/70, Train Loss: 1.0718


Training model:  84%|████████▍ | 59/70 [00:05<00:00, 12.20it/s]

2025-12-12 15:23:55,191 - INFO - Epoch 60/70, Train Loss: 0.9130
2025-12-12 15:23:55,279 - INFO - Epoch 61/70, Train Loss: 0.8735


Training model:  87%|████████▋ | 61/70 [00:05<00:00, 12.14it/s]

2025-12-12 15:23:55,361 - INFO - Epoch 62/70, Train Loss: 0.8439
2025-12-12 15:23:55,444 - INFO - Epoch 63/70, Train Loss: 0.8646


Training model:  90%|█████████ | 63/70 [00:05<00:00, 12.14it/s]

2025-12-12 15:23:55,527 - INFO - Epoch 64/70, Train Loss: 0.8963
2025-12-12 15:23:55,606 - INFO - Epoch 65/70, Train Loss: 0.8113


Training model:  93%|█████████▎| 65/70 [00:05<00:00, 12.18it/s]

2025-12-12 15:23:55,687 - INFO - Epoch 66/70, Train Loss: 0.8476
2025-12-12 15:23:55,769 - INFO - Epoch 67/70, Train Loss: 0.7763


Training model:  96%|█████████▌| 67/70 [00:05<00:00, 12.23it/s]

2025-12-12 15:23:55,846 - INFO - Epoch 68/70, Train Loss: 0.7467
2025-12-12 15:23:55,929 - INFO - Epoch 69/70, Train Loss: 0.7141


Training model:  99%|█████████▊| 69/70 [00:05<00:00, 12.30it/s]

2025-12-12 15:23:56,012 - INFO - Epoch 70/70, Train Loss: 0.7612


Training model: 100%|██████████| 70/70 [00:05<00:00, 11.68it/s]

2025-12-12 15:23:56,012 - INFO - [2.1479397527873516, 1.04226703196764, 0.9984908476471901, 1.0615405179560184, 1.0285776518285275, 1.0168384462594986, 0.9730907045304775, 0.973577369004488, 1.040578480809927, 1.0123360753059387, 0.9851948097348213, 0.9852248169481754, 1.0044987797737122, 0.9884758777916431, 0.9790437780320644, 0.9676056504249573, 0.9932688437402248, 1.016283307224512, 0.9778604730963707, 1.0521416179835796, 0.9723548591136932, 1.0016769617795944, 0.9596938230097294, 0.9887978546321392, 0.9703733585774899, 0.9867909550666809, 0.9607554562389851, 0.9653241857886314, 0.9622937403619289, 0.9591629132628441, 1.0202210135757923, 0.962305523455143, 0.967531181871891, 0.9954728111624718, 1.089143592864275, 0.9954351782798767, 0.9975288361310959, 0.9728072956204414, 0.9406430348753929, 0.9466492906212807, 0.9267650470137596, 0.9296736307442188, 0.9229747951030731, 0.9026137441396713, 0.8643356338143349, 0.8784051313996315, 0.962618250399828, 0.910431481897831, 0.84626860171556




In [38]:
net1 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),   # 224x224 -> 112x112   // (3x3x1)x8
    torch.nn.ReLU(),
    torch.nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),  # 112x112 -> 56x56     // (3x3x8)x16
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 56x56 -> 28x28      // (3x3x16)x32
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14      // (3x3x32)x32
    torch.nn.ReLU(),
    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)

net1.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(net0.parameters(), lr=0.1, momentum=0.9)
optimizer = torch.optim.Adam(net1.parameters(), lr=0.001)

summary(net1, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [16, 3]                   --
├─Conv2d: 1-1                            [16, 8, 112, 112]         80
├─ReLU: 1-2                              [16, 8, 112, 112]         --
├─Conv2d: 1-3                            [16, 16, 56, 56]          1,168
├─ReLU: 1-4                              [16, 16, 56, 56]          --
├─Conv2d: 1-5                            [16, 32, 28, 28]          4,640
├─ReLU: 1-6                              [16, 32, 28, 28]          --
├─Conv2d: 1-7                            [16, 32, 14, 14]          9,248
├─ReLU: 1-8                              [16, 32, 14, 14]          --
├─AdaptiveAvgPool2d: 1-9                 [16, 32, 1, 1]            --
├─Flatten: 1-10                          [16, 32]                  --
├─Linear: 1-11                           [16, 128]                 4,224
├─ReLU: 1-12                             [16, 128]                 --
├─L

In [39]:
init_wandb()
train_model(net1, optimizer, loss_fn, enable_early_stopping=False)
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇███
train_loss,▇▇▇▆▇▆▆▆▆▆▆▆▆▆▆▆▆█▆▆▅▅▅▆▅▄▄▃▃▃▄▄█▅▄▄▃▃▂▁

0,1
epoch,70.0
train_loss,0.76116


Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-12 15:23:59,245 - INFO - Epoch 1/70, Train Loss: 1.0116


Training model:   1%|▏         | 1/70 [00:00<00:13,  5.30it/s]

2025-12-12 15:23:59,409 - INFO - Epoch 2/70, Train Loss: 0.9830


Training model:   3%|▎         | 2/70 [00:00<00:11,  5.74it/s]

2025-12-12 15:23:59,499 - INFO - Epoch 3/70, Train Loss: 0.9704
2025-12-12 15:23:59,584 - INFO - Epoch 4/70, Train Loss: 0.9628


Training model:   6%|▌         | 4/70 [00:00<00:07,  8.35it/s]

2025-12-12 15:23:59,668 - INFO - Epoch 5/70, Train Loss: 0.9753
2025-12-12 15:23:59,750 - INFO - Epoch 6/70, Train Loss: 1.0137


Training model:   9%|▊         | 6/70 [00:00<00:06,  9.76it/s]

2025-12-12 15:23:59,833 - INFO - Epoch 7/70, Train Loss: 0.9637
2025-12-12 15:23:59,920 - INFO - Epoch 8/70, Train Loss: 0.9407


Training model:  11%|█▏        | 8/70 [00:00<00:05, 10.48it/s]

2025-12-12 15:24:00,003 - INFO - Epoch 9/70, Train Loss: 0.9950
2025-12-12 15:24:00,085 - INFO - Epoch 10/70, Train Loss: 1.0566


Training model:  14%|█▍        | 10/70 [00:01<00:05, 11.03it/s]

2025-12-12 15:24:00,162 - INFO - Epoch 11/70, Train Loss: 0.9978
2025-12-12 15:24:00,257 - INFO - Epoch 12/70, Train Loss: 0.9706


Training model:  17%|█▋        | 12/70 [00:01<00:05, 11.23it/s]

2025-12-12 15:24:00,333 - INFO - Epoch 13/70, Train Loss: 0.9491
2025-12-12 15:24:00,415 - INFO - Epoch 14/70, Train Loss: 0.9876


Training model:  20%|██        | 14/70 [00:01<00:04, 11.65it/s]

2025-12-12 15:24:00,506 - INFO - Epoch 15/70, Train Loss: 0.9680
2025-12-12 15:24:00,591 - INFO - Epoch 16/70, Train Loss: 0.9549


Training model:  23%|██▎       | 16/70 [00:01<00:04, 11.56it/s]

2025-12-12 15:24:00,680 - INFO - Epoch 17/70, Train Loss: 1.0065
2025-12-12 15:24:00,768 - INFO - Epoch 18/70, Train Loss: 0.9551


Training model:  26%|██▌       | 18/70 [00:01<00:04, 11.47it/s]

2025-12-12 15:24:00,857 - INFO - Epoch 19/70, Train Loss: 0.9345
2025-12-12 15:24:00,947 - INFO - Epoch 20/70, Train Loss: 0.9275


Training model:  29%|██▊       | 20/70 [00:01<00:04, 11.38it/s]

2025-12-12 15:24:01,077 - INFO - Epoch 21/70, Train Loss: 0.9432
2025-12-12 15:24:01,186 - INFO - Epoch 22/70, Train Loss: 0.9059


Training model:  31%|███▏      | 22/70 [00:02<00:04, 10.25it/s]

2025-12-12 15:24:01,278 - INFO - Epoch 23/70, Train Loss: 0.9060
2025-12-12 15:24:01,367 - INFO - Epoch 24/70, Train Loss: 0.9273


Training model:  34%|███▍      | 24/70 [00:02<00:04, 10.48it/s]

2025-12-12 15:24:01,453 - INFO - Epoch 25/70, Train Loss: 0.9291
2025-12-12 15:24:01,539 - INFO - Epoch 26/70, Train Loss: 0.8931


Training model:  37%|███▋      | 26/70 [00:02<00:04, 10.80it/s]

2025-12-12 15:24:01,626 - INFO - Epoch 27/70, Train Loss: 0.9692
2025-12-12 15:24:01,707 - INFO - Epoch 28/70, Train Loss: 0.8549


Training model:  40%|████      | 28/70 [00:02<00:03, 11.13it/s]

2025-12-12 15:24:01,783 - INFO - Epoch 29/70, Train Loss: 0.9182
2025-12-12 15:24:01,859 - INFO - Epoch 30/70, Train Loss: 0.8638


Training model:  43%|████▎     | 30/70 [00:02<00:03, 11.66it/s]

2025-12-12 15:24:01,944 - INFO - Epoch 31/70, Train Loss: 0.8818
2025-12-12 15:24:02,031 - INFO - Epoch 32/70, Train Loss: 0.8725


Training model:  46%|████▌     | 32/70 [00:02<00:03, 11.65it/s]

2025-12-12 15:24:02,120 - INFO - Epoch 33/70, Train Loss: 0.8487
2025-12-12 15:24:02,208 - INFO - Epoch 34/70, Train Loss: 0.8857


Training model:  49%|████▊     | 34/70 [00:03<00:03, 11.55it/s]

2025-12-12 15:24:02,289 - INFO - Epoch 35/70, Train Loss: 0.8434
2025-12-12 15:24:02,374 - INFO - Epoch 36/70, Train Loss: 0.8281


Training model:  51%|█████▏    | 36/70 [00:03<00:02, 11.70it/s]

2025-12-12 15:24:02,457 - INFO - Epoch 37/70, Train Loss: 0.8240
2025-12-12 15:24:02,545 - INFO - Epoch 38/70, Train Loss: 0.7674


Training model:  54%|█████▍    | 38/70 [00:03<00:02, 11.69it/s]

2025-12-12 15:24:02,629 - INFO - Epoch 39/70, Train Loss: 0.8406
2025-12-12 15:24:02,707 - INFO - Epoch 40/70, Train Loss: 0.7890


Training model:  57%|█████▋    | 40/70 [00:03<00:02, 11.87it/s]

2025-12-12 15:24:02,787 - INFO - Epoch 41/70, Train Loss: 0.7575
2025-12-12 15:24:02,870 - INFO - Epoch 42/70, Train Loss: 0.7330


Training model:  60%|██████    | 42/70 [00:03<00:02, 11.98it/s]

2025-12-12 15:24:02,952 - INFO - Epoch 43/70, Train Loss: 0.6783
2025-12-12 15:24:03,043 - INFO - Epoch 44/70, Train Loss: 0.6900


Training model:  63%|██████▎   | 44/70 [00:03<00:02, 11.86it/s]

2025-12-12 15:24:03,130 - INFO - Epoch 45/70, Train Loss: 0.6385
2025-12-12 15:24:03,213 - INFO - Epoch 46/70, Train Loss: 0.6851


Training model:  66%|██████▌   | 46/70 [00:04<00:02, 11.84it/s]

2025-12-12 15:24:03,302 - INFO - Epoch 47/70, Train Loss: 0.6103
2025-12-12 15:24:03,387 - INFO - Epoch 48/70, Train Loss: 0.6220


Training model:  69%|██████▊   | 48/70 [00:04<00:01, 11.73it/s]

2025-12-12 15:24:03,500 - INFO - Epoch 49/70, Train Loss: 0.6716
2025-12-12 15:24:03,604 - INFO - Epoch 50/70, Train Loss: 0.5680


Training model:  71%|███████▏  | 50/70 [00:04<00:01, 10.83it/s]

2025-12-12 15:24:03,686 - INFO - Epoch 51/70, Train Loss: 0.5274
2025-12-12 15:24:03,771 - INFO - Epoch 52/70, Train Loss: 0.5172


Training model:  74%|███████▍  | 52/70 [00:04<00:01, 11.16it/s]

2025-12-12 15:24:03,854 - INFO - Epoch 53/70, Train Loss: 0.5038
2025-12-12 15:24:03,939 - INFO - Epoch 54/70, Train Loss: 0.4442


Training model:  77%|███████▋  | 54/70 [00:04<00:01, 11.39it/s]

2025-12-12 15:24:04,020 - INFO - Epoch 55/70, Train Loss: 0.5087
2025-12-12 15:24:04,108 - INFO - Epoch 56/70, Train Loss: 0.4670


Training model:  80%|████████  | 56/70 [00:05<00:01, 11.51it/s]

2025-12-12 15:24:04,188 - INFO - Epoch 57/70, Train Loss: 0.5162
2025-12-12 15:24:04,273 - INFO - Epoch 58/70, Train Loss: 0.4401


Training model:  83%|████████▎ | 58/70 [00:05<00:01, 11.70it/s]

2025-12-12 15:24:04,360 - INFO - Epoch 59/70, Train Loss: 0.4126
2025-12-12 15:24:04,446 - INFO - Epoch 60/70, Train Loss: 0.3771


Training model:  86%|████████▌ | 60/70 [00:05<00:00, 11.65it/s]

2025-12-12 15:24:04,529 - INFO - Epoch 61/70, Train Loss: 0.4635
2025-12-12 15:24:04,611 - INFO - Epoch 62/70, Train Loss: 0.7784


Training model:  89%|████████▊ | 62/70 [00:05<00:00, 11.79it/s]

2025-12-12 15:24:04,695 - INFO - Epoch 63/70, Train Loss: 0.4578
2025-12-12 15:24:04,775 - INFO - Epoch 64/70, Train Loss: 0.4590


Training model:  91%|█████████▏| 64/70 [00:05<00:00, 11.90it/s]

2025-12-12 15:24:04,862 - INFO - Epoch 65/70, Train Loss: 0.4413
2025-12-12 15:24:04,942 - INFO - Epoch 66/70, Train Loss: 0.3685


Training model:  94%|█████████▍| 66/70 [00:05<00:00, 11.93it/s]

2025-12-12 15:24:05,025 - INFO - Epoch 67/70, Train Loss: 0.3168
2025-12-12 15:24:05,106 - INFO - Epoch 68/70, Train Loss: 0.2883


Training model:  97%|█████████▋| 68/70 [00:06<00:00, 12.01it/s]

2025-12-12 15:24:05,186 - INFO - Epoch 69/70, Train Loss: 0.2646
2025-12-12 15:24:05,269 - INFO - Epoch 70/70, Train Loss: 0.2874


Training model: 100%|██████████| 70/70 [00:06<00:00, 11.26it/s]

2025-12-12 15:24:05,272 - INFO - [1.0115926787257195, 0.9830499030649662, 0.9704258069396019, 0.9627848975360394, 0.9753327183425426, 1.0136880353093147, 0.9636528678238392, 0.9407158717513084, 0.9950308538973331, 1.056626982986927, 0.9978022091090679, 0.9706433191895485, 0.9490678235888481, 0.9876094870269299, 0.9679656252264977, 0.9548828601837158, 1.0064869113266468, 0.9551485665142536, 0.9344771504402161, 0.927479337900877, 0.9431587606668472, 0.9058591052889824, 0.9059640876948833, 0.927317850291729, 0.9290850628167391, 0.8930986188352108, 0.9691943116486073, 0.8548582680523396, 0.918174222111702, 0.8637864962220192, 0.8818483129143715, 0.8725329972803593, 0.8486837185919285, 0.8857207410037518, 0.8433896042406559, 0.8281376399099827, 0.8239557966589928, 0.7674390841275454, 0.8406185954809189, 0.7890063002705574, 0.7574590370059013, 0.7330352254211903, 0.6783143617212772, 0.6899918941780925, 0.6385442595928907, 0.6851181220263243, 0.6103464476764202, 0.6220027394592762, 0.67162297




0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇████
train_loss,████▇▇▇▇▇▇▇▇▇█▆▇▇▇▇▆▆▆▆▅▅▅▄▃▃▃▃▃▂▂▂▃▂▂▁▁

0,1
epoch,70.0
train_loss,0.28744


In [40]:
net2 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),   # 224x224 -> 112x112   // (3x3x1)x8
    torch.nn.ReLU(),
    torch.nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),  # 112x112 -> 56x56     // (3x3x8)x16
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 56x56 -> 28x28      // (3x3x16)x32
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14      // (3x3x32)x64
    torch.nn.ReLU(),
    torch.nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1),  # 14x14 -> 7x7        // (3x3x64)x32
    torch.nn.ReLU(),
    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)

net2.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(net0.parameters(), lr=0.1, momentum=0.9)
optimizer = torch.optim.Adam(net2.parameters(), lr=0.001)

summary(net2, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [16, 3]                   --
├─Conv2d: 1-1                            [16, 8, 112, 112]         80
├─ReLU: 1-2                              [16, 8, 112, 112]         --
├─Conv2d: 1-3                            [16, 16, 56, 56]          1,168
├─ReLU: 1-4                              [16, 16, 56, 56]          --
├─Conv2d: 1-5                            [16, 32, 28, 28]          4,640
├─ReLU: 1-6                              [16, 32, 28, 28]          --
├─Conv2d: 1-7                            [16, 64, 14, 14]          18,496
├─ReLU: 1-8                              [16, 64, 14, 14]          --
├─Conv2d: 1-9                            [16, 32, 7, 7]            18,464
├─ReLU: 1-10                             [16, 32, 7, 7]            --
├─AdaptiveAvgPool2d: 1-11                [16, 32, 1, 1]            --
├─Flatten: 1-12                          [16, 32]                  --
├

In [41]:
init_wandb()
train_model(net2, optimizer, loss_fn, enable_early_stopping=False)
evaluate_model(net2)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-12 15:24:10,880 - INFO - Epoch 1/70, Train Loss: 1.0514
2025-12-12 15:24:10,965 - INFO - Epoch 2/70, Train Loss: 0.9922


Training model:   3%|▎         | 2/70 [00:00<00:06, 11.30it/s]

2025-12-12 15:24:11,056 - INFO - Epoch 3/70, Train Loss: 1.0469
2025-12-12 15:24:11,144 - INFO - Epoch 4/70, Train Loss: 0.9588


Training model:   6%|▌         | 4/70 [00:00<00:05, 11.22it/s]

2025-12-12 15:24:11,232 - INFO - Epoch 5/70, Train Loss: 1.0077
2025-12-12 15:24:11,320 - INFO - Epoch 6/70, Train Loss: 1.0076


Training model:   9%|▊         | 6/70 [00:00<00:05, 11.29it/s]

2025-12-12 15:24:11,406 - INFO - Epoch 7/70, Train Loss: 1.0312
2025-12-12 15:24:11,492 - INFO - Epoch 8/70, Train Loss: 0.9913


Training model:  11%|█▏        | 8/70 [00:00<00:05, 11.40it/s]

2025-12-12 15:24:11,579 - INFO - Epoch 9/70, Train Loss: 0.9605
2025-12-12 15:24:11,665 - INFO - Epoch 10/70, Train Loss: 0.9726


Training model:  14%|█▍        | 10/70 [00:00<00:05, 11.48it/s]

2025-12-12 15:24:11,762 - INFO - Epoch 11/70, Train Loss: 0.9482
2025-12-12 15:24:11,846 - INFO - Epoch 12/70, Train Loss: 0.9567


Training model:  17%|█▋        | 12/70 [00:01<00:05, 11.32it/s]

2025-12-12 15:24:11,938 - INFO - Epoch 13/70, Train Loss: 0.9147
2025-12-12 15:24:12,030 - INFO - Epoch 14/70, Train Loss: 0.9193


Training model:  20%|██        | 14/70 [00:01<00:05, 11.18it/s]

2025-12-12 15:24:12,116 - INFO - Epoch 15/70, Train Loss: 1.0029
2025-12-12 15:24:12,204 - INFO - Epoch 16/70, Train Loss: 0.9045


Training model:  23%|██▎       | 16/70 [00:01<00:04, 11.27it/s]

2025-12-12 15:24:12,292 - INFO - Epoch 17/70, Train Loss: 0.8962
2025-12-12 15:24:12,379 - INFO - Epoch 18/70, Train Loss: 0.9419


Training model:  26%|██▌       | 18/70 [00:01<00:04, 11.33it/s]

2025-12-12 15:24:12,471 - INFO - Epoch 19/70, Train Loss: 0.8418
2025-12-12 15:24:12,568 - INFO - Epoch 20/70, Train Loss: 0.7624


Training model:  29%|██▊       | 20/70 [00:01<00:04, 11.09it/s]

2025-12-12 15:24:12,657 - INFO - Epoch 21/70, Train Loss: 0.8284
2025-12-12 15:24:12,743 - INFO - Epoch 22/70, Train Loss: 0.6755


Training model:  31%|███▏      | 22/70 [00:01<00:04, 11.18it/s]

2025-12-12 15:24:12,836 - INFO - Epoch 23/70, Train Loss: 0.8268
2025-12-12 15:24:12,926 - INFO - Epoch 24/70, Train Loss: 0.7386


Training model:  34%|███▍      | 24/70 [00:02<00:04, 11.10it/s]

2025-12-12 15:24:13,016 - INFO - Epoch 25/70, Train Loss: 0.6944
2025-12-12 15:24:13,104 - INFO - Epoch 26/70, Train Loss: 0.6358


Training model:  37%|███▋      | 26/70 [00:02<00:03, 11.15it/s]

2025-12-12 15:24:13,194 - INFO - Epoch 27/70, Train Loss: 0.5427
2025-12-12 15:24:13,291 - INFO - Epoch 28/70, Train Loss: 0.5171


Training model:  40%|████      | 28/70 [00:02<00:03, 11.01it/s]

2025-12-12 15:24:13,381 - INFO - Epoch 29/70, Train Loss: 0.5087
2025-12-12 15:24:13,472 - INFO - Epoch 30/70, Train Loss: 0.4189


Training model:  43%|████▎     | 30/70 [00:02<00:03, 11.01it/s]

2025-12-12 15:24:13,561 - INFO - Epoch 31/70, Train Loss: 0.3836
2025-12-12 15:24:13,649 - INFO - Epoch 32/70, Train Loss: 0.4067


Training model:  46%|████▌     | 32/70 [00:02<00:03, 11.10it/s]

2025-12-12 15:24:13,751 - INFO - Epoch 33/70, Train Loss: 0.3383
2025-12-12 15:24:13,869 - INFO - Epoch 34/70, Train Loss: 0.3125


Training model:  49%|████▊     | 34/70 [00:03<00:03, 10.42it/s]

2025-12-12 15:24:13,964 - INFO - Epoch 35/70, Train Loss: 0.3604
2025-12-12 15:24:14,050 - INFO - Epoch 36/70, Train Loss: 0.6133


Training model:  51%|█████▏    | 36/70 [00:03<00:03, 10.61it/s]

2025-12-12 15:24:14,131 - INFO - Epoch 37/70, Train Loss: 0.4498
2025-12-12 15:24:14,222 - INFO - Epoch 38/70, Train Loss: 0.2944


Training model:  54%|█████▍    | 38/70 [00:03<00:02, 10.88it/s]

2025-12-12 15:24:14,307 - INFO - Epoch 39/70, Train Loss: 0.2915
2025-12-12 15:24:14,389 - INFO - Epoch 40/70, Train Loss: 0.2667


Training model:  57%|█████▋    | 40/70 [00:03<00:02, 11.19it/s]

2025-12-12 15:24:14,479 - INFO - Epoch 41/70, Train Loss: 0.1859
2025-12-12 15:24:14,562 - INFO - Epoch 42/70, Train Loss: 0.1883


Training model:  60%|██████    | 42/70 [00:03<00:02, 11.30it/s]

2025-12-12 15:24:14,647 - INFO - Epoch 43/70, Train Loss: 0.1460
2025-12-12 15:24:14,733 - INFO - Epoch 44/70, Train Loss: 0.1506


Training model:  63%|██████▎   | 44/70 [00:03<00:02, 11.42it/s]

2025-12-12 15:24:14,822 - INFO - Epoch 45/70, Train Loss: 0.1368
2025-12-12 15:24:14,909 - INFO - Epoch 46/70, Train Loss: 0.1121


Training model:  66%|██████▌   | 46/70 [00:04<00:02, 11.38it/s]

2025-12-12 15:24:14,999 - INFO - Epoch 47/70, Train Loss: 0.1211
2025-12-12 15:24:15,092 - INFO - Epoch 48/70, Train Loss: 0.1289


Training model:  69%|██████▊   | 48/70 [00:04<00:01, 11.26it/s]

2025-12-12 15:24:15,183 - INFO - Epoch 49/70, Train Loss: 0.0985
2025-12-12 15:24:15,266 - INFO - Epoch 50/70, Train Loss: 0.3266


Training model:  71%|███████▏  | 50/70 [00:04<00:01, 11.32it/s]

2025-12-12 15:24:15,372 - INFO - Epoch 51/70, Train Loss: 0.5259
2025-12-12 15:24:15,485 - INFO - Epoch 52/70, Train Loss: 0.4664


Training model:  74%|███████▍  | 52/70 [00:04<00:01, 10.56it/s]

2025-12-12 15:24:15,585 - INFO - Epoch 53/70, Train Loss: 0.2746
2025-12-12 15:24:15,681 - INFO - Epoch 54/70, Train Loss: 0.1897


Training model:  77%|███████▋  | 54/70 [00:04<00:01, 10.45it/s]

2025-12-12 15:24:15,770 - INFO - Epoch 55/70, Train Loss: 0.1398
2025-12-12 15:24:15,855 - INFO - Epoch 56/70, Train Loss: 0.1278


Training model:  80%|████████  | 56/70 [00:05<00:01, 10.74it/s]

2025-12-12 15:24:15,944 - INFO - Epoch 57/70, Train Loss: 0.1510
2025-12-12 15:24:16,030 - INFO - Epoch 58/70, Train Loss: 0.1148


Training model:  83%|████████▎ | 58/70 [00:05<00:01, 10.94it/s]

2025-12-12 15:24:16,115 - INFO - Epoch 59/70, Train Loss: 0.0820
2025-12-12 15:24:16,204 - INFO - Epoch 60/70, Train Loss: 0.0979


Training model:  86%|████████▌ | 60/70 [00:05<00:00, 11.11it/s]

2025-12-12 15:24:16,294 - INFO - Epoch 61/70, Train Loss: 0.2106
2025-12-12 15:24:16,380 - INFO - Epoch 62/70, Train Loss: 0.2612


Training model:  89%|████████▊ | 62/70 [00:05<00:00, 11.18it/s]

2025-12-12 15:24:16,465 - INFO - Epoch 63/70, Train Loss: 0.1641
2025-12-12 15:24:16,558 - INFO - Epoch 64/70, Train Loss: 0.0938


Training model:  91%|█████████▏| 64/70 [00:05<00:00, 11.20it/s]

2025-12-12 15:24:16,649 - INFO - Epoch 65/70, Train Loss: 0.0949
2025-12-12 15:24:16,734 - INFO - Epoch 66/70, Train Loss: 0.0850


Training model:  94%|█████████▍| 66/70 [00:05<00:00, 11.24it/s]

2025-12-12 15:24:16,826 - INFO - Epoch 67/70, Train Loss: 0.0743
2025-12-12 15:24:16,911 - INFO - Epoch 68/70, Train Loss: 0.0774


Training model:  97%|█████████▋| 68/70 [00:06<00:00, 11.26it/s]

2025-12-12 15:24:17,005 - INFO - Epoch 69/70, Train Loss: 0.0771
2025-12-12 15:24:17,095 - INFO - Epoch 70/70, Train Loss: 0.0633


Training model: 100%|██████████| 70/70 [00:06<00:00, 11.10it/s]

2025-12-12 15:24:17,096 - INFO - [1.051427885890007, 0.9921970665454865, 1.0468668714165688, 0.958775706589222, 1.0076817981898785, 1.0076289884746075, 1.0311594568192959, 0.9913340248167515, 0.9604760892689228, 0.9726138636469841, 0.9481767825782299, 0.956697154790163, 0.9146808385848999, 0.9193496108055115, 1.002855259925127, 0.904484212398529, 0.8962045051157475, 0.9418727159500122, 0.8418341893702745, 0.7624243693426251, 0.8283914234489202, 0.675548393279314, 0.8267624471336603, 0.7385589703917503, 0.694431658834219, 0.6357737481594086, 0.5427270596846938, 0.5170914766786154, 0.5087445545941591, 0.4188974661519751, 0.38363224058412015, 0.4066608380526304, 0.3383357410784811, 0.312480756547302, 0.36042136745527387, 0.6133073754608631, 0.44981733383610845, 0.29436886589974165, 0.2915016404876951, 0.26665277825668454, 0.1858931153838057, 0.18833525478839874, 0.14599887577060144, 0.1506289696553722, 0.13675254295230843, 0.1120947960880585, 0.12109300503652776, 0.12890389945741276, 0.09


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇██
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,██▇██▇▇▇▇█▇▇▅▅▄▃▃▃▃▅▃▃▂▂▁▁▁▁▄▂▂▁▁▁▂▂▁▁▁▁

0,1
epoch,70.0
test_accuracy,0.38776
test_f1,0.46307
test_precision,0.82646
test_recall,0.38776
train_loss,0.06328


In [79]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0001, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = model.state_dict().copy()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                logger.info(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_model = model.state_dict().copy()
            self.counter = 0


In [144]:
batch_size = 32
net3 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),      # (3x3x1)x8
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 224x224 -> 112x112

    torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),       # (3x3x8)x16
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 112x112 -> 56x56

    torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),       # (3x3x16)x32
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 56x56 -> 28x28 

    torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),       # (3x3x32)x64
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 28x28 -> 14x14 

    torch.nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),       # (3x3x64)x32
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 14x14 -> 7x7

    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)

net3.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net3.parameters(), lr=0.001)

summary(net3, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [32, 3]                   --
├─Conv2d: 1-1                            [32, 8, 224, 224]         80
├─ReLU: 1-2                              [32, 8, 224, 224]         --
├─MaxPool2d: 1-3                         [32, 8, 112, 112]         --
├─Conv2d: 1-4                            [32, 16, 112, 112]        1,168
├─ReLU: 1-5                              [32, 16, 112, 112]        --
├─MaxPool2d: 1-6                         [32, 16, 56, 56]          --
├─Conv2d: 1-7                            [32, 32, 56, 56]          4,640
├─ReLU: 1-8                              [32, 32, 56, 56]          --
├─MaxPool2d: 1-9                         [32, 32, 28, 28]          --
├─Conv2d: 1-10                           [32, 64, 28, 28]          18,496
├─ReLU: 1-11                             [32, 64, 28, 28]          --
├─MaxPool2d: 1-12                        [32, 64, 14, 14]          --
├─Con

In [145]:
init_wandb()
train_model(net3, optimizer, loss_fn, enable_early_stopping=False)
evaluate_model(net3)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-12 18:19:35,240 - INFO - Epoch 1/70, Train Loss: 1.2371


Training model:   1%|▏         | 1/70 [00:00<00:22,  3.06it/s]

2025-12-12 18:19:35,293 - INFO - Epoch 2/70, Train Loss: 1.0437
2025-12-12 18:19:35,348 - INFO - Epoch 3/70, Train Loss: 1.0236


Training model:   4%|▍         | 3/70 [00:00<00:08,  8.01it/s]

2025-12-12 18:19:35,406 - INFO - Epoch 4/70, Train Loss: 0.9934
2025-12-12 18:19:35,468 - INFO - Epoch 5/70, Train Loss: 0.9830


Training model:   7%|▋         | 5/70 [00:00<00:05, 10.91it/s]

2025-12-12 18:19:35,534 - INFO - Epoch 6/70, Train Loss: 0.9624
2025-12-12 18:19:35,592 - INFO - Epoch 7/70, Train Loss: 0.9606


Training model:  10%|█         | 7/70 [00:00<00:04, 12.68it/s]

2025-12-12 18:19:35,649 - INFO - Epoch 8/70, Train Loss: 0.9642
2025-12-12 18:19:35,703 - INFO - Epoch 9/70, Train Loss: 0.9533


Training model:  13%|█▎        | 9/70 [00:00<00:04, 14.27it/s]

2025-12-12 18:19:35,756 - INFO - Epoch 10/70, Train Loss: 0.9462
2025-12-12 18:19:35,806 - INFO - Epoch 11/70, Train Loss: 0.9395


Training model:  16%|█▌        | 11/70 [00:00<00:03, 15.69it/s]

2025-12-12 18:19:35,860 - INFO - Epoch 12/70, Train Loss: 0.9669
2025-12-12 18:19:35,912 - INFO - Epoch 13/70, Train Loss: 0.9383


Training model:  19%|█▊        | 13/70 [00:00<00:03, 16.66it/s]

2025-12-12 18:19:35,963 - INFO - Epoch 14/70, Train Loss: 0.9247
2025-12-12 18:19:36,029 - INFO - Epoch 15/70, Train Loss: 0.9538


Training model:  21%|██▏       | 15/70 [00:01<00:03, 16.78it/s]

2025-12-12 18:19:36,090 - INFO - Epoch 16/70, Train Loss: 0.9217
2025-12-12 18:19:36,167 - INFO - Epoch 17/70, Train Loss: 0.9180


Training model:  24%|██▍       | 17/70 [00:01<00:03, 16.00it/s]

2025-12-12 18:19:36,236 - INFO - Epoch 18/70, Train Loss: 0.9242
2025-12-12 18:19:36,298 - INFO - Epoch 19/70, Train Loss: 0.9216


Training model:  27%|██▋       | 19/70 [00:01<00:03, 15.76it/s]

2025-12-12 18:19:36,347 - INFO - Epoch 20/70, Train Loss: 0.9110
2025-12-12 18:19:36,397 - INFO - Epoch 21/70, Train Loss: 0.8896
2025-12-12 18:19:36,447 - INFO - Epoch 22/70, Train Loss: 0.9000


Training model:  31%|███▏      | 22/70 [00:01<00:02, 17.26it/s]

2025-12-12 18:19:36,500 - INFO - Epoch 23/70, Train Loss: 0.8655
2025-12-12 18:19:36,556 - INFO - Epoch 24/70, Train Loss: 0.8688


Training model:  34%|███▍      | 24/70 [00:01<00:02, 17.53it/s]

2025-12-12 18:19:36,614 - INFO - Epoch 25/70, Train Loss: 0.9046
2025-12-12 18:19:36,670 - INFO - Epoch 26/70, Train Loss: 0.8285


Training model:  37%|███▋      | 26/70 [00:01<00:02, 17.52it/s]

2025-12-12 18:19:36,721 - INFO - Epoch 27/70, Train Loss: 0.8221
2025-12-12 18:19:36,769 - INFO - Epoch 28/70, Train Loss: 0.8649
2025-12-12 18:19:36,820 - INFO - Epoch 29/70, Train Loss: 0.8766


Training model:  41%|████▏     | 29/70 [00:01<00:02, 18.36it/s]

2025-12-12 18:19:36,872 - INFO - Epoch 30/70, Train Loss: 0.8146
2025-12-12 18:19:36,920 - INFO - Epoch 31/70, Train Loss: 0.8074
2025-12-12 18:19:36,969 - INFO - Epoch 32/70, Train Loss: 0.7792


Training model:  46%|████▌     | 32/70 [00:02<00:02, 18.96it/s]

2025-12-12 18:19:37,017 - INFO - Epoch 33/70, Train Loss: 0.7222
2025-12-12 18:19:37,066 - INFO - Epoch 34/70, Train Loss: 0.7146
2025-12-12 18:19:37,117 - INFO - Epoch 35/70, Train Loss: 0.6946


Training model:  50%|█████     | 35/70 [00:02<00:01, 19.41it/s]

2025-12-12 18:19:37,168 - INFO - Epoch 36/70, Train Loss: 0.6792
2025-12-12 18:19:37,217 - INFO - Epoch 37/70, Train Loss: 0.6535


Training model:  53%|█████▎    | 37/70 [00:02<00:01, 19.52it/s]

2025-12-12 18:19:37,269 - INFO - Epoch 38/70, Train Loss: 0.7875
2025-12-12 18:19:37,319 - INFO - Epoch 39/70, Train Loss: 0.8244


Training model:  56%|█████▌    | 39/70 [00:02<00:01, 19.55it/s]

2025-12-12 18:19:37,368 - INFO - Epoch 40/70, Train Loss: 0.7356
2025-12-12 18:19:37,419 - INFO - Epoch 41/70, Train Loss: 0.7206
2025-12-12 18:19:37,467 - INFO - Epoch 42/70, Train Loss: 0.6663


Training model:  60%|██████    | 42/70 [00:02<00:01, 19.83it/s]

2025-12-12 18:19:37,517 - INFO - Epoch 43/70, Train Loss: 0.6079
2025-12-12 18:19:37,567 - INFO - Epoch 44/70, Train Loss: 0.5859


Training model:  63%|██████▎   | 44/70 [00:02<00:01, 19.86it/s]

2025-12-12 18:19:37,617 - INFO - Epoch 45/70, Train Loss: 0.5472
2025-12-12 18:19:37,666 - INFO - Epoch 46/70, Train Loss: 0.5453
2025-12-12 18:19:37,715 - INFO - Epoch 47/70, Train Loss: 0.5276


Training model:  67%|██████▋   | 47/70 [00:02<00:01, 20.01it/s]

2025-12-12 18:19:37,764 - INFO - Epoch 48/70, Train Loss: 0.5463
2025-12-12 18:19:37,813 - INFO - Epoch 49/70, Train Loss: 0.4928
2025-12-12 18:19:37,862 - INFO - Epoch 50/70, Train Loss: 0.4680


Training model:  71%|███████▏  | 50/70 [00:02<00:00, 20.10it/s]

2025-12-12 18:19:37,914 - INFO - Epoch 51/70, Train Loss: 0.4946
2025-12-12 18:19:37,963 - INFO - Epoch 52/70, Train Loss: 0.4655
2025-12-12 18:19:38,013 - INFO - Epoch 53/70, Train Loss: 0.4537


Training model:  76%|███████▌  | 53/70 [00:03<00:00, 20.05it/s]

2025-12-12 18:19:38,062 - INFO - Epoch 54/70, Train Loss: 0.4340
2025-12-12 18:19:38,110 - INFO - Epoch 55/70, Train Loss: 0.4009
2025-12-12 18:19:38,160 - INFO - Epoch 56/70, Train Loss: 0.4699


Training model:  80%|████████  | 56/70 [00:03<00:00, 20.15it/s]

2025-12-12 18:19:38,210 - INFO - Epoch 57/70, Train Loss: 0.4040
2025-12-12 18:19:38,257 - INFO - Epoch 58/70, Train Loss: 0.3753
2025-12-12 18:19:38,306 - INFO - Epoch 59/70, Train Loss: 0.3521


Training model:  84%|████████▍ | 59/70 [00:03<00:00, 20.26it/s]

2025-12-12 18:19:38,357 - INFO - Epoch 60/70, Train Loss: 0.4844
2025-12-12 18:19:38,407 - INFO - Epoch 61/70, Train Loss: 0.5056
2025-12-12 18:19:38,457 - INFO - Epoch 62/70, Train Loss: 0.4904


Training model:  89%|████████▊ | 62/70 [00:03<00:00, 20.16it/s]

2025-12-12 18:19:38,507 - INFO - Epoch 63/70, Train Loss: 0.4131
2025-12-12 18:19:38,557 - INFO - Epoch 64/70, Train Loss: 0.3676
2025-12-12 18:19:38,607 - INFO - Epoch 65/70, Train Loss: 0.3849


Training model:  93%|█████████▎| 65/70 [00:03<00:00, 20.12it/s]

2025-12-12 18:19:38,659 - INFO - Epoch 66/70, Train Loss: 0.3024
2025-12-12 18:19:38,707 - INFO - Epoch 67/70, Train Loss: 0.3228
2025-12-12 18:19:38,755 - INFO - Epoch 68/70, Train Loss: 0.3359


Training model:  97%|█████████▋| 68/70 [00:03<00:00, 20.14it/s]

2025-12-12 18:19:38,806 - INFO - Epoch 69/70, Train Loss: 0.2487
2025-12-12 18:19:38,863 - INFO - Epoch 70/70, Train Loss: 0.3159


Training model: 100%|██████████| 70/70 [00:03<00:00, 17.72it/s]

2025-12-12 18:19:38,865 - INFO - [1.237086534500122, 1.0436917394399643, 1.0235626995563507, 0.9933677166700363, 0.9830075800418854, 0.9623955935239792, 0.9606431424617767, 0.9642038643360138, 0.9533436596393585, 0.946183055639267, 0.9394595623016357, 0.9669439196586609, 0.9382504522800446, 0.924747109413147, 0.9537858366966248, 0.9217498153448105, 0.9179900139570236, 0.9241834878921509, 0.9216127246618271, 0.9109524041414261, 0.8895670920610428, 0.8999549299478531, 0.8654675483703613, 0.8687680959701538, 0.9046307355165482, 0.8285076171159744, 0.8220992237329483, 0.8649306893348694, 0.8766479790210724, 0.8145528137683868, 0.8074246942996979, 0.7791862934827805, 0.722162663936615, 0.7145782709121704, 0.6946388632059097, 0.6792204082012177, 0.6534572243690491, 0.7875205725431442, 0.8243864625692368, 0.7356019616127014, 0.7205611914396286, 0.666285365819931, 0.6078769341111183, 0.5859226733446121, 0.5472287386655807, 0.5453073084354401, 0.527581550180912, 0.546343170106411, 0.49275290966


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,█▆▆▆▆▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▅▅▄▃▃▃▃▂▂▂▂▁▂▂▁▁▁

0,1
epoch,70.0
test_accuracy,0.18367
test_f1,0.21502
test_precision,0.87755
test_recall,0.18367
train_loss,0.31587


In [80]:
# Use train test split, for getting validation metrics during training
x_train_tensor, x_val_tensor, y_train_tensor, y_val_tensor = train_test_split(
    x_train_tensor, y_train_tensor, test_size=0.2, random_state=42, stratify=y_train_tensor)

val_dataset = TensorDataset(x_val_tensor, y_val_tensor)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [157]:
net4 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),      # (3x3x1)x8
    torch.nn.BatchNorm2d(8),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 224x224 -> 112x112

    torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),       # (3x3x8)x16
    torch.nn.BatchNorm2d(16),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 112x112 -> 56x56

    torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),       # (3x3x16)x32
    torch.nn.BatchNorm2d(32),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 56x56 -> 28x28 

    torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),       # (3x3x32)x64
    torch.nn.BatchNorm2d(64),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 28x28 -> 14x14 

    torch.nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),       # (3x3x64)x32
    torch.nn.BatchNorm2d(32),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 14x14 -> 7x7

    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)

net4.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net4.parameters(), lr=0.001)

summary(net4, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [32, 3]                   --
├─Conv2d: 1-1                            [32, 8, 224, 224]         80
├─BatchNorm2d: 1-2                       [32, 8, 224, 224]         16
├─ReLU: 1-3                              [32, 8, 224, 224]         --
├─MaxPool2d: 1-4                         [32, 8, 112, 112]         --
├─Conv2d: 1-5                            [32, 16, 112, 112]        1,168
├─BatchNorm2d: 1-6                       [32, 16, 112, 112]        32
├─ReLU: 1-7                              [32, 16, 112, 112]        --
├─MaxPool2d: 1-8                         [32, 16, 56, 56]          --
├─Conv2d: 1-9                            [32, 32, 56, 56]          4,640
├─BatchNorm2d: 1-10                      [32, 32, 56, 56]          64
├─ReLU: 1-11                             [32, 32, 56, 56]          --
├─MaxPool2d: 1-12                        [32, 32, 28, 28]          --
├─Conv2d:

In [158]:
init_wandb()
train_model(net4, optimizer, loss_fn, enable_early_stopping=True, patience=5)
evaluate_model(net4)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-12 18:31:13,643 - INFO - Epoch 1/70, Train Loss: 1.0647, Val Loss: 1.0436, Val Acc: 0.3871


Training model:   1%|▏         | 1/70 [00:00<00:21,  3.20it/s]

2025-12-12 18:31:13,735 - INFO - Epoch 2/70, Train Loss: 0.9453, Val Loss: 1.0350, Val Acc: 0.4839
2025-12-12 18:31:13,846 - INFO - Epoch 3/70, Train Loss: 0.8866, Val Loss: 1.0337, Val Acc: 0.5161


Training model:   4%|▍         | 3/70 [00:00<00:10,  6.39it/s]

2025-12-12 18:31:13,951 - INFO - Epoch 4/70, Train Loss: 0.8378, Val Loss: 1.0360, Val Acc: 0.4516
2025-12-12 18:31:13,952 - INFO - EarlyStopping counter: 1 out of 5


Training model:   6%|▌         | 4/70 [00:00<00:09,  7.21it/s]

2025-12-12 18:31:14,047 - INFO - Epoch 5/70, Train Loss: 0.8040, Val Loss: 1.0350, Val Acc: 0.4839
2025-12-12 18:31:14,047 - INFO - EarlyStopping counter: 2 out of 5
2025-12-12 18:31:14,137 - INFO - Epoch 6/70, Train Loss: 0.7524, Val Loss: 1.0361, Val Acc: 0.4839
2025-12-12 18:31:14,138 - INFO - EarlyStopping counter: 3 out of 5


Training model:   9%|▊         | 6/70 [00:00<00:07,  8.61it/s]

2025-12-12 18:31:14,238 - INFO - Epoch 7/70, Train Loss: 0.7224, Val Loss: 1.0175, Val Acc: 0.4839


Training model:  10%|█         | 7/70 [00:00<00:07,  8.91it/s]

2025-12-12 18:31:14,327 - INFO - Epoch 8/70, Train Loss: 0.6972, Val Loss: 0.9962, Val Acc: 0.4839
2025-12-12 18:31:14,413 - INFO - Epoch 9/70, Train Loss: 0.6614, Val Loss: 1.0571, Val Acc: 0.4516
2025-12-12 18:31:14,414 - INFO - EarlyStopping counter: 1 out of 5


Training model:  13%|█▎        | 9/70 [00:01<00:06,  9.88it/s]

2025-12-12 18:31:14,498 - INFO - Epoch 10/70, Train Loss: 0.6113, Val Loss: 1.0052, Val Acc: 0.4839
2025-12-12 18:31:14,499 - INFO - EarlyStopping counter: 2 out of 5
2025-12-12 18:31:14,588 - INFO - Epoch 11/70, Train Loss: 0.5897, Val Loss: 1.0285, Val Acc: 0.4194
2025-12-12 18:31:14,590 - INFO - EarlyStopping counter: 3 out of 5


Training model:  16%|█▌        | 11/70 [00:01<00:05, 10.36it/s]

2025-12-12 18:31:14,671 - INFO - Epoch 12/70, Train Loss: 0.5019, Val Loss: 1.0451, Val Acc: 0.5161
2025-12-12 18:31:14,673 - INFO - EarlyStopping counter: 4 out of 5
2025-12-12 18:31:14,751 - INFO - Epoch 13/70, Train Loss: 0.5277, Val Loss: 0.9351, Val Acc: 0.4839


Training model:  19%|█▊        | 13/70 [00:01<00:05, 11.03it/s]

2025-12-12 18:31:14,829 - INFO - Epoch 14/70, Train Loss: 0.4907, Val Loss: 0.8818, Val Acc: 0.5161
2025-12-12 18:31:14,903 - INFO - Epoch 15/70, Train Loss: 0.4176, Val Loss: 1.6310, Val Acc: 0.5161
2025-12-12 18:31:14,904 - INFO - EarlyStopping counter: 1 out of 5


Training model:  21%|██▏       | 15/70 [00:01<00:04, 11.64it/s]

2025-12-12 18:31:14,983 - INFO - Epoch 16/70, Train Loss: 0.4139, Val Loss: 0.8027, Val Acc: 0.6452
2025-12-12 18:31:15,062 - INFO - Epoch 17/70, Train Loss: 0.3917, Val Loss: 0.7938, Val Acc: 0.7097


Training model:  24%|██▍       | 17/70 [00:01<00:04, 11.94it/s]

2025-12-12 18:31:15,140 - INFO - Epoch 18/70, Train Loss: 0.3706, Val Loss: 1.0430, Val Acc: 0.6129
2025-12-12 18:31:15,140 - INFO - EarlyStopping counter: 1 out of 5
2025-12-12 18:31:15,218 - INFO - Epoch 19/70, Train Loss: 0.3978, Val Loss: 0.9023, Val Acc: 0.6774
2025-12-12 18:31:15,219 - INFO - EarlyStopping counter: 2 out of 5


Training model:  27%|██▋       | 19/70 [00:01<00:04, 12.20it/s]

2025-12-12 18:31:15,302 - INFO - Epoch 20/70, Train Loss: 0.3191, Val Loss: 0.8841, Val Acc: 0.6129
2025-12-12 18:31:15,302 - INFO - EarlyStopping counter: 3 out of 5
2025-12-12 18:31:15,390 - INFO - Epoch 21/70, Train Loss: 0.2883, Val Loss: 0.9186, Val Acc: 0.5484
2025-12-12 18:31:15,390 - INFO - EarlyStopping counter: 4 out of 5


Training model:  30%|███       | 21/70 [00:02<00:04, 12.00it/s]

2025-12-12 18:31:15,479 - INFO - Epoch 22/70, Train Loss: 0.2851, Val Loss: 0.9483, Val Acc: 0.5806
2025-12-12 18:31:15,479 - INFO - EarlyStopping counter: 5 out of 5
2025-12-12 18:31:15,479 - INFO - Early stopping triggered


Training model:  30%|███       | 21/70 [00:02<00:05,  9.76it/s]

2025-12-12 18:31:15,485 - INFO - Loaded best model weights
2025-12-12 18:31:15,486 - INFO - [1.0646775662899017, 0.9453068822622299, 0.8866231143474579, 0.8378106951713562, 0.803951233625412, 0.752399668097496, 0.7224159240722656, 0.6971715837717056, 0.6613590568304062, 0.6113146841526031, 0.5896765887737274, 0.5019043385982513, 0.5276547595858574, 0.4907457157969475, 0.41763219982385635, 0.41385509073734283, 0.3916967660188675, 0.3706493079662323, 0.39776717126369476, 0.3191038444638252, 0.2882870212197304, 0.2850576862692833]
2025-12-12 18:31:15,515 - INFO - network accuracy: 38.78%
2025-12-12 18:31:15,516 - INFO - network precision: 88.57%
2025-12-12 18:31:15,516 - INFO - network recall: 38.78%
2025-12-12 18:31:15,517 - INFO - network F1 score: 42.86%
2025-12-12 18:31:15,527 - INFO - Detailed Classification Report: 
              precision    recall  f1-score   support

           0       0.20      1.00      0.33         7
           1       1.00      0.29      0.44        42
      


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇██
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,█▇▆▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▂▁▁▁
val_accuracy,▁▃▄▂▃▃▃▃▂▃▂▄▃▄▄▇█▆▇▆▅▅
val_loss,▃▃▃▃▃▃▃▃▃▃▃▃▂▂█▁▁▃▂▂▂▂

0,1
epoch,22.0
test_accuracy,0.38776
test_f1,0.42857
test_precision,0.88571
test_recall,0.38776
train_loss,0.28506
val_accuracy,0.58065
val_loss,0.9483


In [159]:
net5 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),      #(3x3x1)x16
    torch.nn.BatchNorm2d(8),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2), # 224x224 -> 112x112

    torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),       #(3x3x16)x32
    torch.nn.BatchNorm2d(16),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2), # 112x112 -> 56x56

    torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),        #(3x3x32)x64
    torch.nn.BatchNorm2d(32),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2), # 56x56 -> 28x28
    
    torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),       #(3x3x64)x128
    torch.nn.BatchNorm2d(64),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2), # 28x28 -> 14x14 

    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(64, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 3)                       # Output layer     
).to(device)

net5.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net5.parameters(), lr=0.001, weight_decay=1e-4)

summary(net5, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [32, 3]                   --
├─Conv2d: 1-1                            [32, 8, 224, 224]         80
├─BatchNorm2d: 1-2                       [32, 8, 224, 224]         16
├─ReLU: 1-3                              [32, 8, 224, 224]         --
├─MaxPool2d: 1-4                         [32, 8, 112, 112]         --
├─Conv2d: 1-5                            [32, 16, 112, 112]        1,168
├─BatchNorm2d: 1-6                       [32, 16, 112, 112]        32
├─ReLU: 1-7                              [32, 16, 112, 112]        --
├─MaxPool2d: 1-8                         [32, 16, 56, 56]          --
├─Conv2d: 1-9                            [32, 32, 56, 56]          4,640
├─BatchNorm2d: 1-10                      [32, 32, 56, 56]          64
├─ReLU: 1-11                             [32, 32, 56, 56]          --
├─MaxPool2d: 1-12                        [32, 32, 28, 28]          --
├─Conv2d:

In [None]:

init_wandb()
train_model(net5, optimizer, loss_fn, enable_early_stopping=True, patience=5)
evaluate_model(net5)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-12 18:32:10,216 - INFO - Epoch 1/70, Train Loss: 1.1514, Val Loss: 1.0290, Val Acc: 0.4839


Training model:   1%|▏         | 1/70 [00:00<00:21,  3.14it/s]

2025-12-12 18:32:10,296 - INFO - Epoch 2/70, Train Loss: 0.9883, Val Loss: 1.0154, Val Acc: 0.3871
2025-12-12 18:32:10,376 - INFO - Epoch 3/70, Train Loss: 0.9204, Val Loss: 1.0194, Val Acc: 0.4194
2025-12-12 18:32:10,377 - INFO - EarlyStopping counter: 1 out of 10


Training model:   4%|▍         | 3/70 [00:00<00:09,  7.04it/s]

2025-12-12 18:32:10,457 - INFO - Epoch 4/70, Train Loss: 0.9009, Val Loss: 1.0026, Val Acc: 0.4839
2025-12-12 18:32:10,535 - INFO - Epoch 5/70, Train Loss: 0.8774, Val Loss: 0.9977, Val Acc: 0.4839


Training model:   7%|▋         | 5/70 [00:00<00:07,  9.10it/s]

2025-12-12 18:32:10,611 - INFO - Epoch 6/70, Train Loss: 0.8885, Val Loss: 1.0023, Val Acc: 0.4194
2025-12-12 18:32:10,612 - INFO - EarlyStopping counter: 1 out of 10
2025-12-12 18:32:10,766 - INFO - Epoch 7/70, Train Loss: 0.8523, Val Loss: 0.9986, Val Acc: 0.3871
2025-12-12 18:32:10,767 - INFO - EarlyStopping counter: 2 out of 10


Training model:  10%|█         | 7/70 [00:00<00:07,  8.91it/s]

2025-12-12 18:32:10,849 - INFO - Epoch 8/70, Train Loss: 0.8398, Val Loss: 0.9822, Val Acc: 0.5484
2025-12-12 18:32:10,922 - INFO - Epoch 9/70, Train Loss: 0.8182, Val Loss: 0.9680, Val Acc: 0.5161


Training model:  13%|█▎        | 9/70 [00:01<00:06, 10.05it/s]

2025-12-12 18:32:10,997 - INFO - Epoch 10/70, Train Loss: 0.8029, Val Loss: 0.9614, Val Acc: 0.5161
2025-12-12 18:32:11,074 - INFO - Epoch 11/70, Train Loss: 0.7732, Val Loss: 0.9278, Val Acc: 0.5161


Training model:  16%|█▌        | 11/70 [00:01<00:05, 10.96it/s]

2025-12-12 18:32:11,153 - INFO - Epoch 12/70, Train Loss: 0.7683, Val Loss: 0.9107, Val Acc: 0.5161
2025-12-12 18:32:11,232 - INFO - Epoch 13/70, Train Loss: 0.7719, Val Loss: 0.9299, Val Acc: 0.5806
2025-12-12 18:32:11,232 - INFO - EarlyStopping counter: 1 out of 10


Training model:  19%|█▊        | 13/70 [00:01<00:04, 11.48it/s]

2025-12-12 18:32:11,311 - INFO - Epoch 14/70, Train Loss: 0.7248, Val Loss: 0.8936, Val Acc: 0.4839
2025-12-12 18:32:11,386 - INFO - Epoch 15/70, Train Loss: 0.7287, Val Loss: 0.8860, Val Acc: 0.5484


Training model:  21%|██▏       | 15/70 [00:01<00:04, 11.92it/s]

2025-12-12 18:32:11,469 - INFO - Epoch 16/70, Train Loss: 0.7593, Val Loss: 0.8808, Val Acc: 0.5484
2025-12-12 18:32:11,555 - INFO - Epoch 17/70, Train Loss: 0.7067, Val Loss: 0.9955, Val Acc: 0.5161
2025-12-12 18:32:11,556 - INFO - EarlyStopping counter: 1 out of 10


Training model:  24%|██▍       | 17/70 [00:01<00:04, 11.89it/s]

2025-12-12 18:32:11,633 - INFO - Epoch 18/70, Train Loss: 0.6597, Val Loss: 0.8852, Val Acc: 0.5484
2025-12-12 18:32:11,634 - INFO - EarlyStopping counter: 2 out of 10
2025-12-12 18:32:11,711 - INFO - Epoch 19/70, Train Loss: 0.6633, Val Loss: 0.9231, Val Acc: 0.6452
2025-12-12 18:32:11,712 - INFO - EarlyStopping counter: 3 out of 10


Training model:  27%|██▋       | 19/70 [00:01<00:04, 12.19it/s]

2025-12-12 18:32:11,789 - INFO - Epoch 20/70, Train Loss: 0.6473, Val Loss: 0.9171, Val Acc: 0.5484
2025-12-12 18:32:11,790 - INFO - EarlyStopping counter: 4 out of 10
2025-12-12 18:32:11,866 - INFO - Epoch 21/70, Train Loss: 0.6215, Val Loss: 0.9580, Val Acc: 0.5806
2025-12-12 18:32:11,866 - INFO - EarlyStopping counter: 5 out of 10


Training model:  30%|███       | 21/70 [00:01<00:03, 12.41it/s]

2025-12-12 18:32:11,940 - INFO - Epoch 22/70, Train Loss: 0.5964, Val Loss: 0.8933, Val Acc: 0.5806
2025-12-12 18:32:11,940 - INFO - EarlyStopping counter: 6 out of 10
2025-12-12 18:32:12,014 - INFO - Epoch 23/70, Train Loss: 0.5845, Val Loss: 0.8684, Val Acc: 0.4839


Training model:  33%|███▎      | 23/70 [00:02<00:03, 12.69it/s]

2025-12-12 18:32:12,090 - INFO - Epoch 24/70, Train Loss: 0.5648, Val Loss: 0.8610, Val Acc: 0.5161
2025-12-12 18:32:12,162 - INFO - Epoch 25/70, Train Loss: 0.5484, Val Loss: 0.8802, Val Acc: 0.5806
2025-12-12 18:32:12,163 - INFO - EarlyStopping counter: 1 out of 10


Training model:  36%|███▌      | 25/70 [00:02<00:03, 12.95it/s]

2025-12-12 18:32:12,237 - INFO - Epoch 26/70, Train Loss: 0.5414, Val Loss: 0.8898, Val Acc: 0.5806
2025-12-12 18:32:12,237 - INFO - EarlyStopping counter: 2 out of 10
2025-12-12 18:32:12,311 - INFO - Epoch 27/70, Train Loss: 0.5144, Val Loss: 0.8478, Val Acc: 0.5806


Training model:  39%|███▊      | 27/70 [00:02<00:03, 13.09it/s]

2025-12-12 18:32:12,384 - INFO - Epoch 28/70, Train Loss: 0.4909, Val Loss: 0.8658, Val Acc: 0.5161
2025-12-12 18:32:12,384 - INFO - EarlyStopping counter: 1 out of 10
2025-12-12 18:32:12,458 - INFO - Epoch 29/70, Train Loss: 0.5345, Val Loss: 0.9405, Val Acc: 0.5806
2025-12-12 18:32:12,459 - INFO - EarlyStopping counter: 2 out of 10


Training model:  41%|████▏     | 29/70 [00:02<00:03, 13.20it/s]

2025-12-12 18:32:12,534 - INFO - Epoch 30/70, Train Loss: 0.5016, Val Loss: 0.8938, Val Acc: 0.5806
2025-12-12 18:32:12,535 - INFO - EarlyStopping counter: 3 out of 10
2025-12-12 18:32:12,608 - INFO - Epoch 31/70, Train Loss: 0.4673, Val Loss: 0.8975, Val Acc: 0.5484
2025-12-12 18:32:12,609 - INFO - EarlyStopping counter: 4 out of 10


Training model:  44%|████▍     | 31/70 [00:02<00:02, 13.28it/s]

2025-12-12 18:32:12,682 - INFO - Epoch 32/70, Train Loss: 0.4438, Val Loss: 0.8560, Val Acc: 0.5484
2025-12-12 18:32:12,683 - INFO - EarlyStopping counter: 5 out of 10
2025-12-12 18:32:12,758 - INFO - Epoch 33/70, Train Loss: 0.4405, Val Loss: 0.8118, Val Acc: 0.6452


Training model:  47%|████▋     | 33/70 [00:02<00:02, 13.28it/s]

2025-12-12 18:32:12,861 - INFO - Epoch 34/70, Train Loss: 0.4537, Val Loss: 0.8208, Val Acc: 0.5806
2025-12-12 18:32:12,862 - INFO - EarlyStopping counter: 1 out of 10
2025-12-12 18:32:12,951 - INFO - Epoch 35/70, Train Loss: 0.4062, Val Loss: 0.9000, Val Acc: 0.4839
2025-12-12 18:32:12,951 - INFO - EarlyStopping counter: 2 out of 10


Training model:  50%|█████     | 35/70 [00:03<00:02, 12.24it/s]

2025-12-12 18:32:13,034 - INFO - Epoch 36/70, Train Loss: 0.3860, Val Loss: 0.8283, Val Acc: 0.5484
2025-12-12 18:32:13,035 - INFO - EarlyStopping counter: 3 out of 10
2025-12-12 18:32:13,114 - INFO - Epoch 37/70, Train Loss: 0.4251, Val Loss: 0.9002, Val Acc: 0.6129
2025-12-12 18:32:13,114 - INFO - EarlyStopping counter: 4 out of 10


Training model:  53%|█████▎    | 37/70 [00:03<00:02, 12.28it/s]

2025-12-12 18:32:13,194 - INFO - Epoch 38/70, Train Loss: 0.4040, Val Loss: 1.0068, Val Acc: 0.5161
2025-12-12 18:32:13,195 - INFO - EarlyStopping counter: 5 out of 10
2025-12-12 18:32:13,341 - INFO - Epoch 39/70, Train Loss: 0.4549, Val Loss: 0.8694, Val Acc: 0.5806
2025-12-12 18:32:13,342 - INFO - EarlyStopping counter: 6 out of 10


Training model:  56%|█████▌    | 39/70 [00:03<00:02, 10.97it/s]

2025-12-12 18:32:13,417 - INFO - Epoch 40/70, Train Loss: 0.3962, Val Loss: 0.9441, Val Acc: 0.6129
2025-12-12 18:32:13,418 - INFO - EarlyStopping counter: 7 out of 10
2025-12-12 18:32:13,493 - INFO - Epoch 41/70, Train Loss: 0.3680, Val Loss: 0.8807, Val Acc: 0.6129
2025-12-12 18:32:13,493 - INFO - EarlyStopping counter: 8 out of 10


Training model:  59%|█████▊    | 41/70 [00:03<00:02, 11.56it/s]

2025-12-12 18:32:13,570 - INFO - Epoch 42/70, Train Loss: 0.3303, Val Loss: 0.8702, Val Acc: 0.5484
2025-12-12 18:32:13,572 - INFO - EarlyStopping counter: 9 out of 10
2025-12-12 18:32:13,653 - INFO - Epoch 43/70, Train Loss: 0.3159, Val Loss: 0.8641, Val Acc: 0.6129
2025-12-12 18:32:13,654 - INFO - EarlyStopping counter: 10 out of 10
2025-12-12 18:32:13,655 - INFO - Early stopping triggered


Training model:  60%|██████    | 42/70 [00:03<00:02, 11.17it/s]

2025-12-12 18:32:13,660 - INFO - Loaded best model weights
2025-12-12 18:32:13,661 - INFO - [1.1514268517494202, 0.9883494526147842, 0.9204300343990326, 0.9009385257959366, 0.8774115741252899, 0.8885339796543121, 0.8522748202085495, 0.8397817760705948, 0.8182388693094254, 0.8029266744852066, 0.7731685042381287, 0.7683285623788834, 0.7718577235937119, 0.7247912436723709, 0.7287047058343887, 0.7592594772577286, 0.7066857218742371, 0.6597284078598022, 0.6633101105690002, 0.6473288983106613, 0.6214704513549805, 0.5964439809322357, 0.5845062360167503, 0.5648430213332176, 0.5483621954917908, 0.5413946062326431, 0.5143819451332092, 0.4909096360206604, 0.5345412939786911, 0.5016388967633247, 0.4672912210226059, 0.4438416063785553, 0.4405299797654152, 0.4536689445376396, 0.4061700403690338, 0.3859664499759674, 0.42508550733327866, 0.40400509536266327, 0.45491378754377365, 0.39615825563669205, 0.367956779897213, 0.33025261759757996, 0.3159133940935135]
2025-12-12 18:32:13,690 - INFO - network ac


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


2025-12-12 18:32:13,690 - INFO - network precision: 82.77%
2025-12-12 18:32:13,691 - INFO - network recall: 38.78%
2025-12-12 18:32:13,692 - INFO - network F1 score: 44.84%
2025-12-12 18:32:13,702 - INFO - Detailed Classification Report: 
              precision    recall  f1-score   support

           0       0.22      0.86      0.35         7
           1       0.93      0.31      0.46        42
           2       0.00      0.00      0.00         0

    accuracy                           0.39        49
   macro avg       0.38      0.39      0.27        49
weighted avg       0.83      0.39      0.45        49



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,█▇▆▆▆▆▅▅▅▅▅▅▅▄▅▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁
val_accuracy,▄▁▂▄▄▂▁▅▅▅▅▅▆▅▅▅▅█▅▆▆▄▅▆▆▆▆▆▅▅█▆▄▅▇▅▆▇▇▇
val_loss,███▇▇▇▇▆▆▆▅▄▅▃▃▇▃▅▄▆▄▃▃▃▄▂▅▄▄▂▁▁▄▂▄▇▃▅▃▃

0,1
epoch,43.0
test_accuracy,0.38776
test_f1,0.44838
test_precision,0.82766
test_recall,0.38776
train_loss,0.31591
val_accuracy,0.6129
val_loss,0.86408
