In [1]:
import torch
import numpy as np
import wandb

from utils import parse_config, unflatten_dot
from dataset import get_loader
from learner import Learner
from models import ASTPretrained
from types import SimpleNamespace

In [2]:
SEED = 123
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [3]:
def main():
    
    wandb.init()
    config = wandb.config
    train_dl = get_loader(config, subset="train")
    valid_dl = get_loader(config, subset="valid")
    
    model = ASTPretrained(n_classes=11)

    learn = Learner(train_dl, valid_dl, model, config)
    learn.unfreeze()

    learn.fit(4)

In [None]:
import yaml
CONFIG_PATH = "./sweep_config.yaml"
with open(CONFIG_PATH) as file:
    config = yaml.safe_load(file)

sweep_id = wandb.sweep(sweep=config, project="IRMAS_allsync")
wandb.agent(sweep_id, function=main, count=10)

Create sweep with ID: 06f92x4a
Sweep URL: https://wandb.ai/k-pintaric/IRMAS_allsync/sweeps/06f92x4a


[34m[1mwandb[0m: Agent Starting Run: v0l3uixf with config:
[34m[1mwandb[0m: 	EPOCHS: 6
[34m[1mwandb[0m: 	LLRD: {'base_lr': 6.615080968612837e-05, 'lr_decay_rate': 0.9533553746801784, 'weight_decay': 0.00012384424737481054}
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	loss: {'FocalLoss': {'alpha': 0.5, 'gamma': 2}}
[34m[1mwandb[0m: 	metrics: ['hamming_score', 'zero_one_score', 'mAP', 'mean_f1_score']
[34m[1mwandb[0m: 	num_accum: 4
[34m[1mwandb[0m: 	optimizer: {'AdamW': {'weight_decay': 0}}
[34m[1mwandb[0m: 	preprocess: {'PreprocessPipeline': {'target_sr': 16000}}
[34m[1mwandb[0m: 	save_best_model: True
[34m[1mwandb[0m: 	scheduler: {'CosineAnnealingLR': {'eta_min': 0}}
[34m[1mwandb[0m: 	signal_augments: {'RepeatAudio': {'max_repeats': 3}}
[34m[1mwandb[0m: 	spec_augments: {'MaskFrequency': {'max_mask_length': 6}, 'MaskTime': {'max_mask_length': 77}}
[34m[1mwandb[0m: 	train_dir: ./data/processed/all_sync/IRMAS_Training_Data
[34m[1mwandb[0m:

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 1 | train_loss: 0.009 | val_loss: 0.066 |

hamming_score: 0.83
zero_one_score: 0.16
mAP: 0.73
mean_f1_score: 0.57


  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 2 | train_loss: 0.006 | val_loss: 0.058 |

hamming_score: 0.85
zero_one_score: 0.20
mAP: 0.82
mean_f1_score: 0.62


  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 3 | train_loss: 0.004 | val_loss: 0.089 |

hamming_score: 0.85
zero_one_score: 0.18
mAP: 0.80
mean_f1_score: 0.59


  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 4 | train_loss: 0.003 | val_loss: 0.103 |

hamming_score: 0.83
zero_one_score: 0.17
mAP: 0.76
mean_f1_score: 0.58


  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 5 | train_loss: 0.002 | val_loss: 0.113 |

hamming_score: 0.85
zero_one_score: 0.17
mAP: 0.78
mean_f1_score: 0.59


  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 6 | train_loss: 0.001 | val_loss: 0.115 |

hamming_score: 0.85
zero_one_score: 0.17
mAP: 0.78
mean_f1_score: 0.59


0,1
epoch,▁▂▄▅▇█
hamming_score,▂█▆▁▆▇
lr_param_group_0,███████▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
lr_param_group_1,███████▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
lr_param_group_10,███████▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
lr_param_group_11,███████▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
lr_param_group_12,███████▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
lr_param_group_13,███████▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
lr_param_group_14,███████▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
lr_param_group_15,███████▇▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
epoch,6.0
hamming_score,0.85154
lr_param_group_0,0.0
lr_param_group_1,0.0
lr_param_group_10,0.0
lr_param_group_11,0.0
lr_param_group_12,0.0
lr_param_group_13,0.0
lr_param_group_14,0.0
lr_param_group_15,0.0


[34m[1mwandb[0m: Agent Starting Run: wyuhqr6h with config:
[34m[1mwandb[0m: 	EPOCHS: 9
[34m[1mwandb[0m: 	LLRD: {'base_lr': 1.6107190646216428e-05, 'lr_decay_rate': 0.9278405218193506, 'weight_decay': 3.0868878398734965e-06}
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	loss: {'FocalLoss': {'alpha': -1, 'gamma': 3}}
[34m[1mwandb[0m: 	metrics: ['hamming_score', 'zero_one_score', 'mAP', 'mean_f1_score']
[34m[1mwandb[0m: 	num_accum: 4
[34m[1mwandb[0m: 	optimizer: {'AdamW': {'weight_decay': 0}}
[34m[1mwandb[0m: 	preprocess: {'PreprocessPipeline': {'target_sr': 16000}}
[34m[1mwandb[0m: 	save_best_model: True
[34m[1mwandb[0m: 	scheduler: {'CosineAnnealingLR': {'eta_min': 0}}
[34m[1mwandb[0m: 	signal_augments: {'RepeatAudio': {'max_repeats': 2}}
[34m[1mwandb[0m: 	spec_augments: {'MaskFrequency': {'max_mask_length': 27}, 'MaskTime': {'max_mask_length': 2}}
[34m[1mwandb[0m: 	train_dir: ./data/processed/all_sync/IRMAS_Training_Data
[34m[1mwandb[0m:

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 1 | train_loss: 0.010 | val_loss: 0.055 |

hamming_score: 0.86
zero_one_score: 0.20
mAP: 0.81
mean_f1_score: 0.63


  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 2 | train_loss: 0.007 | val_loss: 0.078 |

hamming_score: 0.84
zero_one_score: 0.19
mAP: 0.79
mean_f1_score: 0.60


  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 3 | train_loss: 0.006 | val_loss: 0.089 |

hamming_score: 0.84
zero_one_score: 0.21
mAP: 0.78
mean_f1_score: 0.61


  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 4 | train_loss: 0.005 | val_loss: 0.097 |

hamming_score: 0.85
zero_one_score: 0.20
mAP: 0.78
mean_f1_score: 0.61


  0%|          | 0/9220 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

| EPOCH: 5 | train_loss: 0.004 | val_loss: 0.105 |

hamming_score: 0.85
zero_one_score: 0.22
mAP: 0.77
mean_f1_score: 0.61


  0%|          | 0/9220 [00:00<?, ?it/s]