In [1]:
import numpy as np
import os
from dotenv import load_dotenv
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import TensorDataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

from tqdm.auto import tqdm

import wandb

import logging
import sys

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def setup_logger(name=__name__):
    """
    Sets up a logger that outputs to the console (stdout).
    """
    logger = logging.getLogger(name)
    if not logger.handlers:
        logger.setLevel(logging.INFO)
        handler = logging.StreamHandler(sys.stdout)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger

logger = setup_logger()

In [3]:
# Load the dataset
data_folder = "../data"
preped_folder = os.path.join(data_folder, "_preped")

train_data = pd.read_csv(os.path.join(data_folder, 'train_data.csv')).values.tolist()
test_data = pd.read_csv(os.path.join(data_folder, 'test_data.csv')).values.tolist()

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to consistent size
    transforms.ToTensor(),           # Convert to tensor [0, 1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

x_train = []
y_train = []

for img_name, label in train_data:
    img_path = os.path.join(preped_folder, img_name)
    try:
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img)
        x_train.append(img_tensor)
        y_train.append(label)
    except Exception as e:
        logger.info(f"Error loading {img_name}: {e}")

# Stack into tensors
x_train_tensor = torch.stack(x_train)
logger.info(f"Training images shape: {x_train_tensor.shape}")

# Encode labels to integers
label_to_idx = {label: idx for idx, label in enumerate(np.unique(y_train))}
y_train_encoded = [label_to_idx[label] for label in y_train]
y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long)

logger.info(f"Training labels shape: {y_train_tensor.shape}")
logger.info(f"Label mapping: {label_to_idx}")

2025-12-06 13:00:09,878 - INFO - Training images shape: torch.Size([221, 3, 224, 224])
2025-12-06 13:00:09,879 - INFO - Training labels shape: torch.Size([221])
2025-12-06 13:00:09,880 - INFO - Label mapping: {np.str_('1_Pronacio'): 0, np.str_('2_Neutralis'): 1, np.str_('3_Szupinacio'): 2}


In [4]:
x_test = []
y_test = []

for img_name, label in test_data:
    img_path = os.path.join(preped_folder, img_name)
    try:
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img)
        x_test.append(img_tensor)
        y_test.append(label)
    except Exception as e:
        logger.info(f"Error loading {img_name}: {e}")

x_test_tensor = torch.stack(x_test)
logger.info(f"Test images shape: {x_test_tensor.shape}")
y_test_encoded = [label_to_idx[label] for label in y_test]
y_test_tensor = torch.tensor(y_test_encoded, dtype=torch.long)

logger.info(f"Test labels shape: {y_test_tensor.shape}")

2025-12-06 13:00:17,461 - INFO - Test images shape: torch.Size([49, 3, 224, 224])
2025-12-06 13:00:17,462 - INFO - Test labels shape: torch.Size([49])


In [5]:
if torch.cuda.is_available():
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    logger.info(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        logger.info(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        logger.info(f"  Memory: {props.total_memory / 1024**3:.2f} GB")
        logger.info(f"  Compute Capability: {props.major}.{props.minor}")
else:
    logger.info("CUDA not available")

2025-12-06 13:00:17,504 - INFO - CUDA available: True
2025-12-06 13:00:17,506 - INFO - Number of GPUs: 1
2025-12-06 13:00:17,513 - INFO - 
GPU 0: NVIDIA GeForce RTX 4060
2025-12-06 13:00:17,514 - INFO -   Memory: 8.00 GB
2025-12-06 13:00:17,515 - INFO -   Compute Capability: 8.9


In [6]:
batch_size = 32
num_epochs = 15
lr = 0.001
device = 'cuda' 

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [7]:
# wandb login an init
# Login to wandb with API key
load_dotenv()
wandb.login(key=os.getenv("wandbKey"))

def init_wandb():
    # Initialize wandb project
    wandb.init(
        project="ankle-align",
        config={
            "batch_size": batch_size,
            "num_epochs": num_epochs,
            "learning_rate": lr,
            "architecture": "Custom CNN",
            "dataset": "AnkleAlign",
            "optimizer": "Adam"
        }
    )




[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Win 10\_netrc
[34m[1mwandb[0m: Currently logged in as: [33mbencefarkas[0m ([33mbencefarkas-budapesti-m-szaki-s-gazdas-gtudom-nyi-egyetem[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = model.state_dict().copy()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                logger.info(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_model = model.state_dict().copy()
            self.counter = 0


In [9]:
def train_model(network, optimizer, loss_fn, enable_early_stopping=False):
    torch.cuda.empty_cache()

    loss_values = []

    if enable_early_stopping:
        early_stopping = EarlyStopping(patience=5, verbose=True)

    network.train()
    for epoch in tqdm(range(num_epochs), desc='Training model'):
        network.train()
        epoch_loss = 0.0
        num_batches = 0
        for images, target_labels in train_loader:
            images = images.to(device)
            target_labels = target_labels.to(device)

            pred_logits = network(images)
            loss = loss_fn(pred_logits, target_labels)
            epoch_loss += loss.item()
            num_batches += 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        avg_train_loss = epoch_loss / num_batches

        if enable_early_stopping:
            network.eval()
            val_loss = 0.0
            val_batches = 0
            correct = 0
            total = 0
            with torch.no_grad():
                for images, target_labels in val_loader:
                    images = images.to(device)
                    target_labels = target_labels.to(device)
                    
                    pred_logits = network(images)
                    loss = loss_fn(pred_logits, target_labels)
                    val_loss += loss.item()
                    val_batches += 1
                    
                    _, predicted = torch.max(pred_logits, 1)
                    total += target_labels.size(0)
                    correct += (predicted == target_labels).sum().item()
            
            avg_val_loss = val_loss / val_batches
            val_accuracy = correct / total

        # Log metrics
        if enable_early_stopping:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss,
                "val_accuracy": val_accuracy
            })
        else:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_train_loss
            })
        
        
        if enable_early_stopping:
            logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        else:
            logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}")

        # Early stopping check
        if enable_early_stopping:
            early_stopping(avg_val_loss, network)
            if early_stopping.early_stop:
                logger.info("Early stopping triggered")
                network.load_state_dict(early_stopping.best_model)
                break
    
    # Load best model
    if enable_early_stopping and early_stopping.best_model is not None:
        network.load_state_dict(early_stopping.best_model)
        logger.info("Loaded best model weights")

In [10]:
def evaluate_model(network):
    # Training score
    true_labels = y_test_encoded
    predicted_labels = []
    network.eval()
    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(device)
            outputs = network(images)
            _, predicted = torch.max(outputs, 1)
            predicted_labels.extend(predicted.cpu().numpy())

    accuracy = np.mean([true == pred for true, pred in zip(true_labels, predicted_labels)])
    precision = precision_score(true_labels, predicted_labels, average='weighted')
    recall = recall_score(true_labels, predicted_labels, average='weighted')
    f1 = f1_score(true_labels, predicted_labels, average='weighted')

    logger.info(f"network accuracy: {accuracy * 100:.2f}%")
    logger.info(f"network precision: {precision * 100:.2f}%")
    logger.info(f"network recall: {recall * 100:.2f}%")
    logger.info(f"network F1 score: {f1 * 100:.2f}%")

    logger.info(f"Detailed Classification Report: \n{classification_report(true_labels, predicted_labels)}")

    # Log test metrics
    wandb.log({
        "test_accuracy": accuracy,
        "test_precision": precision,
        "test_recall": recall,
        "test_f1": f1
    })

    wandb.finish()

In [11]:
net1 = torch.nn.Sequential(
    torch.nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),  # 224x224 -> 112x112
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 112x112 -> 56x56
    torch.nn.ReLU(),
    torch.nn.Flatten(),
    torch.nn.Linear(64 * 56 * 56, 128),  # 64 channels * 56 * 56 spatial size
    torch.nn.ReLU(),
    torch.nn.Linear(128, 3)  # 3 classes output
).to(device)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net1.parameters(), lr=lr)

In [12]:
init_wandb()
train_model(net1, optimizer, loss_fn)
evaluate_model(net1)

Training model:   0%|          | 0/15 [00:00<?, ?it/s]

2025-12-06 13:00:23,009 - INFO - Epoch 1/15, Train Loss: 6.1536


Training model:   7%|▋         | 1/15 [00:00<00:06,  2.07it/s]

2025-12-06 13:00:23,166 - INFO - Epoch 2/15, Train Loss: 1.2604


Training model:  13%|█▎        | 2/15 [00:00<00:03,  3.43it/s]

2025-12-06 13:00:23,322 - INFO - Epoch 3/15, Train Loss: 0.9692


Training model:  20%|██        | 3/15 [00:00<00:02,  4.36it/s]

2025-12-06 13:00:23,479 - INFO - Epoch 4/15, Train Loss: 0.7975


Training model:  27%|██▋       | 4/15 [00:00<00:02,  4.98it/s]

2025-12-06 13:00:23,640 - INFO - Epoch 5/15, Train Loss: 0.6145


Training model:  33%|███▎      | 5/15 [00:01<00:01,  5.37it/s]

2025-12-06 13:00:23,797 - INFO - Epoch 6/15, Train Loss: 0.4358


Training model:  40%|████      | 6/15 [00:01<00:01,  5.67it/s]

2025-12-06 13:00:23,953 - INFO - Epoch 7/15, Train Loss: 0.2441


Training model:  47%|████▋     | 7/15 [00:01<00:01,  5.89it/s]

2025-12-06 13:00:24,111 - INFO - Epoch 8/15, Train Loss: 0.1393


Training model:  53%|█████▎    | 8/15 [00:01<00:01,  6.02it/s]

2025-12-06 13:00:24,268 - INFO - Epoch 9/15, Train Loss: 0.0723


Training model:  60%|██████    | 9/15 [00:01<00:00,  6.13it/s]

2025-12-06 13:00:24,424 - INFO - Epoch 10/15, Train Loss: 0.0801


Training model:  67%|██████▋   | 10/15 [00:01<00:00,  6.22it/s]

2025-12-06 13:00:24,581 - INFO - Epoch 11/15, Train Loss: 0.0704


Training model:  73%|███████▎  | 11/15 [00:02<00:00,  6.26it/s]

2025-12-06 13:00:24,737 - INFO - Epoch 12/15, Train Loss: 0.0691


Training model:  80%|████████  | 12/15 [00:02<00:00,  6.30it/s]

2025-12-06 13:00:24,894 - INFO - Epoch 13/15, Train Loss: 0.0702


Training model:  87%|████████▋ | 13/15 [00:02<00:00,  6.32it/s]

2025-12-06 13:00:25,051 - INFO - Epoch 14/15, Train Loss: 0.0541


Training model:  93%|█████████▎| 14/15 [00:02<00:00,  6.34it/s]

2025-12-06 13:00:25,207 - INFO - Epoch 15/15, Train Loss: 0.0463


Training model: 100%|██████████| 15/15 [00:02<00:00,  5.59it/s]

2025-12-06 13:00:25,247 - INFO - network accuracy: 28.57%



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


2025-12-06 13:00:25,248 - INFO - network precision: 88.15%
2025-12-06 13:00:25,249 - INFO - network recall: 28.57%
2025-12-06 13:00:25,249 - INFO - network F1 score: 28.66%
2025-12-06 13:00:25,259 - INFO - Detailed Classification Report: 
              precision    recall  f1-score   support

           0       0.17      1.00      0.29         7
           1       1.00      0.17      0.29        42
           2       0.00      0.00      0.00         0

    accuracy                           0.29        49
   macro avg       0.39      0.39      0.19        49
weighted avg       0.88      0.29      0.29        49



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,█▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
epoch,15.0
test_accuracy,0.28571
test_f1,0.28656
test_precision,0.88153
test_recall,0.28571
train_loss,0.04627


In [16]:
# Use train test split, for getting validation metrics during training
x_train_tensor, x_val_tensor, y_train_tensor, y_val_tensor = train_test_split(
    x_train_tensor, y_train_tensor, test_size=0.2, random_state=42, stratify=y_train_tensor)

val_dataset = TensorDataset(x_val_tensor, y_val_tensor)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [17]:
net1 = torch.nn.Sequential(
    torch.nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),  # 224x224 -> 112x112
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # 112x112 -> 56x56
    torch.nn.ReLU(),
    torch.nn.Flatten(),
    torch.nn.Linear(64 * 56 * 56, 128),  # 64 channels * 56 * 56 spatial size
    torch.nn.ReLU(),
    torch.nn.Linear(128, 3)  # 3 classes output
).to(device)

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net1.parameters(), lr=lr)

In [18]:
init_wandb()
train_model(net1, optimizer, loss_fn, enable_early_stopping=True)
evaluate_model(net1)

Training model:   0%|          | 0/15 [00:00<?, ?it/s]

2025-12-06 13:14:47,486 - INFO - Epoch 1/15, Train Loss: 6.2651, Val Loss: 1.8797, Val Acc: 0.4167


Training model:   7%|▋         | 1/15 [00:00<00:07,  1.94it/s]

2025-12-06 13:14:47,680 - INFO - Epoch 2/15, Train Loss: 1.0908, Val Loss: 1.0649, Val Acc: 0.4444


Training model:  13%|█▎        | 2/15 [00:00<00:04,  3.07it/s]

2025-12-06 13:14:47,815 - INFO - Epoch 3/15, Train Loss: 0.8060, Val Loss: 1.2538, Val Acc: 0.5556
2025-12-06 13:14:47,815 - INFO - EarlyStopping counter: 1 out of 5


Training model:  20%|██        | 3/15 [00:00<00:02,  4.20it/s]

2025-12-06 13:14:47,943 - INFO - Epoch 4/15, Train Loss: 0.6775, Val Loss: 1.1957, Val Acc: 0.5556
2025-12-06 13:14:47,943 - INFO - EarlyStopping counter: 2 out of 5


Training model:  27%|██▋       | 4/15 [00:00<00:02,  5.14it/s]

2025-12-06 13:14:48,073 - INFO - Epoch 5/15, Train Loss: 0.5917, Val Loss: 1.2129, Val Acc: 0.5278
2025-12-06 13:14:48,073 - INFO - EarlyStopping counter: 3 out of 5


Training model:  33%|███▎      | 5/15 [00:01<00:01,  5.83it/s]

2025-12-06 13:14:48,257 - INFO - Epoch 6/15, Train Loss: 0.4481, Val Loss: 1.4125, Val Acc: 0.5556
2025-12-06 13:14:48,258 - INFO - EarlyStopping counter: 4 out of 5


Training model:  40%|████      | 6/15 [00:01<00:01,  5.69it/s]

2025-12-06 13:14:48,382 - INFO - Epoch 7/15, Train Loss: 0.3059, Val Loss: 1.4560, Val Acc: 0.5556
2025-12-06 13:14:48,383 - INFO - EarlyStopping counter: 5 out of 5
2025-12-06 13:14:48,383 - INFO - Early stopping triggered


Training model:  40%|████      | 6/15 [00:01<00:02,  4.25it/s]

2025-12-06 13:14:48,386 - INFO - Loaded best model weights
2025-12-06 13:14:48,422 - INFO - network accuracy: 65.31%
2025-12-06 13:14:48,423 - INFO - network precision: 86.83%
2025-12-06 13:14:48,424 - INFO - network recall: 65.31%
2025-12-06 13:14:48,425 - INFO - network F1 score: 70.95%
2025-12-06 13:14:48,436 - INFO - Detailed Classification Report: 
              precision    recall  f1-score   support

           0       0.30      0.86      0.44         7
           1       0.96      0.62      0.75        42
           2       0.00      0.00      0.00         0

    accuracy                           0.65        49
   macro avg       0.42      0.49      0.40        49
weighted avg       0.87      0.65      0.71        49




  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▂▃▅▆▇█
test_accuracy,▁
test_f1,▁
test_precision,▁
test_recall,▁
train_loss,█▂▂▁▁▁▁
val_accuracy,▁▂██▇██
val_loss,█▁▃▂▂▄▄

0,1
epoch,7.0
test_accuracy,0.65306
test_f1,0.70945
test_precision,0.86825
test_recall,0.65306
train_loss,0.30589
val_accuracy,0.55556
val_loss,1.45601
