### Imports

In [1]:
import os
import torch
from src.alexnet import DataManager, ModelManager

### Constants

In [2]:

#? ------------- data constants -------------
DATA_PATH = os.path.expanduser("~/.cache/kagglehub/datasets/arjunashok33/miniimagenet/versions/1")
LABELS_PATH = "labels.json"
BATCH_SIZE = 512
NUM_WORKERS = 12
TRAIN_SPLIT = 0.7
VAL_SPLIT = 0.15

#? ------------- checkpoint path -------------
CHECKPOINT_PATH = "checkpoints/best_model.pth"

#? ------------- model constants -------------
TRAIN = False
LOAD_MODEL = True

NUM_EPOCHS = 100
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 3e-4
DROPOUT_RATE = 0.3
PATIENCE = 15
LABEL_SMOOTHING = 0.1

### Load Data and Model

In [3]:
# detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Optimizations applied: Mixed Precision, Optimized DataLoader, Larger Batch Size")



data_manager = DataManager(
    data_path=DATA_PATH,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    train_split=TRAIN_SPLIT,
    val_split=VAL_SPLIT
)
data_manager.setup()

model_manager = ModelManager(
    data_manager=data_manager,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    dropout_rate=DROPOUT_RATE,
    patience=PATIENCE,
    label_smoothing=LABEL_SMOOTHING
)

Using device: cuda
Optimizations applied: Mixed Precision, Optimized DataLoader, Larger Batch Size
Loading dataset from /home/spina/.cache/kagglehub/datasets/arjunashok33/miniimagenet/versions/1
Dataset loaded: 60000 samples, 100 classes
Classes: ['kit fox', 'English setter', 'Siberian husky', 'Australian terrier', 'English springer', 'grey whale', 'lesser panda', 'Egyptian cat', 'ibex', 'Persian cat', 'cougar', 'gazelle', 'porcupine', 'sea lion', 'malamute', 'badger', 'Great Dane', 'Walker hound', 'Welsh springer spaniel', 'whippet', 'Scottish deerhound', 'killer whale', 'mink', 'African elephant', 'Weimaraner', 'soft-coated wheaten terrier', 'Dandie Dinmont', 'red wolf', 'Old English sheepdog', 'jaguar', 'otterhound', 'bloodhound', 'Airedale', 'hyena', 'meerkat', 'giant schnauzer', 'titi', 'three-toed sloth', 'sorrel', 'black-footed ferret', 'dalmatian', 'black-and-tan coonhound', 'papillon', 'skunk', 'Staffordshire bullterrier', 'Mexican hairless', 'Bouvier des Flandres', 'weasel', 

#### Train

In [4]:
if LOAD_MODEL:
    if os.path.exists(CHECKPOINT_PATH):
        try:
            model_manager.load_model(CHECKPOINT_PATH)
            print(f"Warm-started from {CHECKPOINT_PATH}")
        except Exception as e:
            print(f"Could not warm-start from {CHECKPOINT_PATH}: {e}")
    else:
        print(f"No checkpoint found at {CHECKPOINT_PATH}")
        
if TRAIN:
    print(f"Training")

    #? -------------- Training --------------
    # Train the model
    training_history = model_manager.train()

    # Plot training history
    model_manager.plot_training_history()
    #? ---------------------------------------


Model loaded from checkpoints/best_model.pth
Best validation loss: 2.7823
Best validation accuracy: 49.03%
Warm-started from checkpoints/best_model.pth


### Test

In [5]:
test_loss, test_accuracy, test_accuracy5 = model_manager.test()

print(f"Test Results - Loss: {test_loss:.4f}, Acc@1: {test_accuracy:.2f}%, Acc@5: {test_accuracy5:.2f}%")

Testing: 100%|██████████| 18/18 [00:07<00:00,  2.37it/s, Loss=2.8061, Acc@1=50.83%, Acc@5=72.23%]

Test Results - Loss: 2.7189, Acc@1: 50.83%, Acc@5: 72.23%



