# Mid-Term Report Experiments

This notebook comprises the required experiments using the SAND dataset/challenge. The notebook consists of the following sections: 1)

## Setup

In [None]:
!nvidia-smi

In [None]:
!pip install wandb --quiet # Install WandB
!pip install pytorch_metric_learning --quiet # Install the Pytorch Metric Library
!pip install torchinfo --quiet # Install torchinfo

## Imports

In [None]:
import torch
from torchsummary import summary
import torchvision
import torchaudio
from torchvision.utils import make_grid
from torchvision import transforms
import torchvision.transforms.v2 as T
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import os
import gc
import random
from pathlib import Path
# from tqdm import tqdm
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn import metrics as mt
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import glob
import wandb
import matplotlib.pyplot as plt
from pytorch_metric_learning import samplers
import csv
import warnings
warnings.filterwarnings("ignore")

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)


## Config

In [None]:
config = {
    ## Problem Configs
    "subset" : 1,
    "task_num" : 1,
    "fs" : 8000, # in Hz
    "max_len" : 5, # in seconds
    ## Data Configs
    "train_dir" : "/content/drive/MyDrive/Fall2025/11685/SAND_Challenge/Dataset/training_split_balanced/train",
    "val_dir" :  "/content/drive/MyDrive/Fall2025/11685/SAND_Challenge/Dataset/training_split_balanced/val",
    "test_dir" :  "/content/drive/MyDrive/Fall2025/11685/SAND_Challenge/Dataset/test",
    ## Model Configs
    "model" : "ViT_baseline",
    ## Training Configs
    'batch_size': 16, # Increase this if your GPU can handle it
    'lr': 1e-6,
    "weight_decay" : 1e-4,
    'epochs': 50,
    'num_classes': 5,
    'checkpoint_dir': "/content/drive/MyDrive/Fall2025/11685/SAND_Challenge/checkpoints",
    'augument': True,
    'ablation_ID': 22
}

random.seed(42)

In [None]:
# Mount drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
phonation_list = ["phonationA",
                  "phonationE",
                  "phonationI",
                  "phonationO",
                  "phonationU",
                  "rhythmKA",
                  "rhythmPA",
                  "rhythmTA"]

In [None]:
!pwd

## Dataset

### Data Augmentation

Change this to appropriate speech augmentation

In [None]:
import torchaudio.transforms as AT

def create_audio_transforms(sample_rate: int = 8000, n_mels: int = 64, augment: bool = True) -> torch.nn.Sequential:
    """
    Create transform pipeline for audio classification tasks.

    Args:
        sample_rate (int): Sample rate of the audio signals.
        n_mels (int): Number of Mel filterbanks.
        augment (bool): Whether to apply data augmentation.

    Returns:
        torch.nn.Sequential: Audio transform pipeline.
    """

    transform_list = []

    # Step 1: Convert waveform to Mel Spectrogram
    transform_list.append(
        AT.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024,
            hop_length=512,
            n_mels=n_mels
        )
    )

    # Step 2: Apply log transformation to compress dynamic range
    transform_list.append(AT.AmplitudeToDB())  # Log-scale the Mel spectrogram

    # Step 3: (Optional) Data augmentation for spectrograms
    if augment:
        transform_list.extend([
            AT.FrequencyMasking(freq_mask_param=8),
            AT.TimeMasking(time_mask_param=20),
        ])

    # Return the composed transformation pipeline
    return torch.nn.Sequential(*transform_list)


### Task 1 Data Loaders

In [None]:
class AudioMultiPhonationDataset(Dataset):
    """
    Custom Dataset for loading multiple phonation audio samples per subject with class labels.
    """

    def __init__(self, root_dir, phonation_types, transform=None, metadata_file="sand_task_1.xlsx"):
        """
        Args:
            root_dir (str): Path to the root directory containing phonation folders.
            phonation_types (list): List of phonation or rhythm folder names (e.g., ["PhonationA", "PhonationE", ...]).
            transform (callable, optional): Optional transform to be applied on a sample (e.g., waveform -> spectrogram).
            metadata_file (str): Path to Excel file containing ID -> class mappings.
        """
        self.root_dir = root_dir
        self.phonation_types = phonation_types
        self.transform = transform

        # Load metadata from Excel file
        metadata_path = os.path.join(root_dir, metadata_file)
        self.metadata_df = pd.read_excel(metadata_path)

        # Map ID to label (class)
        self.id_to_label = dict(zip(self.metadata_df['ID'], self.metadata_df['Class']))

        # Build a list of available IDs that have audio in all phonation folders
        self.subject_ids = self._collect_valid_subject_ids()


    def _collect_valid_subject_ids(self):
        """
        Collect subject IDs that have a complete set of phonation files.
        """
        phonation_files = {}
        for phonation in self.phonation_types:
            phonation_path = os.path.join(self.root_dir, phonation)
            files = os.listdir(phonation_path)
            ids = set(f.split('_')[0] for f in files if f.endswith('.wav'))
            phonation_files[phonation] = ids

        # Keep only IDs that are present in all phonation folders
        valid_ids = set.intersection(*phonation_files.values())
        # Also make sure the ID exists in the metadata
        valid_ids = valid_ids.intersection(set(str(id_) for id_ in self.metadata_df['ID']))
        return sorted(valid_ids)

    def __len__(self):
        return len(self.subject_ids)

    def __getitem__(self, idx):
        subject_id = self.subject_ids[idx]
        audio_tensors = []
        target_length = config["max_len"] * config["fs"]

        for phonation in self.phonation_types:
            filename = f"{subject_id}_{phonation}.wav"
            filepath = os.path.join(self.root_dir, phonation, filename)
            waveform, sample_rate = torchaudio.load(filepath)

            num_samples = waveform.shape[1]
            if num_samples > target_length:
                waveform = waveform[:, :target_length]
            elif num_samples < target_length:
                pad_amount = target_length - num_samples
                waveform = torch.nn.functional.pad(waveform, (0, pad_amount))

            if self.transform:
                waveform = self.transform(waveform)

            audio_tensors.append(waveform)

        # final tensor shape: (P, 1, F, T)
        audio_tensors = torch.stack(audio_tensors).to(dtype=torch.float32)

        # scalar label (IMPORTANT: python int first)
        label = int(self.id_to_label[subject_id] - 1)

        return audio_tensors, label



In [None]:
class AudioMultiPhonationDataset_Test(Dataset):
    """
    Custom Dataset for loading multiple phonation audio samples per subject with class labels.
    """

    def __init__(self, root_dir, phonation_types, transform=None, metadata_file="sand_task1_test.xlsx"):
        """
        Args:
            root_dir (str): Path to the root directory containing phonation folders.
            phonation_types (list): List of phonation or rhythm folder names (e.g., ["PhonationA", "PhonationE", ...]).
            transform (callable, optional): Optional transform to be applied on a sample (e.g., waveform -> spectrogram).
            metadata_file (str): Path to Excel file containing ID -> class mappings.
        """
        self.root_dir = root_dir
        self.phonation_types = phonation_types
        self.transform = transform

        # Load metadata from Excel file
        metadata_path = os.path.join(root_dir, metadata_file)
        self.metadata_df = pd.read_excel(metadata_path)

        # Map ID to label (class)
        # We don't have classes yet
        # self.id_to_label = dict(zip(self.metadata_df['ID'], self.metadata_df['Class']))


        # Build a list of available IDs that have audio in all phonation folders

        self.subject_ids = self._collect_valid_subject_ids()


    def _collect_valid_subject_ids(self):
        """
        Collect subject IDs that have a complete set of phonation files.
        """
        phonation_files = {}
        for phonation in self.phonation_types:
            phonation_path = os.path.join(self.root_dir, phonation)
            files = os.listdir(phonation_path)
            ids = set(f.split('_')[0] for f in files if f.endswith('.wav'))
            phonation_files[phonation] = ids

        # Keep only IDs that are present in all phonation folders
        valid_ids = set.intersection(*phonation_files.values())

        # Filter by metadata ID column ONLY (no class filtering)
        # pad with ID in front to prevent ID format issues
        metadata_ids = set("ID" + str(id_).zfill(3) for id_ in self.metadata_df['ID'])

        valid_ids = valid_ids.intersection(metadata_ids)

        return sorted(valid_ids)


    def __len__(self):
        return len(self.subject_ids)

    def __getitem__(self, idx):
        """
        Returns:
            tuple: (list of waveforms or transformed tensors, class label)
        """
        subject_id = self.subject_ids[idx]
        audio_tensors = []
        target_length = config["max_len"] * config["fs"]

        for phonation in self.phonation_types:
            filename = f"{subject_id}_{phonation}.wav"
            filepath = os.path.join(self.root_dir, phonation, filename)

            waveform, sample_rate = torchaudio.load(filepath)

            num_samples = waveform.shape[1]
            if num_samples > target_length:
                waveform = waveform[:, :target_length]  # Truncate
            elif num_samples < target_length:
                pad_amount = target_length - num_samples
                waveform = torch.nn.functional.pad(waveform, (0, pad_amount))  # Pad at the end


            # Optionally apply transformation (e.g., resampling, MFCC, MelSpectrogram)
            if self.transform:
                waveform = self.transform(waveform)

            audio_tensors.append(waveform) # (1, F, T)
            # audio_tensors.append(waveform) # (F, T)

        # #Phoneme_files list of (1(audio_channels), F, T) tensors
        #      -->  tensor (#Phoneme_files, 1(audio_channels), F, T)
        audio_tensors = torch.stack(audio_tensors).to(dtype=torch.float32)

        # Retrieve class label from metadata
        # label = self.id_to_label[subject_id] - 1 # Label to index

        return audio_tensors#, label


In [None]:
# train transforms
train_transforms = create_audio_transforms(augment=config['augument'])

# val transforms
val_transforms   = create_audio_transforms(augment=False)

# test transforms
test_transforms   = create_audio_transforms(augment=False)

In [None]:
!pip install openpyxl

In [None]:
from torch.utils.data import Subset
import random

# subset function
def subset_function(dataset, subset_selection):

    # only use subset if its less than 1
    if subset_selection < 1.0:

        subset_size = int(len(dataset) * subset_selection)
        indices = random.sample(range(len(dataset)), subset_size)
        print(f"Using subset of {subset_size}/{len(dataset)} samples ({subset_selection*100:.0f}%)")
        return Subset(dataset, indices)

    else:

        print(f"Using full dataset ({len(dataset)} samples)")
        return dataset

In [None]:
train_dataset = AudioMultiPhonationDataset(
    root_dir=config["train_dir"],
    phonation_types=phonation_list,
    transform=train_transforms
)

train_dataset = subset_function(train_dataset, config['subset'])

val_dataset = AudioMultiPhonationDataset(
    root_dir=config["val_dir"],
    phonation_types=phonation_list,
    transform=val_transforms
)

val_dataset = subset_function(val_dataset, config['subset'])

test_dataset = AudioMultiPhonationDataset_Test(
    root_dir=config["test_dir"],
    phonation_types=phonation_list,
    transform=test_transforms
)

test_dataset = subset_function(test_dataset, config['subset'])

# Balanced Batch Sampler

In [None]:
'''
# get labels for each id (0 indexed)
train_labels = torch.tensor([train_dataset.id_to_label[id] - 1 for id in train_dataset.subject_ids])

# get number of instances per class, create waits accordingly
class_counts = torch.bincount(train_labels)
class_weights = 1.0 / class_counts.float()

# assign weights for each sample according to class weights
sample_weights = class_weights[train_labels]

# create sampler for train loader using our weights
from torch.utils.data import WeightedRandomSampler
train_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)
'''

In [None]:
import random
from torch.utils.data import Sampler

# creating a balanced sampler so batches have roughly the same distribution

class BalancedBatchSampler(Sampler):

    def __init__(self, labels, batch_size=16, num_classes=5, max_batches_per_epoch=None):
        self.labels = labels
        self.batch_size = batch_size
        self.num_classes = num_classes

        # group indices by class
        self.class_to_indices = {c: [i for i, lbl in enumerate(labels) if lbl == c] for c in range(num_classes) }
        self.class_pos = {c: 0 for c in range(num_classes)}

        # per-batch target counts
        self.base_count = batch_size // num_classes
        self.remainder = batch_size % num_classes

        # if max batches not provided, use batch size
        natural_batches = len(labels) // batch_size
        self.max_batches = max_batches_per_epoch if max_batches_per_epoch is not None else natural_batches


    # iteration method for creating batches
    def __iter__(self):

        # iterate through the number of batches we specify
        for _ in range(self.max_batches):

            # add an extra sample if batch size isn't neatly divisible by number of classes
            extra_classes = random.sample(range(self.num_classes), self.remainder)
            batch_indices = []

            # pick samples class by class
            for c in range(self.num_classes):

                count = self.base_count + (1 if c in extra_classes else 0)

                for _ in range(count):

                    # cycle through clasa samples
                    idx = self.class_to_indices[c][self.class_pos[c] % len(self.class_to_indices[c])]
                    batch_indices.append(idx)
                    # track where we left off in cycle
                    self.class_pos[c] += 1

            # shuffle batch and return the indices
            random.shuffle(batch_indices)
            yield batch_indices

    def __len__(self):
        return self.max_batches



# Dataloaders

In [None]:
# Extract labels from the train_dataset
train_labels = [train_dataset.id_to_label[sid] - 1 for sid in train_dataset.subject_ids]
train_labels = torch.tensor(train_labels)

# 2) count appearances of each class
class_sample_count = torch.tensor([(train_labels == t).sum() for t in torch.unique(train_labels, sorted=True)])

# 3) inverse frequency = weight per class
class_weights = 1.0 / class_sample_count.float()

# 4) assign to each sample the weight of its class
# if your labels are zero-indexed (0..4), this works as-is:
sample_weights = class_weights[train_labels]

# create sampler instance for training dataloader
train_sampler = BalancedBatchSampler(train_labels, batch_size=config["batch_size"], num_classes=5)

# pass custom sampler to dataloader
train_loader = DataLoader(train_dataset, num_workers=4, pin_memory=True, batch_sampler=train_sampler)

val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False,  num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False,  num_workers=4, pin_memory=True)

# Inspect Distribution of Batches

In [None]:
from collections import Counter

# check the distributions of training batches and val batches
def inspect_batches(dataloader, num_batches=5):
    for i, (_, labels) in enumerate(dataloader):

        # convert to list of labels
        if labels.ndim > 1:
            labels = labels.squeeze()

        labels = labels.tolist()

        # count and print occurances of classes in a batch
        counts = Counter(labels)
        sorted_counts = {cls: counts.get(cls, 0) for cls in sorted(counts.keys())}
        print(f"Batch {i+1}: {sorted_counts}")

        # stopping point for reporting batch distributions
        if i + 1 == num_batches:
            break

# training batches should have roughly the same distribution
print('training batches:')
inspect_batches(train_loader, num_batches=5)

# val batches should not have the same distribution
print('\nval batches')
inspect_batches(val_loader, num_batches=5)


## Data Visualization

Show some samples, spectrograms, and listen to audio.

In [None]:
import matplotlib.pyplot as plt

# Testing code to check if your data loaders are working
for i, data in enumerate(val_loader):
    frames, labels = data
    # print(f"frames {frames.shape}")
    print(frames.shape, labels.shape)
    print(labels)

    # Visualize sample mfcc to inspect and verify everything is correctly done, especially augmentations
    plt.figure(figsize=(10, 6))
    plt.imshow(frames[0][0][0].numpy().T, aspect='auto', origin='lower', cmap='viridis')
    plt.xlabel('Time')
    plt.ylabel('Features')
    plt.title('Feature Representation')
    plt.show()

    break

## Model Architecture

In [None]:
for i in range(len(train_dataset)):
    x, y = train_dataset[i]
    width = x.shape[-1]
    break


import torch
import torch.nn as nn
from torchvision.models.vision_transformer import VisionTransformer, vit_b_16

class ViTNetwork(nn.Module):
    def __init__(self, num_classes, num_mels=64, time_steps=width,  # time_steps: width of spectrogram
                 embed_dim=768, num_heads=12, depth=6, patch_size=16,
                 aggregate="mean"):  # 'mean' or 'lstm'
        """
        Args:
            num_classes: number of ALS categories to classify
            num_mels: number of Mel bins (spectrogram height)
            time_steps: number of time frames per spectrogram (width)
            aggregate: how to combine multiple phonations ('mean' or 'lstm')
        """
        super().__init__()

        # --- Vision Transformer encoder ---
        # self.vit = VisionTransformer(
        #     image_size=(num_mels, time_steps),
        #     patch_size=patch_size,
        #     num_layers=depth,
        #     num_heads=num_heads,
        #     hidden_dim=embed_dim,
        #     mlp_dim=embed_dim * 4,
        #     num_classes=None  # We’ll take embeddings, not classification logits
        # )
        self.vit = timm.create_model(
            "vit_base_patch16_224",
            pretrained=False,
            num_classes=0,
            img_size=(num_mels, time_steps)
        )

        # --- Aggregation of multiple phonations (P) ---
        self.aggregate = aggregate
        if aggregate == "lstm":
            self.lstm = nn.LSTM(input_size=embed_dim,
                                hidden_size=embed_dim,
                                num_layers=1,
                                batch_first=True)
        elif aggregate == "mean":
            self.lstm = None

        # --- Final classifier ---
        self.cls_layer = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        """
        x: Tensor of shape (B, P, 1, Mels, Time)
           B = batch size, P = phonations per participant
        """
        B, P, C, M, T = x.shape
        x = x.view(B * P, C, M, T)

        # ViT expects (B, 3, H, W); replicate channel if needed
        if C == 1:
            x = x.repeat(1, 3, 1, 1)

        feats = self.vit(x)  # (B*P, embed_dim)
        feats = feats.view(B, P, -1)  # (B, P, embed_dim)

        # Aggregate across phonations
        if self.aggregate == "mean":
            pooled = feats.mean(dim=1)
        elif self.aggregate == "lstm":
            _, (h_n, _) = self.lstm(feats)
            pooled = h_n[-1]
        else:
            raise ValueError("aggregate must be 'mean' or 'lstm'")

        out = self.cls_layer(pooled)
        return out

In [None]:
import timm
model = ViTNetwork(num_classes=config['num_classes']).to(DEVICE)

In [None]:
x = torch.randn(2, 4, 1, 64, width).to(DEVICE)  # (B, P, 1, M, T)
model.eval()
with torch.no_grad():
    out = model(x)

print("Logits shape:", out.shape)      # Expected: (2, num_classes)

In [None]:
# --------------------------------------------------- #

# Defining Loss function
criterion = torch.nn.CrossEntropyLoss()

# --------------------------------------------------- #

# Defining Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

# --------------------------------------------------- #

# Defining Scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.8, patience=6, min_lr=1e-7)# TODO: Use a good scheduler such as ReduceLRonPlateau, StepLR, MultistepLR, CosineAnnealing, etc.

# --------------------------------------------------- #

# Initialising mixed-precision training. # Good news. We've already implemented FP16 (Mixed precision training) for you
# It is useful only in the case of compatible GPUs such as T4/V100
scaler = torch.cuda.amp.GradScaler()

## Training and Validation

In [None]:
def train(model, dataloader, optimizer, criterion):

    model.train()
    tloss, tacc = 0, 0 # Monitoring loss and accuracy
    batch_bar   = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train')

    for i, (audio_tensors, labels) in enumerate(dataloader):

        ### Initialize Gradients
        optimizer.zero_grad()

        audio_tensors      = audio_tensors.to(DEVICE)
        labels             = labels.to(DEVICE)

        with torch.autocast(device_type=DEVICE, dtype=torch.float16):
            ### Forward Propagation
            logits  = model(audio_tensors)

            ### Loss Calculation
            loss    = criterion(logits, labels)

        ### Backward Propagation
        scaler.scale(loss).backward()

        ### Gradient Descent
        scaler.step(optimizer)
        scaler.update()

        tloss   += loss.item()
        tacc    += torch.sum(torch.argmax(logits, dim= 1) == labels).item()/logits.shape[0]

        batch_bar.set_postfix(loss="{:.04f}".format(float(tloss / (i + 1))),
                              acc="{:.04f}%".format(float(tacc*100 / (i + 1))))
        batch_bar.update()

        ### Release memory
        del audio_tensors, labels, logits
        torch.cuda.empty_cache()


    batch_bar.close()
    tloss   /= len(train_loader)
    tacc    /= len(train_loader)


    return tloss, tacc

In [None]:
!pip install torchmetrics
from torchmetrics import F1Score

f1 = F1Score(num_classes=5, average='macro', task='multiclass').to(DEVICE)


In [None]:
def eval(model, dataloader):

    model.eval() # set model in evaluation mode
    vloss, vacc = 0, 0 # Monitoring loss and accuracy
    batch_bar   = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val')

    all_preds = []
    all_labels = []

    for i, (audio_tensors, labels) in enumerate(dataloader):

        ### Move data to device
        audio_tensors      = audio_tensors.to(DEVICE)
        labels             =  labels.to(DEVICE)

        # makes sure that there are no gradients computed as we are not training the model now
        with torch.inference_mode():
            ### Forward Propagation
            logits  = model(audio_tensors)
            ### Loss Calculation
            loss    = criterion(logits, labels)

            preds = torch.argmax(logits, dim=1)

        vloss   += loss.item()
        vacc    += torch.sum(preds == labels).item()/logits.shape[0]

        # accumulate predictions and labels
        all_preds.append(preds)
        all_labels.append(labels)

        batch_bar.set_postfix(loss="{:.04f}".format(float(vloss / (i + 1))),
                              acc="{:.04f}%".format(float(vacc*100 / (i + 1))))
        batch_bar.update()

        ### Release memory
        del audio_tensors, labels, logits, preds
        torch.cuda.empty_cache()

    batch_bar.close()
    vloss   /= len(val_loader)
    vacc    /= len(val_loader)

    # Compute F1 scores of all predictions
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    vf1 = f1(all_preds, all_labels)

    del all_preds, all_labels
    torch.cuda.empty_cache()

    return vloss, vacc, vf1

## Weights and Biases Setup

In [None]:
# add api key
wandb.login(key=None)

In [None]:
# Create your wandb run
RESUME_OLD_RUN = False

if RESUME_OLD_RUN == True:
    print("Resuming previous WanDB run...")
    run = wandb.init(
        name    = f"ablation {config['ablation_ID']}",
        #id     = None,
        resume = "must",
        project = "SAND",
        config  = config
    )
else:
    print("Initializing new WanDB run...")
    run = wandb.init(
        name    = f"ablation {config['ablation_ID']}",
        reinit  = True,
        project = None,
        config  = config
    )

In [None]:
### Save your model architecture as a string with str(model)
model_arch  = str(model)

### Save it in a txt file
arch_file   = open("model_arch.txt", "w")
file_write  = arch_file.write(model_arch)
arch_file.close()

### log it in your wandb run with wandb.save()
wandb.save('model_arch.txt')

# Experiment

In [None]:
# Iterate over number of epochs to train and evaluate your model
torch.cuda.empty_cache()
gc.collect()

best_val_f1 = 0.0
run_name = 'last_run'
os.makedirs(run_name, exist_ok=True)

for epoch in range(config['epochs']):

    print("\nEpoch {}/{}".format(epoch+1, config['epochs']))

    curr_lr                 = float(optimizer.param_groups[0]['lr'])
    train_loss, train_acc   = train(model, train_loader, optimizer, criterion)
    val_loss, val_acc, val_f1      = eval(model, val_loader)

    print("\tTrain Acc {:.04f}%\tTrain Loss {:.04f}\t Learning Rate {:.07f}".format(train_acc*100, train_loss, curr_lr))
    print("\tVal Acc {:.04f}%\tVal Loss {:.04f}".format(val_acc*100, val_loss))
    print(f'F1 Score: {val_f1}')

    ## Log metrics at each epoch in your run
    wandb.log({'train_acc': train_acc*100, 'train_loss': train_loss,
               'val_acc': val_acc*100, 'valid_loss': val_loss, 'lr': curr_lr, 'valid_F1': val_f1})

    # step scheduler using F1
    scheduler.step(val_f1)

    # save model if f1 is best we've seen
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1

        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_f1': best_val_f1,
            'train_acc': train_acc,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': config
        }

        # Save checkpoint locally
        checkpoint_path = f'best_model_epoch_{epoch+1}.pth'
        full_path = os.path.join(run_name, checkpoint_path)
        torch.save(checkpoint, full_path)

        # # Save checkpoint to wandb

        print(f"Saved new best model with Val f1: {val_f1} at {full_path}")

wandb.finish()

## Testing

In [None]:
'''
def test(model, test_loader):
    ### What you call for model to perform inference?
    model.eval() # TODO train or eval?

    all_preds = []

    ### Which mode do you need to avoid gradients?
    with torch.no_grad(): # TODO

        for i, audio_tensors in enumerate(tqdm(test_loader)):

            audio_tensors   = audio_tensors.to(DEVICE)

            logits  = model(audio_tensors)

            preds = torch.argmax(logits, dim=1)

            all_preds.append(preds)

    all_preds = torch.cat(all_preds)

    ## SANITY CHECK
    sample_predictions = all_preds[:10]

    # Print a preview of predictions for manual inspection
    print("\n length of predictions:", len(all_preds))
    print("\nPredictions Generated successfully!")

    return all_preds
'''

In [None]:
'''
# Generate model test predictions
predictions = test(model, test_loader)

all_labels = []
for audio_tensors, labels in test_loader:
    all_labels.append(labels)

all_labels = torch.cat(all_labels)

print(f'length of ground truths: {len(all_labels)}')

#from sklearn.metrics import f1_score
#vf1 = f1_score(all_labels.cpu().numpy(), predictions.cpu().numpy(), average='macro')  # or 'weighted' if class imbalance

#print(f'f1 score on val dataset: {vf1}')
'''

In [None]:
'''
# Create CSV file with predictions

#with open("./submission.csv", "w+") as f:
#    f.write("id,label\n")
#    for i in range(len(predictions)):
#        f.write("{},{}\n".format(i, predictions[i]))
#
#    print("submission.csv file created successfully!")
'''