In [3]:
import os
import torch
from omegaconf import OmegaConf
from collections import Counter
from tqdm import tqdm
from utils.dataset import get_data_loaders
from models.CNN import CNN
from models.resnet import ResNet, BasicBlock

In [15]:
# Config variables
PROJECT_ROOT = os.path.abspath("../")

cfg = OmegaConf.create({
    "project_root": PROJECT_ROOT,
    "verbose": True,
    "wandb": True,
    "sys_log": True,
    "model": "CNN",
    "CNN": {
        "c1": 16, "c2": 32, "c3": 64, "k1": 3, "k2": 3, "k3": 3,
        "pk": 2, "ps": 2, "kernel_size": 3, "stride": 1, "padding": 1
    },
    "train": {
        "epochs": 50, "batch_size": 20, "shuffle": True, "train_ratio": 0.8,
        "print_label_frequencies": True
    },
    "data": {
        "data_path": f"{PROJECT_ROOT}/data/raw/derivatives/non-linear_anatomical_alignment",
        "zarr_dir_path": f"{PROJECT_ROOT}/zarr_datasets",
        "zarr_path": f"{PROJECT_ROOT}/zarr_datasets/pool_emotions",
        "label_path": f"{PROJECT_ROOT}/data/updated_annotations/pooled_annotations_structured.tsv",
        "sessions": ["01", "02", "03", "04", "05", "06", "07", "08"],
        "file_pattern_template": "*_ses-forrestgump_task-forrestgump_rec-dico7Tad2grpbold7TadNL_run-{}_bold.nii.gz",
        "subjects": ["sub-04"],
        "session_offsets": [0, 902, 1784, 2660, 3636, 4560, 5438, 6522],
        "emotion_idx": {"NONE": 0, "HAPPINESS": 1, "FEAR": 2, "SADNESS": 3, "LOVE": 4, "ANGER": 5},
        "normalization": False,
        "weight_decay": 0,
        "learning_rate": 0.0001,
        "seed": 42,
        "save_model": True,
        "load_model": False,
        "save_model_path": "output/models",
        "load_model_path": f"{PROJECT_ROOT}/src/output/models/k94ke7h1.pth",
    }
})

def evaluate_model(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    label_counts = Counter()
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            data, labels = batch["data_tensor"], batch["label_tensor"]
            data = data.float().to(device)
            labels = labels.long().to(device)
            if data.dim() == 4:
                data = data.unsqueeze(1)
            output = model(data)
            _, predictions = torch.max(output, dim=1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
            label_counts.update(labels.cpu().numpy())
    
    accuracy = correct / total if total > 0 else 0
    return accuracy * 100, label_counts

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    train_dataloader, val_dataloader = get_data_loaders(cfg)
    print(f"Loaded Observations: {len(train_dataloader.dataset) + len(val_dataloader.dataset)}")
    
    output_dim = len(cfg.data.emotion_idx)
    model_path_torch = cfg.data.load_model_path
    
    if cfg.model == "CNN":
        model = CNN(cfg=cfg, output_dim=output_dim)
    elif cfg.model == "ResNet":
        model = ResNet(BasicBlock, [1, 1, 1, 1], in_channels=1, num_classes=output_dim)
    else:
        raise ValueError("Invalid model specified")
    
    if model_path_torch:
        model.load_state_dict(torch.load(model_path_torch, map_location=device))
        print(f"Loaded model from {model_path_torch}")
    
    model.to(device)
    train_accuracy, train_label_counts = evaluate_model(model, train_dataloader, device)
    val_accuracy, val_label_counts = evaluate_model(model, val_dataloader, device)
    
    print(f"Training Accuracy: {train_accuracy:.2f}%")
    print(f"Validation Accuracy: {val_accuracy:.2f}%")
    
    inverse_emotion_idx = {v: k for k, v in cfg.data.emotion_idx.items()}
    print("Training Label Distribution:")
    for label, count in sorted(train_label_counts.items()):
        print(f"{inverse_emotion_idx[label]}: {count}")
    
    print("Validation Label Distribution:")
    for label, count in sorted(val_label_counts.items()):
        print(f"{inverse_emotion_idx[label]}: {count}")

main()


Using device: cuda
Dataset contains 8 files.
Spatial dimensions: (132, 175, 48)
Maximum timepoints per file: 542
Subjects: ['sub-04']
Sessions: ['01' '02' '03' '04' '05' '06' '07' '08']
Emotion categories: ['NONE', 'HAPPINESS', 'FEAR', 'SADNESS', 'LOVE', 'ANGER']
Total valid labeled timepoints: 799
Loaded Observations: 799


  model.load_state_dict(torch.load(model_path_torch, map_location=device))


Loaded model from /home/paperspace/DeepEmotion/src/output/models/k94ke7h1.pth


Evaluating: 100%|██████████| 32/32 [00:14<00:00,  2.15it/s]
Evaluating: 100%|██████████| 8/8 [00:03<00:00,  2.12it/s]

Training Accuracy: 99.53%
Validation Accuracy: 88.12%
Training Label Distribution:
HAPPINESS: 157
FEAR: 140
SADNESS: 194
LOVE: 86
ANGER: 62
Validation Label Distribution:
HAPPINESS: 49
FEAR: 30
SADNESS: 44
LOVE: 19
ANGER: 18



