In [1]:
import models.models_multi_task as md_multi
from models.multitask_training_session import TrainingSession
import datasets.iemocap as ds
from constants import *
from torchsummary import summary
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import pandas as pd 

# Проба многозадачного обучения IEMOCAP

## Модель:

Подготовил следующую архитектуру, основанную на однозадачной AlexNet-like архитектуре: 

In [2]:
model = md_multi.AlexNetMultiTask(num_emotions=4, num_speakers=10, num_genders=2)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
device = torch.device("cpu")
model = model.to(device)
print(model)
summary(model, (1, 224, 224), batch_size=32,
        device='cpu'
       )
model=None
torch.cuda.empty_cache()

AlexNetMultiTask(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(5, 5))
  (joint_classifier): Sequential(
    (0): Linear(in_features=6400, out_features=2048, bias=True)
    (1): Dropout(p=0.5, i

## Проверяю модель и свой код для обучения на работоспособность

Мой код для обучения построен вокруг класса TrainingSession в модуле multitask_training_session.py <br>
У этого класса есть метод overfit_one_batch, этот метод позволяет обучить модель на малой выборке данных, это полезно для того, чтобы проверить на работоспособность свой код. <br>
На этой стадии я убрал из модели регуляризацию.

In [3]:
train_ds = ds.IemocapDataset(  # Без препроцессинга, тренировочный
    PATH_TO_PICKLE, IEMOCAP_PATH_TO_WAVS, IEMOCAP_PATH_TO_EGEMAPS, IEMOCAP_PATH_FOR_PARSER, 
    base_name='IEMOCAP-4', label_type='four', mode='train', preprocessing=False, 
    augmentation=False, padding='repeat', spectrogram_shape=224, spectrogram_type='melspec', tasks=('emotion', 'speaker', 'gender') 
)
test_ds = ds.IemocapDataset(  # Без препроцессинга, тестовый
    PATH_TO_PICKLE, IEMOCAP_PATH_TO_WAVS, IEMOCAP_PATH_TO_EGEMAPS, IEMOCAP_PATH_FOR_PARSER, 
    base_name='IEMOCAP-4', label_type='four', mode='test', preprocessing=False, 
    augmentation=False, padding='repeat', spectrogram_shape=224, spectrogram_type='melspec', tasks=('emotion', 'speaker', 'gender') 
)



In [4]:
model = md_multi.AlexNetMultiTask(4, 10, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
# device = torch.device("cpu") 
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)

In [5]:
# Инициализируем экземпляр класса TrainingSession
ts = TrainingSession(name='FirstTry',
                      model=model,
                      train_dataset=train_ds,
                      test_dataset=test_ds,
                      criterion=criterion,
                      optimizer=optimizer,
                      num_epochs=200,
                      batch_size=32,
                      device=device,
                     path_to_weights=WEIGHTS_FOLDER,
                     path_to_results=RESULTS_FOLDER)

INITIALIZING TRAINING SESSION...
Loaders ready
TRAINING SESSION FirstTry__IEMOCAP-4_four_prep-false_224_train INITIALIZED
Trying to load checkpoint from file
Found file
Loading file models\training_sessions\FirstTry__IEMOCAP-4_four_prep-false_224_train.pt
Updating model...
Updating optimizer...
Success!


In [9]:
ts.overfit_one_batch(num_epochs=5, batch_size=32)  # 5 эпох на одном батче размером 32, в одной эпохе 50 одинаковых батчей

TRAIN SIZE 2858
TEST SIZE 714
Epoch 0
Emotion loss | Speaker loss | Gender loss | Total loss
0: 1.4414201974868774 | 2.3414275646209717 | 0.44765982031822205 | 4.230507850646973
1: 0.9333392381668091 | 1.9413270950317383 | 0.43485891819000244 | 3.3095250129699707
2: 0.8018255233764648 | 1.745065689086914 | 0.32746991515159607 | 2.874361038208008
3: 0.8215715289115906 | 1.4112733602523804 | 0.10227905958890915 | 2.3351237773895264
4: 0.5799142122268677 | 1.30913245677948 | 0.06190920248627663 | 1.950955867767334
5: 0.6207345724105835 | 1.1703191995620728 | 0.047992024570703506 | 1.839045763015747
6: 0.5136659741401672 | 1.0797524452209473 | 0.028099965304136276 | 1.6215183734893799
7: 0.46535301208496094 | 1.0105384588241577 | 0.07338234037160873 | 1.5492738485336304
8: 0.36521920561790466 | 1.0084718465805054 | 0.014499970711767673 | 1.3881911039352417
9: 0.3800300657749176 | 0.7452290654182434 | 0.012217322364449501 | 1.1374764442443848
10: 0.21810980141162872 | 0.6663102507591248 | 0

36: 0.012218687683343887 | 0.03701668605208397 | 0.12943458557128906 | 0.17866995930671692
37: 0.0037573757581412792 | 0.366527259349823 | 1.3485373528965283e-06 | 0.3702859878540039
38: 0.09070053696632385 | 0.0444038026034832 | 0.00030247398535721004 | 0.13540682196617126
39: 0.0009722106624394655 | 0.011130941100418568 | 0.011927219107747078 | 0.024030370637774467
40: 0.013412228785455227 | 0.04771655797958374 | 0.0013015508884564042 | 0.06243033707141876
41: 0.005402492359280586 | 0.012367974035441875 | 0.00048231796245090663 | 0.01825278252363205
42: 0.005212982185184956 | 0.0824459046125412 | 0.01985791139304638 | 0.1075168028473854
43: 0.03840407729148865 | 0.08887898176908493 | 0.009803637862205505 | 0.13708670437335968
44: 0.11791308224201202 | 0.05791176110506058 | 0.003193353768438101 | 0.17901819944381714
45: 0.022144043818116188 | 0.01419439073652029 | 4.3993472900183406e-06 | 0.03634283319115639
46: 0.017170283943414688 | 0.016948655247688293 | 0.0006145340739749372 | 0.0

16: 0.005717601161450148 | 0.0010633096098899841 | 1.110121161218558e-06 | 0.006782020907849073
17: 0.0003339503309689462 | 0.002387040061876178 | 8.940686058167557e-08 | 0.002721079858019948
18: 0.0005621638265438378 | 0.007715967018157244 | 2.2351736461700966e-08 | 0.008278152905404568
19: 0.0858149379491806 | 0.0020198433194309473 | 6.183922778291162e-07 | 0.0878354012966156
20: 0.007901272736489773 | 0.0037698964588344097 | 8.456351565655496e-07 | 0.011672014370560646
21: 0.00012759340461343527 | 0.0006163700018078089 | 0.00032811093842610717 | 0.0010720742866396904
22: 7.725906471023336e-05 | 0.0027982445899397135 | 6.46609187242575e-05 | 0.0029401646461337805
23: 0.04838520288467407 | 0.062128085643053055 | 0.0020728365052491426 | 0.1125861182808876
24: 0.0006428712513297796 | 0.01046313438564539 | 2.188776488765143e-05 | 0.011127893812954426
25: 0.008424758911132812 | 0.023545067757368088 | 1.0765209481178317e-05 | 0.031980592757463455
26: 0.00012566197256091982 | 0.003842133795

# Time passed: 7 s
# Epoch losses | emotion = 0.0288 | speaker = 0.0210 | gender = 0.0161 |
# Train accuracies | emotion = 0.99625 | speaker = 0.995 | gender = 0.9975 |


Модель переобучается, а значит наш код работает нормально. Можно попробовать обучить модель на всем датасете. 

## Обучаем модель на IEMOCAP

Для этого у класса TrainingSession есть метод execute().  <br>
Он обучает модель, сохраняет ее после каждой эпохи, а так же копию модели в случае, если на этой эпохе был достигнут лучший результат. На этом этапе регуляризация включена.

In [10]:
ts.execute()

Epoch #1
# Time passed: 68 s
# Epoch losses | emotion = 1.3595 | speaker = 2.4053 | gender = 0.7198 | total = 16019.1389 |
# Train accuracies | emotion = 0.4008958566629339 | speaker = 0.1187010078387458 | gender = 0.5663493840985442 |
# Validation process on validation set
# Validation losses | emotion = 0.0371 | speaker = 0.0726 | gender = 0.0213 | total = 7401.1727 |
# Validation accuracies | emotion = 0.4322508398656215 | speaker = 0.14669652855543114 | gender = 0.5856662933930571 |
# Saving checkpoint...
## Saving best model_alex
# Done and done!
Epoch #2
# Time passed: 66 s
# Epoch losses | emotion = 1.2164 | speaker = 2.2841 | gender = 0.6160 | total = 14704.1351 |
# Train accuracies | emotion = 0.4209126539753639 | speaker = 0.12234042553191489 | gender = 0.618421052631579 |
# Validation process on validation set
# Validation losses | emotion = 0.0356 | speaker = 0.0739 | gender = 0.0177 | total = 7132.9266 |
# Validation accuracies | emotion = 0.4176931690929451 | speaker = 0.

# Time passed: 60 s
# Epoch losses | emotion = 0.9956 | speaker = 1.9256 | gender = 0.1356 | total = 10918.8157 |
# Train accuracies | emotion = 0.5290753479443289 | speaker = 0.18267077267637177 | gender = 0.8722804351303791 |
# Validation process on validation set
# Validation losses | emotion = 0.0309 | speaker = 0.0647 | gender = 0.0089 | total = 5395.1400 |
# Validation accuracies | emotion = 0.606942889137738 | speaker = 0.3023516237402016 | gender = 0.9384098544232923 |
# Saving checkpoint...
## Saving best model_alex
# Done and done!
Epoch #15
# Time passed: 63 s
# Epoch losses | emotion = 1.0048 | speaker = 1.8686 | gender = 0.1515 | total = 10805.0271 |
# Train accuracies | emotion = 0.5326241134751774 | speaker = 0.19010824934677117 | gender = 0.8770250093318402 |
# Validation process on validation set
# Validation losses | emotion = 0.0255 | speaker = 0.0617 | gender = 0.0066 | total = 5212.1861 |
# Validation accuracies | emotion = 0.606942889137738 | speaker = 0.318029115

# Validation losses | emotion = 0.0277 | speaker = 0.0397 | gender = 0.0014 | total = 4642.8207 |
# Validation accuracies | emotion = 0.6013437849944009 | speaker = 0.41993281075027994 | gender = 0.9406494960806271 |
# Saving checkpoint...
# Done and done!
Epoch #28
# Time passed: 53 s
# Epoch losses | emotion = 0.9255 | speaker = 1.3175 | gender = 0.0967 | total = 8357.1652 |
# Train accuracies | emotion = 0.5692889137737962 | speaker = 0.294432890737482 | gender = 0.9148636218205087 |
# Validation process on validation set
# Validation losses | emotion = 0.0356 | speaker = 0.0522 | gender = 0.0060 | total = 4493.1566 |
# Validation accuracies | emotion = 0.6349384098544233 | speaker = 0.46472564389697646 | gender = 0.9552071668533034 |
# Saving checkpoint...
# Done and done!
Epoch #29
# Time passed: 50 s
# Epoch losses | emotion = 0.9132 | speaker = 1.2896 | gender = 0.0889 | total = 8185.9788 |
# Train accuracies | emotion = 0.571485114105881 | speaker = 0.30171448430320114 | gender

# Time passed: 50 s
# Epoch losses | emotion = 0.8451 | speaker = 1.0786 | gender = 0.0720 | total = 7128.6494 |
# Train accuracies | emotion = 0.5925354382323218 | speaker = 0.37378799879824104 | gender = 0.9319367437795318 |
# Validation process on validation set
# Validation losses | emotion = 0.0316 | speaker = 0.0323 | gender = 0.0159 | total = 4060.5910 |
# Validation accuracies | emotion = 0.6506159014557671 | speaker = 0.5531914893617021 | gender = 0.9540873460246361 |
# Saving checkpoint...
# Done and done!
Epoch #42
# Time passed: 50 s
# Epoch losses | emotion = 0.8552 | speaker = 1.0710 | gender = 0.0716 | total = 7136.1463 |
# Train accuracies | emotion = 0.5938649816029435 | speaker = 0.3787127392950461 | gender = 0.9329107342825148 |
# Validation process on validation set
# Validation losses | emotion = 0.0177 | speaker = 0.0402 | gender = 0.0066 | total = 4159.9014 |
# Validation accuracies | emotion = 0.6427771556550952 | speaker = 0.5464725643896976 | gender = 0.947368

# Done and done!
Epoch #55
# Time passed: 49 s
# Epoch losses | emotion = 0.7826 | speaker = 0.9457 | gender = 0.0625 | total = 6396.6078 |
# Train accuracies | emotion = 0.6123282093046931 | speaker = 0.43357426448131936 | gender = 0.942502290542604 |
# Validation process on validation set
# Validation losses | emotion = 0.0346 | speaker = 0.0322 | gender = 0.0056 | total = 4022.4774 |
# Validation accuracies | emotion = 0.6349384098544233 | speaker = 0.5666293393057111 | gender = 0.9652855543113102 |
# Saving checkpoint...
# Done and done!
Epoch #56
# Time passed: 50 s
# Epoch losses | emotion = 0.7734 | speaker = 0.9358 | gender = 0.0637 | total = 6332.6761 |
# Train accuracies | emotion = 0.6136018237082067 | speaker = 0.43717505199168133 | gender = 0.9430741081426972 |
# Validation process on validation set
# Validation losses | emotion = 0.0252 | speaker = 0.0387 | gender = 0.0022 | total = 4041.9369 |
# Validation accuracies | emotion = 0.6528555431131019 | speaker = 0.574468085

# Done and done!
Epoch #69
# Time passed: 60 s
# Epoch losses | emotion = 0.7293 | speaker = 0.8020 | gender = 0.0528 | total = 5658.1953 |
# Train accuracies | emotion = 0.629854585585147 | speaker = 0.4807155492802311 | gender = 0.9499732216758362 |
# Validation process on validation set
# Validation losses | emotion = 0.0301 | speaker = 0.0358 | gender = 0.0179 | total = 4090.4267 |
# Validation accuracies | emotion = 0.6338185890257558 | speaker = 0.5621500559910414 | gender = 0.9507278835386338 |
# Saving checkpoint...
# Done and done!
Epoch #70
# Time passed: 61 s
# Epoch losses | emotion = 0.7049 | speaker = 0.8028 | gender = 0.0546 | total = 5580.5698 |
# Train accuracies | emotion = 0.6310990241561351 | speaker = 0.48376259798432253 | gender = 0.9503799392097264 |
# Validation process on validation set
# Validation losses | emotion = 0.0369 | speaker = 0.0370 | gender = 0.0094 | total = 4085.2251 |
# Validation accuracies | emotion = 0.6371780515117581 | speaker = 0.5946248600

# Done and done!
Epoch #83
# Time passed: 61 s
# Epoch losses | emotion = 0.6674 | speaker = 0.7450 | gender = 0.0449 | total = 5205.6116 |
# Train accuracies | emotion = 0.6457352365790149 | speaker = 0.5191583804422618 | gender = 0.9551599454930585 |
# Validation process on validation set
# Validation losses | emotion = 0.0274 | speaker = 0.0347 | gender = 0.0253 | total = 4431.4551 |
# Validation accuracies | emotion = 0.6472564389697648 | speaker = 0.574468085106383 | gender = 0.948488241881299 |
# Saving checkpoint...
# Done and done!
Epoch #84
# Time passed: 61 s
# Epoch losses | emotion = 0.6745 | speaker = 0.7326 | gender = 0.0411 | total = 5173.0467 |
# Train accuracies | emotion = 0.6468265077587586 | speaker = 0.5215565509518477 | gender = 0.9555271156614942 |
# Validation process on validation set
# Validation losses | emotion = 0.0323 | speaker = 0.0384 | gender = 0.0077 | total = 4479.6075 |
# Validation accuracies | emotion = 0.6349384098544233 | speaker = 0.557670772676

# Done and done!
Epoch #97
# Time passed: 60 s
# Epoch losses | emotion = 0.6085 | speaker = 0.6316 | gender = 0.0422 | total = 4580.3048 |
# Train accuracies | emotion = 0.6609309520785953 | speaker = 0.550917214070491 | gender = 0.9594699899562462 |
# Validation process on validation set
# Validation losses | emotion = 0.0213 | speaker = 0.0500 | gender = 0.0053 | total = 4519.2555 |
# Validation accuracies | emotion = 0.6405375139977604 | speaker = 0.5722284434490481 | gender = 0.9451287793952967 |
# Saving checkpoint...
# Done and done!
Epoch #98
# Time passed: 58 s
# Epoch losses | emotion = 0.6135 | speaker = 0.6951 | gender = 0.0582 | total = 4882.4820 |
# Train accuracies | emotion = 0.661976940832324 | speaker = 0.5528344036382751 | gender = 0.9596978769111227 |
# Validation process on validation set
# Validation losses | emotion = 0.0324 | speaker = 0.0395 | gender = 0.0003 | total = 4662.7024 |
# Validation accuracies | emotion = 0.6338185890257558 | speaker = 0.556550951847

In [6]:
ts.create_results_file()

In [18]:
results = pd.read_csv(os.path.join(ts.path_to_results, f'{ts.name}_results.csv'), delimiter=';')
results_test = results.loc[results['subset']=='test']
results_test_accuracy = results_test.loc[results_test['metric']=='accuracy']
results_test_accuracy.drop('epoch', axis=1, inplace=True)
results_test_accuracy.groupby('task').max()

Unnamed: 0_level_0,result,subset,metric
task,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
emotion,0.653975,test,accuracy
gender,0.965286,test,accuracy
speaker,0.603583,test,accuracy


### К сожалению, при обычном суммировании потерь, модель не опережает по точности однозадачную. Нужно пробовать взвешенную сумму потерь.

В класс TrainingSession я добавил параметр loss_weighter. Этот параметр принимает функцию от отдельных потерь, и возвращает вычисленную определенным образом сумму. В файле multitask_training_session.py на данный момент есть следующие функции для взвешивания потерь: <br>
1. Невзвешенная сумма unweighted_sum(loss_1, loss_2, loss_3). Параметр по умолчанию.
2. Усредненная сумма averaged_sum(loss_1, loss_2, loss_3). Подсчитывает итоговый лосс по формуле loss = a x loss_1 + b x loss_2 + c x loss_3, где a = loss_1 / (loss_1 + loss_2 + loss_3), b и с - аналогично. 
3. Автоматическое взвешенное суммивание функций потерь с помощью оценки алеаторической неопределенности модели по отношению к отдельной задаче. Описан в классе AutomaticWeightedLoss(nn.Module), добро пожаловать в исходный код. <br> Является имплементацией метода, описанного в статье [Liebel L, Körner M. Auxiliary tasks in multi-task learning[J]. arXiv preprint arXiv:1805.06334, 2018.], https://arxiv.org/pdf/1805.06334.pdf (это более практическая статья)<br> Метод является небольшим усовершенствованием другого метода, описанного в статье [Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics], https://openaccess.thecvf.com/content_cvpr_2018/html/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.html (в этой статье описана вся теория метода)