In [None]:
!pip install sox==1.4.1
!pip install PySoundFile==0.9.0.post1

In [None]:
import datetime
from contextlib import suppress

import tqdm
import torch
import torchaudio
import torchsummary
import torchmetrics
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import seaborn as sns
import matplotlib.pyplot as plt; plt.style.use('seaborn')
from sklearn import metrics

import utils

In [None]:
def get_writer(name:str):
    log_dir = f"logs/fit/esc50/{name}"
    writer = SummaryWriter(log_dir)
    return writer

@torch.no_grad()
def compute_metrics(y_pred, y_true):
    y_pred, y_true = y_pred.cpu(), y_true.cpu()
    num_classes = y_pred.shape[1]
    
    return {
        'Acc': float(torchmetrics.functional.accuracy(y_pred, y_true, top_k=1, num_classes=num_classes)),
        'Acc5': float(torchmetrics.functional.accuracy(y_pred, y_true, top_k=5, num_classes=num_classes)),
        'AUC': float(torchmetrics.functional.auroc(y_pred, y_true, num_classes=num_classes)),
    }

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
config = utils.tools.ConfigBase._get_config()
date = datetime.datetime.now().strftime("%Y.%m.%d - %H-%M")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

results = {}
models_names = []
for fold in range(5):
    torch.cuda.empty_cache()

    folds = [x for x in range(1, 6)]
    rm_fold = folds.pop(fold)

    name = f'Fold {rm_fold} {date}'
    writer = get_writer(name)
    # --------------------------------------------------------- #
    train_dataset = utils.datasets.ESCDataset(
        audio_length=5,
        folds=folds,
    )
    test_dataset = utils.datasets.ESCDataset(
        audio_length=5,
        folds=[rm_fold],
    )

    num_classes = train_dataset.meta['y'].max()+1
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config['train']['batch_size'],
        num_workers=0,
        drop_last=True,
        shuffle=True,
        collate_fn=train_dataset.collate_fn
    )
    test_loader = DataLoader(
        test_dataset, 
        batch_size=config['train']['batch_size'],
        num_workers=0,
        drop_last=False,
        shuffle=False,
        collate_fn=test_dataset.collate_fn
    )

    # --------------------------------------------------------- #
    for batch in DataLoader(
            train_dataset, 
            batch_size=1,
            num_workers=0,
            collate_fn=train_dataset.collate_fn
        ):
        x, y = batch
        break

    model = utils.nn.Model(x.shape, num_classes=num_classes).to('cuda')
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)

    if fold == 0:
        torchsummary.summary(model)

    trainer = utils.trainers.SupervisedClassification(
        model=model,
        device=device,
        optimizer=optimizer,
        loss_fn=torch.nn.CrossEntropyLoss(),
        accum_iter=config['train']['accum_iter'],
        grad_clip=config['train']['grad_clip'],
        metrics={
            "Acc": torchmetrics.Accuracy(num_classes=num_classes),
            "AUC": torchmetrics.AUROC(num_classes=num_classes),
        }
    )

    # --------------------------------------------------------- #
    wait = 0
    patience = 25

    epoch = 0
    best_score = -torch.inf
    best_metrics = None
    with suppress(KeyboardInterrupt):
        while wait < patience:
            # --------------------------------------------- #
            # Train
            train_loss = trainer.train_epoch(train_loader)
            writer.add_scalar('Acc/train', trainer.metrics['Acc'].compute(), epoch)
            writer.add_scalar('AUC/train', trainer.metrics['AUC'].compute(), epoch)

            # --------------------------------------------- #
            # Test    
            test_pred, test_true = trainer.evaluate(test_loader)
            test_pred = test_pred.softmax(dim=1)
            test_metrics = compute_metrics(test_pred, test_true)
            for _metric in test_metrics:
                writer.add_scalar(f'{_metric}/test', test_metrics[_metric], epoch)

            # --------------------------------------------- #
            # EarlyStopping
            wait, epoch = wait+1, epoch+1
            if test_metrics['AUC'] > best_score:
                torch.save(trainer.checkpoint(), f'output/weights/esc50/{name}.torch')
                best_score = test_metrics['AUC']
                best_metrics = test_metrics
                wait = 0

    models_names.append(name)
    results[rm_fold] = best_metrics

print(models_names)
print(results)