# Inputs

In [1]:
import os
import pandas as pd
import yaml

In [2]:
DATASET_PARAMS = yaml.load(open("configs/dataset.yaml"), Loader=yaml.SafeLoader)['DATASET']
PREPROCESSING_PARAMS = yaml.load(open("configs/preprocessing.yaml"), Loader=yaml.SafeLoader)['PREPROCESSING']
EFFICIENTCAPSNET_PARAMS = yaml.load(open("configs/efficientcapsnet.yaml"), Loader=yaml.SafeLoader)['CAPSNET']
TRAINING_PARAMS = yaml.load(open("configs/training.yaml"), Loader=yaml.SafeLoader)

In [3]:
categories =  DATASET_PARAMS['CATEGORIES']

BASE_DIR = os.getcwd()
DATASET_DIR = os.path.join(BASE_DIR, DATASET_PARAMS['DATA_PATH'])

TRAIN_DIR = os.path.join(DATASET_DIR, "train")
VAL_DIR = os.path.join(DATASET_DIR, "val")
TEST_DIR = os.path.join(DATASET_DIR, "test")

In [4]:
import torch

In [5]:
torch.__version__

'2.6.0+cu126'

In [6]:
from src.utils import get_device
DEVICE = get_device()

# Dataset

In [7]:
from src.utils import generate_filenames_df

In [8]:
train_filenames_df = generate_filenames_df(TRAIN_DIR, categories)
val_filenames_df = generate_filenames_df(VAL_DIR, categories)

## Pre Processor

In [9]:
from src.preprocessing import preprocess

target_input_size = tuple(PREPROCESSING_PARAMS['INPUT_SIZE'])
train_transform = preprocess(
    target_input_size=target_input_size,
    rotation_range=PREPROCESSING_PARAMS['ROTATION_RANGE'],
    width_shift_range=PREPROCESSING_PARAMS['WIDTH_SHIFT_RANGE'],
    height_shift_range=PREPROCESSING_PARAMS['HEIGHT_SHIFT_RANGE'],
    brightness_range=PREPROCESSING_PARAMS['BRIGHTNESS_RANGE'],
    zoom_range=PREPROCESSING_PARAMS['ZOOM_RANGE'],
    horizontal_flip=PREPROCESSING_PARAMS['HORIZONTAL_FLIP'],
    vertical_flip=PREPROCESSING_PARAMS['VERTICAL_FLIP'],
    channel_shift_range=PREPROCESSING_PARAMS['CHANNEL_SHIFT_RANGE'],
    fill_mode=PREPROCESSING_PARAMS['FILL_MODE'],
    drop_out=PREPROCESSING_PARAMS['DROP_OUT'],
    shear_range=PREPROCESSING_PARAMS['SHEAR_RANGE']
    )

val_transform = preprocess(target_input_size=target_input_size) # only rescaling

In [10]:
from src.dataset import Dataset

dataset = Dataset(dataframe = train_filenames_df,
    transform=train_transform
)

## Train loader

In [11]:
from src.loader import Loader
train_loader = Loader(train_filenames_df, 
                     batch_size=DATASET_PARAMS['BATCH_SIZE'], 
                     num_workers=EFFICIENTCAPSNET_PARAMS['NUM_WORKERS'], 
                     transform=train_transform, 
                     shuffle=True)

val_loader = Loader(val_filenames_df, 
                    batch_size=DATASET_PARAMS['BATCH_SIZE'], 
                    num_workers=EFFICIENTCAPSNET_PARAMS['NUM_WORKERS'], 
                    transform=val_transform, 
                    shuffle=False)


# Model

### EfficientCapsNet

In [12]:
from src.model import EfficientCapsNet

model = EfficientCapsNet(input_size=(EFFICIENTCAPSNET_PARAMS['INPUT_SIZE']))



### DenseNet

In [18]:
from src.densenet import DenseNet121

model = DenseNet121(num_classes=EFFICIENTCAPSNET_PARAMS['NUM_CLASSES'])




# Training

## Metrics

In [19]:
# Loss
from src.loss import MarginLoss, marginLoss

margin_loss = marginLoss

# Optimizer
from torch.optim import Adam

optimizer = Adam(model.parameters(), lr=TRAINING_PARAMS['LEARNING_RATE'])

# Metrics
from src.metrics import (
    MatthewsCorrelationCoefficient,
    CustomPrecision,
    CustomRecall,
    CustomAUC,
    CustomSpecificity
)
metrics = {
    "mcc": MatthewsCorrelationCoefficient(num_classes=EFFICIENTCAPSNET_PARAMS['NUM_CLASSES']),
    "precision": CustomPrecision(num_classes=EFFICIENTCAPSNET_PARAMS['NUM_CLASSES'], average=TRAINING_PARAMS['AVERAGE']),
    "recall": CustomRecall(num_classes=EFFICIENTCAPSNET_PARAMS['NUM_CLASSES'], average=TRAINING_PARAMS['AVERAGE']),
    "auc": CustomAUC(num_classes=EFFICIENTCAPSNET_PARAMS['NUM_CLASSES'], average=TRAINING_PARAMS['AVERAGE']),
    "specificity": CustomSpecificity(num_classes=EFFICIENTCAPSNET_PARAMS['NUM_CLASSES'], average=TRAINING_PARAMS['AVERAGE'])
}

In [20]:
from src.train import train
from src.utils import get_device

history = train(model=model, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    criterion=margin_loss, 
    optimizer=optimizer, 
    num_epochs=TRAINING_PARAMS['NUM_EPOCHS'], 
    device=get_device(),
    metrics=metrics,
    print_every=TRAINING_PARAMS['PRINT_EVERY'],
    save_patience=TRAINING_PARAMS['SAVE_PATIENCE'],
    save_path=TRAINING_PARAMS['SAVE_PATH'],
    save_model=TRAINING_PARAMS['SAVE_MODEL'],
    save_metrics=TRAINING_PARAMS['SAVE_METRICS']
    )


Training...

Epoch 1/10
data.shape: torch.Size([4, 3, 299, 299])
model_outputs_raw.shape: torch.Size([4, 3])
outputs: tensor([[-0.1638,  0.3862,  0.0384],
        [-0.0500,  0.0915,  0.0565],
        [-0.1316,  0.1138,  0.1088],
        [-0.0330,  0.2585,  0.1249]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
targets: tensor([0, 0, 0, 0], device='cuda:0')
Batch 1/1474
Running loss: 1.0057408809661865
Metrics: {'mcc': array([nan,  0.,  0., ...,  0.,  0.,  0.]), 'precision': array([nan,  0.,  0., ...,  0.,  0.,  0.]), 'recall': array([nan,  0.,  0., ...,  0.,  0.,  0.]), 'auc': array([nan,  0.,  0., ...,  0.,  0.,  0.]), 'specificity': array([nan,  0.,  0., ...,  0.,  0.,  0.])}
data.shape: torch.Size([4, 3, 299, 299])
model_outputs_raw.shape: torch.Size([4, 3])
outputs: tensor([[-0.0872,  0.4080, -0.0318],
        [ 0.3081,  0.2507,  0.0895],
        [-0.0207,  0.3128, -0.0331],
        [ 0.2416,  0.0767, -0.0772]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
targets: tensor

KeyboardInterrupt: 