# Multi-task SELD using Conformer

In [1]:
import torch
from torchinfo import summary

from datasets import create_dataloaders
from multi_task import MultiTaskSELD, train_model

SAMPLE_RATE = 24000
FRAME_LENGTH = SAMPLE_RATE // 10
MAX_EVENTS = 5
NUM_CLASSES = 13
BATCH_SIZE = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [2]:
model_args = {
    'backbone': 'conformer',
    'num_classes': NUM_CLASSES,
    'num_events': MAX_EVENTS,
    'input_dim': 7,
    'dropout': 0.05,
    'hidden_dim': 64,
}
train_dataloader, _ = create_dataloaders(BATCH_SIZE)
model = MultiTaskSELD(**model_args)
features, *_ = next(iter(train_dataloader))
summary(model, input_data=[features])

Layer (type:depth-idx)                                  Output Shape              Param #
MultiTaskSELD                                           [128, 50, 13, 4]          --
├─SELDConformerBackbone: 1-1                            [128, 50, 128]            --
│    └─ResNet: 2-1                                      [128, 96, 8, 250]         --
│    │    └─ResBlock: 3-1                               [128, 24, 32, 250]        7,080
│    │    └─ResBlock: 3-2                               [128, 48, 16, 250]        32,688
│    │    └─ResBlock: 3-3                               [128, 96, 8, 250]         129,888
│    └─Linear: 2-2                                      [128, 250, 128]           98,432
│    └─ModuleList: 2-3                                  --                        --
│    │    └─Conformer: 3-4                              [128, 250, 128]           381,056
│    │    └─Conformer: 3-5                              [128, 250, 128]           381,056
│    │    └─Conformer: 3-6        

## Evaluation

In [3]:
model_args = {
    'backbone': 'conformer',
    'num_classes': NUM_CLASSES,
    'num_events': MAX_EVENTS,
    'input_dim': 7,
    'hidden_dim': 64,
    'dropout': 0.05,
}
train_dataloader, test_dataloader = create_dataloaders(BATCH_SIZE)
model = train_model(
    model_args,
    train_dataloader,
    test_dataloader,
    epochs=250,
    device=device,
    sde_weight=0.0,
)

Epoch 1/250: 100%|██████████| 24/24 [00:24<00:00,  1.03s/it, loss=0.0424, test_loss=0.0315]
Epoch 2/250: 100%|██████████| 24/24 [00:23<00:00,  1.01it/s, loss=0.0239, test_loss=0.03]
Epoch 3/250: 100%|██████████| 24/24 [00:23<00:00,  1.03it/s, loss=0.0223, test_loss=0.0304]
Epoch 4/250: 100%|██████████| 24/24 [00:23<00:00,  1.01it/s, loss=0.0211, test_loss=0.0322]
Epoch 5/250: 100%|██████████| 24/24 [00:23<00:00,  1.01it/s, loss=0.0192, test_loss=0.0294]
Epoch 6/250: 100%|██████████| 24/24 [00:24<00:00,  1.04s/it, loss=0.0176, test_loss=0.0296]
Epoch 7/250: 100%|██████████| 24/24 [00:23<00:00,  1.00it/s, loss=0.016, test_loss=0.0294]
Epoch 8/250: 100%|██████████| 24/24 [00:24<00:00,  1.04s/it, loss=0.014, test_loss=0.0304]
Epoch 9/250: 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, loss=0.0128, test_loss=0.0302]
Epoch 10/250: 100%|██████████| 24/24 [00:26<00:00,  1.11s/it, loss=0.0118, test_loss=0.0271]


Macro: ER=0.77, F=0.09, LE=133.72, LR=0.18
Micro: ER=0.77, F=0.25, LE=29.65, LR=0.44


Epoch 11/250: 100%|██████████| 24/24 [00:27<00:00,  1.15s/it, loss=0.0105, test_loss=0.0285]
Epoch 12/250: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.00984, test_loss=0.0257]
Epoch 13/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00901, test_loss=0.029]
Epoch 14/250: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.00881, test_loss=0.0272]
Epoch 15/250: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.00828, test_loss=0.027]
Epoch 16/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00796, test_loss=0.0264]
Epoch 17/250: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.00754, test_loss=0.0262]
Epoch 18/250: 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, loss=0.00716, test_loss=0.0258]
Epoch 19/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00717, test_loss=0.0273]
Epoch 20/250: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.00687, test_loss=0.0264]


Macro: ER=0.73, F=0.14, LE=113.33, LR=0.24
Micro: ER=0.73, F=0.31, LE=30.73, LR=0.53


Epoch 21/250: 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, loss=0.00653, test_loss=0.0262]
Epoch 22/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00633, test_loss=0.0276]
Epoch 23/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00624, test_loss=0.0261]
Epoch 24/250: 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, loss=0.00601, test_loss=0.0255]
Epoch 25/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00595, test_loss=0.0261]
Epoch 26/250: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.00571, test_loss=0.0254]
Epoch 27/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00565, test_loss=0.0252]
Epoch 28/250: 100%|██████████| 24/24 [00:27<00:00,  1.14s/it, loss=0.00564, test_loss=0.0263]
Epoch 29/250: 100%|██████████| 24/24 [00:27<00:00,  1.13s/it, loss=0.00546, test_loss=0.0243]
Epoch 30/250: 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, loss=0.00526, test_loss=0.0259]


Macro: ER=0.76, F=0.15, LE=77.90, LR=0.28
Micro: ER=0.76, F=0.31, LE=30.74, LR=0.60


Epoch 31/250: 100%|██████████| 24/24 [00:27<00:00,  1.14s/it, loss=0.00524, test_loss=0.026]
Epoch 32/250: 100%|██████████| 24/24 [00:26<00:00,  1.10s/it, loss=0.00499, test_loss=0.0257]
Epoch 33/250: 100%|██████████| 24/24 [00:26<00:00,  1.11s/it, loss=0.00484, test_loss=0.0243]
Epoch 34/250: 100%|██████████| 24/24 [00:27<00:00,  1.13s/it, loss=0.0047, test_loss=0.0254]
Epoch 35/250: 100%|██████████| 24/24 [00:26<00:00,  1.10s/it, loss=0.00461, test_loss=0.0241]
Epoch 36/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00458, test_loss=0.0267]
Epoch 37/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00467, test_loss=0.0267]
Epoch 38/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00457, test_loss=0.0267]
Epoch 39/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.0045, test_loss=0.0261]
Epoch 40/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00437, test_loss=0.0245]


Macro: ER=0.73, F=0.19, LE=80.45, LR=0.34
Micro: ER=0.73, F=0.35, LE=28.11, LR=0.61


Epoch 41/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00415, test_loss=0.0252]
Epoch 42/250: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.00401, test_loss=0.025]
Epoch 43/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00396, test_loss=0.0256]
Epoch 44/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00396, test_loss=0.0268]
Epoch 45/250: 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, loss=0.00398, test_loss=0.0252]
Epoch 46/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00393, test_loss=0.0252]
Epoch 47/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00377, test_loss=0.0276]
Epoch 48/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.0038, test_loss=0.0266]
Epoch 49/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00369, test_loss=0.0252]
Epoch 50/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00366, test_loss=0.0253]


Macro: ER=0.78, F=0.18, LE=67.90, LR=0.37
Micro: ER=0.78, F=0.32, LE=30.87, LR=0.64


Epoch 51/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.0036, test_loss=0.0258]
Epoch 52/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00347, test_loss=0.0259]
Epoch 53/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00354, test_loss=0.0262]
Epoch 54/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00357, test_loss=0.0265]
Epoch 55/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00352, test_loss=0.0256]
Epoch 56/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00331, test_loss=0.0266]
Epoch 57/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00309, test_loss=0.0249]
Epoch 58/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00321, test_loss=0.0264]
Epoch 59/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00314, test_loss=0.0264]
Epoch 60/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00301, test_loss=0.0248]


Macro: ER=0.77, F=0.18, LE=44.99, LR=0.36
Micro: ER=0.77, F=0.34, LE=29.44, LR=0.63


Epoch 61/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00289, test_loss=0.0259]
Epoch 62/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00289, test_loss=0.0267]
Epoch 63/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00294, test_loss=0.0265]
Epoch 64/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00293, test_loss=0.0261]
Epoch 65/250: 100%|██████████| 24/24 [00:28<00:00,  1.21s/it, loss=0.0029, test_loss=0.0277]
Epoch 66/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00279, test_loss=0.0269]
Epoch 67/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00271, test_loss=0.0265]
Epoch 68/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00273, test_loss=0.026]
Epoch 69/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00264, test_loss=0.0252]
Epoch 70/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00267, test_loss=0.0267]


Macro: ER=0.80, F=0.19, LE=47.58, LR=0.40
Micro: ER=0.80, F=0.32, LE=31.31, LR=0.62


Epoch 71/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00257, test_loss=0.0257]
Epoch 72/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00248, test_loss=0.0264]
Epoch 73/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00238, test_loss=0.0264]
Epoch 74/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00241, test_loss=0.0256]
Epoch 75/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00243, test_loss=0.026]
Epoch 76/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00248, test_loss=0.0271]
Epoch 77/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00233, test_loss=0.027]
Epoch 78/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00229, test_loss=0.0257]
Epoch 79/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.0023, test_loss=0.026]
Epoch 80/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00234, test_loss=0.027]


Macro: ER=0.80, F=0.19, LE=47.96, LR=0.41
Micro: ER=0.80, F=0.33, LE=32.53, LR=0.64


Epoch 81/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00232, test_loss=0.0275]
Epoch 82/250: 100%|██████████| 24/24 [00:29<00:00,  1.22s/it, loss=0.00219, test_loss=0.0267]
Epoch 83/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00213, test_loss=0.0255]
Epoch 84/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00219, test_loss=0.0254]
Epoch 85/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00216, test_loss=0.028]
Epoch 86/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00205, test_loss=0.0256]
Epoch 87/250: 100%|██████████| 24/24 [00:28<00:00,  1.21s/it, loss=0.00197, test_loss=0.0254]
Epoch 88/250: 100%|██████████| 24/24 [00:29<00:00,  1.21s/it, loss=0.00194, test_loss=0.0265]
Epoch 89/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00196, test_loss=0.0252]
Epoch 90/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00196, test_loss=0.0263]


Macro: ER=0.88, F=0.18, LE=50.07, LR=0.42
Micro: ER=0.88, F=0.30, LE=30.03, LR=0.63


Epoch 91/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00193, test_loss=0.0271]
Epoch 92/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.0019, test_loss=0.0261]
Epoch 93/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00184, test_loss=0.0257]
Epoch 94/250: 100%|██████████| 24/24 [00:27<00:00,  1.16s/it, loss=0.00183, test_loss=0.026]
Epoch 95/250: 100%|██████████| 24/24 [00:26<00:00,  1.12s/it, loss=0.00185, test_loss=0.0255]
Epoch 96/250: 100%|██████████| 24/24 [00:27<00:00,  1.14s/it, loss=0.00184, test_loss=0.0268]
Epoch 97/250: 100%|██████████| 24/24 [00:27<00:00,  1.13s/it, loss=0.00178, test_loss=0.027]
Epoch 98/250: 100%|██████████| 24/24 [00:26<00:00,  1.12s/it, loss=0.00174, test_loss=0.0262]
Epoch 99/250: 100%|██████████| 24/24 [00:27<00:00,  1.14s/it, loss=0.00195, test_loss=0.0294]
Epoch 100/250: 100%|██████████| 24/24 [00:27<00:00,  1.13s/it, loss=0.00193, test_loss=0.0278]


Macro: ER=0.85, F=0.20, LE=45.93, LR=0.42
Micro: ER=0.85, F=0.33, LE=28.92, LR=0.61


Epoch 101/250: 100%|██████████| 24/24 [00:27<00:00,  1.14s/it, loss=0.00181, test_loss=0.0282]
Epoch 102/250: 100%|██████████| 24/24 [00:28<00:00,  1.17s/it, loss=0.0017, test_loss=0.0259]
Epoch 103/250: 100%|██████████| 24/24 [00:29<00:00,  1.22s/it, loss=0.00158, test_loss=0.0259]
Epoch 104/250: 100%|██████████| 24/24 [00:29<00:00,  1.21s/it, loss=0.0015, test_loss=0.0264]
Epoch 105/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00149, test_loss=0.0267]
Epoch 106/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.00149, test_loss=0.0267]
Epoch 107/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.00142, test_loss=0.0266]
Epoch 108/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00142, test_loss=0.0268]
Epoch 109/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00144, test_loss=0.0266]
Epoch 110/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.00143, test_loss=0.0265]


Macro: ER=0.80, F=0.20, LE=46.99, LR=0.41
Micro: ER=0.80, F=0.33, LE=29.09, LR=0.61


Epoch 111/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.00139, test_loss=0.0267]
Epoch 112/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.00141, test_loss=0.0273]
Epoch 113/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.00137, test_loss=0.026]
Epoch 114/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.00136, test_loss=0.0272]
Epoch 115/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.00138, test_loss=0.0261]
Epoch 116/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.00134, test_loss=0.0256]
Epoch 117/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.00131, test_loss=0.0272]
Epoch 118/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.00134, test_loss=0.0266]
Epoch 119/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.00134, test_loss=0.0267]
Epoch 120/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.0013, test_loss=0.0266]


Macro: ER=0.87, F=0.20, LE=46.20, LR=0.40
Micro: ER=0.87, F=0.32, LE=30.62, LR=0.64


Epoch 121/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.00125, test_loss=0.0252]
Epoch 122/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.00119, test_loss=0.0265]
Epoch 123/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.00117, test_loss=0.0259]
Epoch 124/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.00112, test_loss=0.0259]
Epoch 125/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.00112, test_loss=0.0257]
Epoch 126/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.00107, test_loss=0.0275]
Epoch 127/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.00109, test_loss=0.0276]
Epoch 128/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.00108, test_loss=0.0256]
Epoch 129/250: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=0.00106, test_loss=0.0258]
Epoch 130/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.00102, test_loss=0.0269]


Macro: ER=0.87, F=0.19, LE=47.48, LR=0.44
Micro: ER=0.87, F=0.31, LE=31.98, LR=0.67


Epoch 131/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.00108, test_loss=0.0265]
Epoch 132/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.00103, test_loss=0.0268]
Epoch 133/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.00106, test_loss=0.0281]
Epoch 134/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.0014, test_loss=0.027]
Epoch 135/250: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=0.00129, test_loss=0.0266]
Epoch 136/250: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=0.00126, test_loss=0.0275]
Epoch 137/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.00124, test_loss=0.0271]
Epoch 138/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.0012, test_loss=0.0267]
Epoch 139/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.00112, test_loss=0.0273]
Epoch 140/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.00106, test_loss=0.028]


Macro: ER=0.89, F=0.19, LE=47.87, LR=0.43
Micro: ER=0.89, F=0.29, LE=32.75, LR=0.66


Epoch 141/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.000986, test_loss=0.0269]
Epoch 142/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.00102, test_loss=0.0267]
Epoch 143/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000936, test_loss=0.0276]
Epoch 144/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.000902, test_loss=0.0267]
Epoch 145/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.000965, test_loss=0.0298]
Epoch 146/250: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=0.00124, test_loss=0.0276]
Epoch 147/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.00116, test_loss=0.0284]
Epoch 148/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.0011, test_loss=0.0277]
Epoch 149/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.00102, test_loss=0.0269]
Epoch 150/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000955, test_loss=0.0282]


Macro: ER=0.89, F=0.17, LE=59.47, LR=0.41
Micro: ER=0.89, F=0.30, LE=29.42, LR=0.61


Epoch 151/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.000947, test_loss=0.0279]
Epoch 152/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000904, test_loss=0.0279]
Epoch 153/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.000854, test_loss=0.027]
Epoch 154/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000844, test_loss=0.028]
Epoch 155/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.000833, test_loss=0.0283]
Epoch 156/250: 100%|██████████| 24/24 [00:31<00:00,  1.31s/it, loss=0.000789, test_loss=0.0276]
Epoch 157/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000755, test_loss=0.0276]
Epoch 158/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000778, test_loss=0.0289]
Epoch 159/250: 100%|██████████| 24/24 [00:31<00:00,  1.30s/it, loss=0.000769, test_loss=0.0286]
Epoch 160/250: 100%|██████████| 24/24 [00:30<00:00,  1.29s/it, loss=0.000765, test_loss=0.027]


Macro: ER=0.83, F=0.19, LE=44.11, LR=0.43
Micro: ER=0.83, F=0.32, LE=28.21, LR=0.62


Epoch 161/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000734, test_loss=0.0273]
Epoch 162/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.000681, test_loss=0.0262]
Epoch 163/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.000661, test_loss=0.0265]
Epoch 164/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.000662, test_loss=0.0264]
Epoch 165/250: 100%|██████████| 24/24 [00:31<00:00,  1.30s/it, loss=0.000646, test_loss=0.027]
Epoch 166/250: 100%|██████████| 24/24 [00:31<00:00,  1.29s/it, loss=0.000609, test_loss=0.0271]
Epoch 167/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.00059, test_loss=0.0271]
Epoch 168/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.000579, test_loss=0.027]
Epoch 169/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.000584, test_loss=0.0271]
Epoch 170/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.000596, test_loss=0.0271]


Macro: ER=0.84, F=0.20, LE=56.47, LR=0.43
Micro: ER=0.84, F=0.32, LE=29.86, LR=0.63


Epoch 171/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.000583, test_loss=0.0272]
Epoch 172/250: 100%|██████████| 24/24 [00:30<00:00,  1.25s/it, loss=0.000582, test_loss=0.0268]
Epoch 173/250: 100%|██████████| 24/24 [00:31<00:00,  1.30s/it, loss=0.000649, test_loss=0.0279]
Epoch 174/250: 100%|██████████| 24/24 [00:30<00:00,  1.29s/it, loss=0.000777, test_loss=0.0278]
Epoch 175/250: 100%|██████████| 24/24 [00:31<00:00,  1.30s/it, loss=0.00086, test_loss=0.0289]
Epoch 176/250: 100%|██████████| 24/24 [00:31<00:00,  1.29s/it, loss=0.000889, test_loss=0.0273]
Epoch 177/250: 100%|██████████| 24/24 [00:31<00:00,  1.31s/it, loss=0.00117, test_loss=0.0339]
Epoch 178/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.00177, test_loss=0.0273]
Epoch 179/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.00253, test_loss=0.028]
Epoch 180/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.00373, test_loss=0.0316]


Macro: ER=0.97, F=0.13, LE=59.89, LR=0.36
Micro: ER=0.97, F=0.25, LE=31.68, LR=0.55


Epoch 181/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.0032, test_loss=0.027]
Epoch 182/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.0027, test_loss=0.0261]
Epoch 183/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.00212, test_loss=0.0263]
Epoch 184/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.00147, test_loss=0.0274]
Epoch 185/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.00115, test_loss=0.0276]
Epoch 186/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000895, test_loss=0.0265]
Epoch 187/250: 100%|██████████| 24/24 [00:30<00:00,  1.29s/it, loss=0.000792, test_loss=0.0265]
Epoch 188/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.00072, test_loss=0.0267]
Epoch 189/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000647, test_loss=0.0266]
Epoch 190/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.000613, test_loss=0.0265]


Macro: ER=0.84, F=0.18, LE=45.92, LR=0.44
Micro: ER=0.84, F=0.32, LE=27.70, LR=0.63


Epoch 191/250: 100%|██████████| 24/24 [00:30<00:00,  1.29s/it, loss=0.00057, test_loss=0.0266]
Epoch 192/250: 100%|██████████| 24/24 [00:30<00:00,  1.29s/it, loss=0.000529, test_loss=0.0269]
Epoch 193/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.000513, test_loss=0.027]
Epoch 194/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000484, test_loss=0.0272]
Epoch 195/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.000483, test_loss=0.0265]
Epoch 196/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.000484, test_loss=0.0268]
Epoch 197/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.000466, test_loss=0.0266]
Epoch 198/250: 100%|██████████| 24/24 [00:29<00:00,  1.21s/it, loss=0.000456, test_loss=0.0269]
Epoch 199/250: 100%|██████████| 24/24 [00:28<00:00,  1.21s/it, loss=0.000433, test_loss=0.0265]
Epoch 200/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.000431, test_loss=0.0268]


Macro: ER=0.83, F=0.19, LE=46.03, LR=0.43
Micro: ER=0.83, F=0.33, LE=27.55, LR=0.62


Epoch 201/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.000436, test_loss=0.0263]
Epoch 202/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.000466, test_loss=0.0264]
Epoch 203/250: 100%|██████████| 24/24 [00:28<00:00,  1.21s/it, loss=0.000449, test_loss=0.0265]
Epoch 204/250: 100%|██████████| 24/24 [00:29<00:00,  1.21s/it, loss=0.000442, test_loss=0.0273]
Epoch 205/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.000431, test_loss=0.0265]
Epoch 206/250: 100%|██████████| 24/24 [00:28<00:00,  1.21s/it, loss=0.00044, test_loss=0.0274]
Epoch 207/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.000423, test_loss=0.0265]
Epoch 208/250: 100%|██████████| 24/24 [00:28<00:00,  1.18s/it, loss=0.000458, test_loss=0.0269]
Epoch 209/250: 100%|██████████| 24/24 [00:29<00:00,  1.22s/it, loss=0.000591, test_loss=0.0289]
Epoch 210/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000565, test_loss=0.028]


Macro: ER=0.89, F=0.16, LE=56.65, LR=0.41
Micro: ER=0.89, F=0.30, LE=28.43, LR=0.62


Epoch 211/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000504, test_loss=0.0279]
Epoch 212/250: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=0.000518, test_loss=0.0265]
Epoch 213/250: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=0.000478, test_loss=0.0282]
Epoch 214/250: 100%|██████████| 24/24 [00:29<00:00,  1.22s/it, loss=0.00044, test_loss=0.0276]
Epoch 215/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.000429, test_loss=0.0276]
Epoch 216/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000417, test_loss=0.0281]
Epoch 217/250: 100%|██████████| 24/24 [00:29<00:00,  1.21s/it, loss=0.000392, test_loss=0.0267]
Epoch 218/250: 100%|██████████| 24/24 [00:30<00:00,  1.29s/it, loss=0.00037, test_loss=0.0265]
Epoch 219/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000355, test_loss=0.0273]
Epoch 220/250: 100%|██████████| 24/24 [00:30<00:00,  1.26s/it, loss=0.000356, test_loss=0.0274]


Macro: ER=0.86, F=0.18, LE=57.00, LR=0.43
Micro: ER=0.86, F=0.32, LE=27.71, LR=0.62


Epoch 221/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.000358, test_loss=0.0267]
Epoch 222/250: 100%|██████████| 24/24 [00:29<00:00,  1.21s/it, loss=0.000347, test_loss=0.0276]
Epoch 223/250: 100%|██████████| 24/24 [00:28<00:00,  1.21s/it, loss=0.000373, test_loss=0.0271]
Epoch 224/250: 100%|██████████| 24/24 [00:30<00:00,  1.29s/it, loss=0.000375, test_loss=0.0281]
Epoch 225/250: 100%|██████████| 24/24 [00:30<00:00,  1.27s/it, loss=0.00037, test_loss=0.0262]
Epoch 226/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.00038, test_loss=0.0275]
Epoch 227/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000379, test_loss=0.0272]
Epoch 228/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000394, test_loss=0.0274]
Epoch 229/250: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=0.000388, test_loss=0.027]
Epoch 230/250: 100%|██████████| 24/24 [00:30<00:00,  1.29s/it, loss=0.000391, test_loss=0.0265]


Macro: ER=0.86, F=0.18, LE=45.96, LR=0.44
Micro: ER=0.86, F=0.31, LE=27.56, LR=0.62


Epoch 231/250: 100%|██████████| 24/24 [00:30<00:00,  1.28s/it, loss=0.000393, test_loss=0.0268]
Epoch 232/250: 100%|██████████| 24/24 [00:29<00:00,  1.21s/it, loss=0.000375, test_loss=0.0273]
Epoch 233/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.000375, test_loss=0.0269]
Epoch 234/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.000364, test_loss=0.0272]
Epoch 235/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.000363, test_loss=0.0268]
Epoch 236/250: 100%|██████████| 24/24 [00:29<00:00,  1.21s/it, loss=0.000463, test_loss=0.0276]
Epoch 237/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.000475, test_loss=0.0267]
Epoch 238/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.000475, test_loss=0.0268]
Epoch 239/250: 100%|██████████| 24/24 [00:28<00:00,  1.19s/it, loss=0.000491, test_loss=0.0266]
Epoch 240/250: 100%|██████████| 24/24 [00:29<00:00,  1.22s/it, loss=0.000517, test_loss=0.0279]


Macro: ER=0.94, F=0.20, LE=45.78, LR=0.47
Micro: ER=0.94, F=0.30, LE=28.94, LR=0.66


Epoch 241/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.000538, test_loss=0.0274]
Epoch 242/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000506, test_loss=0.0269]
Epoch 243/250: 100%|██████████| 24/24 [00:29<00:00,  1.25s/it, loss=0.000543, test_loss=0.0275]
Epoch 244/250: 100%|██████████| 24/24 [00:28<00:00,  1.20s/it, loss=0.000594, test_loss=0.0251]
Epoch 245/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000698, test_loss=0.0273]
Epoch 246/250: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=0.000656, test_loss=0.027]
Epoch 247/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000605, test_loss=0.0283]
Epoch 248/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000712, test_loss=0.0293]
Epoch 249/250: 100%|██████████| 24/24 [00:29<00:00,  1.23s/it, loss=0.000822, test_loss=0.0269]
Epoch 250/250: 100%|██████████| 24/24 [00:29<00:00,  1.24s/it, loss=0.000924, test_loss=0.0302]

Macro: ER=0.97, F=0.19, LE=43.94, LR=0.48
Micro: ER=0.97, F=0.29, LE=31.27, LR=0.66
Micro @ epoch 68: ER=0.78, F=0.21, LE=56.28, LR=0.38
Macro @ epoch 68: ER=0.78, F=0.34, LE=30.98, LR=0.61





### Data augmentation

In [3]:
model_args = {
    'backbone': 'conformer',
    'num_classes': NUM_CLASSES,
    'num_events': MAX_EVENTS,
    'input_dim': 7,
    'hidden_dim': 64,
    'dropout': 0.05,
}
train_dataloader, test_dataloader = create_dataloaders(
    BATCH_SIZE, augments=[0, 1, 2, 3, 4, 5, 6, 7]
)
model = train_model(
    model_args,
    train_dataloader,
    test_dataloader,
    epochs=40,
    device=device,
    sde_weight=0.0,
)

Epoch 1/40: 100%|██████████| 189/189 [02:14<00:00,  1.41it/s, loss=0.025, test_loss=0.0291]
Epoch 2/40: 100%|██████████| 189/189 [02:16<00:00,  1.39it/s, loss=0.0137, test_loss=0.024]
Epoch 3/40: 100%|██████████| 189/189 [02:17<00:00,  1.38it/s, loss=0.00987, test_loss=0.0234]
Epoch 4/40: 100%|██████████| 189/189 [02:18<00:00,  1.36it/s, loss=0.00856, test_loss=0.0219]
Epoch 5/40: 100%|██████████| 189/189 [02:17<00:00,  1.38it/s, loss=0.00785, test_loss=0.022]
Epoch 6/40: 100%|██████████| 189/189 [02:17<00:00,  1.38it/s, loss=0.00703, test_loss=0.0207]
Epoch 7/40: 100%|██████████| 189/189 [02:19<00:00,  1.35it/s, loss=0.00672, test_loss=0.0193]
Epoch 8/40: 100%|██████████| 189/189 [02:16<00:00,  1.39it/s, loss=0.00617, test_loss=0.0186]
Epoch 9/40: 100%|██████████| 189/189 [02:16<00:00,  1.39it/s, loss=0.00588, test_loss=0.02]
Epoch 10/40: 100%|██████████| 189/189 [02:16<00:00,  1.38it/s, loss=0.00575, test_loss=0.0196]


Macro: ER=0.62, F=0.18, LE=106.51, LR=0.28
Micro: ER=0.62, F=0.45, LE=20.75, LR=0.65


Epoch 11/40: 100%|██████████| 189/189 [02:13<00:00,  1.41it/s, loss=0.00571, test_loss=0.0179]
Epoch 12/40: 100%|██████████| 189/189 [02:15<00:00,  1.39it/s, loss=0.00522, test_loss=0.0195]
Epoch 13/40: 100%|██████████| 189/189 [02:16<00:00,  1.38it/s, loss=0.00492, test_loss=0.0195]
Epoch 14/40: 100%|██████████| 189/189 [02:12<00:00,  1.43it/s, loss=0.00471, test_loss=0.0198]
Epoch 15/40: 100%|██████████| 189/189 [02:16<00:00,  1.38it/s, loss=0.00442, test_loss=0.0193]
Epoch 16/40: 100%|██████████| 189/189 [02:16<00:00,  1.38it/s, loss=0.00425, test_loss=0.0211]
Epoch 17/40: 100%|██████████| 189/189 [02:17<00:00,  1.38it/s, loss=0.00419, test_loss=0.0206]
Epoch 18/40: 100%|██████████| 189/189 [02:16<00:00,  1.38it/s, loss=0.00392, test_loss=0.0197]
Epoch 19/40: 100%|██████████| 189/189 [02:17<00:00,  1.38it/s, loss=0.00378, test_loss=0.0206]
Epoch 20/40: 100%|██████████| 189/189 [02:14<00:00,  1.40it/s, loss=0.00354, test_loss=0.0208]


Macro: ER=0.67, F=0.22, LE=72.75, LR=0.34
Micro: ER=0.67, F=0.44, LE=22.99, LR=0.68


Epoch 21/40: 100%|██████████| 189/189 [02:17<00:00,  1.37it/s, loss=0.0033, test_loss=0.0216]
Epoch 22/40: 100%|██████████| 189/189 [02:14<00:00,  1.41it/s, loss=0.00378, test_loss=0.0213]
Epoch 23/40: 100%|██████████| 189/189 [02:16<00:00,  1.39it/s, loss=0.00346, test_loss=0.0219]
Epoch 24/40: 100%|██████████| 189/189 [02:19<00:00,  1.36it/s, loss=0.00298, test_loss=0.021]
Epoch 25/40: 100%|██████████| 189/189 [02:19<00:00,  1.36it/s, loss=0.00286, test_loss=0.0213]
Epoch 26/40: 100%|██████████| 189/189 [02:16<00:00,  1.39it/s, loss=0.00274, test_loss=0.0213]
Epoch 27/40: 100%|██████████| 189/189 [02:17<00:00,  1.38it/s, loss=0.0026, test_loss=0.0201]
Epoch 28/40: 100%|██████████| 189/189 [02:17<00:00,  1.37it/s, loss=0.0025, test_loss=0.0209]
Epoch 29/40: 100%|██████████| 189/189 [02:17<00:00,  1.37it/s, loss=0.00225, test_loss=0.0211]
Epoch 30/40: 100%|██████████| 189/189 [02:12<00:00,  1.43it/s, loss=0.00216, test_loss=0.0228]


Macro: ER=0.74, F=0.24, LE=36.54, LR=0.45
Micro: ER=0.74, F=0.41, LE=23.46, LR=0.70


Epoch 31/40: 100%|██████████| 189/189 [02:09<00:00,  1.46it/s, loss=0.00224, test_loss=0.0227]
Epoch 32/40: 100%|██████████| 189/189 [02:09<00:00,  1.46it/s, loss=0.0024, test_loss=0.0218]
Epoch 33/40: 100%|██████████| 189/189 [02:08<00:00,  1.47it/s, loss=0.00178, test_loss=0.0224]
Epoch 34/40: 100%|██████████| 189/189 [02:10<00:00,  1.45it/s, loss=0.00156, test_loss=0.0215]
Epoch 35/40: 100%|██████████| 189/189 [02:10<00:00,  1.45it/s, loss=0.0015, test_loss=0.0225]
Epoch 36/40: 100%|██████████| 189/189 [02:10<00:00,  1.45it/s, loss=0.00151, test_loss=0.0217]
Epoch 37/40: 100%|██████████| 189/189 [02:05<00:00,  1.50it/s, loss=0.00331, test_loss=0.0247]
Epoch 38/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.00236, test_loss=0.0236]
Epoch 39/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.00138, test_loss=0.0227]
Epoch 40/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.0012, test_loss=0.0231]

Macro: ER=0.76, F=0.25, LE=38.34, LR=0.51
Micro: ER=0.76, F=0.43, LE=23.60, LR=0.72
Micro @ epoch 39: ER=0.73, F=0.27, LE=38.09, LR=0.52
Macro @ epoch 39: ER=0.73, F=0.45, LE=23.65, LR=0.73





### Distance estimation

In [3]:
model_args = {
    'backbone': 'conformer',
    'num_classes': NUM_CLASSES,
    'num_events': MAX_EVENTS,
    'input_dim': 7,
    'hidden_dim': 64,
    'dropout': 0.05,
}
train_dataloader, test_dataloader = create_dataloaders(
    BATCH_SIZE, augments=[0, 1, 2, 3, 4, 5, 6, 7]
)
model = train_model(
    model_args,
    train_dataloader,
    test_dataloader,
    epochs=40,
    device=device,
    sde_weight=0.5,
)

Epoch 1/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.09, test_loss=0.131]
Epoch 2/40: 100%|██████████| 189/189 [02:06<00:00,  1.49it/s, loss=0.0595, test_loss=0.147]
Epoch 3/40: 100%|██████████| 189/189 [02:06<00:00,  1.49it/s, loss=0.0498, test_loss=0.126]
Epoch 4/40: 100%|██████████| 189/189 [02:08<00:00,  1.47it/s, loss=0.0431, test_loss=0.119]
Epoch 5/40: 100%|██████████| 189/189 [02:09<00:00,  1.46it/s, loss=0.0391, test_loss=0.121]
Epoch 6/40: 100%|██████████| 189/189 [02:16<00:00,  1.39it/s, loss=0.0368, test_loss=0.119]
Epoch 7/40: 100%|██████████| 189/189 [02:16<00:00,  1.38it/s, loss=0.0352, test_loss=0.126]
Epoch 8/40: 100%|██████████| 189/189 [02:16<00:00,  1.38it/s, loss=0.0335, test_loss=0.135]
Epoch 9/40: 100%|██████████| 189/189 [02:12<00:00,  1.43it/s, loss=0.0313, test_loss=0.124]
Epoch 10/40: 100%|██████████| 189/189 [02:10<00:00,  1.45it/s, loss=0.0336, test_loss=0.167]


Macro: ER=0.94, F=0.04, LE=139.06, LR=0.21
Micro: ER=0.94, F=0.12, LE=45.27, LR=0.55


Epoch 11/40: 100%|██████████| 189/189 [02:09<00:00,  1.46it/s, loss=0.03, test_loss=0.119]
Epoch 12/40: 100%|██████████| 189/189 [02:08<00:00,  1.47it/s, loss=0.0245, test_loss=0.114]
Epoch 13/40: 100%|██████████| 189/189 [02:08<00:00,  1.47it/s, loss=0.0228, test_loss=0.113]
Epoch 14/40: 100%|██████████| 189/189 [02:10<00:00,  1.45it/s, loss=0.0216, test_loss=0.109]
Epoch 15/40: 100%|██████████| 189/189 [02:09<00:00,  1.46it/s, loss=0.0211, test_loss=0.117]
Epoch 16/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.0205, test_loss=0.111]
Epoch 17/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.0196, test_loss=0.111]
Epoch 18/40: 100%|██████████| 189/189 [02:02<00:00,  1.54it/s, loss=0.0191, test_loss=0.118]
Epoch 19/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.0185, test_loss=0.119]
Epoch 20/40: 100%|██████████| 189/189 [02:02<00:00,  1.54it/s, loss=0.018, test_loss=0.114]


Macro: ER=0.77, F=0.16, LE=108.17, LR=0.30
Micro: ER=0.77, F=0.37, LE=25.27, LR=0.69


Epoch 21/40: 100%|██████████| 189/189 [02:02<00:00,  1.54it/s, loss=0.0195, test_loss=0.118]
Epoch 22/40: 100%|██████████| 189/189 [02:02<00:00,  1.55it/s, loss=0.0276, test_loss=0.124]
Epoch 23/40: 100%|██████████| 189/189 [02:02<00:00,  1.55it/s, loss=0.0215, test_loss=0.114]
Epoch 24/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.0175, test_loss=0.115]
Epoch 25/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.0166, test_loss=0.113]
Epoch 26/40: 100%|██████████| 189/189 [02:04<00:00,  1.51it/s, loss=0.0156, test_loss=0.117]
Epoch 27/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.0139, test_loss=0.113]
Epoch 28/40: 100%|██████████| 189/189 [02:02<00:00,  1.54it/s, loss=0.0124, test_loss=0.114]
Epoch 29/40: 100%|██████████| 189/189 [02:02<00:00,  1.54it/s, loss=0.0114, test_loss=0.113]
Epoch 30/40: 100%|██████████| 189/189 [02:03<00:00,  1.52it/s, loss=0.00989, test_loss=0.115]


Macro: ER=0.76, F=0.15, LE=98.99, LR=0.29
Micro: ER=0.76, F=0.36, LE=27.69, LR=0.66


Epoch 31/40: 100%|██████████| 189/189 [02:02<00:00,  1.54it/s, loss=0.00983, test_loss=0.115]
Epoch 32/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.00944, test_loss=0.115]
Epoch 33/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.00834, test_loss=0.118]
Epoch 34/40: 100%|██████████| 189/189 [02:02<00:00,  1.55it/s, loss=0.0156, test_loss=0.144]
Epoch 35/40: 100%|██████████| 189/189 [02:04<00:00,  1.52it/s, loss=0.0154, test_loss=0.117]
Epoch 36/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.0087, test_loss=0.116]
Epoch 37/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.00795, test_loss=0.114]
Epoch 38/40: 100%|██████████| 189/189 [02:03<00:00,  1.53it/s, loss=0.00759, test_loss=0.116]
Epoch 39/40: 100%|██████████| 189/189 [02:02<00:00,  1.54it/s, loss=0.00746, test_loss=0.114]
Epoch 40/40: 100%|██████████| 189/189 [02:02<00:00,  1.54it/s, loss=0.00726, test_loss=0.114]

Macro: ER=0.78, F=0.16, LE=87.83, LR=0.31
Micro: ER=0.78, F=0.36, LE=26.20, LR=0.68
Micro @ epoch 39: ER=0.76, F=0.16, LE=87.55, LR=0.31
Macro @ epoch 39: ER=0.76, F=0.37, LE=26.66, LR=0.68





### All together

In [3]:
model_args = {
    'backbone': 'conformer',
    'num_classes': NUM_CLASSES,
    'num_events': MAX_EVENTS,
    'input_dim': 7,
    'hidden_dim': 64,
    'dropout': 0.05,
}
train_dataloader, test_dataloader = create_dataloaders(
    BATCH_SIZE, augments=[0, 1, 2, 3, 4, 5, 6, 7], normalized=True
)
model = train_model(
    model_args,
    train_dataloader,
    test_dataloader,
    epochs=40,
    device=device,
    sde_weight=0.0,
)
torch.save(model.state_dict(), 'data/conformer_all.pth')

Epoch 1/40: 100%|██████████| 189/189 [03:16<00:00,  1.04s/it, loss=0.0235, test_loss=0.0272]
Epoch 2/40: 100%|██████████| 189/189 [03:20<00:00,  1.06s/it, loss=0.0138, test_loss=0.0234]
Epoch 3/40: 100%|██████████| 189/189 [03:21<00:00,  1.07s/it, loss=0.0101, test_loss=0.023]
Epoch 4/40: 100%|██████████| 189/189 [03:22<00:00,  1.07s/it, loss=0.00867, test_loss=0.0222]
Epoch 5/40: 100%|██████████| 189/189 [03:18<00:00,  1.05s/it, loss=0.0078, test_loss=0.0216]
Epoch 6/40: 100%|██████████| 189/189 [03:13<00:00,  1.03s/it, loss=0.00714, test_loss=0.0216]
Epoch 7/40: 100%|██████████| 189/189 [03:13<00:00,  1.02s/it, loss=0.00667, test_loss=0.0186]
Epoch 8/40: 100%|██████████| 189/189 [03:13<00:00,  1.03s/it, loss=0.00627, test_loss=0.0198]
Epoch 9/40: 100%|██████████| 189/189 [03:13<00:00,  1.02s/it, loss=0.00604, test_loss=0.0203]
Epoch 10/40: 100%|██████████| 189/189 [03:14<00:00,  1.03s/it, loss=0.0059, test_loss=0.0187]


Macro: ER=0.62, F=0.19, LE=106.03, LR=0.27
Micro: ER=0.62, F=0.46, LE=18.97, LR=0.62


Epoch 11/40: 100%|██████████| 189/189 [03:14<00:00,  1.03s/it, loss=0.00563, test_loss=0.0189]
Epoch 12/40: 100%|██████████| 189/189 [03:12<00:00,  1.02s/it, loss=0.00529, test_loss=0.02]
Epoch 13/40: 100%|██████████| 189/189 [03:13<00:00,  1.03s/it, loss=0.0051, test_loss=0.0184]
Epoch 14/40: 100%|██████████| 189/189 [03:13<00:00,  1.02s/it, loss=0.00493, test_loss=0.0181]
Epoch 15/40: 100%|██████████| 189/189 [03:19<00:00,  1.06s/it, loss=0.00476, test_loss=0.0198]
Epoch 16/40: 100%|██████████| 189/189 [03:23<00:00,  1.08s/it, loss=0.00453, test_loss=0.0194]
Epoch 17/40: 100%|██████████| 189/189 [03:24<00:00,  1.08s/it, loss=0.00443, test_loss=0.0211]
Epoch 18/40: 100%|██████████| 189/189 [03:19<00:00,  1.06s/it, loss=0.00406, test_loss=0.0189]
Epoch 19/40: 100%|██████████| 189/189 [03:15<00:00,  1.03s/it, loss=0.00387, test_loss=0.0195]
Epoch 20/40: 100%|██████████| 189/189 [03:16<00:00,  1.04s/it, loss=0.00377, test_loss=0.0201]


Macro: ER=0.68, F=0.20, LE=82.40, LR=0.34
Micro: ER=0.68, F=0.42, LE=21.65, LR=0.69


Epoch 21/40: 100%|██████████| 189/189 [03:14<00:00,  1.03s/it, loss=0.00355, test_loss=0.0195]
Epoch 22/40: 100%|██████████| 189/189 [03:15<00:00,  1.03s/it, loss=0.00426, test_loss=0.0201]
Epoch 23/40: 100%|██████████| 189/189 [03:23<00:00,  1.08s/it, loss=0.00345, test_loss=0.0193]
Epoch 24/40: 100%|██████████| 189/189 [03:23<00:00,  1.08s/it, loss=0.00321, test_loss=0.0204]
Epoch 25/40: 100%|██████████| 189/189 [03:24<00:00,  1.08s/it, loss=0.00291, test_loss=0.0214]
Epoch 26/40: 100%|██████████| 189/189 [02:58<00:00,  1.06it/s, loss=0.00282, test_loss=0.0215]
Epoch 27/40: 100%|██████████| 189/189 [02:16<00:00,  1.38it/s, loss=0.00279, test_loss=0.0223]
Epoch 28/40: 100%|██████████| 189/189 [02:19<00:00,  1.36it/s, loss=0.0026, test_loss=0.0222]
Epoch 29/40: 100%|██████████| 189/189 [02:16<00:00,  1.38it/s, loss=0.00243, test_loss=0.0226]
Epoch 30/40: 100%|██████████| 189/189 [02:19<00:00,  1.36it/s, loss=0.00252, test_loss=0.0241]


Macro: ER=0.78, F=0.20, LE=47.57, LR=0.43
Micro: ER=0.78, F=0.36, LE=23.62, LR=0.67


Epoch 31/40: 100%|██████████| 189/189 [02:18<00:00,  1.36it/s, loss=0.00271, test_loss=0.0226]
Epoch 32/40: 100%|██████████| 189/189 [02:19<00:00,  1.36it/s, loss=0.00214, test_loss=0.0237]
Epoch 33/40: 100%|██████████| 189/189 [02:15<00:00,  1.40it/s, loss=0.00192, test_loss=0.0224]
Epoch 34/40: 100%|██████████| 189/189 [02:17<00:00,  1.37it/s, loss=0.00167, test_loss=0.0225]
Epoch 35/40: 100%|██████████| 189/189 [02:17<00:00,  1.37it/s, loss=0.00157, test_loss=0.0227]
Epoch 36/40: 100%|██████████| 189/189 [02:17<00:00,  1.37it/s, loss=0.00282, test_loss=0.0246]
Epoch 37/40: 100%|██████████| 189/189 [02:18<00:00,  1.37it/s, loss=0.00238, test_loss=0.0223]
Epoch 38/40: 100%|██████████| 189/189 [02:17<00:00,  1.38it/s, loss=0.0015, test_loss=0.0226]
Epoch 39/40: 100%|██████████| 189/189 [02:16<00:00,  1.39it/s, loss=0.00125, test_loss=0.022]
Epoch 40/40: 100%|██████████| 189/189 [02:19<00:00,  1.35it/s, loss=0.00145, test_loss=0.0232]

Macro: ER=0.76, F=0.23, LE=48.76, LR=0.51
Micro: ER=0.76, F=0.42, LE=21.75, LR=0.71
Micro @ epoch 39: ER=0.73, F=0.26, LE=45.73, LR=0.51
Macro @ epoch 39: ER=0.73, F=0.44, LE=20.86, LR=0.71



