In [None]:
# Imports and config
import sys
import os
sys.path.insert(0, os.path.dirname(os.getcwd()))
import torch
from src.data_loader import get_data_loaders
from src.train import train_model
from src.models.mobilenet import MobileNet
from src.models.squeezenet import SqueezeNet, SqueezeNetCompact
from src.models.shiftnet import ShiftNet

DATA_DIR = "../data/FER2013"
BATCH_SIZE = 64
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_EPOCHS = 20   # change to 5-10 for quick runs
LR = 1e-3

print("Device:", DEVICE)
train_loader, test_loader = get_data_loaders(DATA_DIR, batch_size=BATCH_SIZE)

In [None]:
# Prepare model save folders
os.makedirs("../models/mobilenet", exist_ok=True)
os.makedirs("../models/squeezenet", exist_ok=True)
os.makedirs("../models/shiftnet", exist_ok=True)

In [None]:
# Train SqueezeNet from scratch
squeezenet = SqueezeNet(num_classes=7)
print(squeezenet)
squeezenet_trained = train_model(squeezenet, train_loader, test_loader, device=DEVICE,
                                 num_epochs=NUM_EPOCHS, lr=LR)
# saved automatically by train_model into models/<classname>/best_model.pth

In [None]:
# Train SqueezeNet from scratch
squeezenetcompact = SqueezeNetCompact(num_classes=7)
print(squeezenetcompact)
squeezenet_trained = train_model(squeezenetcompact, train_loader, test_loader, device=DEVICE,
                                 num_epochs=NUM_EPOCHS, lr=LR)
# saved automatically by train_model into models/<classname>/best_model.pth