# Running Notebook

In [None]:
import torch

from scripts.models import CNN
from scripts.trainer import Trainer

from scripts.utils.data_loading import make_loaders

from scripts.config.experiments import EXPERIMENTS as experiments

In [None]:
for cfg in experiments:
    print(f"=== Running {cfg['name']} ===")
    train_loader, test_loader, n_train, n_test = make_loaders(cfg["metadata_path"], cfg["batch_size"])
    print(f"Train samples: {n_train}, Test samples: {n_test}")
    num_epochs = cfg["num_epochs"]
    device = cfg["device"]
    
    model = CNN(clip_duration=cfg["clip_duration"]).to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    trainer = Trainer(
        model,
        optimizer,
        criterion,
        device=device,
        log_dir=cfg["log_dir"],
        checkpoint_dir=cfg["checkpoint_dir"],
        segments_per_clip=cfg["segments_per_clip"],
        clip_duration=cfg["clip_duration"],
    )

    trainer.fit(train_loader, test_loader, num_epochs=num_epochs)
    print(f"Finished {cfg['name']}. Logs: {cfg['log_dir']} | Checkpoints: {cfg['checkpoint_dir']}")
    # live view: tensorboard --logdir runs/logs


=== Running full_30s ===
Train samples: 6392, Test samples: 1598
Epoch 001 | train_loss=2.0867, train_acc=0.13 | test_loss=2.0801, test_acc=0.12 | checkpoint=epoch_001.pt
Epoch 002 | train_loss=2.0800, train_acc=0.12 | test_loss=2.0806, test_acc=0.12 | checkpoint=epoch_002.pt
Epoch 003 | train_loss=2.0800, train_acc=0.13 | test_loss=2.0803, test_acc=0.12 | checkpoint=epoch_003.pt
Epoch 004 | train_loss=2.0799, train_acc=0.13 | test_loss=2.0803, test_acc=0.12 | checkpoint=epoch_004.pt
Epoch 005 | train_loss=2.0800, train_acc=0.12 | test_loss=2.0804, test_acc=0.12 | checkpoint=epoch_005.pt
Finished full_30s. Logs: runs/logs/CNN_exp_full | Checkpoints: runs/checkpoints/CNN_exp_full
=== Running ensemble_3s ===
Train samples: 63920, Test samples: 15980
Epoch 001 | train_loss=1.5077, train_acc=0.45 | test_loss=1.2995, test_acc=0.57 | checkpoint=epoch_001.pt
Epoch 002 | train_loss=1.2691, train_acc=0.55 | test_loss=1.2258, test_acc=0.59 | checkpoint=epoch_002.pt
Epoch 003 | train_loss=1.1715,