In [None]:
# Imports and config
import sys
import os

notebook_dir = os.getcwd()
project_root = os.path.dirname(notebook_dir)
sys.path.insert(0, project_root)

import torch

from src.data_loader import get_data_loaders
from src.models.squeezenet import SqueezeNet
from src.models.mobilenet import MobileNetV1, MobileNetV2, MobileNetV3
from src.models.shiftnet import ShiftNet
from src.models.shufflenet import ShuffleNetV2
from src.evaluate import evaluate_model
from src.utils import save_results

# Config
DATA_DIR = "../data/FER2013"
BATCH_SIZE = 64
DEVICE = 'cpu'

# Load data
train_loader, test_loader = get_data_loaders(DATA_DIR, batch_size=BATCH_SIZE)
print("Device:", DEVICE)

In [None]:
import inspect

# Helper to load saved model and evaluate
results = {}

def load_and_eval(model_cls, model_path, model_name):
    # Detect accepted arguments
    sig = inspect.signature(model_cls.__init__).parameters

    kwargs = {}
    if "num_classes" in sig:
        kwargs["num_classes"] = 7
    if "in_channels" in sig:
        kwargs["in_channels"] = 1

    # Create model with only supported arguments
    model = model_cls(**kwargs).to(DEVICE)

    if os.path.exists(model_path):
        state = torch.load(model_path, map_location=DEVICE, weights_only=True)
        fixed_state = {}
        for k, v in state.items():
            new_key = k.replace("model.", "")
            fixed_state[new_key] = v
        model.load_state_dict(fixed_state, strict=False)
        print("\n" + "="*50)
        print(f"Loading and evaluating {model_name} model...")
        print("="*50)
        
        r = evaluate_model(model, test_loader, DEVICE)
        results[model_name] = r
        
        print(f"\nAccuracy: {r["accuracy"]:.2f}%")
        print(f"Inference Time: {r["inference_time_sec"]:.2f}s")
        print(f"Energy Consumed: {r["energy_consumed_kwh"]:.6f} kWh")
        print(f"COâ‚‚ Emissions: {r["co2_emissions_kg"]:.6f} kg")
        print(f"FLOPs: {r["flops"]}")
        print(f"Parameters: {r["parameters"]}\n")
        
        # Save results
        save_results(results, path="../results.json")
    else:
        print(f"{model_name} checkpoint not found:", model_path)
    return 

In [None]:
# Evaluate SqueezeNet
squeezenet_path = "../models/squeezenet/best_model.pth"

load_and_eval(SqueezeNet, squeezenet_path, "SqueezeNet")

In [None]:
# Evaluate MobileNetV1
mobilenetv1_path = "../models/mobilenetv1/best_model.pth"

load_and_eval(MobileNetV1, mobilenetv1_path, "MobileNetV1")

In [None]:
# Evaluate MobileNetV2
mobilenetv2_path = "../models/mobilenetv2/best_model.pth"

load_and_eval(MobileNetV2, mobilenetv2_path, "MobileNetV2")

In [None]:
# Evaluate MobileNetV3
mobilenetv3_path = "../models/mobilenetv3/best_model.pth"

load_and_eval(MobileNetV3, mobilenetv3_path, "MobileNetV3")

In [None]:
# Evaluate ShiftNet
shiftnet_path = "../models/shiftnet/best_model.pth"

load_and_eval(ShiftNet, shiftnet_path, "ShiftNet")

In [None]:
# Evaluate ShuffleNet
shufflenet_path = "../models/shufflenetv2/best_model.pth"

load_and_eval(ShuffleNetV2, shufflenet_path, "ShuffleNetV2")