### Imports

In [1]:
import os
import sys
import os
import torch

sys.path.insert(0, os.path.abspath("src"))

from config import *
from DataManager import DataManager
from AlexNet import AlexNet
from AllenDataLoader import AllenDataLoader

In [5]:
!bat src/config.py

]10;?]11;?[c[38;5;246m───────┬────────────────────────────────────────────────────────────────────────[0m
       [38;5;246m│ [0mFile: [1msrc/config.py[0m
[38;5;246m───────┼────────────────────────────────────────────────────────────────────────[0m
[38;5;246m   1[0m   [38;5;246m│[0m [38;5;203mimport[0m[38;5;231m [0m[38;5;231mos[0m
[38;5;246m   2[0m   [38;5;246m│[0m [38;5;203mfrom[0m[38;5;231m [0m[38;5;231mpathlib[0m[38;5;231m [0m[38;5;203mimport[0m[38;5;231m [0m[38;5;231mPath[0m
[38;5;246m   3[0m   [38;5;246m│[0m 
[38;5;246m   4[0m   [38;5;246m│[0m [38;5;242m"""[0m[38;5;242mProject-wide configuration constants.[0m
[38;5;246m   5[0m   [38;5;246m│[0m 
[38;5;246m   6[0m   [38;5;246m│[0m [38;5;242mAll file-system locations are defined with `pathlib.Path` and resolved [0m
[38;5;246m    [0m   [38;5;246m│[0m [38;5;242mrelative[0m
[38;5;246m   7[0m   [38;5;246m│[0m [38;5;242mto the repository layout to ensure portability 

### Load Data and Model

In [9]:
# detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print("Optimizations applied: Mixed Precision, Optimized DataLoader, Larger Batch Size")

data_manager = DataManager(
    data_path=DATA_PATH,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    train_split=TRAIN_SPLIT,
    val_split=VAL_SPLIT
)
data_manager.setup()

model_manager = AlexNet(
    data_manager=data_manager,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    dropout_rate=DROPOUT_RATE,
    patience=PATIENCE,
    label_smoothing=LABEL_SMOOTHING
)

Using device: cuda
Optimizations applied: Mixed Precision, Optimized DataLoader, Larger Batch Size
Loading dataset from /home/spina/.cache/kagglehub/datasets/arjunashok33/miniimagenet/versions/1
Dataset loaded: 60000 samples, 100 classes
Classes: ['kit fox', 'English setter', 'Siberian husky', 'Australian terrier', 'English springer', 'grey whale', 'lesser panda', 'Egyptian cat', 'ibex', 'Persian cat', 'cougar', 'gazelle', 'porcupine', 'sea lion', 'malamute', 'badger', 'Great Dane', 'Walker hound', 'Welsh springer spaniel', 'whippet', 'Scottish deerhound', 'killer whale', 'mink', 'African elephant', 'Weimaraner', 'soft-coated wheaten terrier', 'Dandie Dinmont', 'red wolf', 'Old English sheepdog', 'jaguar', 'otterhound', 'bloodhound', 'Airedale', 'hyena', 'meerkat', 'giant schnauzer', 'titi', 'three-toed sloth', 'sorrel', 'black-footed ferret', 'dalmatian', 'black-and-tan coonhound', 'papillon', 'skunk', 'Staffordshire bullterrier', 'Mexican hairless', 'Bouvier des Flandres', 'weasel', 

#### Load Model

In [10]:
if LOAD_MODEL:
    if os.path.exists(CHECKPOINT_PATH):
        try:
            model_manager.load_model(CHECKPOINT_PATH)
        except Exception as e:
            print(f"Could not warm-start from {CHECKPOINT_PATH}: {e}")
    else:
        print(f"No checkpoint found at {CHECKPOINT_PATH}")

Model loaded from /home/spina/Desktop/units/DL/project/mice-representation/checkpoints/best_model.pth
Best validation loss: 2.7823
Best validation accuracy: 49.03%


#### Train

In [11]:
if TRAIN:
    print(f"Training")

    #? -------------- Training --------------
    # Train the model
    training_history = model_manager.train()

    # Plot training history
    model_manager.plot_training_history()
    #? ---------------------------------------


### Test

In [12]:
test_loss, test_accuracy, test_accuracy5 = model_manager.test()

print(f"Test Results - Loss: {test_loss:.4f}, Acc-Top1: {test_accuracy:.2f}%, Acc-Top5: {test_accuracy5:.2f}%")

Testing: 100%|██████████| 18/18 [00:07<00:00,  2.32it/s, Loss=2.6530, Acc@1=50.88%, Acc@5=72.31%]

Test Results - Loss: 2.7168, Acc-Top1: 50.88%, Acc-Top5: 72.31%





### Allen Dataset

In [2]:
dataset = "neuropixels"   # or "calcium"
download = False          # True to download the PKL
out = None                # output path for the PKL
path ="AllenData/neuropixels.zarr"               # path to an existing PKL

data_loader = AllenDataLoader(dataset=dataset)

if download:
    saved_path = data_loader.download(out_path=out)

INFO:Initialized AllenDataLoader for neuropixels dataset


Convert data to a compatible type

In [14]:
!REMOVE_ENV=1  src/data/data_converter.sh

Creating conda environment
Converting 1 files
OK: AllenData/neuropixels.pkl -> AllenData/neuropixels.zarr
File successfully converted
File not found: AllenData/calcium.pkl


Inspect data

In [3]:
data_loader.inspect(pkl_path=path)

INFO:Top-level keys: ['neural_data', 'stimuli']
INFO:stimuli: shape=(118, 918, 1174)
INFO:neural_data: visual_areas=['VISal', 'VISam', 'VISl', 'VISp', 'VISpm', 'VISrl']


Inspection complete. Enable --verbose for detailed logs.
