In [1]:
import numpy as np
import os
from dotenv import load_dotenv
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, balanced_accuracy_score
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import TensorDataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

from tqdm.auto import tqdm

import wandb

import logging
import sys
from torchinfo import summary

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def setup_logger(name=__name__):
    """
    Sets up a logger that outputs to the console (stdout).
    """
    logger = logging.getLogger(name)
    if not logger.handlers:
        logger.setLevel(logging.INFO)
        handler = logging.StreamHandler(sys.stdout)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger

logger = setup_logger()

In [3]:
# Load the dataset
data_folder = "../data"
preped_folder = os.path.join(data_folder, "_preped")

train_data = pd.read_csv(os.path.join(data_folder, 'train_data.csv')).values.tolist()
test_data = pd.read_csv(os.path.join(data_folder, 'test_data.csv')).values.tolist()

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to consistent size
    transforms.ToTensor(),           # Convert to tensor [0, 1]
    transforms.Normalize(mean=[0.5], std=[0.5])
])

x_train = []
y_train = []

for img_name, label in train_data:
    img_path = os.path.join(preped_folder, img_name)
    try:
        img = Image.open(img_path).convert('L') # Convert to grayscale
        img_tensor = transform(img)
        x_train.append(img_tensor)
        y_train.append(label)
    except Exception as e:
        logger.info(f"Error loading {img_name}: {e}")

# Stack into tensors
x_train_tensor = torch.stack(x_train)
logger.info(f"Training images shape: {x_train_tensor.shape}")

# Encode labels to integers
label_to_idx = {label: idx for idx, label in enumerate(np.unique(y_train))}
y_train_encoded = [label_to_idx[label] for label in y_train]
y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long)

logger.info(f"Training labels shape: {y_train_tensor.shape}")
logger.info(f"Label mapping: {label_to_idx}")

2025-12-13 15:11:58,466 - INFO - Training images shape: torch.Size([241, 1, 224, 224])
2025-12-13 15:11:58,468 - INFO - Training labels shape: torch.Size([241])
2025-12-13 15:11:58,468 - INFO - Label mapping: {np.str_('1_Pronacio'): 0, np.str_('2_Neutralis'): 1, np.str_('3_Szupinacio'): 2}


In [4]:
x_test = []
y_test = []

for img_name, label in test_data:
    img_path = os.path.join(preped_folder, img_name)
    try:
        img = Image.open(img_path).convert('L') # Convert to grayscale
        img_tensor = transform(img)
        x_test.append(img_tensor)
        y_test.append(label)
    except Exception as e:
        logger.info(f"Error loading {img_name}: {e}")

x_test_tensor = torch.stack(x_test)
logger.info(f"Test images shape: {x_test_tensor.shape}")
y_test_encoded = [label_to_idx[label] for label in y_test]
y_test_tensor = torch.tensor(y_test_encoded, dtype=torch.long)

logger.info(f"Test labels shape: {y_test_tensor.shape}")

2025-12-13 15:12:04,654 - INFO - Test images shape: torch.Size([49, 1, 224, 224])
2025-12-13 15:12:04,655 - INFO - Test labels shape: torch.Size([49])


In [5]:
if torch.cuda.is_available():
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        logger.info(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        logger.info(f"  Memory: {props.total_memory / 1024**3:.2f} GB")
        logger.info(f"  Compute Capability: {props.major}.{props.minor}")
else:
    logger.info("CUDA not available")

2025-12-13 15:12:04,707 - INFO - CUDA available: True
2025-12-13 15:12:04,708 - INFO - Number of GPUs: 1
2025-12-13 15:12:04,714 - INFO - 
GPU 0: NVIDIA GeForce RTX 4060
2025-12-13 15:12:04,714 - INFO -   Memory: 8.00 GB
2025-12-13 15:12:04,715 - INFO -   Compute Capability: 8.9


In [6]:
batch_size = 16
num_epochs = 70
device = 'cuda' 

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [28]:
# wandb login an init
# Login to wandb with API key
load_dotenv()
wandb.login(key=os.getenv("wandbKey"))

def init_wandb():
    # Initialize wandb project
    wandb.init(
        project="ankle-align-inc-model",
        config={
            "batch_size": batch_size,
            "num_epochs": num_epochs,
      
            "architecture": "Custom CNN",
            "dataset": "AnkleAlign",
            "optimizer": "Adam"
        }
    )

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Win 10\_netrc


In [8]:
net0 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),   # 224x224 -> 112x112
    torch.nn.ReLU(),
    torch.nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),  # 112x112 -> 56x56
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 56x56 -> 28x28
    torch.nn.ReLU(),
    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)


def init_weights(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

net0.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net0.parameters(), lr=0.01)

summary(net0, input_size=(batch_size, 1, 224, 224))


Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [16, 3]                   --
├─Conv2d: 1-1                            [16, 8, 112, 112]         80
├─ReLU: 1-2                              [16, 8, 112, 112]         --
├─Conv2d: 1-3                            [16, 16, 56, 56]          1,168
├─ReLU: 1-4                              [16, 16, 56, 56]          --
├─Conv2d: 1-5                            [16, 32, 28, 28]          4,640
├─ReLU: 1-6                              [16, 32, 28, 28]          --
├─AdaptiveAvgPool2d: 1-7                 [16, 32, 1, 1]            --
├─Flatten: 1-8                           [16, 32]                  --
├─Linear: 1-9                            [16, 128]                 4,224
├─ReLU: 1-10                             [16, 128]                 --
├─Linear: 1-11                           [16, 64]                  8,256
├─ReLU: 1-12                             [16, 64]                  --
├─L

In [9]:
# Trying to overfit one batch
init_wandb()
one_batch = next(iter(train_loader))
images, labels = one_batch

images = images.to(device)
labels = labels.to(device)

loss_values = []
net0.train()
for epoch in tqdm(range(num_epochs), desc='Training model'):
        pred_logits = net0(images)
        loss = loss_fn(pred_logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_values.append(loss.item())
        wandb.log({
                "epoch": epoch + 1,
                "train_loss": loss.item()
            })
        
wandb.finish()
print(loss_values)

Training model: 100%|██████████| 70/70 [00:00<00:00, 180.74it/s]
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇████
train_loss,███▇▇▇▆▆▆▅▅▅▄▃▅▄▄▃▂▃▃▃▂▂▂▂▂▂▂▂▁▁▁▂▄▁▂▁▁▁

0,1
epoch,70.0
train_loss,0.07373


[1.0581742525100708, 0.9670201539993286, 0.9727962017059326, 0.9077773690223694, 0.9199650287628174, 0.9004878997802734, 0.8884644508361816, 0.8728632926940918, 0.8450402021408081, 0.8295108675956726, 0.8084721565246582, 0.7869530320167542, 0.7585417032241821, 0.7349890470504761, 0.7065781354904175, 0.6773555278778076, 0.6501045227050781, 0.6146928071975708, 0.5823861360549927, 0.5504868030548096, 0.5249641537666321, 0.4958270192146301, 0.45998427271842957, 0.42188993096351624, 0.3856785297393799, 0.4135778546333313, 0.7269619107246399, 0.48664823174476624, 0.5767006874084473, 0.48961764574050903, 0.3356057405471802, 0.4124259054660797, 0.41717058420181274, 0.5303608179092407, 0.42846354842185974, 0.3833099603652954, 0.2570996582508087, 0.3637549877166748, 0.3280019462108612, 0.30245280265808105, 0.34380239248275757, 0.2971813976764679, 0.23902207612991333, 0.26879552006721497, 0.22528496384620667, 0.22261683642864227, 0.2360658049583435, 0.1880386918783188, 0.21882149577140808, 0.1755

Net0 looks the most simply CNN which could learn on a 16 image batch, and overfit on this data. Smaller networks were not sifficent enaught to learon on 16 images.

In [10]:
def train_model(network, optimizer, loss_fn, enable_early_stopping=False, patience=5):
    torch.cuda.empty_cache()

    loss_values = []

    if enable_early_stopping:
        early_stopping = EarlyStopping(patience=patience, verbose=True)

    network.train()
    for epoch in tqdm(range(num_epochs), desc='Training model'):
        network.train()
        epoch_loss = 0.0
        num_batches = 0
        for images, target_labels in train_loader:
            images = images.to(device)
            target_labels = target_labels.to(device)

            pred_logits = network(images)
            loss = loss_fn(pred_logits, target_labels)
            epoch_loss += loss.item()
            num_batches += 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        avg_train_loss = epoch_loss / num_batches

        if enable_early_stopping:
            network.eval()
            val_loss = 0.0
            val_batches = 0
            correct = 0
            total = 0
            with torch.no_grad():
                for images, target_labels in val_loader:
                    images = images.to(device)
                    target_labels = target_labels.to(device)
                    
                    pred_logits = network(images)
                    loss = loss_fn(pred_logits, target_labels)
                    val_loss += loss.item()
                    val_batches += 1
                    
                    _, predicted = torch.max(pred_logits, 1)
                    total += target_labels.size(0)
                    correct += (predicted == target_labels).sum().item()
            
            avg_val_loss = val_loss / val_batches
            val_accuracy = correct / total

        # Log metrics
        if enable_early_stopping:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss,
                "val_accuracy": val_accuracy
            })
        else:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_train_loss
            })
        loss_values.append(avg_train_loss)
        
        if enable_early_stopping:
            logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        else:
            logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}")

        # Early stopping check
        if enable_early_stopping:
            early_stopping(avg_val_loss, network)
            if early_stopping.early_stop:
                logger.info("Early stopping triggered")
                network.load_state_dict(early_stopping.best_model)
                break
    
    # Load best model
    if enable_early_stopping and early_stopping.best_model is not None:
        network.load_state_dict(early_stopping.best_model)
        logger.info("Loaded best model weights")

    logger.info(loss_values)

In [11]:
def evaluate_model(network):
    # Training score
    true_labels = y_test_encoded
    predicted_labels = []
    network.eval()
    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            outputs = network(images)
            _, predicted = torch.max(outputs, 1)
            predicted_labels.extend(predicted.cpu().numpy())

    balanced_accuracy = balanced_accuracy_score(true_labels, predicted_labels)
    precision = precision_score(true_labels, predicted_labels, average='weighted')
    recall = recall_score(true_labels, predicted_labels, average='weighted')
    f1 = f1_score(true_labels, predicted_labels, average='weighted')

    logger.info(f"network accuracy: {balanced_accuracy * 100:.2f}%")
    logger.info(f"network precision: {precision * 100:.2f}%")
    logger.info(f"network recall: {recall * 100:.2f}%")
    logger.info(f"network F1 score: {f1 * 100:.2f}%")

    logger.info(f"Detailed Classification Report: \n{classification_report(true_labels, predicted_labels)}")

    # Log test metrics
    wandb.log({
        "test_accuracy": balanced_accuracy,
        "test_precision": precision,
        "test_recall": recall,
        "test_f1": f1
    })

    wandb.finish()

In [12]:
init_wandb()
train_model(net0, optimizer, loss_fn, enable_early_stopping=False)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-13 15:12:13,034 - INFO - Epoch 1/70, Train Loss: 2.8289


Training model:   1%|▏         | 1/70 [00:00<00:13,  5.28it/s]

2025-12-13 15:12:13,177 - INFO - Epoch 2/70, Train Loss: 1.0543


Training model:   3%|▎         | 2/70 [00:00<00:11,  6.17it/s]

2025-12-13 15:12:13,324 - INFO - Epoch 3/70, Train Loss: 0.9901


Training model:   4%|▍         | 3/70 [00:00<00:10,  6.45it/s]

2025-12-13 15:12:13,445 - INFO - Epoch 4/70, Train Loss: 0.9886


Training model:   6%|▌         | 4/70 [00:00<00:09,  7.07it/s]

2025-12-13 15:12:13,521 - INFO - Epoch 5/70, Train Loss: 1.0537
2025-12-13 15:12:13,603 - INFO - Epoch 6/70, Train Loss: 1.0069


Training model:   9%|▊         | 6/70 [00:00<00:06,  9.22it/s]

2025-12-13 15:12:13,685 - INFO - Epoch 7/70, Train Loss: 0.9860
2025-12-13 15:12:13,770 - INFO - Epoch 8/70, Train Loss: 0.9730


Training model:  11%|█▏        | 8/70 [00:00<00:06, 10.24it/s]

2025-12-13 15:12:13,856 - INFO - Epoch 9/70, Train Loss: 0.9650
2025-12-13 15:12:13,940 - INFO - Epoch 10/70, Train Loss: 0.9724


Training model:  14%|█▍        | 10/70 [00:01<00:05, 10.78it/s]

2025-12-13 15:12:14,025 - INFO - Epoch 11/70, Train Loss: 0.9679
2025-12-13 15:12:14,104 - INFO - Epoch 12/70, Train Loss: 0.9688


Training model:  17%|█▋        | 12/70 [00:01<00:05, 11.24it/s]

2025-12-13 15:12:14,191 - INFO - Epoch 13/70, Train Loss: 0.9848
2025-12-13 15:12:14,277 - INFO - Epoch 14/70, Train Loss: 0.9675


Training model:  20%|██        | 14/70 [00:01<00:04, 11.35it/s]

2025-12-13 15:12:14,359 - INFO - Epoch 15/70, Train Loss: 0.9952
2025-12-13 15:12:14,445 - INFO - Epoch 16/70, Train Loss: 0.9745


Training model:  23%|██▎       | 16/70 [00:01<00:04, 11.51it/s]

2025-12-13 15:12:14,530 - INFO - Epoch 17/70, Train Loss: 0.9400
2025-12-13 15:12:14,611 - INFO - Epoch 18/70, Train Loss: 0.9567


Training model:  26%|██▌       | 18/70 [00:01<00:04, 11.68it/s]

2025-12-13 15:12:14,694 - INFO - Epoch 19/70, Train Loss: 0.9402
2025-12-13 15:12:14,771 - INFO - Epoch 20/70, Train Loss: 0.9566


Training model:  29%|██▊       | 20/70 [00:01<00:04, 11.93it/s]

2025-12-13 15:12:14,855 - INFO - Epoch 21/70, Train Loss: 0.9357
2025-12-13 15:12:14,934 - INFO - Epoch 22/70, Train Loss: 0.9241


Training model:  31%|███▏      | 22/70 [00:02<00:03, 12.03it/s]

2025-12-13 15:12:15,010 - INFO - Epoch 23/70, Train Loss: 0.9783
2025-12-13 15:12:15,093 - INFO - Epoch 24/70, Train Loss: 0.9329


Training model:  34%|███▍      | 24/70 [00:02<00:03, 12.19it/s]

2025-12-13 15:12:15,176 - INFO - Epoch 25/70, Train Loss: 0.9299
2025-12-13 15:12:15,261 - INFO - Epoch 26/70, Train Loss: 0.9176


Training model:  37%|███▋      | 26/70 [00:02<00:03, 12.09it/s]

2025-12-13 15:12:15,350 - INFO - Epoch 27/70, Train Loss: 0.8957
2025-12-13 15:12:15,431 - INFO - Epoch 28/70, Train Loss: 0.8998


Training model:  40%|████      | 28/70 [00:02<00:03, 12.00it/s]

2025-12-13 15:12:15,512 - INFO - Epoch 29/70, Train Loss: 0.8832
2025-12-13 15:12:15,597 - INFO - Epoch 30/70, Train Loss: 0.8854


Training model:  43%|████▎     | 30/70 [00:02<00:03, 12.02it/s]

2025-12-13 15:12:15,679 - INFO - Epoch 31/70, Train Loss: 0.8916
2025-12-13 15:12:15,760 - INFO - Epoch 32/70, Train Loss: 0.8682


Training model:  46%|████▌     | 32/70 [00:02<00:03, 12.08it/s]

2025-12-13 15:12:15,847 - INFO - Epoch 33/70, Train Loss: 0.8547
2025-12-13 15:12:15,928 - INFO - Epoch 34/70, Train Loss: 0.8549


Training model:  49%|████▊     | 34/70 [00:03<00:02, 12.03it/s]

2025-12-13 15:12:16,004 - INFO - Epoch 35/70, Train Loss: 0.8424
2025-12-13 15:12:16,091 - INFO - Epoch 36/70, Train Loss: 0.9325


Training model:  51%|█████▏    | 36/70 [00:03<00:02, 12.11it/s]

2025-12-13 15:12:16,180 - INFO - Epoch 37/70, Train Loss: 0.9609
2025-12-13 15:12:16,261 - INFO - Epoch 38/70, Train Loss: 0.8654


Training model:  54%|█████▍    | 38/70 [00:03<00:02, 12.00it/s]

2025-12-13 15:12:16,348 - INFO - Epoch 39/70, Train Loss: 0.8585
2025-12-13 15:12:16,432 - INFO - Epoch 40/70, Train Loss: 0.8291


Training model:  57%|█████▋    | 40/70 [00:03<00:02, 11.90it/s]

2025-12-13 15:12:16,519 - INFO - Epoch 41/70, Train Loss: 0.8535
2025-12-13 15:12:16,600 - INFO - Epoch 42/70, Train Loss: 0.7929


Training model:  60%|██████    | 42/70 [00:03<00:02, 11.91it/s]

2025-12-13 15:12:16,681 - INFO - Epoch 43/70, Train Loss: 0.7591
2025-12-13 15:12:16,761 - INFO - Epoch 44/70, Train Loss: 0.8933


Training model:  63%|██████▎   | 44/70 [00:03<00:02, 12.05it/s]

2025-12-13 15:12:16,840 - INFO - Epoch 45/70, Train Loss: 0.8936
2025-12-13 15:12:16,921 - INFO - Epoch 46/70, Train Loss: 0.8291


Training model:  66%|██████▌   | 46/70 [00:04<00:01, 12.19it/s]

2025-12-13 15:12:16,999 - INFO - Epoch 47/70, Train Loss: 0.7428
2025-12-13 15:12:17,081 - INFO - Epoch 48/70, Train Loss: 0.7234


Training model:  69%|██████▊   | 48/70 [00:04<00:01, 12.29it/s]

2025-12-13 15:12:17,165 - INFO - Epoch 49/70, Train Loss: 0.7270
2025-12-13 15:12:17,249 - INFO - Epoch 50/70, Train Loss: 0.8925


Training model:  71%|███████▏  | 50/70 [00:04<00:01, 12.17it/s]

2025-12-13 15:12:17,335 - INFO - Epoch 51/70, Train Loss: 0.7995
2025-12-13 15:12:17,421 - INFO - Epoch 52/70, Train Loss: 0.8590


Training model:  74%|███████▍  | 52/70 [00:04<00:01, 12.00it/s]

2025-12-13 15:12:17,499 - INFO - Epoch 53/70, Train Loss: 0.7459
2025-12-13 15:12:17,583 - INFO - Epoch 54/70, Train Loss: 0.6258


Training model:  77%|███████▋  | 54/70 [00:04<00:01, 12.09it/s]

2025-12-13 15:12:17,678 - INFO - Epoch 55/70, Train Loss: 0.6511
2025-12-13 15:12:17,774 - INFO - Epoch 56/70, Train Loss: 0.6691


Training model:  80%|████████  | 56/70 [00:04<00:01, 11.55it/s]

2025-12-13 15:12:17,876 - INFO - Epoch 57/70, Train Loss: 0.6153
2025-12-13 15:12:17,962 - INFO - Epoch 58/70, Train Loss: 0.6537


Training model:  83%|████████▎ | 58/70 [00:05<00:01, 11.26it/s]

2025-12-13 15:12:18,044 - INFO - Epoch 59/70, Train Loss: 0.5918
2025-12-13 15:12:18,127 - INFO - Epoch 60/70, Train Loss: 0.5144


Training model:  86%|████████▌ | 60/70 [00:05<00:00, 11.51it/s]

2025-12-13 15:12:18,214 - INFO - Epoch 61/70, Train Loss: 0.4706
2025-12-13 15:12:18,305 - INFO - Epoch 62/70, Train Loss: 0.5963


Training model:  89%|████████▊ | 62/70 [00:05<00:00, 11.44it/s]

2025-12-13 15:12:18,385 - INFO - Epoch 63/70, Train Loss: 0.5442
2025-12-13 15:12:18,462 - INFO - Epoch 64/70, Train Loss: 0.4984


Training model:  91%|█████████▏| 64/70 [00:05<00:00, 11.80it/s]

2025-12-13 15:12:18,543 - INFO - Epoch 65/70, Train Loss: 0.4310
2025-12-13 15:12:18,626 - INFO - Epoch 66/70, Train Loss: 0.5272


Training model:  94%|█████████▍| 66/70 [00:05<00:00, 11.90it/s]

2025-12-13 15:12:18,703 - INFO - Epoch 67/70, Train Loss: 0.7059
2025-12-13 15:12:18,788 - INFO - Epoch 68/70, Train Loss: 0.5550


Training model:  97%|█████████▋| 68/70 [00:05<00:00, 12.04it/s]

2025-12-13 15:12:18,868 - INFO - Epoch 69/70, Train Loss: 0.4919
2025-12-13 15:12:18,953 - INFO - Epoch 70/70, Train Loss: 0.5344


Training model: 100%|██████████| 70/70 [00:06<00:00, 11.46it/s]

2025-12-13 15:12:18,955 - INFO - [2.8288876973092556, 1.0543148964643478, 0.9900975152850151, 0.9886178970336914, 1.0536944270133972, 1.0068634450435638, 0.986005425453186, 0.9730230793356895, 0.9650396108627319, 0.9724284037947655, 0.9679272286593914, 0.9688276834785938, 0.984809335321188, 0.9674697928130627, 0.9952278211712837, 0.9745336249470711, 0.9400264099240303, 0.9566500522196293, 0.9401758387684822, 0.9565533101558685, 0.9356943182647228, 0.9241379201412201, 0.9782601222395897, 0.932905524969101, 0.9299440570175648, 0.9175514318048954, 0.8956962414085865, 0.8998011834919453, 0.8831529654562473, 0.885401364415884, 0.8916499838232994, 0.8682254068553448, 0.8547017276287079, 0.8548958227038383, 0.84241608902812, 0.9325481355190277, 0.9608963802456856, 0.8653992414474487, 0.8585188575088978, 0.8291248418390751, 0.8535263873636723, 0.7928954921662807, 0.7590807257220149, 0.8933034278452396, 0.8935810551047325, 0.8291191644966602, 0.7427658755332232, 0.72342748939991, 0.726977024227




Net0 can learn on one batch of 16 images, but can not learn on all the provided train data, to simple for 200+ images.

In [13]:
net1 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),   # 224x224 -> 112x112   // (3x3x1)x8
    torch.nn.ReLU(),
    torch.nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),  # 112x112 -> 56x56     // (3x3x8)x16
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 56x56 -> 28x28      // (3x3x16)x32
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14      // (3x3x32)x32
    torch.nn.ReLU(),
    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)

net1.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net1.parameters(), lr=0.001)

summary(net1, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [16, 3]                   --
├─Conv2d: 1-1                            [16, 8, 112, 112]         80
├─ReLU: 1-2                              [16, 8, 112, 112]         --
├─Conv2d: 1-3                            [16, 16, 56, 56]          1,168
├─ReLU: 1-4                              [16, 16, 56, 56]          --
├─Conv2d: 1-5                            [16, 32, 28, 28]          4,640
├─ReLU: 1-6                              [16, 32, 28, 28]          --
├─Conv2d: 1-7                            [16, 32, 14, 14]          9,248
├─ReLU: 1-8                              [16, 32, 14, 14]          --
├─AdaptiveAvgPool2d: 1-9                 [16, 32, 1, 1]            --
├─Flatten: 1-10                          [16, 32]                  --
├─Linear: 1-11                           [16, 128]                 4,224
├─ReLU: 1-12                             [16, 128]                 --
├─L

In [14]:
init_wandb()
train_model(net1, optimizer, loss_fn, enable_early_stopping=False)
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
train_loss,█▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁

0,1
epoch,70.0
train_loss,0.53442


Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-13 15:12:22,045 - INFO - Epoch 1/70, Train Loss: 1.0980


Training model:   1%|▏         | 1/70 [00:00<00:14,  4.75it/s]

2025-12-13 15:12:22,222 - INFO - Epoch 2/70, Train Loss: 0.9738


Training model:   3%|▎         | 2/70 [00:00<00:13,  5.23it/s]

2025-12-13 15:12:22,370 - INFO - Epoch 3/70, Train Loss: 0.9853


Training model:   4%|▍         | 3/70 [00:00<00:11,  5.83it/s]

2025-12-13 15:12:22,473 - INFO - Epoch 4/70, Train Loss: 1.0332


Training model:   6%|▌         | 4/70 [00:00<00:09,  6.93it/s]

2025-12-13 15:12:22,556 - INFO - Epoch 5/70, Train Loss: 0.9755
2025-12-13 15:12:22,643 - INFO - Epoch 6/70, Train Loss: 0.9703


Training model:   9%|▊         | 6/70 [00:00<00:07,  8.86it/s]

2025-12-13 15:12:22,734 - INFO - Epoch 7/70, Train Loss: 0.9599
2025-12-13 15:12:22,845 - INFO - Epoch 8/70, Train Loss: 0.9548


Training model:  11%|█▏        | 8/70 [00:01<00:06,  9.29it/s]

2025-12-13 15:12:22,939 - INFO - Epoch 9/70, Train Loss: 1.0313
2025-12-13 15:12:23,027 - INFO - Epoch 10/70, Train Loss: 0.9922


Training model:  14%|█▍        | 10/70 [00:01<00:06,  9.85it/s]

2025-12-13 15:12:23,113 - INFO - Epoch 11/70, Train Loss: 0.9762
2025-12-13 15:12:23,196 - INFO - Epoch 12/70, Train Loss: 0.9628


Training model:  17%|█▋        | 12/70 [00:01<00:05, 10.48it/s]

2025-12-13 15:12:23,286 - INFO - Epoch 13/70, Train Loss: 0.9618
2025-12-13 15:12:23,367 - INFO - Epoch 14/70, Train Loss: 0.9642


Training model:  20%|██        | 14/70 [00:01<00:05, 10.86it/s]

2025-12-13 15:12:23,452 - INFO - Epoch 15/70, Train Loss: 0.9497
2025-12-13 15:12:23,542 - INFO - Epoch 16/70, Train Loss: 0.9707


Training model:  23%|██▎       | 16/70 [00:01<00:04, 11.04it/s]

2025-12-13 15:12:23,631 - INFO - Epoch 17/70, Train Loss: 0.9723
2025-12-13 15:12:23,707 - INFO - Epoch 18/70, Train Loss: 0.9539


Training model:  26%|██▌       | 18/70 [00:01<00:04, 11.35it/s]

2025-12-13 15:12:23,793 - INFO - Epoch 19/70, Train Loss: 0.9779
2025-12-13 15:12:23,890 - INFO - Epoch 20/70, Train Loss: 0.9378


Training model:  29%|██▊       | 20/70 [00:02<00:04, 11.23it/s]

2025-12-13 15:12:24,002 - INFO - Epoch 21/70, Train Loss: 0.9758
2025-12-13 15:12:24,173 - INFO - Epoch 22/70, Train Loss: 0.9603


Training model:  31%|███▏      | 22/70 [00:02<00:05,  9.50it/s]

2025-12-13 15:12:24,262 - INFO - Epoch 23/70, Train Loss: 1.0300
2025-12-13 15:12:24,351 - INFO - Epoch 24/70, Train Loss: 0.9732


Training model:  34%|███▍      | 24/70 [00:02<00:04,  9.97it/s]

2025-12-13 15:12:24,434 - INFO - Epoch 25/70, Train Loss: 0.9551
2025-12-13 15:12:24,518 - INFO - Epoch 26/70, Train Loss: 0.9365


Training model:  37%|███▋      | 26/70 [00:02<00:04, 10.51it/s]

2025-12-13 15:12:24,603 - INFO - Epoch 27/70, Train Loss: 0.9975
2025-12-13 15:12:24,691 - INFO - Epoch 28/70, Train Loss: 0.9454


Training model:  40%|████      | 28/70 [00:02<00:03, 10.80it/s]

2025-12-13 15:12:24,782 - INFO - Epoch 29/70, Train Loss: 0.9199
2025-12-13 15:12:24,876 - INFO - Epoch 30/70, Train Loss: 0.8901


Training model:  43%|████▎     | 30/70 [00:03<00:03, 10.82it/s]

2025-12-13 15:12:24,955 - INFO - Epoch 31/70, Train Loss: 0.9409
2025-12-13 15:12:25,038 - INFO - Epoch 32/70, Train Loss: 0.8677


Training model:  46%|████▌     | 32/70 [00:03<00:03, 11.23it/s]

2025-12-13 15:12:25,126 - INFO - Epoch 33/70, Train Loss: 0.8734
2025-12-13 15:12:25,213 - INFO - Epoch 34/70, Train Loss: 0.9281


Training model:  49%|████▊     | 34/70 [00:03<00:03, 11.28it/s]

2025-12-13 15:12:25,302 - INFO - Epoch 35/70, Train Loss: 0.8849
2025-12-13 15:12:25,394 - INFO - Epoch 36/70, Train Loss: 0.8405


Training model:  51%|█████▏    | 36/70 [00:03<00:03, 11.22it/s]

2025-12-13 15:12:25,487 - INFO - Epoch 37/70, Train Loss: 0.7894
2025-12-13 15:12:25,577 - INFO - Epoch 38/70, Train Loss: 0.7945


Training model:  54%|█████▍    | 38/70 [00:03<00:02, 11.12it/s]

2025-12-13 15:12:25,670 - INFO - Epoch 39/70, Train Loss: 0.7900
2025-12-13 15:12:25,755 - INFO - Epoch 40/70, Train Loss: 0.7374


Training model:  57%|█████▋    | 40/70 [00:03<00:02, 11.15it/s]

2025-12-13 15:12:25,847 - INFO - Epoch 41/70, Train Loss: 0.8709
2025-12-13 15:12:25,937 - INFO - Epoch 42/70, Train Loss: 0.8102


Training model:  60%|██████    | 42/70 [00:04<00:02, 11.12it/s]

2025-12-13 15:12:26,019 - INFO - Epoch 43/70, Train Loss: 0.9423
2025-12-13 15:12:26,107 - INFO - Epoch 44/70, Train Loss: 0.8493


Training model:  63%|██████▎   | 44/70 [00:04<00:02, 11.31it/s]

2025-12-13 15:12:26,193 - INFO - Epoch 45/70, Train Loss: 0.7996
2025-12-13 15:12:26,274 - INFO - Epoch 46/70, Train Loss: 0.7295


Training model:  66%|██████▌   | 46/70 [00:04<00:02, 11.49it/s]

2025-12-13 15:12:26,360 - INFO - Epoch 47/70, Train Loss: 0.7054
2025-12-13 15:12:26,451 - INFO - Epoch 48/70, Train Loss: 0.8226


Training model:  69%|██████▊   | 48/70 [00:04<00:01, 11.44it/s]

2025-12-13 15:12:26,537 - INFO - Epoch 49/70, Train Loss: 1.0332
2025-12-13 15:12:26,619 - INFO - Epoch 50/70, Train Loss: 0.7497


Training model:  71%|███████▏  | 50/70 [00:04<00:01, 11.56it/s]

2025-12-13 15:12:26,764 - INFO - Epoch 51/70, Train Loss: 0.7095
2025-12-13 15:12:26,853 - INFO - Epoch 52/70, Train Loss: 0.6863


Training model:  74%|███████▍  | 52/70 [00:05<00:01, 10.46it/s]

2025-12-13 15:12:26,933 - INFO - Epoch 53/70, Train Loss: 0.7186
2025-12-13 15:12:27,023 - INFO - Epoch 54/70, Train Loss: 0.6533


Training model:  77%|███████▋  | 54/70 [00:05<00:01, 10.82it/s]

2025-12-13 15:12:27,109 - INFO - Epoch 55/70, Train Loss: 0.6804
2025-12-13 15:12:27,190 - INFO - Epoch 56/70, Train Loss: 0.5673


Training model:  80%|████████  | 56/70 [00:05<00:01, 11.14it/s]

2025-12-13 15:12:27,273 - INFO - Epoch 57/70, Train Loss: 0.6212
2025-12-13 15:12:27,353 - INFO - Epoch 58/70, Train Loss: 0.5604


Training model:  83%|████████▎ | 58/70 [00:05<00:01, 11.47it/s]

2025-12-13 15:12:27,443 - INFO - Epoch 59/70, Train Loss: 0.5434
2025-12-13 15:12:27,527 - INFO - Epoch 60/70, Train Loss: 0.5722


Training model:  86%|████████▌ | 60/70 [00:05<00:00, 11.46it/s]

2025-12-13 15:12:27,612 - INFO - Epoch 61/70, Train Loss: 0.6138
2025-12-13 15:12:27,700 - INFO - Epoch 62/70, Train Loss: 0.6788


Training model:  89%|████████▊ | 62/70 [00:05<00:00, 11.49it/s]

2025-12-13 15:12:27,797 - INFO - Epoch 63/70, Train Loss: 0.6225
2025-12-13 15:12:27,885 - INFO - Epoch 64/70, Train Loss: 0.5496


Training model:  91%|█████████▏| 64/70 [00:06<00:00, 11.28it/s]

2025-12-13 15:12:27,983 - INFO - Epoch 65/70, Train Loss: 0.4817
2025-12-13 15:12:28,080 - INFO - Epoch 66/70, Train Loss: 0.5147


Training model:  94%|█████████▍| 66/70 [00:06<00:00, 10.96it/s]

2025-12-13 15:12:28,178 - INFO - Epoch 67/70, Train Loss: 0.5253
2025-12-13 15:12:28,275 - INFO - Epoch 68/70, Train Loss: 0.4921


Training model:  97%|█████████▋| 68/70 [00:06<00:00, 10.73it/s]

2025-12-13 15:12:28,367 - INFO - Epoch 69/70, Train Loss: 0.4723
2025-12-13 15:12:28,453 - INFO - Epoch 70/70, Train Loss: 0.4469


Training model: 100%|██████████| 70/70 [00:06<00:00, 10.58it/s]

2025-12-13 15:12:28,455 - INFO - [1.0980132706463337, 0.9737811982631683, 0.985345970839262, 1.033153835684061, 0.975472379475832, 0.9702713824808598, 0.9599343352019787, 0.9548133164644241, 1.0313443951308727, 0.9922207295894623, 0.9761888794600964, 0.9627539999783039, 0.9618349336087704, 0.9642414040863514, 0.9496512375771999, 0.970670823007822, 0.9722770005464554, 0.9539233110845089, 0.9779463186860085, 0.937753651291132, 0.9758425801992416, 0.9602543003857136, 1.0300241522490978, 0.973215576261282, 0.9550528340041637, 0.9365019574761391, 0.997471671551466, 0.9454287067055702, 0.9199320301413536, 0.8900775350630283, 0.9409403055906296, 0.8677223846316338, 0.8734305314719677, 0.9281336925923824, 0.8849347159266472, 0.8405459448695183, 0.7894293991848826, 0.7944715432822704, 0.7900476083159447, 0.7374011885840446, 0.8708726316690445, 0.8102416172623634, 0.9423232525587082, 0.8492570146918297, 0.7995656691491604, 0.729454031214118, 0.7054288741201162, 0.8225831836462021, 1.033160939812




0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
train_loss,▇█▇▇▇██▇▇▇▇▇▇▇▇▇▇█▆▇▆▇▆▆▅▇▆▅▄▅▄▃▂▃▂▄▃▁▂▁

0,1
epoch,70.0
train_loss,0.44686


Net1 could learn on all the provided train data, but it took the model over 50 epoch to converge below 0.5 loss on train set. So I think I will create 1 more little bit more complex network.

In [15]:
net2 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),   # 224x224 -> 112x112   // (3x3x1)x8
    torch.nn.ReLU(),
    torch.nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),  # 112x112 -> 56x56     // (3x3x8)x16
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # 56x56 -> 28x28      // (3x3x16)x32
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14      // (3x3x32)x64
    torch.nn.ReLU(),
    torch.nn.Conv2d(64, 32, kernel_size=3, stride=2, padding=1),  # 14x14 -> 7x7        // (3x3x64)x32
    torch.nn.ReLU(),
    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)

net2.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net2.parameters(), lr=0.001)

summary(net2, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [16, 3]                   --
├─Conv2d: 1-1                            [16, 8, 112, 112]         80
├─ReLU: 1-2                              [16, 8, 112, 112]         --
├─Conv2d: 1-3                            [16, 16, 56, 56]          1,168
├─ReLU: 1-4                              [16, 16, 56, 56]          --
├─Conv2d: 1-5                            [16, 32, 28, 28]          4,640
├─ReLU: 1-6                              [16, 32, 28, 28]          --
├─Conv2d: 1-7                            [16, 64, 14, 14]          18,496
├─ReLU: 1-8                              [16, 64, 14, 14]          --
├─Conv2d: 1-9                            [16, 32, 7, 7]            18,464
├─ReLU: 1-10                             [16, 32, 7, 7]            --
├─AdaptiveAvgPool2d: 1-11                [16, 32, 1, 1]            --
├─Flatten: 1-12                          [16, 32]                  --
├

In [16]:
init_wandb()
train_model(net2, optimizer, loss_fn, enable_early_stopping=False)
evaluate_model(net2)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-13 15:12:31,348 - INFO - Epoch 1/70, Train Loss: 1.1290


Training model:   1%|▏         | 1/70 [00:00<00:12,  5.49it/s]

2025-12-13 15:12:31,536 - INFO - Epoch 2/70, Train Loss: 1.0153


Training model:   3%|▎         | 2/70 [00:00<00:12,  5.39it/s]

2025-12-13 15:12:31,651 - INFO - Epoch 3/70, Train Loss: 1.0552


Training model:   4%|▍         | 3/70 [00:00<00:10,  6.52it/s]

2025-12-13 15:12:31,744 - INFO - Epoch 4/70, Train Loss: 0.9918
2025-12-13 15:12:31,832 - INFO - Epoch 5/70, Train Loss: 0.9569


Training model:   7%|▋         | 5/70 [00:00<00:07,  8.49it/s]

2025-12-13 15:12:31,924 - INFO - Epoch 6/70, Train Loss: 0.9552
2025-12-13 15:12:32,011 - INFO - Epoch 7/70, Train Loss: 0.9981


Training model:  10%|█         | 7/70 [00:00<00:06,  9.51it/s]

2025-12-13 15:12:32,102 - INFO - Epoch 8/70, Train Loss: 0.9633
2025-12-13 15:12:32,185 - INFO - Epoch 9/70, Train Loss: 0.9529


Training model:  13%|█▎        | 9/70 [00:01<00:05, 10.20it/s]

2025-12-13 15:12:32,276 - INFO - Epoch 10/70, Train Loss: 0.9201
2025-12-13 15:12:32,366 - INFO - Epoch 11/70, Train Loss: 1.0935


Training model:  16%|█▌        | 11/70 [00:01<00:05, 10.48it/s]

2025-12-13 15:12:32,455 - INFO - Epoch 12/70, Train Loss: 0.9907
2025-12-13 15:12:32,545 - INFO - Epoch 13/70, Train Loss: 0.9297


Training model:  19%|█▊        | 13/70 [00:01<00:05, 10.72it/s]

2025-12-13 15:12:32,635 - INFO - Epoch 14/70, Train Loss: 0.9483
2025-12-13 15:12:32,733 - INFO - Epoch 15/70, Train Loss: 1.0066


Training model:  21%|██▏       | 15/70 [00:01<00:05, 10.68it/s]

2025-12-13 15:12:32,853 - INFO - Epoch 16/70, Train Loss: 0.9730
2025-12-13 15:12:32,985 - INFO - Epoch 17/70, Train Loss: 0.9264


Training model:  24%|██▍       | 17/70 [00:01<00:05,  9.63it/s]

2025-12-13 15:12:33,074 - INFO - Epoch 18/70, Train Loss: 0.8942
2025-12-13 15:12:33,178 - INFO - Epoch 19/70, Train Loss: 0.8338


Training model:  27%|██▋       | 19/70 [00:02<00:05,  9.84it/s]

2025-12-13 15:12:33,308 - INFO - Epoch 20/70, Train Loss: 0.8458


Training model:  29%|██▊       | 20/70 [00:02<00:05,  9.36it/s]

2025-12-13 15:12:33,447 - INFO - Epoch 21/70, Train Loss: 0.8133


Training model:  30%|███       | 21/70 [00:02<00:05,  8.82it/s]

2025-12-13 15:12:33,590 - INFO - Epoch 22/70, Train Loss: 0.8147


Training model:  31%|███▏      | 22/70 [00:02<00:05,  8.33it/s]

2025-12-13 15:12:33,708 - INFO - Epoch 23/70, Train Loss: 0.8973


Training model:  33%|███▎      | 23/70 [00:02<00:05,  8.35it/s]

2025-12-13 15:12:33,812 - INFO - Epoch 24/70, Train Loss: 0.6855


Training model:  34%|███▍      | 24/70 [00:02<00:05,  8.65it/s]

2025-12-13 15:12:33,916 - INFO - Epoch 25/70, Train Loss: 0.5944


Training model:  36%|███▌      | 25/70 [00:02<00:05,  8.90it/s]

2025-12-13 15:12:34,014 - INFO - Epoch 26/70, Train Loss: 0.6057
2025-12-13 15:12:34,145 - INFO - Epoch 27/70, Train Loss: 0.5092


Training model:  39%|███▊      | 27/70 [00:02<00:04,  8.83it/s]

2025-12-13 15:12:34,287 - INFO - Epoch 28/70, Train Loss: 0.5201


Training model:  40%|████      | 28/70 [00:03<00:05,  8.33it/s]

2025-12-13 15:12:34,402 - INFO - Epoch 29/70, Train Loss: 0.4381


Training model:  41%|████▏     | 29/70 [00:03<00:04,  8.41it/s]

2025-12-13 15:12:34,560 - INFO - Epoch 30/70, Train Loss: 0.3534


Training model:  43%|████▎     | 30/70 [00:03<00:05,  7.73it/s]

2025-12-13 15:12:34,695 - INFO - Epoch 31/70, Train Loss: 0.3639


Training model:  44%|████▍     | 31/70 [00:03<00:05,  7.65it/s]

2025-12-13 15:12:34,812 - INFO - Epoch 32/70, Train Loss: 0.4109


Training model:  46%|████▌     | 32/70 [00:03<00:04,  7.87it/s]

2025-12-13 15:12:34,944 - INFO - Epoch 33/70, Train Loss: 0.3155


Training model:  47%|████▋     | 33/70 [00:03<00:04,  7.80it/s]

2025-12-13 15:12:35,052 - INFO - Epoch 34/70, Train Loss: 0.3173


Training model:  49%|████▊     | 34/70 [00:03<00:04,  8.16it/s]

2025-12-13 15:12:35,166 - INFO - Epoch 35/70, Train Loss: 0.2744


Training model:  50%|█████     | 35/70 [00:03<00:04,  8.34it/s]

2025-12-13 15:12:35,287 - INFO - Epoch 36/70, Train Loss: 0.2056


Training model:  51%|█████▏    | 36/70 [00:04<00:04,  8.29it/s]

2025-12-13 15:12:35,403 - INFO - Epoch 37/70, Train Loss: 0.1840


Training model:  53%|█████▎    | 37/70 [00:04<00:03,  8.41it/s]

2025-12-13 15:12:35,531 - INFO - Epoch 38/70, Train Loss: 0.1624


Training model:  54%|█████▍    | 38/70 [00:04<00:03,  8.20it/s]

2025-12-13 15:12:35,652 - INFO - Epoch 39/70, Train Loss: 0.1559


Training model:  56%|█████▌    | 39/70 [00:04<00:03,  8.24it/s]

2025-12-13 15:12:35,788 - INFO - Epoch 40/70, Train Loss: 0.1429


Training model:  57%|█████▋    | 40/70 [00:04<00:03,  7.95it/s]

2025-12-13 15:12:35,912 - INFO - Epoch 41/70, Train Loss: 0.1633


Training model:  59%|█████▊    | 41/70 [00:04<00:03,  8.00it/s]

2025-12-13 15:12:36,041 - INFO - Epoch 42/70, Train Loss: 0.3789


Training model:  60%|██████    | 42/70 [00:04<00:03,  7.92it/s]

2025-12-13 15:12:36,189 - INFO - Epoch 43/70, Train Loss: 0.7051


Training model:  61%|██████▏   | 43/70 [00:05<00:03,  7.54it/s]

2025-12-13 15:12:36,333 - INFO - Epoch 44/70, Train Loss: 0.5958


Training model:  63%|██████▎   | 44/70 [00:05<00:03,  7.34it/s]

2025-12-13 15:12:36,465 - INFO - Epoch 45/70, Train Loss: 0.4963


Training model:  64%|██████▍   | 45/70 [00:05<00:03,  7.41it/s]

2025-12-13 15:12:36,589 - INFO - Epoch 46/70, Train Loss: 0.3736


Training model:  66%|██████▌   | 46/70 [00:05<00:03,  7.59it/s]

2025-12-13 15:12:36,759 - INFO - Epoch 47/70, Train Loss: 0.2435


Training model:  67%|██████▋   | 47/70 [00:05<00:03,  6.99it/s]

2025-12-13 15:12:36,900 - INFO - Epoch 48/70, Train Loss: 0.2006


Training model:  69%|██████▊   | 48/70 [00:05<00:03,  7.02it/s]

2025-12-13 15:12:37,157 - INFO - Epoch 49/70, Train Loss: 0.1637


Training model:  70%|███████   | 49/70 [00:05<00:03,  5.65it/s]

2025-12-13 15:12:37,380 - INFO - Epoch 50/70, Train Loss: 0.1218


Training model:  71%|███████▏  | 50/70 [00:06<00:03,  5.24it/s]

2025-12-13 15:12:37,524 - INFO - Epoch 51/70, Train Loss: 0.1267


Training model:  73%|███████▎  | 51/70 [00:06<00:03,  5.66it/s]

2025-12-13 15:12:37,652 - INFO - Epoch 52/70, Train Loss: 0.1209


Training model:  74%|███████▍  | 52/70 [00:06<00:02,  6.17it/s]

2025-12-13 15:12:37,885 - INFO - Epoch 53/70, Train Loss: 0.1052


Training model:  76%|███████▌  | 53/70 [00:06<00:03,  5.44it/s]

2025-12-13 15:12:38,066 - INFO - Epoch 54/70, Train Loss: 0.1061


Training model:  77%|███████▋  | 54/70 [00:06<00:02,  5.48it/s]

2025-12-13 15:12:38,189 - INFO - Epoch 55/70, Train Loss: 0.0827


Training model:  79%|███████▊  | 55/70 [00:07<00:02,  6.07it/s]

2025-12-13 15:12:38,328 - INFO - Epoch 56/70, Train Loss: 0.0813


Training model:  80%|████████  | 56/70 [00:07<00:02,  6.37it/s]

2025-12-13 15:12:38,483 - INFO - Epoch 57/70, Train Loss: 0.0814


Training model:  81%|████████▏ | 57/70 [00:07<00:02,  6.40it/s]

2025-12-13 15:12:38,619 - INFO - Epoch 58/70, Train Loss: 0.0748


Training model:  83%|████████▎ | 58/70 [00:07<00:01,  6.66it/s]

2025-12-13 15:12:38,753 - INFO - Epoch 59/70, Train Loss: 0.0709


Training model:  84%|████████▍ | 59/70 [00:07<00:01,  6.88it/s]

2025-12-13 15:12:38,876 - INFO - Epoch 60/70, Train Loss: 0.0663


Training model:  86%|████████▌ | 60/70 [00:07<00:01,  7.20it/s]

2025-12-13 15:12:39,004 - INFO - Epoch 61/70, Train Loss: 0.0892


Training model:  87%|████████▋ | 61/70 [00:07<00:01,  7.39it/s]

2025-12-13 15:12:39,147 - INFO - Epoch 62/70, Train Loss: 0.0817


Training model:  89%|████████▊ | 62/70 [00:07<00:01,  7.25it/s]

2025-12-13 15:12:39,333 - INFO - Epoch 63/70, Train Loss: 0.1133


Training model:  90%|█████████ | 63/70 [00:08<00:01,  6.56it/s]

2025-12-13 15:12:39,504 - INFO - Epoch 64/70, Train Loss: 0.1018


Training model:  91%|█████████▏| 64/70 [00:08<00:00,  6.32it/s]

2025-12-13 15:12:39,770 - INFO - Epoch 65/70, Train Loss: 0.1185


Training model:  93%|█████████▎| 65/70 [00:08<00:00,  5.26it/s]

2025-12-13 15:12:39,930 - INFO - Epoch 66/70, Train Loss: 0.0767


Training model:  94%|█████████▍| 66/70 [00:08<00:00,  5.52it/s]

2025-12-13 15:12:40,084 - INFO - Epoch 67/70, Train Loss: 0.0517


Training model:  96%|█████████▌| 67/70 [00:08<00:00,  5.78it/s]

2025-12-13 15:12:40,227 - INFO - Epoch 68/70, Train Loss: 0.0650


Training model:  97%|█████████▋| 68/70 [00:09<00:00,  6.10it/s]

2025-12-13 15:12:40,351 - INFO - Epoch 69/70, Train Loss: 0.0542


Training model:  99%|█████████▊| 69/70 [00:09<00:00,  6.58it/s]

2025-12-13 15:12:40,480 - INFO - Epoch 70/70, Train Loss: 0.0637


Training model: 100%|██████████| 70/70 [00:09<00:00,  7.52it/s]

2025-12-13 15:12:40,482 - INFO - [1.128953706473112, 1.0152508057653904, 1.0551845915615559, 0.9917939081788063, 0.9568653926253319, 0.9551576264202595, 0.9980925507843494, 0.9632662162184715, 0.9529300108551979, 0.9201431628316641, 1.0935109071433544, 0.9907349273562431, 0.9297497626394033, 0.9483215883374214, 1.0065739825367928, 0.9730098620057106, 0.9263771623373032, 0.8941906467080116, 0.8338040113449097, 0.8458393402397633, 0.8133164457976818, 0.8147167712450027, 0.8972919546067715, 0.685528723988682, 0.5943929916247725, 0.6056549847126007, 0.5091847740113735, 0.5200991872698069, 0.4380995538085699, 0.3533913432620466, 0.36389950942248106, 0.4108554683625698, 0.3155387270380743, 0.3173329159617424, 0.2744336361065507, 0.20559349981340347, 0.1840467118890956, 0.16244903313236136, 0.15588786933221854, 0.1429109300370328, 0.16334831807762384, 0.3789045000448823, 0.705118122510612, 0.5957707650959492, 0.49630474112927914, 0.37360685877501965, 0.243512105778791, 0.2005699034780264, 0.1


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇██
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,▇█▇▇▇▇▇█▇▇▇▆▆▆▇▅▄▃▃▃▂▂▂▂▂▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁

0,1
epoch,70.0
test_accuracy,0.44048
test_f1,0.27106
test_precision,0.76832
test_recall,0.2449
train_loss,0.06374


Net2 seems the best in convergence but it definitely overfits, so the next step is to solve this with net3.

In [17]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0001, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = model.state_dict().copy()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                logger.info(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_model = model.state_dict().copy()
            self.counter = 0


In [18]:
batch_size = 32
net3 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),      # (3x3x1)x8
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 224x224 -> 112x112

    torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),       # (3x3x8)x16
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 112x112 -> 56x56

    torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),       # (3x3x16)x32
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 56x56 -> 28x28 

    torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),       # (3x3x32)x64
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 28x28 -> 14x14 

    torch.nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),       # (3x3x64)x32
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 14x14 -> 7x7

    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)

net3.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net3.parameters(), lr=0.001)

summary(net3, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [32, 3]                   --
├─Conv2d: 1-1                            [32, 8, 224, 224]         80
├─ReLU: 1-2                              [32, 8, 224, 224]         --
├─MaxPool2d: 1-3                         [32, 8, 112, 112]         --
├─Conv2d: 1-4                            [32, 16, 112, 112]        1,168
├─ReLU: 1-5                              [32, 16, 112, 112]        --
├─MaxPool2d: 1-6                         [32, 16, 56, 56]          --
├─Conv2d: 1-7                            [32, 32, 56, 56]          4,640
├─ReLU: 1-8                              [32, 32, 56, 56]          --
├─MaxPool2d: 1-9                         [32, 32, 28, 28]          --
├─Conv2d: 1-10                           [32, 64, 28, 28]          18,496
├─ReLU: 1-11                             [32, 64, 28, 28]          --
├─MaxPool2d: 1-12                        [32, 64, 14, 14]          --
├─Con

In [19]:
init_wandb()
train_model(net3, optimizer, loss_fn, enable_early_stopping=False)
evaluate_model(net3)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-13 15:12:43,726 - INFO - Epoch 1/70, Train Loss: 1.1033


Training model:   1%|▏         | 1/70 [00:00<00:28,  2.43it/s]

2025-12-13 15:12:43,837 - INFO - Epoch 2/70, Train Loss: 1.0451


Training model:   3%|▎         | 2/70 [00:00<00:15,  4.26it/s]

2025-12-13 15:12:43,950 - INFO - Epoch 3/70, Train Loss: 0.9921


Training model:   4%|▍         | 3/70 [00:00<00:12,  5.58it/s]

2025-12-13 15:12:44,061 - INFO - Epoch 4/70, Train Loss: 0.9792


Training model:   6%|▌         | 4/70 [00:00<00:10,  6.57it/s]

2025-12-13 15:12:44,173 - INFO - Epoch 5/70, Train Loss: 0.9773


Training model:   7%|▋         | 5/70 [00:00<00:08,  7.27it/s]

2025-12-13 15:12:44,290 - INFO - Epoch 6/70, Train Loss: 0.9840


Training model:   9%|▊         | 6/70 [00:00<00:08,  7.66it/s]

2025-12-13 15:12:44,402 - INFO - Epoch 7/70, Train Loss: 0.9777


Training model:  10%|█         | 7/70 [00:01<00:07,  8.02it/s]

2025-12-13 15:12:44,514 - INFO - Epoch 8/70, Train Loss: 0.9721


Training model:  11%|█▏        | 8/70 [00:01<00:07,  8.29it/s]

2025-12-13 15:12:44,626 - INFO - Epoch 9/70, Train Loss: 0.9926


Training model:  13%|█▎        | 9/70 [00:01<00:07,  8.47it/s]

2025-12-13 15:12:44,741 - INFO - Epoch 10/70, Train Loss: 0.9922


Training model:  14%|█▍        | 10/70 [00:01<00:07,  8.56it/s]

2025-12-13 15:12:44,849 - INFO - Epoch 11/70, Train Loss: 0.9604


Training model:  16%|█▌        | 11/70 [00:01<00:06,  8.76it/s]

2025-12-13 15:12:44,958 - INFO - Epoch 12/70, Train Loss: 0.9706


Training model:  17%|█▋        | 12/70 [00:01<00:06,  8.86it/s]

2025-12-13 15:12:45,067 - INFO - Epoch 13/70, Train Loss: 0.9449


Training model:  19%|█▊        | 13/70 [00:01<00:06,  8.98it/s]

2025-12-13 15:12:45,178 - INFO - Epoch 14/70, Train Loss: 0.9122


Training model:  20%|██        | 14/70 [00:01<00:06,  9.00it/s]

2025-12-13 15:12:45,289 - INFO - Epoch 15/70, Train Loss: 0.9579


Training model:  21%|██▏       | 15/70 [00:01<00:06,  9.00it/s]

2025-12-13 15:12:45,401 - INFO - Epoch 16/70, Train Loss: 0.9286


Training model:  23%|██▎       | 16/70 [00:02<00:06,  8.95it/s]

2025-12-13 15:12:45,513 - INFO - Epoch 17/70, Train Loss: 0.9558


Training model:  24%|██▍       | 17/70 [00:02<00:05,  8.96it/s]

2025-12-13 15:12:45,622 - INFO - Epoch 18/70, Train Loss: 0.9683


Training model:  26%|██▌       | 18/70 [00:02<00:05,  9.01it/s]

2025-12-13 15:12:45,741 - INFO - Epoch 19/70, Train Loss: 0.9100


Training model:  27%|██▋       | 19/70 [00:02<00:05,  8.82it/s]

2025-12-13 15:12:45,849 - INFO - Epoch 20/70, Train Loss: 0.9370


Training model:  29%|██▊       | 20/70 [00:02<00:05,  8.95it/s]

2025-12-13 15:12:45,961 - INFO - Epoch 21/70, Train Loss: 0.8583


Training model:  30%|███       | 21/70 [00:02<00:05,  8.95it/s]

2025-12-13 15:12:46,071 - INFO - Epoch 22/70, Train Loss: 0.8821


Training model:  31%|███▏      | 22/70 [00:02<00:05,  8.97it/s]

2025-12-13 15:12:46,181 - INFO - Epoch 23/70, Train Loss: 0.8452


Training model:  33%|███▎      | 23/70 [00:02<00:05,  9.02it/s]

2025-12-13 15:12:46,291 - INFO - Epoch 24/70, Train Loss: 0.9404


Training model:  34%|███▍      | 24/70 [00:02<00:05,  9.05it/s]

2025-12-13 15:12:46,397 - INFO - Epoch 25/70, Train Loss: 0.9145


Training model:  36%|███▌      | 25/70 [00:03<00:04,  9.15it/s]

2025-12-13 15:12:46,504 - INFO - Epoch 26/70, Train Loss: 0.8262


Training model:  37%|███▋      | 26/70 [00:03<00:04,  9.20it/s]

2025-12-13 15:12:46,619 - INFO - Epoch 27/70, Train Loss: 0.8251


Training model:  39%|███▊      | 27/70 [00:03<00:04,  9.06it/s]

2025-12-13 15:12:46,728 - INFO - Epoch 28/70, Train Loss: 0.8182


Training model:  40%|████      | 28/70 [00:03<00:04,  9.08it/s]

2025-12-13 15:12:46,840 - INFO - Epoch 29/70, Train Loss: 0.8135


Training model:  41%|████▏     | 29/70 [00:03<00:04,  9.03it/s]

2025-12-13 15:12:46,951 - INFO - Epoch 30/70, Train Loss: 0.8333


Training model:  43%|████▎     | 30/70 [00:03<00:04,  9.04it/s]

2025-12-13 15:12:47,062 - INFO - Epoch 31/70, Train Loss: 0.7682


Training model:  44%|████▍     | 31/70 [00:03<00:04,  9.04it/s]

2025-12-13 15:12:47,171 - INFO - Epoch 32/70, Train Loss: 0.8817


Training model:  46%|████▌     | 32/70 [00:03<00:04,  9.07it/s]

2025-12-13 15:12:47,281 - INFO - Epoch 33/70, Train Loss: 0.7813


Training model:  47%|████▋     | 33/70 [00:03<00:04,  9.07it/s]

2025-12-13 15:12:47,393 - INFO - Epoch 34/70, Train Loss: 0.7364


Training model:  49%|████▊     | 34/70 [00:04<00:03,  9.04it/s]

2025-12-13 15:12:47,506 - INFO - Epoch 35/70, Train Loss: 0.7548


Training model:  50%|█████     | 35/70 [00:04<00:03,  8.98it/s]

2025-12-13 15:12:47,617 - INFO - Epoch 36/70, Train Loss: 0.6827


Training model:  51%|█████▏    | 36/70 [00:04<00:03,  8.99it/s]

2025-12-13 15:12:47,728 - INFO - Epoch 37/70, Train Loss: 0.6428


Training model:  53%|█████▎    | 37/70 [00:04<00:03,  9.00it/s]

2025-12-13 15:12:47,838 - INFO - Epoch 38/70, Train Loss: 0.7205


Training model:  54%|█████▍    | 38/70 [00:04<00:03,  9.01it/s]

2025-12-13 15:12:47,948 - INFO - Epoch 39/70, Train Loss: 0.6231


Training model:  56%|█████▌    | 39/70 [00:04<00:03,  9.03it/s]

2025-12-13 15:12:48,061 - INFO - Epoch 40/70, Train Loss: 0.6497


Training model:  57%|█████▋    | 40/70 [00:04<00:03,  8.99it/s]

2025-12-13 15:12:48,179 - INFO - Epoch 41/70, Train Loss: 0.7756


Training model:  59%|█████▊    | 41/70 [00:04<00:03,  8.80it/s]

2025-12-13 15:12:48,318 - INFO - Epoch 42/70, Train Loss: 0.6067


Training model:  60%|██████    | 42/70 [00:05<00:03,  8.28it/s]

2025-12-13 15:12:48,430 - INFO - Epoch 43/70, Train Loss: 0.6977


Training model:  61%|██████▏   | 43/70 [00:05<00:03,  8.45it/s]

2025-12-13 15:12:48,545 - INFO - Epoch 44/70, Train Loss: 0.5209


Training model:  63%|██████▎   | 44/70 [00:05<00:03,  8.52it/s]

2025-12-13 15:12:48,656 - INFO - Epoch 45/70, Train Loss: 0.5252


Training model:  64%|██████▍   | 45/70 [00:05<00:02,  8.66it/s]

2025-12-13 15:12:48,764 - INFO - Epoch 46/70, Train Loss: 0.6361


Training model:  66%|██████▌   | 46/70 [00:05<00:02,  8.83it/s]

2025-12-13 15:12:48,885 - INFO - Epoch 47/70, Train Loss: 0.4644


Training model:  67%|██████▋   | 47/70 [00:05<00:02,  8.66it/s]

2025-12-13 15:12:49,002 - INFO - Epoch 48/70, Train Loss: 0.4728


Training model:  69%|██████▊   | 48/70 [00:05<00:02,  8.62it/s]

2025-12-13 15:12:49,124 - INFO - Epoch 49/70, Train Loss: 0.4582


Training model:  70%|███████   | 49/70 [00:05<00:02,  8.50it/s]

2025-12-13 15:12:49,231 - INFO - Epoch 50/70, Train Loss: 0.5918


Training model:  71%|███████▏  | 50/70 [00:05<00:02,  8.73it/s]

2025-12-13 15:12:49,342 - INFO - Epoch 51/70, Train Loss: 0.4538


Training model:  73%|███████▎  | 51/70 [00:06<00:02,  8.81it/s]

2025-12-13 15:12:49,452 - INFO - Epoch 52/70, Train Loss: 0.3979


Training model:  74%|███████▍  | 52/70 [00:06<00:02,  8.89it/s]

2025-12-13 15:12:49,562 - INFO - Epoch 53/70, Train Loss: 0.3400


Training model:  76%|███████▌  | 53/70 [00:06<00:01,  8.96it/s]

2025-12-13 15:12:49,669 - INFO - Epoch 54/70, Train Loss: 0.2828


Training model:  77%|███████▋  | 54/70 [00:06<00:01,  9.08it/s]

2025-12-13 15:12:49,776 - INFO - Epoch 55/70, Train Loss: 0.3595


Training model:  79%|███████▊  | 55/70 [00:06<00:01,  9.15it/s]

2025-12-13 15:12:49,883 - INFO - Epoch 56/70, Train Loss: 0.2601


Training model:  80%|████████  | 56/70 [00:06<00:01,  9.18it/s]

2025-12-13 15:12:49,995 - INFO - Epoch 57/70, Train Loss: 0.2430


Training model:  81%|████████▏ | 57/70 [00:06<00:01,  9.12it/s]

2025-12-13 15:12:50,108 - INFO - Epoch 58/70, Train Loss: 0.2007


Training model:  83%|████████▎ | 58/70 [00:06<00:01,  9.05it/s]

2025-12-13 15:12:50,216 - INFO - Epoch 59/70, Train Loss: 0.2175


Training model:  84%|████████▍ | 59/70 [00:06<00:01,  9.10it/s]

2025-12-13 15:12:50,326 - INFO - Epoch 60/70, Train Loss: 0.2875


Training model:  86%|████████▌ | 60/70 [00:07<00:01,  9.11it/s]

2025-12-13 15:12:50,437 - INFO - Epoch 61/70, Train Loss: 0.1857


Training model:  87%|████████▋ | 61/70 [00:07<00:00,  9.06it/s]

2025-12-13 15:12:50,548 - INFO - Epoch 62/70, Train Loss: 0.1539


Training model:  89%|████████▊ | 62/70 [00:07<00:00,  9.06it/s]

2025-12-13 15:12:50,657 - INFO - Epoch 63/70, Train Loss: 0.2187


Training model:  90%|█████████ | 63/70 [00:07<00:00,  9.09it/s]

2025-12-13 15:12:50,764 - INFO - Epoch 64/70, Train Loss: 0.2484


Training model:  91%|█████████▏| 64/70 [00:07<00:00,  9.17it/s]

2025-12-13 15:12:50,877 - INFO - Epoch 65/70, Train Loss: 0.1611


Training model:  93%|█████████▎| 65/70 [00:07<00:00,  9.08it/s]

2025-12-13 15:12:50,987 - INFO - Epoch 66/70, Train Loss: 0.3536


Training model:  94%|█████████▍| 66/70 [00:07<00:00,  9.08it/s]

2025-12-13 15:12:51,095 - INFO - Epoch 67/70, Train Loss: 0.5810


Training model:  96%|█████████▌| 67/70 [00:07<00:00,  9.13it/s]

2025-12-13 15:12:51,202 - INFO - Epoch 68/70, Train Loss: 0.3388


Training model:  97%|█████████▋| 68/70 [00:07<00:00,  9.18it/s]

2025-12-13 15:12:51,312 - INFO - Epoch 69/70, Train Loss: 0.2507


Training model:  99%|█████████▊| 69/70 [00:07<00:00,  9.15it/s]

2025-12-13 15:12:51,422 - INFO - Epoch 70/70, Train Loss: 0.2437


Training model: 100%|██████████| 70/70 [00:08<00:00,  8.63it/s]

2025-12-13 15:12:51,424 - INFO - [1.1033370196819305, 1.0450826063752174, 0.9920924566686153, 0.9791661575436592, 0.977278720587492, 0.9839993752539158, 0.9776948690414429, 0.9720984809100628, 0.9925960823893547, 0.9922092147171497, 0.960443951189518, 0.9706465303897858, 0.9449320510029793, 0.9122158885002136, 0.9578891545534134, 0.9285791404545307, 0.9557551965117455, 0.9683436900377274, 0.9100451059639454, 0.9369592182338238, 0.8582910466939211, 0.8820824734866619, 0.8452195972204208, 0.9403965435922146, 0.9145348407328129, 0.8261917866766453, 0.8251167125999928, 0.8181727007031441, 0.8134674727916718, 0.833258930593729, 0.7682208716869354, 0.8816547282040119, 0.7813036227598786, 0.736373906955123, 0.754819568246603, 0.6826564967632294, 0.6428337823599577, 0.720543198287487, 0.6230727918446064, 0.6497019156813622, 0.77556411921978, 0.6067338909488171, 0.6976957954466343, 0.5209095953032374, 0.5252386219799519, 0.6361428480231552, 0.4644148927181959, 0.47276318073272705, 0.45817940682


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,██▇▇███▇▇▇▇▇▇▆▇▆▆▆▆▆▅▅▅▆▅▅▃▃▄▃▂▃▂▁▁▂▁▃▂▂

0,1
epoch,70.0
test_accuracy,0.53571
test_f1,0.47665
test_precision,0.78181
test_recall,0.40816
train_loss,0.24366


Net3 intorduced maxpool layers after each conv layer and acheved better generalisation, more over batch size could be inreased to 32, and still perfomed godd on it unlike previous networks.

In [20]:
# Use train test split, for getting validation metrics during training
x_train_tensor, x_val_tensor, y_train_tensor, y_val_tensor = train_test_split(
    x_train_tensor, y_train_tensor, test_size=0.2, random_state=42, stratify=y_train_tensor)

val_dataset = TensorDataset(x_val_tensor, y_val_tensor)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [21]:
def init_weights(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m, torch.nn.BatchNorm2d):
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)

In [29]:
net4 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),      # (3x3x1)x8
    torch.nn.BatchNorm2d(8),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 224x224 -> 112x112

    torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),       # (3x3x8)x16
    torch.nn.BatchNorm2d(16),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 112x112 -> 56x56

    torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),       # (3x3x16)x32
    torch.nn.BatchNorm2d(32),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 56x56 -> 28x28 

    torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),       # (3x3x32)x64
    torch.nn.BatchNorm2d(64),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 28x28 -> 14x14 

    torch.nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),       # (3x3x64)x32
    torch.nn.BatchNorm2d(32),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 14x14 -> 7x7

    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)

net4.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net4.parameters(), lr=0.0005)

summary(net4, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [32, 3]                   --
├─Conv2d: 1-1                            [32, 8, 224, 224]         80
├─BatchNorm2d: 1-2                       [32, 8, 224, 224]         16
├─ReLU: 1-3                              [32, 8, 224, 224]         --
├─MaxPool2d: 1-4                         [32, 8, 112, 112]         --
├─Conv2d: 1-5                            [32, 16, 112, 112]        1,168
├─BatchNorm2d: 1-6                       [32, 16, 112, 112]        32
├─ReLU: 1-7                              [32, 16, 112, 112]        --
├─MaxPool2d: 1-8                         [32, 16, 56, 56]          --
├─Conv2d: 1-9                            [32, 32, 56, 56]          4,640
├─BatchNorm2d: 1-10                      [32, 32, 56, 56]          64
├─ReLU: 1-11                             [32, 32, 56, 56]          --
├─MaxPool2d: 1-12                        [32, 32, 28, 28]          --
├─Conv2d:

In [30]:
init_wandb()
train_model(net4, optimizer, loss_fn, enable_early_stopping=True, patience=5)
evaluate_model(net4)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-13 15:16:31,650 - INFO - Epoch 1/70, Train Loss: 1.1161, Val Loss: 1.0317, Val Acc: 0.3469


Training model:   1%|▏         | 1/70 [00:00<00:22,  3.03it/s]

2025-12-13 15:16:31,780 - INFO - Epoch 2/70, Train Loss: 0.9851, Val Loss: 1.0954, Val Acc: 0.3469
2025-12-13 15:16:31,781 - INFO - EarlyStopping counter: 1 out of 5


Training model:   3%|▎         | 2/70 [00:00<00:14,  4.71it/s]

2025-12-13 15:16:31,907 - INFO - Epoch 3/70, Train Loss: 0.9209, Val Loss: 1.0621, Val Acc: 0.3265
2025-12-13 15:16:31,908 - INFO - EarlyStopping counter: 2 out of 5


Training model:   4%|▍         | 3/70 [00:00<00:11,  5.77it/s]

2025-12-13 15:16:32,036 - INFO - Epoch 4/70, Train Loss: 0.8793, Val Loss: 1.0329, Val Acc: 0.5306
2025-12-13 15:16:32,037 - INFO - EarlyStopping counter: 3 out of 5


Training model:   6%|▌         | 4/70 [00:00<00:10,  6.40it/s]

2025-12-13 15:16:32,166 - INFO - Epoch 5/70, Train Loss: 0.8427, Val Loss: 1.0189, Val Acc: 0.5306


Training model:   7%|▋         | 5/70 [00:00<00:09,  6.83it/s]

2025-12-13 15:16:32,284 - INFO - Epoch 6/70, Train Loss: 0.8192, Val Loss: 1.0201, Val Acc: 0.4898
2025-12-13 15:16:32,285 - INFO - EarlyStopping counter: 1 out of 5


Training model:   9%|▊         | 6/70 [00:00<00:08,  7.32it/s]

2025-12-13 15:16:32,406 - INFO - Epoch 7/70, Train Loss: 0.7749, Val Loss: 1.0072, Val Acc: 0.5306


Training model:  10%|█         | 7/70 [00:01<00:08,  7.59it/s]

2025-12-13 15:16:32,528 - INFO - Epoch 8/70, Train Loss: 0.7634, Val Loss: 1.0014, Val Acc: 0.5306


Training model:  11%|█▏        | 8/70 [00:01<00:08,  7.75it/s]

2025-12-13 15:16:32,657 - INFO - Epoch 9/70, Train Loss: 0.7345, Val Loss: 1.0214, Val Acc: 0.4898
2025-12-13 15:16:32,657 - INFO - EarlyStopping counter: 1 out of 5


Training model:  13%|█▎        | 9/70 [00:01<00:07,  7.77it/s]

2025-12-13 15:16:32,774 - INFO - Epoch 10/70, Train Loss: 0.7110, Val Loss: 1.0280, Val Acc: 0.5510
2025-12-13 15:16:32,775 - INFO - EarlyStopping counter: 2 out of 5


Training model:  14%|█▍        | 10/70 [00:01<00:07,  7.99it/s]

2025-12-13 15:16:32,894 - INFO - Epoch 11/70, Train Loss: 0.6820, Val Loss: 1.0217, Val Acc: 0.5918
2025-12-13 15:16:32,895 - INFO - EarlyStopping counter: 3 out of 5


Training model:  16%|█▌        | 11/70 [00:01<00:07,  8.09it/s]

2025-12-13 15:16:33,068 - INFO - Epoch 12/70, Train Loss: 0.6592, Val Loss: 1.0124, Val Acc: 0.5102
2025-12-13 15:16:33,068 - INFO - EarlyStopping counter: 4 out of 5


Training model:  17%|█▋        | 12/70 [00:01<00:08,  7.19it/s]

2025-12-13 15:16:33,188 - INFO - Epoch 13/70, Train Loss: 0.6656, Val Loss: 1.0383, Val Acc: 0.5102
2025-12-13 15:16:33,188 - INFO - EarlyStopping counter: 5 out of 5
2025-12-13 15:16:33,189 - INFO - Early stopping triggered


Training model:  17%|█▋        | 12/70 [00:01<00:09,  6.42it/s]

2025-12-13 15:16:33,193 - INFO - Loaded best model weights
2025-12-13 15:16:33,193 - INFO - [1.1161300539970398, 0.985075443983078, 0.9209022025267283, 0.8792893091837565, 0.8427437742551168, 0.8192348182201385, 0.7748598953088125, 0.7633983890215555, 0.7344794869422913, 0.7110288341840109, 0.6819542249043783, 0.6592182715733846, 0.6656493445237478]
2025-12-13 15:16:33,221 - INFO - network accuracy: 58.33%
2025-12-13 15:16:33,222 - INFO - network precision: 89.42%
2025-12-13 15:16:33,222 - INFO - network recall: 28.57%
2025-12-13 15:16:33,223 - INFO - network F1 score: 30.37%
2025-12-13 15:16:33,231 - INFO - Detailed Classification Report: 
              precision    recall  f1-score   support

           0       0.26      1.00      0.41         7
           1       1.00      0.17      0.29        42
           2       0.00      0.00      0.00         0

    accuracy                           0.29        49
   macro avg       0.42      0.39      0.23        49
weighted avg       0.89  


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▂▂▃▃▄▅▅▆▆▇▇█
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,█▆▅▄▄▃▃▃▂▂▁▁▁
val_accuracy,▂▂▁▆▆▅▆▆▅▇█▆▆
val_loss,▃█▆▃▂▂▁▁▂▃▃▂▄

0,1
epoch,13.0
test_accuracy,0.58333
test_f1,0.30372
test_precision,0.89418
test_recall,0.28571
train_loss,0.66565
val_accuracy,0.5102
val_loss,1.03833


Net4 introduced some batch normalization to help with the overfitting problem. The learning rate also got decreased.

In [35]:
net5 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),      # (3x3x1)x8
    torch.nn.BatchNorm2d(8),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 224x224 -> 112x112

    torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),       # (3x3x8)x16
    torch.nn.BatchNorm2d(16),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 112x112 -> 56x56

    torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),       # (3x3x16)x32
    torch.nn.BatchNorm2d(32),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 56x56 -> 28x28 

    torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),       # (3x3x32)x64
    torch.nn.BatchNorm2d(64),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 28x28 -> 14x14 

    torch.nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),       # (3x3x64)x32
    torch.nn.BatchNorm2d(32),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 14x14 -> 7x7

    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 128),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(128, 64),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(64, 3)                       # Output layer     
).to(device)

net5.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net5.parameters(), lr=0.0005)

summary(net5, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [32, 3]                   --
├─Conv2d: 1-1                            [32, 8, 224, 224]         80
├─BatchNorm2d: 1-2                       [32, 8, 224, 224]         16
├─ReLU: 1-3                              [32, 8, 224, 224]         --
├─MaxPool2d: 1-4                         [32, 8, 112, 112]         --
├─Conv2d: 1-5                            [32, 16, 112, 112]        1,168
├─BatchNorm2d: 1-6                       [32, 16, 112, 112]        32
├─ReLU: 1-7                              [32, 16, 112, 112]        --
├─MaxPool2d: 1-8                         [32, 16, 56, 56]          --
├─Conv2d: 1-9                            [32, 32, 56, 56]          4,640
├─BatchNorm2d: 1-10                      [32, 32, 56, 56]          64
├─ReLU: 1-11                             [32, 32, 56, 56]          --
├─MaxPool2d: 1-12                        [32, 32, 28, 28]          --
├─Conv2d:

In [36]:
init_wandb()
train_model(net5, optimizer, loss_fn, enable_early_stopping=True, patience=5)
evaluate_model(net5)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-13 15:57:42,182 - INFO - Epoch 1/70, Train Loss: 1.4819, Val Loss: 1.0912, Val Acc: 0.3878


Training model:   1%|▏         | 1/70 [00:00<00:15,  4.45it/s]

2025-12-13 15:57:42,308 - INFO - Epoch 2/70, Train Loss: 1.1386, Val Loss: 1.0551, Val Acc: 0.4286


Training model:   3%|▎         | 2/70 [00:00<00:11,  5.99it/s]

2025-12-13 15:57:42,489 - INFO - Epoch 3/70, Train Loss: 1.1347, Val Loss: 1.0378, Val Acc: 0.4898


Training model:   4%|▍         | 3/70 [00:00<00:11,  5.75it/s]

2025-12-13 15:57:42,637 - INFO - Epoch 4/70, Train Loss: 1.0606, Val Loss: 1.0251, Val Acc: 0.4490


Training model:   6%|▌         | 4/70 [00:00<00:10,  6.13it/s]

2025-12-13 15:57:42,767 - INFO - Epoch 5/70, Train Loss: 0.9933, Val Loss: 1.0068, Val Acc: 0.5918


Training model:   7%|▋         | 5/70 [00:00<00:09,  6.61it/s]

2025-12-13 15:57:42,900 - INFO - Epoch 6/70, Train Loss: 0.9415, Val Loss: 1.0032, Val Acc: 0.5714


Training model:   9%|▊         | 6/70 [00:00<00:09,  6.90it/s]

2025-12-13 15:57:43,024 - INFO - Epoch 7/70, Train Loss: 0.9744, Val Loss: 1.0096, Val Acc: 0.5714
2025-12-13 15:57:43,024 - INFO - EarlyStopping counter: 1 out of 5


Training model:  10%|█         | 7/70 [00:01<00:08,  7.26it/s]

2025-12-13 15:57:43,241 - INFO - Epoch 8/70, Train Loss: 0.9258, Val Loss: 1.0378, Val Acc: 0.5102
2025-12-13 15:57:43,242 - INFO - EarlyStopping counter: 2 out of 5


Training model:  11%|█▏        | 8/70 [00:01<00:10,  6.13it/s]

2025-12-13 15:57:43,364 - INFO - Epoch 9/70, Train Loss: 0.9108, Val Loss: 1.0369, Val Acc: 0.5510
2025-12-13 15:57:43,364 - INFO - EarlyStopping counter: 3 out of 5


Training model:  13%|█▎        | 9/70 [00:01<00:09,  6.63it/s]

2025-12-13 15:57:43,571 - INFO - Epoch 10/70, Train Loss: 0.9153, Val Loss: 1.0458, Val Acc: 0.6327
2025-12-13 15:57:43,572 - INFO - EarlyStopping counter: 4 out of 5


Training model:  14%|█▍        | 10/70 [00:01<00:10,  5.95it/s]

2025-12-13 15:57:43,702 - INFO - Epoch 11/70, Train Loss: 0.8838, Val Loss: 1.0887, Val Acc: 0.5102
2025-12-13 15:57:43,702 - INFO - EarlyStopping counter: 5 out of 5
2025-12-13 15:57:43,702 - INFO - Early stopping triggered


Training model:  14%|█▍        | 10/70 [00:01<00:10,  5.72it/s]

2025-12-13 15:57:43,708 - INFO - Loaded best model weights
2025-12-13 15:57:43,709 - INFO - [1.4818991422653198, 1.1385776698589325, 1.134666085243225, 1.0605972806612651, 0.9933082362016042, 0.9415448606014252, 0.9744154115517935, 0.9258333047231039, 0.910777618487676, 0.9152572055657705, 0.8837806483109792]
2025-12-13 15:57:43,752 - INFO - network accuracy: 51.19%





2025-12-13 15:57:43,753 - INFO - network precision: 87.80%
2025-12-13 15:57:43,754 - INFO - network recall: 16.33%
2025-12-13 15:57:43,756 - INFO - network F1 score: 7.62%
2025-12-13 15:57:43,769 - INFO - Detailed Classification Report: 
              precision    recall  f1-score   support

           0       0.15      1.00      0.25         7
           1       1.00      0.02      0.05        42

    accuracy                           0.16        49
   macro avg       0.57      0.51      0.15        49
weighted avg       0.88      0.16      0.08        49



[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▂▂▃▄▅▅▆▇▇█
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,█▄▄▃▂▂▂▁▁▁▁
val_accuracy,▁▂▄▃▇▆▆▅▆█▅
val_loss,█▅▄▃▁▁▂▄▄▄█

0,1
epoch,11.0
test_accuracy,0.5119
test_f1,0.07623
test_precision,0.87798
test_recall,0.16327
train_loss,0.88378
val_accuracy,0.5102
val_loss,1.08871


Net5 introdiced dropout to the fully connected layers but it did not imrpoved the model's performance compared to the previous net4.
I think I overshot the complexity a little, I will remove some complexity.

In [43]:
# Introduce some transformations to train and val datasets
# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to consistent size
    transforms.RandomHorizontalFlip(p=0.3), # Random horizontal flip
    transforms.RandomRotation(degrees=15),  # Random rotation
    transforms.ToTensor(),           # Convert to tensor [0, 1]
    transforms.Normalize(mean=[0.5], std=[0.5])
])

x_train = []
y_train = []

for img_name, label in train_data:
    img_path = os.path.join(preped_folder, img_name)
    try:
        img = Image.open(img_path).convert('L') # Convert to grayscale
        img_tensor = transform(img)
        x_train.append(img_tensor)
        y_train.append(label)
    except Exception as e:
        logger.info(f"Error loading {img_name}: {e}")

# Stack into tensors
x_train_tensor = torch.stack(x_train)
logger.info(f"Training images shape: {x_train_tensor.shape}")

# Encode labels to integers
label_to_idx = {label: idx for idx, label in enumerate(np.unique(y_train))}
y_train_encoded = [label_to_idx[label] for label in y_train]
y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long)

2025-12-13 16:07:20,459 - INFO - Training images shape: torch.Size([241, 1, 224, 224])


In [44]:
# Use train test split, for getting validation metrics during training
x_train_tensor, x_val_tensor, y_train_tensor, y_val_tensor = train_test_split(
    x_train_tensor, y_train_tensor, test_size=0.2, random_state=42, stratify=y_train_tensor)

val_dataset = TensorDataset(x_val_tensor, y_val_tensor)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [49]:
net6 = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),      # (3x3x1)x8
    torch.nn.BatchNorm2d(8),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 224x224 -> 112x112

    torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),       # (3x3x8)x16
    torch.nn.BatchNorm2d(16),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 112x112 -> 56x56

    torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),       # (3x3x16)x32
    torch.nn.BatchNorm2d(32),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 56x56 -> 28x28 

    torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),       # (3x3x32)x64
    torch.nn.BatchNorm2d(64),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(kernel_size=2, stride=2),    # 28x28 -> 14x14 

    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(64, 128),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(128, 3),
).to(device)

net6.apply(init_weights)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net6.parameters(), lr=0.0005)
summary(net6, input_size=(batch_size, 1, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [32, 3]                   --
├─Conv2d: 1-1                            [32, 8, 224, 224]         80
├─BatchNorm2d: 1-2                       [32, 8, 224, 224]         16
├─ReLU: 1-3                              [32, 8, 224, 224]         --
├─MaxPool2d: 1-4                         [32, 8, 112, 112]         --
├─Conv2d: 1-5                            [32, 16, 112, 112]        1,168
├─BatchNorm2d: 1-6                       [32, 16, 112, 112]        32
├─ReLU: 1-7                              [32, 16, 112, 112]        --
├─MaxPool2d: 1-8                         [32, 16, 56, 56]          --
├─Conv2d: 1-9                            [32, 32, 56, 56]          4,640
├─BatchNorm2d: 1-10                      [32, 32, 56, 56]          64
├─ReLU: 1-11                             [32, 32, 56, 56]          --
├─MaxPool2d: 1-12                        [32, 32, 28, 28]          --
├─Conv2d:

In [50]:
init_wandb()
train_model(net6, optimizer, loss_fn, enable_early_stopping=True, patience=10)
evaluate_model(net6)

Training model:   0%|          | 0/70 [00:00<?, ?it/s]

2025-12-13 16:08:56,735 - INFO - Epoch 1/70, Train Loss: 1.0942, Val Loss: 1.1332, Val Acc: 0.4082


Training model:   1%|▏         | 1/70 [00:00<00:29,  2.37it/s]

2025-12-13 16:08:56,864 - INFO - Epoch 2/70, Train Loss: 0.9878, Val Loss: 1.0450, Val Acc: 0.4286


Training model:   3%|▎         | 2/70 [00:00<00:16,  4.01it/s]

2025-12-13 16:08:56,984 - INFO - Epoch 3/70, Train Loss: 0.9853, Val Loss: 1.0164, Val Acc: 0.4490


Training model:   4%|▍         | 3/70 [00:00<00:12,  5.27it/s]

2025-12-13 16:08:57,097 - INFO - Epoch 4/70, Train Loss: 0.9671, Val Loss: 1.0156, Val Acc: 0.5102


Training model:   6%|▌         | 4/70 [00:00<00:10,  6.25it/s]

2025-12-13 16:08:57,213 - INFO - Epoch 5/70, Train Loss: 0.9939, Val Loss: 1.0174, Val Acc: 0.4898
2025-12-13 16:08:57,213 - INFO - EarlyStopping counter: 1 out of 10


Training model:   7%|▋         | 5/70 [00:00<00:09,  6.94it/s]

2025-12-13 16:08:57,327 - INFO - Epoch 6/70, Train Loss: 1.0014, Val Loss: 1.0161, Val Acc: 0.5510
2025-12-13 16:08:57,327 - INFO - EarlyStopping counter: 2 out of 10


Training model:   9%|▊         | 6/70 [00:01<00:08,  7.48it/s]

2025-12-13 16:08:57,455 - INFO - Epoch 7/70, Train Loss: 0.9679, Val Loss: 1.0098, Val Acc: 0.5714


Training model:  10%|█         | 7/70 [00:01<00:08,  7.57it/s]

2025-12-13 16:08:57,578 - INFO - Epoch 8/70, Train Loss: 0.9739, Val Loss: 1.0081, Val Acc: 0.5102


Training model:  11%|█▏        | 8/70 [00:01<00:07,  7.75it/s]

2025-12-13 16:08:57,705 - INFO - Epoch 9/70, Train Loss: 0.9466, Val Loss: 1.0164, Val Acc: 0.5306
2025-12-13 16:08:57,705 - INFO - EarlyStopping counter: 1 out of 10


Training model:  13%|█▎        | 9/70 [00:01<00:07,  7.81it/s]

2025-12-13 16:08:57,822 - INFO - Epoch 10/70, Train Loss: 0.9444, Val Loss: 1.0193, Val Acc: 0.5306
2025-12-13 16:08:57,822 - INFO - EarlyStopping counter: 2 out of 10


Training model:  14%|█▍        | 10/70 [00:01<00:07,  8.02it/s]

2025-12-13 16:08:57,943 - INFO - Epoch 11/70, Train Loss: 0.9497, Val Loss: 1.0320, Val Acc: 0.5102
2025-12-13 16:08:57,943 - INFO - EarlyStopping counter: 3 out of 10


Training model:  16%|█▌        | 11/70 [00:01<00:07,  8.10it/s]

2025-12-13 16:08:58,068 - INFO - Epoch 12/70, Train Loss: 0.9621, Val Loss: 1.0393, Val Acc: 0.4694
2025-12-13 16:08:58,069 - INFO - EarlyStopping counter: 4 out of 10


Training model:  17%|█▋        | 12/70 [00:01<00:07,  8.05it/s]

2025-12-13 16:08:58,190 - INFO - Epoch 13/70, Train Loss: 0.9119, Val Loss: 1.0414, Val Acc: 0.4694
2025-12-13 16:08:58,191 - INFO - EarlyStopping counter: 5 out of 10


Training model:  19%|█▊        | 13/70 [00:01<00:07,  8.10it/s]

2025-12-13 16:08:58,320 - INFO - Epoch 14/70, Train Loss: 0.9256, Val Loss: 1.0381, Val Acc: 0.4490
2025-12-13 16:08:58,320 - INFO - EarlyStopping counter: 6 out of 10


Training model:  20%|██        | 14/70 [00:02<00:07,  7.98it/s]

2025-12-13 16:08:58,441 - INFO - Epoch 15/70, Train Loss: 0.9336, Val Loss: 1.0230, Val Acc: 0.5102
2025-12-13 16:08:58,441 - INFO - EarlyStopping counter: 7 out of 10


Training model:  21%|██▏       | 15/70 [00:02<00:06,  8.04it/s]

2025-12-13 16:08:58,566 - INFO - Epoch 16/70, Train Loss: 0.9283, Val Loss: 1.0212, Val Acc: 0.5714
2025-12-13 16:08:58,567 - INFO - EarlyStopping counter: 8 out of 10


Training model:  23%|██▎       | 16/70 [00:02<00:06,  8.03it/s]

2025-12-13 16:08:58,688 - INFO - Epoch 17/70, Train Loss: 0.9094, Val Loss: 1.0381, Val Acc: 0.5102
2025-12-13 16:08:58,688 - INFO - EarlyStopping counter: 9 out of 10


Training model:  24%|██▍       | 17/70 [00:02<00:06,  8.11it/s]

2025-12-13 16:08:58,804 - INFO - Epoch 18/70, Train Loss: 0.8929, Val Loss: 1.0336, Val Acc: 0.6122
2025-12-13 16:08:58,804 - INFO - EarlyStopping counter: 10 out of 10
2025-12-13 16:08:58,804 - INFO - Early stopping triggered


Training model:  24%|██▍       | 17/70 [00:02<00:07,  6.82it/s]

2025-12-13 16:08:58,808 - INFO - Loaded best model weights
2025-12-13 16:08:58,808 - INFO - [1.09416929880778, 0.9878494342168173, 0.985325038433075, 0.9670831660429636, 0.993853231271108, 1.0014472007751465, 0.9679023226102194, 0.9739410976568857, 0.9466008941332499, 0.9444488286972046, 0.9497192005316416, 0.9621264139811198, 0.9119105835755666, 0.9256239036719004, 0.933592309554418, 0.9282891054948171, 0.9093623459339142, 0.8928599953651428]
2025-12-13 16:08:58,836 - INFO - network accuracy: 55.95%
2025-12-13 16:08:58,837 - INFO - network precision: 88.49%
2025-12-13 16:08:58,837 - INFO - network recall: 24.49%
2025-12-13 16:08:58,838 - INFO - network F1 score: 22.89%
2025-12-13 16:08:58,850 - INFO - Detailed Classification Report: 
              precision    recall  f1-score   support

           0       0.19      1.00      0.33         7
           1       1.00      0.12      0.21        42
           2       0.00      0.00      0.00         0

    accuracy                         


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▂▂▃▃▃▄▄▅▅▆▆▆▇▇██
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,█▄▄▄▅▅▄▄▃▃▃▃▂▂▂▂▂▁
val_accuracy,▁▂▂▄▄▆▇▄▅▅▄▃▃▂▄▇▄█
val_loss,█▃▁▁▂▁▁▁▁▂▂▃▃▃▂▂▃▂

0,1
epoch,18.0
test_accuracy,0.55952
test_f1,0.22888
test_precision,0.88492
test_recall,0.2449
train_loss,0.89286
val_accuracy,0.61224
val_loss,1.03359
