In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import torch.nn as nn

In [3]:
import sys
sys.path.append('./data')
sys.path.append('./pytorch')

In [4]:
from data import data
from pytorch import torch_models, radam

Using TensorFlow backend.


In [5]:
IMG_LEN = 1024
TXT_LEN = 300
N_CLASSES = 50

In [6]:
x_img, x_txt, y = data.get_unpacked_data()

In [7]:
x_img_train, x_img_test, x_txt_train, x_txt_test, y_train, y_test = train_test_split(
    x_img, 
    x_txt, 
    y, 
    test_size=0.2, 
    random_state=42,
    stratify=y
)

x_img_train, x_img_val, x_txt_train, x_txt_val, y_train, y_val = train_test_split(
    x_img_train,
    x_txt_train,
    y_train,
    test_size=0.2,
    random_state=42,
    stratify=y_train
)

In [8]:
img_sscaler = StandardScaler()
img_sscaler.fit(x_img_train)

x_img_train = img_sscaler.transform(x_img_train)
x_img_val = img_sscaler.transform(x_img_val)
x_img_test = img_sscaler.transform(x_img_test)

txt_sscaler = StandardScaler()
txt_sscaler.fit(x_txt_train)

x_txt_train = txt_sscaler.transform(x_txt_train)
x_txt_val = txt_sscaler.transform(x_txt_val)
x_txt_test = txt_sscaler.transform(x_txt_test)

In [9]:
x_img_train_t = torch.tensor(x_img_train).float()
x_img_val_t = torch.tensor(x_img_val).float()
x_img_test_t = torch.tensor(x_img_test).float()

x_txt_train_t = torch.tensor(x_txt_train).float()
x_txt_val_t = torch.tensor(x_txt_val).float()
x_txt_test_t = torch.tensor(x_txt_test).float()

y_train_t = torch.tensor(y_train).float()
y_val_t = torch.tensor(y_val).float()
y_test_t = torch.tensor(y_test).float()

In [10]:
train_ds = TensorDataset(x_img_train_t, x_txt_train_t, y_train_t)
val_ds = TensorDataset(x_img_val_t, x_txt_val_t, y_val_t)
test_ds = TensorDataset(x_img_test_t, x_txt_test_t, y_test_t)

In [33]:
# before 14.03.2020 it was 512
# experiments marked bs*number* is on batch_size == number
# otherwise batch_size == 512
BATCH_SIZE = 2048

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [12]:
model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs2048_rs42_d128_wd0005_drop_05_te')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=60,
    writer=writer,
    train_loader=train_loader,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(2.1974, grad_fn=<DivBackward0>) average train loss tensor(2.7907, grad_fn=<DivBackward0>)
avg common loss: tensor(2.8358, grad_fn=<DivBackward0>) avg img loss: tensor(2.6939, grad_fn=<DivBackward0>) avg txt loss: tensor(2.8423, grad_fn=<DivBackward0>)
val common acc: 0.5197645327446652 val img acc: 0.4095364238410596 val txt acc: 0.45359823399558497 val_avg_loss: tensor(2.0510)
avg common val loss: tensor(1.7943) avg img val loss: tensor(2.2408) avg txt val loss: tensor(2.1179)
epoch: 1 train_loss: tensor(2.0479, grad_fn=<DivBackward0>) average train loss tensor(2.1448, grad_fn=<DivBackward0>)
avg common loss: tensor(1.8956, grad_fn=<DivBackward0>) avg img loss: tensor(2.3155, grad_fn=<DivBackward0>) avg txt loss: tensor(2.2233, grad_fn=<DivBackward0>)
val common acc: 0.5838704930095658 val img acc: 0.42590139808682853 val txt acc: 0.47602649006622516 val_avg_loss: tensor(1.9120)
avg common val loss: tensor(1.5607) avg img val loss: tensor(2.1630) avg txt va

val common acc: 0.6226931567328918 val img acc: 0.4503605592347314 val txt acc: 0.496747608535688 val_avg_loss: tensor(1.7855)
avg common val loss: tensor(1.3831) avg img val loss: tensor(2.0605) avg txt val loss: tensor(1.9130)
epoch: 17 train_loss: tensor(1.7370, grad_fn=<DivBackward0>) average train loss tensor(1.8817, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5410, grad_fn=<DivBackward0>) avg img loss: tensor(2.1048, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9993, grad_fn=<DivBackward0>)
val common acc: 0.6235761589403973 val img acc: 0.4522737306843267 val txt acc: 0.49686534216335543 val_avg_loss: tensor(1.7836)
avg common val loss: tensor(1.3830) avg img val loss: tensor(2.0568) avg txt val loss: tensor(1.9110)
epoch: 18 train_loss: tensor(1.7441, grad_fn=<DivBackward0>) average train loss tensor(1.8776, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5367, grad_fn=<DivBackward0>) avg img loss: tensor(2.1001, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9960, gr

epoch: 33 train_loss: tensor(1.6685, grad_fn=<DivBackward0>) average train loss tensor(1.8496, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5014, grad_fn=<DivBackward0>) avg img loss: tensor(2.0748, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9727, grad_fn=<DivBackward0>)
val common acc: 0.6257836644591611 val img acc: 0.4544812362030905 val txt acc: 0.500485651214128 val_avg_loss: tensor(1.7706)
avg common val loss: tensor(1.3661) avg img val loss: tensor(2.0457) avg txt val loss: tensor(1.9000)
epoch: 34 train_loss: tensor(1.6918, grad_fn=<DivBackward0>) average train loss tensor(1.8475, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4987, grad_fn=<DivBackward0>) avg img loss: tensor(2.0739, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9699, grad_fn=<DivBackward0>)
val common acc: 0.6263723325974981 val img acc: 0.4560117733627667 val txt acc: 0.5000441501103753 val_avg_loss: tensor(1.7682)
avg common val loss: tensor(1.3614) avg img val loss: tensor(2.0443) avg txt val 

val common acc: 0.6274025018395879 val img acc: 0.4544223693892568 val txt acc: 0.49983811626195734 val_avg_loss: tensor(1.7640)
avg common val loss: tensor(1.3571) avg img val loss: tensor(2.0384) avg txt val loss: tensor(1.8964)
epoch: 50 train_loss: tensor(1.6286, grad_fn=<DivBackward0>) average train loss tensor(1.8327, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4802, grad_fn=<DivBackward0>) avg img loss: tensor(2.0617, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9562, grad_fn=<DivBackward0>)
val common acc: 0.6279911699779249 val img acc: 0.4549227373068433 val txt acc: 0.5015746872700515 val_avg_loss: tensor(1.7630)
avg common val loss: tensor(1.3568) avg img val loss: tensor(2.0378) avg txt val loss: tensor(1.8944)
epoch: 51 train_loss: tensor(1.6516, grad_fn=<DivBackward0>) average train loss tensor(1.8336, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4811, grad_fn=<DivBackward0>) avg img loss: tensor(2.0622, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9574, g

In [13]:
model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs2048_rs42_d128_wd0005_drop_05_100')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=100,
    writer=writer,
    train_loader=train_loader,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(2.2356, grad_fn=<DivBackward0>) average train loss tensor(2.7814, grad_fn=<DivBackward0>)
avg common loss: tensor(2.8003, grad_fn=<DivBackward0>) avg img loss: tensor(2.7005, grad_fn=<DivBackward0>) avg txt loss: tensor(2.8435, grad_fn=<DivBackward0>)
val common acc: 0.5174098601913172 val img acc: 0.41539367181751286 val txt acc: 0.4542457689477557 val_avg_loss: tensor(2.0460)
avg common val loss: tensor(1.7837) avg img val loss: tensor(2.2355) avg txt val loss: tensor(2.1187)
epoch: 1 train_loss: tensor(2.0575, grad_fn=<DivBackward0>) average train loss tensor(2.1415, grad_fn=<DivBackward0>)
avg common loss: tensor(1.8901, grad_fn=<DivBackward0>) avg img loss: tensor(2.3134, grad_fn=<DivBackward0>) avg txt loss: tensor(2.2210, grad_fn=<DivBackward0>)
val common acc: 0.5812509197939661 val img acc: 0.4293745401030169 val txt acc: 0.4767623252391464 val_avg_loss: tensor(1.9108)
avg common val loss: tensor(1.5600) avg img val loss: tensor(2.1621) avg txt val 

val common acc: 0.6226342899190581 val img acc: 0.45100809418690213 val txt acc: 0.49892568064753495 val_avg_loss: tensor(1.7844)
avg common val loss: tensor(1.3849) avg img val loss: tensor(2.0584) avg txt val loss: tensor(1.9099)
epoch: 17 train_loss: tensor(1.7406, grad_fn=<DivBackward0>) average train loss tensor(1.8783, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5383, grad_fn=<DivBackward0>) avg img loss: tensor(2.1032, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9934, grad_fn=<DivBackward0>)
val common acc: 0.6245180279617366 val img acc: 0.4519793966151582 val txt acc: 0.49969094922737306 val_avg_loss: tensor(1.7833)
avg common val loss: tensor(1.3829) avg img val loss: tensor(2.0585) avg txt val loss: tensor(1.9086)
epoch: 18 train_loss: tensor(1.7108, grad_fn=<DivBackward0>) average train loss tensor(1.8760, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5325, grad_fn=<DivBackward0>) avg img loss: tensor(2.1009, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9945,

epoch: 33 train_loss: tensor(1.6737, grad_fn=<DivBackward0>) average train loss tensor(1.8483, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4985, grad_fn=<DivBackward0>) avg img loss: tensor(2.0771, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9692, grad_fn=<DivBackward0>)
val common acc: 0.6246651949963208 val img acc: 0.4541280353200883 val txt acc: 0.5005739514348786 val_avg_loss: tensor(1.7697)
avg common val loss: tensor(1.3636) avg img val loss: tensor(2.0465) avg txt val loss: tensor(1.8989)
epoch: 34 train_loss: tensor(1.6784, grad_fn=<DivBackward0>) average train loss tensor(1.8478, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5001, grad_fn=<DivBackward0>) avg img loss: tensor(2.0764, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9669, grad_fn=<DivBackward0>)
val common acc: 0.6256953642384105 val img acc: 0.4543046357615894 val txt acc: 0.500868285504047 val_avg_loss: tensor(1.7692)
avg common val loss: tensor(1.3640) avg img val loss: tensor(2.0457) avg txt val 

val common acc: 0.6266666666666667 val img acc: 0.4546284032376748 val txt acc: 0.5027520235467255 val_avg_loss: tensor(1.7631)
avg common val loss: tensor(1.3545) avg img val loss: tensor(2.0419) avg txt val loss: tensor(1.8928)
epoch: 50 train_loss: tensor(1.6489, grad_fn=<DivBackward0>) average train loss tensor(1.8331, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4781, grad_fn=<DivBackward0>) avg img loss: tensor(2.0605, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9607, grad_fn=<DivBackward0>)
val common acc: 0.6276379690949228 val img acc: 0.45559970566593083 val txt acc: 0.5032818248712289 val_avg_loss: tensor(1.7624)
avg common val loss: tensor(1.3537) avg img val loss: tensor(2.0414) avg txt val loss: tensor(1.8920)
epoch: 51 train_loss: tensor(1.6474, grad_fn=<DivBackward0>) average train loss tensor(1.8348, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4840, grad_fn=<DivBackward0>) avg img loss: tensor(2.0610, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9594, g

epoch: 66 train_loss: tensor(1.6402, grad_fn=<DivBackward0>) average train loss tensor(1.8268, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4729, grad_fn=<DivBackward0>) avg img loss: tensor(2.0550, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9525, grad_fn=<DivBackward0>)
val common acc: 0.6278145695364239 val img acc: 0.45480500367917587 val txt acc: 0.5033112582781457 val_avg_loss: tensor(1.7574)
avg common val loss: tensor(1.3464) avg img val loss: tensor(2.0360) avg txt val loss: tensor(1.8897)
epoch: 67 train_loss: tensor(1.6426, grad_fn=<DivBackward0>) average train loss tensor(1.8239, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4675, grad_fn=<DivBackward0>) avg img loss: tensor(2.0509, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9532, grad_fn=<DivBackward0>)
val common acc: 0.6295217071376011 val img acc: 0.45689477557027225 val txt acc: 0.5023988226637234 val_avg_loss: tensor(1.7590)
avg common val loss: tensor(1.3501) avg img val loss: tensor(2.0368) avg txt v

val common acc: 0.628962472406181 val img acc: 0.45627667402501837 val txt acc: 0.5028108903605593 val_avg_loss: tensor(1.7580)
avg common val loss: tensor(1.3489) avg img val loss: tensor(2.0367) avg txt val loss: tensor(1.8884)
epoch: 83 train_loss: tensor(1.6523, grad_fn=<DivBackward0>) average train loss tensor(1.8206, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4640, grad_fn=<DivBackward0>) avg img loss: tensor(2.0499, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9479, grad_fn=<DivBackward0>)
val common acc: 0.6293156732891833 val img acc: 0.45695364238410596 val txt acc: 0.5025754231052244 val_avg_loss: tensor(1.7602)
avg common val loss: tensor(1.3524) avg img val loss: tensor(2.0403) avg txt val loss: tensor(1.8880)
epoch: 84 train_loss: tensor(1.6236, grad_fn=<DivBackward0>) average train loss tensor(1.8166, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4587, grad_fn=<DivBackward0>) avg img loss: tensor(2.0457, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9452, g

epoch: 99 train_loss: tensor(1.6396, grad_fn=<DivBackward0>) average train loss tensor(1.8143, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4541, grad_fn=<DivBackward0>) avg img loss: tensor(2.0423, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9465, grad_fn=<DivBackward0>)
val common acc: 0.6315231788079471 val img acc: 0.45521707137601175 val txt acc: 0.5026931567328918 val_avg_loss: tensor(1.7555)
avg common val loss: tensor(1.3440) avg img val loss: tensor(2.0359) avg txt val loss: tensor(1.8866)


In [12]:
model = torch_models.NormModelTridentBN(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bn_bs2048_rs42_d128_wd0005_drop_05_100')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=100,
    writer=writer,
    train_loader=train_loader,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(2.1816, grad_fn=<DivBackward0>) average train loss tensor(2.6583, grad_fn=<DivBackward0>)
avg common loss: tensor(2.6052, grad_fn=<DivBackward0>) avg img loss: tensor(2.6679, grad_fn=<DivBackward0>) avg txt loss: tensor(2.7018, grad_fn=<DivBackward0>)
val common acc: 0.5574098601913171 val img acc: 0.41845474613686534 val txt acc: 0.47049300956585727 val_avg_loss: tensor(2.0080)
avg common val loss: tensor(1.6924) avg img val loss: tensor(2.2498) avg txt val loss: tensor(2.0820)
epoch: 1 train_loss: tensor(1.9992, grad_fn=<DivBackward0>) average train loss tensor(2.0952, grad_fn=<DivBackward0>)
avg common loss: tensor(1.7858, grad_fn=<DivBackward0>) avg img loss: tensor(2.3150, grad_fn=<DivBackward0>) avg txt loss: tensor(2.1848, grad_fn=<DivBackward0>)
val common acc: 0.6010890360559235 val img acc: 0.4314054451802796 val txt acc: 0.4853568800588668 val_avg_loss: tensor(1.8835)
avg common val loss: tensor(1.4962) avg img val loss: tensor(2.1614) avg txt val

val common acc: 0.632317880794702 val img acc: 0.4512141280353201 val txt acc: 0.502869757174393 val_avg_loss: tensor(1.7514)
avg common val loss: tensor(1.3321) avg img val loss: tensor(2.0391) avg txt val loss: tensor(1.8830)
epoch: 17 train_loss: tensor(1.6046, grad_fn=<DivBackward0>) average train loss tensor(1.7855, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3926, grad_fn=<DivBackward0>) avg img loss: tensor(2.0440, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9198, grad_fn=<DivBackward0>)
val common acc: 0.6329948491537896 val img acc: 0.4539514348785872 val txt acc: 0.503252391464312 val_avg_loss: tensor(1.7522)
avg common val loss: tensor(1.3351) avg img val loss: tensor(2.0365) avg txt val loss: tensor(1.8849)
epoch: 18 train_loss: tensor(1.6036, grad_fn=<DivBackward0>) average train loss tensor(1.7818, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3878, grad_fn=<DivBackward0>) avg img loss: tensor(2.0394, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9183, grad_

epoch: 33 train_loss: tensor(1.5069, grad_fn=<DivBackward0>) average train loss tensor(1.7479, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3477, grad_fn=<DivBackward0>) avg img loss: tensor(2.0091, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8868, grad_fn=<DivBackward0>)
val common acc: 0.6350257542310522 val img acc: 0.4559823399558499 val txt acc: 0.5040176600441502 val_avg_loss: tensor(1.7432)
avg common val loss: tensor(1.3253) avg img val loss: tensor(2.0274) avg txt val loss: tensor(1.8768)
epoch: 34 train_loss: tensor(1.5099, grad_fn=<DivBackward0>) average train loss tensor(1.7466, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3447, grad_fn=<DivBackward0>) avg img loss: tensor(2.0101, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8849, grad_fn=<DivBackward0>)
val common acc: 0.6350257542310522 val img acc: 0.45757174392935984 val txt acc: 0.5040765268579838 val_avg_loss: tensor(1.7422)
avg common val loss: tensor(1.3239) avg img val loss: tensor(2.0260) avg txt va

val common acc: 0.6356144223693893 val img acc: 0.45559970566593083 val txt acc: 0.5045768947755702 val_avg_loss: tensor(1.7408)
avg common val loss: tensor(1.3239) avg img val loss: tensor(2.0239) avg txt val loss: tensor(1.8746)
epoch: 50 train_loss: tensor(1.4807, grad_fn=<DivBackward0>) average train loss tensor(1.7304, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3256, grad_fn=<DivBackward0>) avg img loss: tensor(1.9935, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8721, grad_fn=<DivBackward0>)
val common acc: 0.6345253863134658 val img acc: 0.45624724061810157 val txt acc: 0.5031935246504783 val_avg_loss: tensor(1.7417)
avg common val loss: tensor(1.3241) avg img val loss: tensor(2.0259) avg txt val loss: tensor(1.8752)
epoch: 51 train_loss: tensor(1.4806, grad_fn=<DivBackward0>) average train loss tensor(1.7287, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3236, grad_fn=<DivBackward0>) avg img loss: tensor(1.9896, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8731, 

epoch: 66 train_loss: tensor(1.4665, grad_fn=<DivBackward0>) average train loss tensor(1.7183, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3128, grad_fn=<DivBackward0>) avg img loss: tensor(1.9782, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8638, grad_fn=<DivBackward0>)
val common acc: 0.6378513612950699 val img acc: 0.45742457689477556 val txt acc: 0.5059308314937454 val_avg_loss: tensor(1.7381)
avg common val loss: tensor(1.3191) avg img val loss: tensor(2.0227) avg txt val loss: tensor(1.8724)
epoch: 67 train_loss: tensor(1.4614, grad_fn=<DivBackward0>) average train loss tensor(1.7179, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3084, grad_fn=<DivBackward0>) avg img loss: tensor(1.9801, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8651, grad_fn=<DivBackward0>)
val common acc: 0.6364091243561443 val img acc: 0.45863134657836646 val txt acc: 0.5059013980868285 val_avg_loss: tensor(1.7397)
avg common val loss: tensor(1.3219) avg img val loss: tensor(2.0250) avg txt v

val common acc: 0.6365562913907284 val img acc: 0.45736571008094185 val txt acc: 0.5066666666666667 val_avg_loss: tensor(1.7369)
avg common val loss: tensor(1.3183) avg img val loss: tensor(2.0234) avg txt val loss: tensor(1.8689)
epoch: 83 train_loss: tensor(1.4346, grad_fn=<DivBackward0>) average train loss tensor(1.7097, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3016, grad_fn=<DivBackward0>) avg img loss: tensor(1.9706, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8570, grad_fn=<DivBackward0>)
val common acc: 0.6375864606328182 val img acc: 0.4573951434878587 val txt acc: 0.5062545989698307 val_avg_loss: tensor(1.7368)
avg common val loss: tensor(1.3184) avg img val loss: tensor(2.0222) avg txt val loss: tensor(1.8700)
epoch: 84 train_loss: tensor(1.4473, grad_fn=<DivBackward0>) average train loss tensor(1.7095, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3009, grad_fn=<DivBackward0>) avg img loss: tensor(1.9694, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8583, g

epoch: 99 train_loss: tensor(1.4137, grad_fn=<DivBackward0>) average train loss tensor(1.7025, grad_fn=<DivBackward0>)
avg common loss: tensor(1.2943, grad_fn=<DivBackward0>) avg img loss: tensor(1.9621, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8512, grad_fn=<DivBackward0>)
val common acc: 0.6367623252391464 val img acc: 0.4571008094186902 val txt acc: 0.505813097866078 val_avg_loss: tensor(1.7377)
avg common val loss: tensor(1.3206) avg img val loss: tensor(2.0226) avg txt val loss: tensor(1.8698)


In [93]:
def fit_mtl(mtl, optimizer, epochs, train_loader, val_loader, writer):
    for epoch in range(epochs):
        print(mtl.sigma)
        mtl.train()

        loss_sum = 0.0
        loss_count = 0

        for x_img_cur, x_txt_cur, y_cur in train_loader:
            mtl.zero_grad()
            loss = mtl(x_img_cur, x_txt_cur, y_cur)
            loss.backward()

            loss_sum += loss
            loss_count += 1

            optimizer.step()

        print('epoch:', epoch, 'train_loss:', loss, 'average train loss', loss_sum / loss_count)
        if writer is not None:
            writer.add_scalar('train_loss', loss, epoch)
            writer.add_scalar('avg_train_loss', loss_sum / loss_count, epoch)

        if val_loader is not None:
            mtl.eval()

            correct = 0
            total = 0
            loss_sum = 0.0
            loss_count = 0

            with torch.no_grad():
                for x_img_cur, x_txt_cur, y_cur in val_loader:
                    output = mtl.model(x_img_cur, x_txt_cur)[0]
                    loss = F.nll_loss(output, torch.argmax(y_cur, dim=1))
                    loss_sum += loss
                    loss_count += 1
                    for idx, i in enumerate(output):
                        if torch.argmax(i) == torch.argmax(y_cur, dim=1)[idx]:
                            correct += 1
                        total += 1

            print('val_acc:', correct / total, 'val_avg_loss:', loss_sum / loss_count)
            if writer is not None:
                writer.add_scalar('val_acc', correct / total, epoch)
                writer.add_scalar('val_avg_loss', loss_sum / loss_count, epoch)
        

In [100]:
class TridentMTL(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch_models.NormModelTrident(drop=0.5)
        self.sigma = nn.Parameter(torch.ones(3))
        
    def forward(self, inp_img, inp_txt, target):
        outputs = self.model(inp_img, inp_txt)

        ls = [F.nll_loss(output, torch.argmax(target, dim=1)) for output in outputs]
        l = sum([(0.5 * le / s ** 2) for le, s in zip(ls, self.sigma)]) + torch.log(self.sigma.prod())
        return l

## Experiments with trainable loss merge

In [101]:
# model = torch_models.NormModelTrident(drop=0.5)
mtl = TridentMTL()
optimizer = optim.Adam(mtl.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_mtl_bs2048_rs42_d128_wd0005_drop_05_100')

fit_mtl(
    mtl=mtl,
    optimizer=optimizer,
    epochs=100,
    writer=writer,
    train_loader=train_loader,
    val_loader=val_loader
)

Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
epoch: 0 train_loss: tensor(3.2125, grad_fn=<AddBackward0>) average train loss tensor(4.0740, grad_fn=<DivBackward0>)
val_acc: 0.5265342163355409 val_avg_loss: tensor(1.7825)
Parameter containing:
tensor([1.0531, 1.0557, 1.0549], requires_grad=True)
epoch: 1 train_loss: tensor(2.7969, grad_fn=<AddBackward0>) average train loss tensor(3.0016, grad_fn=<DivBackward0>)
val_acc: 0.5858719646799116 val_avg_loss: tensor(1.5567)
Parameter containing:
tensor([1.0828, 1.1012, 1.0945], requires_grad=True)
epoch: 2 train_loss: tensor(2.6973, grad_fn=<AddBackward0>) average train loss tensor(2.8025, grad_fn=<DivBackward0>)
val_acc: 0.6001471670345843 val_avg_loss: tensor(1.4942)
Parameter containing:
tensor([1.1076, 1.1420, 1.1304], requires_grad=True)
epoch: 3 train_loss: tensor(2.5827, grad_fn=<AddBackward0>) average train loss tensor(2.6975, grad_fn=<DivBackward0>)
val_acc: 0.6055040470934511 val_avg_loss: tensor(1.4663)
Parameter co

val_acc: 0.6272259013980869 val_avg_loss: tensor(1.3653)
Parameter containing:
tensor([1.2296, 1.4485, 1.4093], requires_grad=True)
epoch: 33 train_loss: tensor(2.2934, grad_fn=<AddBackward0>) average train loss tensor(2.4177, grad_fn=<DivBackward0>)
val_acc: 0.6265489330389993 val_avg_loss: tensor(1.3670)
Parameter containing:
tensor([1.2283, 1.4485, 1.4090], requires_grad=True)
epoch: 34 train_loss: tensor(2.2896, grad_fn=<AddBackward0>) average train loss tensor(2.4196, grad_fn=<DivBackward0>)
val_acc: 0.6255776306107432 val_avg_loss: tensor(1.3674)
Parameter containing:
tensor([1.2280, 1.4485, 1.4087], requires_grad=True)
epoch: 35 train_loss: tensor(2.3016, grad_fn=<AddBackward0>) average train loss tensor(2.4182, grad_fn=<DivBackward0>)
val_acc: 0.6253715967623252 val_avg_loss: tensor(1.3645)
Parameter containing:
tensor([1.2275, 1.4484, 1.4085], requires_grad=True)
epoch: 36 train_loss: tensor(2.2929, grad_fn=<AddBackward0>) average train loss tensor(2.4185, grad_fn=<DivBackward

val_acc: 0.6263134657836644 val_avg_loss: tensor(1.3581)
Parameter containing:
tensor([1.2176, 1.4446, 1.4037], requires_grad=True)
epoch: 66 train_loss: tensor(2.2616, grad_fn=<AddBackward0>) average train loss tensor(2.4039, grad_fn=<DivBackward0>)
val_acc: 0.6279028697571744 val_avg_loss: tensor(1.3544)
Parameter containing:
tensor([1.2171, 1.4445, 1.4036], requires_grad=True)
epoch: 67 train_loss: tensor(2.2425, grad_fn=<AddBackward0>) average train loss tensor(2.4028, grad_fn=<DivBackward0>)
val_acc: 0.6292862398822664 val_avg_loss: tensor(1.3521)
Parameter containing:
tensor([1.2169, 1.4441, 1.4030], requires_grad=True)
epoch: 68 train_loss: tensor(2.2664, grad_fn=<AddBackward0>) average train loss tensor(2.4043, grad_fn=<DivBackward0>)
val_acc: 0.6284621044885945 val_avg_loss: tensor(1.3523)
Parameter containing:
tensor([1.2168, 1.4442, 1.4035], requires_grad=True)
epoch: 69 train_loss: tensor(2.2743, grad_fn=<AddBackward0>) average train loss tensor(2.4031, grad_fn=<DivBackward

val_acc: 0.6288741721854305 val_avg_loss: tensor(1.3483)
Parameter containing:
tensor([1.2114, 1.4411, 1.4009], requires_grad=True)
epoch: 99 train_loss: tensor(2.2512, grad_fn=<AddBackward0>) average train loss tensor(2.3951, grad_fn=<DivBackward0>)
val_acc: 0.6298749080206034 val_avg_loss: tensor(1.3490)
