In [1]:
from modeling.learner import get_preds
from modeling import ASTPretrained
from modeling.dataset import get_loader
from modeling.utils import parse_config
from modeling.models import average_model_weights
from modeling.metrics import *

import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
%matplotlib inline

In [2]:
CONFIG_PATH = "../configs/config.yaml"
config = parse_config(CONFIG_PATH)

In [3]:
valid_dl = get_loader(config, subset="valid")
test_dl = get_loader(config, subset="test")

In [4]:
base_model = ASTPretrained(n_classes=11, download_weights=False)

In [5]:
def evaluate(weight_list: list, base_model: nn.Module):
    # Create empty lists to store the accuracy and F1 score for each weight on the validation and test sets
    val_mAPs, val_accs, val_f1s, test_mAPs, test_accs, test_f1s = [], [], [], [], [], []

    # Loop through each weight and calculate accuracy and F1 score on the validation and test sets
    weights = [f"../weights/{weight}.pth" for weight in weight_list]
    
    for w in weights:
        # Make predictions on the validation set using the current weight
        base_model.load_state_dict(torch.load(w))
        preds, targets = get_preds(valid_dl, base_model, device="cuda")

        # Calculate accuracy and F1 score on the validation set
        thresholds = optimize_accuracy(preds, targets)
        
        val_mAP = mean_average_precision(preds, targets)
        val_acc = hamming_score(preds, targets, thresholds)
        val_f1 = mean_f1_score(preds, targets, thresholds)

        # Make predictions on the test set using the current weight
        preds, targets = get_preds(test_dl, base_model, device="cuda")

        # Calculate accuracy and F1 score on the test set
        test_mAP = mean_average_precision(preds, targets)
        test_acc = hamming_score(preds, targets, thresholds)
        test_f1 = mean_f1_score(preds, targets, thresholds)

        # Append the results to the corresponding lists
        val_mAPs.append(val_mAP)
        val_accs.append(val_acc)
        val_f1s.append(val_f1)
        test_mAPs.append(test_mAP)
        test_accs.append(test_acc)
        test_f1s.append(test_f1)

    # Create a pandas dataframe to display the results
    df = pd.DataFrame({
        'Weight': weight_list,
        'Val Acc': val_accs,
        'Val mAP': val_mAPs,
        'Val F1': val_f1s,
        'Test Acc': test_accs,
        'Test mAP': test_mAPs,
        'Test F1': test_f1s
    })

    # Display the dataframe
    return df

In [6]:
weight_list = ["bpm_sync_bce", "pitch_sync_bce", "no_sync2_bce", "bpm_sync_focal", "pitch_sync_focal", "no_sync2_focal"]

In [7]:
results = evaluate(weight_list, base_model)
results

  0%|          | 0/187 [00:00<?, ?it/s]

  0%|          | 0/174 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

  0%|          | 0/174 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

  0%|          | 0/174 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

  0%|          | 0/174 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

  0%|          | 0/174 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

  0%|          | 0/174 [00:00<?, ?it/s]

Unnamed: 0,Weight,Val Acc,Val mAP,Val F1,Test Acc,Test mAP,Test F1
0,bpm_sync_bce,0.934672,0.865198,0.764908,0.918412,0.8268,0.69793
1,pitch_sync_bce,0.931986,0.861343,0.757578,0.913226,0.820174,0.681324
2,no_sync2_bce,0.935832,0.871443,0.771809,0.915917,0.834264,0.694079
3,bpm_sync_focal,0.933146,0.852915,0.756612,0.916902,0.823464,0.690036
4,pitch_sync_focal,0.931498,0.860144,0.758313,0.912504,0.821003,0.684958
5,no_sync2_focal,0.935161,0.867765,0.774281,0.916049,0.833635,0.699292


In [8]:
weights = [f"../weights/{weight}.pth" for weight in weight_list[:3]]

average_weight = average_model_weights(weights)
torch.save(average_weight, "../weights/averaged_weights_bce.pth")

In [9]:
results = evaluate(["averaged_weights_bce"], base_model)
results

  0%|          | 0/187 [00:00<?, ?it/s]

  0%|          | 0/174 [00:00<?, ?it/s]

Unnamed: 0,Weight,Val Acc,Val mAP,Val F1,Test Acc,Test mAP,Test F1
0,averaged_weights_bce,0.934489,0.870258,0.763884,0.916114,0.830385,0.686354


In [10]:
weights = [f"../weights/{weight}.pth" for weight in weight_list[3:]]

average_weight = average_model_weights(weights)
torch.save(average_weight, "../weights/averaged_weights_focal.pth")

In [11]:
results = evaluate(["averaged_weights_focal"], base_model)
results

  0%|          | 0/187 [00:00<?, ?it/s]

  0%|          | 0/174 [00:00<?, ?it/s]

Unnamed: 0,Weight,Val Acc,Val mAP,Val F1,Test Acc,Test mAP,Test F1
0,averaged_weights_focal,0.93394,0.866946,0.764757,0.916508,0.833178,0.691408


In [17]:
best_weight = torch.load("../weights/bpm_sync_bce.pth")
base_model.load_state_dict(best_weight)
preds, targets = get_preds(valid_dl, base_model, "cuda")
best_model_thresholds = optimize_accuracy(preds, targets)
np.save("acc_model_thresh", best_model_thresholds)

  0%|          | 0/187 [00:00<?, ?it/s]

In [18]:
example = valid_dl.dataset[0][0]

In [20]:
traced_model = torch.jit.trace(base_model.to("cpu"), example)
torch.jit.save(traced_model, "acc_model_ast.pt")

In [21]:
best_model_thresholds

array([0.7094, 0.5323, 0.5094, 0.6527, 0.7043, 0.5403, 0.5436, 0.4266,
       0.3344, 0.7884, 0.2298])