In [1]:
import torch
import numpy as np
import wandb

from modeling.dataset import get_loader
from modeling.learner import Learner
from modeling.models import ASTPretrained
from modeling.utils import parse_config

In [2]:
SEED = 123
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

In [3]:
def main(config):
    
    
    wandb.init(config=config, anonymous="allow")
    
    train_dl = get_loader(config, subset="train")
    valid_dl = get_loader(config, subset="valid")
    
    model = ASTPretrained(n_classes=11, dropout=config.dropout)

    learn = Learner(train_dl, valid_dl, model, config)

    learn.fit()
    
    wandb.finish()

In [4]:
CONFIG_PATH = "../configs/config.yaml"
config = parse_config(CONFIG_PATH)

main(config)

[34m[1mwandb[0m: Currently logged in as: [33mk-pintaric[0m. Use [1m`wandb login --relogin`[0m to force relogin


Some weights of the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 were not used when initializing ASTModel: ['classifier.layernorm.bias', 'classifier.dense.weight', 'classifier.layernorm.weight', 'classifier.dense.bias']
- This IS expected if you are initializing ASTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ASTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

| EPOCH: 1 | train_loss: 0.142 | val_loss: 0.548 |

mean_average_precision: 0.42


  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

| EPOCH: 2 | train_loss: 0.099 | val_loss: 0.480 |

mean_average_precision: 0.46


  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

| EPOCH: 3 | train_loss: 0.081 | val_loss: 0.441 |

mean_average_precision: 0.52


  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

| EPOCH: 4 | train_loss: 0.073 | val_loss: 0.416 |

mean_average_precision: 0.56


  0%|          | 0/839 [00:00<?, ?it/s]

  0%|          | 0/187 [00:00<?, ?it/s]

| EPOCH: 5 | train_loss: 0.069 | val_loss: 0.397 |

mean_average_precision: 0.60


VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▃▅▆█
lr_param_group_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_average_precision,▁▃▅▆█
test_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▄▂▁▁
train_loss_per_batch,██▇▆▆▅▅▄▄▄▄▃▃▃▃▂▃▂▂▂▂▂▁▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
train_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_loss,█▅▃▂▁
valid_loss_per_batch,▇▇▆▆▆▆█▇▄▄▄▄▄▅▅▅▄▃▃▄▄▃▄▃▃▃▄▃▄▃▄▃▂▁▃▄▂▂▁▁

0,1
epoch,5.0
lr_param_group_0,1e-05
mean_average_precision,0.60021
test_step,935.0
train_loss,0.06853
train_loss_per_batch,0.06165
train_step,4195.0
val_loss,0.39742
valid_loss_per_batch,0.4021
