In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import torch.nn as nn

import numpy as np

import pickle

In [3]:
import sys
sys.path.append('./data')
sys.path.append('./pytorch')

In [4]:
from data import data
from pytorch import torch_models, radam

Using TensorFlow backend.


In [5]:
IMG_LEN = 1024
TXT_LEN = 300
N_CLASSES = 50

In [6]:
x_img, x_txt, y = data.get_unpacked_data()

In [7]:
x_img_train, x_img_test, x_txt_train, x_txt_test, y_train, y_test = train_test_split(
    x_img, 
    x_txt, 
    y, 
    test_size=0.2, 
    random_state=42,
    stratify=y
)

x_img_train, x_img_val, x_txt_train, x_txt_val, y_train, y_val = train_test_split(
    x_img_train,
    x_txt_train,
    y_train,
    test_size=0.2,
    random_state=42,
    stratify=y_train
)

In [8]:
img_sscaler = StandardScaler()
img_sscaler.fit(x_img_train)

x_img_train = img_sscaler.transform(x_img_train)
x_img_val = img_sscaler.transform(x_img_val)
x_img_test = img_sscaler.transform(x_img_test)

txt_sscaler = StandardScaler()
txt_sscaler.fit(x_txt_train)

x_txt_train = txt_sscaler.transform(x_txt_train)
x_txt_val = txt_sscaler.transform(x_txt_val)
x_txt_test = txt_sscaler.transform(x_txt_test)

In [9]:
x_img_train_t = torch.tensor(x_img_train).float()
x_img_val_t = torch.tensor(x_img_val).float()
x_img_test_t = torch.tensor(x_img_test).float()

x_txt_train_t = torch.tensor(x_txt_train).float()
x_txt_val_t = torch.tensor(x_txt_val).float()
x_txt_test_t = torch.tensor(x_txt_test).float()

y_train_t = torch.tensor(y_train).float()
y_val_t = torch.tensor(y_val).float()
y_test_t = torch.tensor(y_test).float()

In [10]:
train_ds = TensorDataset(x_img_train_t, x_txt_train_t, y_train_t)
val_ds = TensorDataset(x_img_val_t, x_txt_val_t, y_val_t)
test_ds = TensorDataset(x_img_test_t, x_txt_test_t, y_test_t)

In [12]:
training_indices = np.random.randint(low=0, high=len(x_img_train), size=20000)
x_img_train_20k = x_img_train[training_indices]
x_txt_train_20k = x_txt_train[training_indices]
y_train_20k = y_train[training_indices]

x_img_train_20k_t = torch.tensor(x_img_train_20k).float()
x_txt_train_20k_t = torch.tensor(x_txt_train_20k).float()
y_train_20k_t = torch.tensor(y_train_20k).float()

train_ds_20k = TensorDataset(x_img_train_20k_t, x_txt_train_20k_t, y_train_20k_t)

In [11]:
training_indices = np.random.randint(low=0, high=len(x_img_train), size=2000)
x_img_train_2k = x_img_train[training_indices]
x_txt_train_2k = x_txt_train[training_indices]
y_train_2k = y_train[training_indices]

x_img_train_2k_t = torch.tensor(x_img_train_2k).float()
x_txt_train_2k_t = torch.tensor(x_txt_train_2k).float()
y_train_2k_t = torch.tensor(y_train_2k).float()

train_ds_2k = TensorDataset(x_img_train_2k_t, x_txt_train_2k_t, y_train_2k_t)

In [12]:
# before 14.03.2020 it was 512
# experiments marked bs*number* is on batch_size == number
# otherwise batch_size == 512
BATCH_SIZE = 2048

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [14]:
train_loader_20k = DataLoader(train_ds_20k, batch_size=BATCH_SIZE)

In [14]:
train_loader_2k = DataLoader(train_ds_2k, batch_size=BATCH_SIZE)

In [23]:
model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs2048_rs42_d128_wd0005_drop_05_2k')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=300,
    writer=writer,
    train_loader=train_loader_2k,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(3.9252, grad_fn=<DivBackward0>) average train loss tensor(3.9252, grad_fn=<DivBackward0>)
avg common loss: tensor(3.9157, grad_fn=<DivBackward0>) avg img loss: tensor(3.9325, grad_fn=<DivBackward0>) avg txt loss: tensor(3.9273, grad_fn=<DivBackward0>)
val common acc: 0.03334805003679176 val img acc: 0.13036055923473142 val txt acc: 0.03820456217807211 val_avg_loss: tensor(3.8642)
avg common val loss: tensor(3.9075) avg img val loss: tensor(3.8140) avg txt val loss: tensor(3.8711)
epoch: 1 train_loss: tensor(3.8594, grad_fn=<DivBackward0>) average train loss tensor(3.8594, grad_fn=<DivBackward0>)
avg common loss: tensor(3.9041, grad_fn=<DivBackward0>) avg img loss: tensor(3.7996, grad_fn=<DivBackward0>) avg txt loss: tensor(3.8744, grad_fn=<DivBackward0>)
val common acc: 0.036850625459896987 val img acc: 0.15546725533480502 val txt acc: 0.11143487858719647 val_avg_loss: tensor(3.8110)
avg common val loss: tensor(3.8952) avg img val loss: tensor(3.7096) avg tx

val common acc: 0.23072847682119205 val img acc: 0.32647534952170715 val txt acc: 0.2555114054451803 val_avg_loss: tensor(3.0289)
avg common val loss: tensor(3.3253) avg img val loss: tensor(2.6624) avg txt val loss: tensor(3.0990)
epoch: 17 train_loss: tensor(2.8888, grad_fn=<DivBackward0>) average train loss tensor(2.8888, grad_fn=<DivBackward0>)
avg common loss: tensor(3.2937, grad_fn=<DivBackward0>) avg img loss: tensor(2.3650, grad_fn=<DivBackward0>) avg txt loss: tensor(3.0076, grad_fn=<DivBackward0>)
val common acc: 0.23376011773362768 val img acc: 0.32777041942604856 val txt acc: 0.26822663723325973 val_avg_loss: tensor(2.9870)
avg common val loss: tensor(3.2685) avg img val loss: tensor(2.6424) avg txt val loss: tensor(3.0501)
epoch: 18 train_loss: tensor(2.8400, grad_fn=<DivBackward0>) average train loss tensor(2.8400, grad_fn=<DivBackward0>)
avg common loss: tensor(3.2400, grad_fn=<DivBackward0>) avg img loss: tensor(2.3194, grad_fn=<DivBackward0>) avg txt loss: tensor(2.960

val common acc: 0.3665636497424577 val img acc: 0.35158204562178075 val txt acc: 0.3717439293598234 val_avg_loss: tensor(2.5389)
avg common val loss: tensor(2.5407) avg img val loss: tensor(2.5459) avg txt val loss: tensor(2.5300)
epoch: 33 train_loss: tensor(2.1347, grad_fn=<DivBackward0>) average train loss tensor(2.1347, grad_fn=<DivBackward0>)
avg common loss: tensor(2.3569, grad_fn=<DivBackward0>) avg img loss: tensor(1.7117, grad_fn=<DivBackward0>) avg txt loss: tensor(2.3355, grad_fn=<DivBackward0>)
val common acc: 0.37289183222958056 val img acc: 0.35214128035320086 val txt acc: 0.3753936718175129 val_avg_loss: tensor(2.5192)
avg common val loss: tensor(2.5005) avg img val loss: tensor(2.5483) avg txt val loss: tensor(2.5087)
epoch: 34 train_loss: tensor(2.0917, grad_fn=<DivBackward0>) average train loss tensor(2.0917, grad_fn=<DivBackward0>)
avg common loss: tensor(2.3002, grad_fn=<DivBackward0>) avg img loss: tensor(1.6930, grad_fn=<DivBackward0>) avg txt loss: tensor(2.2819,

val common acc: 0.4462693156732892 val img acc: 0.35137601177336275 val txt acc: 0.4111552612214864 val_avg_loss: tensor(2.3874)
avg common val loss: tensor(2.1667) avg img val loss: tensor(2.6716) avg txt val loss: tensor(2.3239)
epoch: 49 train_loss: tensor(1.5562, grad_fn=<DivBackward0>) average train loss tensor(1.5562, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5582, grad_fn=<DivBackward0>) avg img loss: tensor(1.2354, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8751, grad_fn=<DivBackward0>)
val common acc: 0.4501250919793966 val img acc: 0.3520235467255335 val txt acc: 0.4115673289183223 val_avg_loss: tensor(2.3872)
avg common val loss: tensor(2.1594) avg img val loss: tensor(2.6824) avg txt val loss: tensor(2.3197)
epoch: 50 train_loss: tensor(1.5456, grad_fn=<DivBackward0>) average train loss tensor(1.5456, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5211, grad_fn=<DivBackward0>) avg img loss: tensor(1.2425, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8734, g

val common acc: 0.4692568064753495 val img acc: 0.34095658572479765 val txt acc: 0.42075055187637966 val_avg_loss: tensor(2.4726)
avg common val loss: tensor(2.2308) avg img val loss: tensor(2.8834) avg txt val loss: tensor(2.3038)
epoch: 65 train_loss: tensor(1.1512, grad_fn=<DivBackward0>) average train loss tensor(1.1512, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9939, grad_fn=<DivBackward0>) avg img loss: tensor(0.8801, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5796, grad_fn=<DivBackward0>)
val common acc: 0.4682560706401766 val img acc: 0.34077998528329656 val txt acc: 0.42160412067696834 val_avg_loss: tensor(2.4823)
avg common val loss: tensor(2.2437) avg img val loss: tensor(2.8972) avg txt val loss: tensor(2.3060)
epoch: 66 train_loss: tensor(1.1511, grad_fn=<DivBackward0>) average train loss tensor(1.1511, grad_fn=<DivBackward0>)
avg common loss: tensor(1.0001, grad_fn=<DivBackward0>) avg img loss: tensor(0.8690, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5841

val common acc: 0.4688447387785136 val img acc: 0.33536423841059604 val txt acc: 0.42042678440029435 val_avg_loss: tensor(2.6360)
avg common val loss: tensor(2.4435) avg img val loss: tensor(3.1067) avg txt val loss: tensor(2.3579)
epoch: 81 train_loss: tensor(0.9129, grad_fn=<DivBackward0>) average train loss tensor(0.9129, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7125, grad_fn=<DivBackward0>) avg img loss: tensor(0.6721, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3541, grad_fn=<DivBackward0>)
val common acc: 0.4686092715231788 val img acc: 0.3356291390728477 val txt acc: 0.420103016924209 val_avg_loss: tensor(2.6464)
avg common val loss: tensor(2.4577) avg img val loss: tensor(3.1185) avg txt val loss: tensor(2.3629)
epoch: 82 train_loss: tensor(0.8816, grad_fn=<DivBackward0>) average train loss tensor(0.8816, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6822, grad_fn=<DivBackward0>) avg img loss: tensor(0.6220, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3406, g

val common acc: 0.468962472406181 val img acc: 0.3355408388520971 val txt acc: 0.4181015452538631 val_avg_loss: tensor(2.8102)
avg common val loss: tensor(2.6679) avg img val loss: tensor(3.3197) avg txt val loss: tensor(2.4431)
epoch: 97 train_loss: tensor(0.7316, grad_fn=<DivBackward0>) average train loss tensor(0.7316, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5088, grad_fn=<DivBackward0>) avg img loss: tensor(0.5060, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1799, grad_fn=<DivBackward0>)
val common acc: 0.4655187637969095 val img acc: 0.33445180279617365 val txt acc: 0.4176600441501104 val_avg_loss: tensor(2.8202)
avg common val loss: tensor(2.6840) avg img val loss: tensor(3.3291) avg txt val loss: tensor(2.4476)
epoch: 98 train_loss: tensor(0.7297, grad_fn=<DivBackward0>) average train loss tensor(0.7297, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4892, grad_fn=<DivBackward0>) avg img loss: tensor(0.5214, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1786, gr

val common acc: 0.4672553348050037 val img acc: 0.3343340691685063 val txt acc: 0.41300956585724796 val_avg_loss: tensor(2.9700)
avg common val loss: tensor(2.8888) avg img val loss: tensor(3.4791) avg txt val loss: tensor(2.5421)
epoch: 113 train_loss: tensor(0.6041, grad_fn=<DivBackward0>) average train loss tensor(0.6041, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3542, grad_fn=<DivBackward0>) avg img loss: tensor(0.4177, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0404, grad_fn=<DivBackward0>)
val common acc: 0.4682560706401766 val img acc: 0.3346578366445916 val txt acc: 0.4133038999264165 val_avg_loss: tensor(2.9817)
avg common val loss: tensor(2.9044) avg img val loss: tensor(3.4920) avg txt val loss: tensor(2.5486)
epoch: 114 train_loss: tensor(0.5979, grad_fn=<DivBackward0>) average train loss tensor(0.5979, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3329, grad_fn=<DivBackward0>) avg img loss: tensor(0.4228, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0380,

val common acc: 0.46713760117733627 val img acc: 0.33398086828550405 val txt acc: 0.4108903605592347 val_avg_loss: tensor(3.0903)
avg common val loss: tensor(3.0478) avg img val loss: tensor(3.5905) avg txt val loss: tensor(2.6325)
epoch: 129 train_loss: tensor(0.5282, grad_fn=<DivBackward0>) average train loss tensor(0.5282, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2746, grad_fn=<DivBackward0>) avg img loss: tensor(0.3760, grad_fn=<DivBackward0>) avg txt loss: tensor(0.9340, grad_fn=<DivBackward0>)
val common acc: 0.46769683590875644 val img acc: 0.3337748344370861 val txt acc: 0.41171449595290655 val_avg_loss: tensor(3.1007)
avg common val loss: tensor(3.0602) avg img val loss: tensor(3.6040) avg txt val loss: tensor(2.6380)
epoch: 130 train_loss: tensor(0.5341, grad_fn=<DivBackward0>) average train loss tensor(0.5341, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2662, grad_fn=<DivBackward0>) avg img loss: tensor(0.3771, grad_fn=<DivBackward0>) avg txt loss: tensor(0.95

val common acc: 0.4666372332597498 val img acc: 0.33224429727740984 val txt acc: 0.4078292862398823 val_avg_loss: tensor(3.1906)
avg common val loss: tensor(3.1737) avg img val loss: tensor(3.6795) avg txt val loss: tensor(2.7187)
epoch: 145 train_loss: tensor(0.4736, grad_fn=<DivBackward0>) average train loss tensor(0.4736, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2139, grad_fn=<DivBackward0>) avg img loss: tensor(0.3363, grad_fn=<DivBackward0>) avg txt loss: tensor(0.8708, grad_fn=<DivBackward0>)
val common acc: 0.4656364974245769 val img acc: 0.33277409860191315 val txt acc: 0.40747608535688007 val_avg_loss: tensor(3.1959)
avg common val loss: tensor(3.1815) avg img val loss: tensor(3.6834) avg txt val loss: tensor(2.7229)
epoch: 146 train_loss: tensor(0.4845, grad_fn=<DivBackward0>) average train loss tensor(0.4845, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2436, grad_fn=<DivBackward0>) avg img loss: tensor(0.3350, grad_fn=<DivBackward0>) avg txt loss: tensor(0.874

val common acc: 0.4647240618101545 val img acc: 0.3312141280353201 val txt acc: 0.4062104488594555 val_avg_loss: tensor(3.2662)
avg common val loss: tensor(3.2782) avg img val loss: tensor(3.7209) avg txt val loss: tensor(2.7995)
epoch: 161 train_loss: tensor(0.4304, grad_fn=<DivBackward0>) average train loss tensor(0.4304, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1838, grad_fn=<DivBackward0>) avg img loss: tensor(0.3154, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7920, grad_fn=<DivBackward0>)
val common acc: 0.46513612950699046 val img acc: 0.3308609271523179 val txt acc: 0.40650478292862396 val_avg_loss: tensor(3.2714)
avg common val loss: tensor(3.2868) avg img val loss: tensor(3.7251) avg txt val loss: tensor(2.8023)
epoch: 162 train_loss: tensor(0.4354, grad_fn=<DivBackward0>) average train loss tensor(0.4354, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1907, grad_fn=<DivBackward0>) avg img loss: tensor(0.3212, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7944

val common acc: 0.46578366445916114 val img acc: 0.3318027961736571 val txt acc: 0.40479764532744666 val_avg_loss: tensor(3.3256)
avg common val loss: tensor(3.3532) avg img val loss: tensor(3.7581) avg txt val loss: tensor(2.8654)
epoch: 177 train_loss: tensor(0.4109, grad_fn=<DivBackward0>) average train loss tensor(0.4109, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1717, grad_fn=<DivBackward0>) avg img loss: tensor(0.3111, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7498, grad_fn=<DivBackward0>)
val common acc: 0.46516556291390726 val img acc: 0.33230316409124355 val txt acc: 0.40503311258278146 val_avg_loss: tensor(3.3323)
avg common val loss: tensor(3.3632) avg img val loss: tensor(3.7628) avg txt val loss: tensor(2.8710)
epoch: 178 train_loss: tensor(0.4068, grad_fn=<DivBackward0>) average train loss tensor(0.4068, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1612, grad_fn=<DivBackward0>) avg img loss: tensor(0.2968, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7

val common acc: 0.46445916114790287 val img acc: 0.3303605592347314 val txt acc: 0.4048270787343635 val_avg_loss: tensor(3.3678)
avg common val loss: tensor(3.3702) avg img val loss: tensor(3.7986) avg txt val loss: tensor(2.9344)
epoch: 193 train_loss: tensor(0.3828, grad_fn=<DivBackward0>) average train loss tensor(0.3828, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1469, grad_fn=<DivBackward0>) avg img loss: tensor(0.2946, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7070, grad_fn=<DivBackward0>)
val common acc: 0.464429727740986 val img acc: 0.33015452538631346 val txt acc: 0.40467991169977924 val_avg_loss: tensor(3.3704)
avg common val loss: tensor(3.3719) avg img val loss: tensor(3.8013) avg txt val loss: tensor(2.9379)
epoch: 194 train_loss: tensor(0.3846, grad_fn=<DivBackward0>) average train loss tensor(0.3846, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1408, grad_fn=<DivBackward0>) avg img loss: tensor(0.2899, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7232

val common acc: 0.4672553348050037 val img acc: 0.33200883002207504 val txt acc: 0.4017954378219279 val_avg_loss: tensor(3.4157)
avg common val loss: tensor(3.4495) avg img val loss: tensor(3.8053) avg txt val loss: tensor(2.9924)
epoch: 209 train_loss: tensor(0.3705, grad_fn=<DivBackward0>) average train loss tensor(0.3705, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1305, grad_fn=<DivBackward0>) avg img loss: tensor(0.2804, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7007, grad_fn=<DivBackward0>)
val common acc: 0.46678440029433405 val img acc: 0.3324797645327447 val txt acc: 0.4017954378219279 val_avg_loss: tensor(3.4173)
avg common val loss: tensor(3.4502) avg img val loss: tensor(3.8059) avg txt val loss: tensor(2.9959)
epoch: 210 train_loss: tensor(0.3521, grad_fn=<DivBackward0>) average train loss tensor(0.3521, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1063, grad_fn=<DivBackward0>) avg img loss: tensor(0.2755, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6745

val common acc: 0.4634878587196468 val img acc: 0.32980132450331123 val txt acc: 0.4027373068432671 val_avg_loss: tensor(3.4415)
avg common val loss: tensor(3.4670) avg img val loss: tensor(3.8225) avg txt val loss: tensor(3.0352)
epoch: 225 train_loss: tensor(0.3469, grad_fn=<DivBackward0>) average train loss tensor(0.3469, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1119, grad_fn=<DivBackward0>) avg img loss: tensor(0.2716, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6573, grad_fn=<DivBackward0>)
val common acc: 0.4643414275202355 val img acc: 0.3302722590139809 val txt acc: 0.40253127299484914 val_avg_loss: tensor(3.4390)
avg common val loss: tensor(3.4569) avg img val loss: tensor(3.8215) avg txt val loss: tensor(3.0386)
epoch: 226 train_loss: tensor(0.3435, grad_fn=<DivBackward0>) average train loss tensor(0.3435, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1137, grad_fn=<DivBackward0>) avg img loss: tensor(0.2677, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6490

val common acc: 0.46231052244297277 val img acc: 0.3288005886681383 val txt acc: 0.4012362030905077 val_avg_loss: tensor(3.4648)
avg common val loss: tensor(3.4802) avg img val loss: tensor(3.8376) avg txt val loss: tensor(3.0766)
epoch: 241 train_loss: tensor(0.3396, grad_fn=<DivBackward0>) average train loss tensor(0.3396, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1125, grad_fn=<DivBackward0>) avg img loss: tensor(0.2610, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6454, grad_fn=<DivBackward0>)
val common acc: 0.46278145695364237 val img acc: 0.3291832229580574 val txt acc: 0.4016482707873436 val_avg_loss: tensor(3.4653)
avg common val loss: tensor(3.4816) avg img val loss: tensor(3.8356) avg txt val loss: tensor(3.0788)
epoch: 242 train_loss: tensor(0.3376, grad_fn=<DivBackward0>) average train loss tensor(0.3376, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1051, grad_fn=<DivBackward0>) avg img loss: tensor(0.2752, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6324

val common acc: 0.4672553348050037 val img acc: 0.3318616629874908 val txt acc: 0.40114790286975716 val_avg_loss: tensor(3.4805)
avg common val loss: tensor(3.4931) avg img val loss: tensor(3.8383) avg txt val loss: tensor(3.1100)
epoch: 257 train_loss: tensor(0.3326, grad_fn=<DivBackward0>) average train loss tensor(0.3326, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0960, grad_fn=<DivBackward0>) avg img loss: tensor(0.2670, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6347, grad_fn=<DivBackward0>)
val common acc: 0.46707873436350256 val img acc: 0.3315084621044886 val txt acc: 0.40105960264900664 val_avg_loss: tensor(3.4815)
avg common val loss: tensor(3.4911) avg img val loss: tensor(3.8400) avg txt val loss: tensor(3.1132)
epoch: 258 train_loss: tensor(0.3311, grad_fn=<DivBackward0>) average train loss tensor(0.3311, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1016, grad_fn=<DivBackward0>) avg img loss: tensor(0.2596, grad_fn=<DivBackward0>) avg txt loss: tensor(0.632

val common acc: 0.467579102281089 val img acc: 0.33177336276674024 val txt acc: 0.40044150110375276 val_avg_loss: tensor(3.4941)
avg common val loss: tensor(3.5063) avg img val loss: tensor(3.8366) avg txt val loss: tensor(3.1395)
epoch: 273 train_loss: tensor(0.3210, grad_fn=<DivBackward0>) average train loss tensor(0.3210, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0960, grad_fn=<DivBackward0>) avg img loss: tensor(0.2587, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6083, grad_fn=<DivBackward0>)
val common acc: 0.4671670345842531 val img acc: 0.33183222958057396 val txt acc: 0.4005003679175865 val_avg_loss: tensor(3.4956)
avg common val loss: tensor(3.5079) avg img val loss: tensor(3.8378) avg txt val loss: tensor(3.1412)
epoch: 274 train_loss: tensor(0.3246, grad_fn=<DivBackward0>) average train loss tensor(0.3246, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1054, grad_fn=<DivBackward0>) avg img loss: tensor(0.2648, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6036

val common acc: 0.46245768947755705 val img acc: 0.32900662251655627 val txt acc: 0.3983517292126564 val_avg_loss: tensor(3.5184)
avg common val loss: tensor(3.5386) avg img val loss: tensor(3.8434) avg txt val loss: tensor(3.1733)
epoch: 289 train_loss: tensor(0.3099, grad_fn=<DivBackward0>) average train loss tensor(0.3099, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0802, grad_fn=<DivBackward0>) avg img loss: tensor(0.2554, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5940, grad_fn=<DivBackward0>)
val common acc: 0.4618984547461369 val img acc: 0.32868285504047096 val txt acc: 0.3984400294334069 val_avg_loss: tensor(3.5189)
avg common val loss: tensor(3.5388) avg img val loss: tensor(3.8425) avg txt val loss: tensor(3.1752)
epoch: 290 train_loss: tensor(0.3080, grad_fn=<DivBackward0>) average train loss tensor(0.3080, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0862, grad_fn=<DivBackward0>) avg img loss: tensor(0.2513, grad_fn=<DivBackward0>) avg txt loss: tensor(0.586

In [24]:
model = torch_models.NormModelTridentBN(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bn_bs2048_rs42_d128_wd0005_drop_05_2k')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=300,
    writer=writer,
    train_loader=train_loader_2k,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(4.1888, grad_fn=<DivBackward0>) average train loss tensor(4.1888, grad_fn=<DivBackward0>)
avg common loss: tensor(4.0863, grad_fn=<DivBackward0>) avg img loss: tensor(4.2585, grad_fn=<DivBackward0>) avg txt loss: tensor(4.2216, grad_fn=<DivBackward0>)
val common acc: 0.019690949227373068 val img acc: 0.08724061810154525 val txt acc: 0.048653421633554086 val_avg_loss: tensor(3.8843)
avg common val loss: tensor(3.9028) avg img val loss: tensor(3.8571) avg txt val loss: tensor(3.8931)
epoch: 1 train_loss: tensor(3.9500, grad_fn=<DivBackward0>) average train loss tensor(3.9500, grad_fn=<DivBackward0>)
avg common loss: tensor(3.9980, grad_fn=<DivBackward0>) avg img loss: tensor(3.7773, grad_fn=<DivBackward0>) avg txt loss: tensor(4.0747, grad_fn=<DivBackward0>)
val common acc: 0.032259013980868285 val img acc: 0.19293598233995585 val txt acc: 0.0847682119205298 val_avg_loss: tensor(3.8485)
avg common val loss: tensor(3.8989) avg img val loss: tensor(3.7824) avg t

epoch: 16 train_loss: tensor(2.4737, grad_fn=<DivBackward0>) average train loss tensor(2.4737, grad_fn=<DivBackward0>)
avg common loss: tensor(2.8327, grad_fn=<DivBackward0>) avg img loss: tensor(2.0876, grad_fn=<DivBackward0>) avg txt loss: tensor(2.5009, grad_fn=<DivBackward0>)
val common acc: 0.33115526122148636 val img acc: 0.34301692420897717 val txt acc: 0.3591464311994113 val_avg_loss: tensor(2.9331)
avg common val loss: tensor(3.1658) avg img val loss: tensor(2.7153) avg txt val loss: tensor(2.9183)
epoch: 17 train_loss: tensor(2.4277, grad_fn=<DivBackward0>) average train loss tensor(2.4277, grad_fn=<DivBackward0>)
avg common loss: tensor(2.7795, grad_fn=<DivBackward0>) avg img loss: tensor(2.0527, grad_fn=<DivBackward0>) avg txt loss: tensor(2.4510, grad_fn=<DivBackward0>)
val common acc: 0.3423105224429728 val img acc: 0.3459896983075791 val txt acc: 0.36568064753495216 val_avg_loss: tensor(2.8730)
avg common val loss: tensor(3.0759) avg img val loss: tensor(2.6897) avg txt 

val common acc: 0.44571008094186904 val img acc: 0.34875643855776306 val txt acc: 0.4166887417218543 val_avg_loss: tensor(2.4098)
avg common val loss: tensor(2.2553) avg img val loss: tensor(2.5925) avg txt val loss: tensor(2.3818)
epoch: 33 train_loss: tensor(1.6814, grad_fn=<DivBackward0>) average train loss tensor(1.6814, grad_fn=<DivBackward0>)
avg common loss: tensor(1.8545, grad_fn=<DivBackward0>) avg img loss: tensor(1.3282, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8616, grad_fn=<DivBackward0>)
val common acc: 0.4479764532744665 val img acc: 0.3479617365710081 val txt acc: 0.4178072111846946 val_avg_loss: tensor(2.4004)
avg common val loss: tensor(2.2334) avg img val loss: tensor(2.5967) avg txt val loss: tensor(2.3710)
epoch: 34 train_loss: tensor(1.6350, grad_fn=<DivBackward0>) average train loss tensor(1.6350, grad_fn=<DivBackward0>)
avg common loss: tensor(1.8053, grad_fn=<DivBackward0>) avg img loss: tensor(1.2742, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8257, 

val common acc: 0.46501839587932303 val img acc: 0.338925680647535 val txt acc: 0.42013245033112584 val_avg_loss: tensor(2.3654)
avg common val loss: tensor(2.0831) avg img val loss: tensor(2.6906) avg txt val loss: tensor(2.3225)
epoch: 49 train_loss: tensor(1.1717, grad_fn=<DivBackward0>) average train loss tensor(1.1717, grad_fn=<DivBackward0>)
avg common loss: tensor(1.2072, grad_fn=<DivBackward0>) avg img loss: tensor(0.8467, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4611, grad_fn=<DivBackward0>)
val common acc: 0.4661074319352465 val img acc: 0.33898454746136863 val txt acc: 0.41980868285504047 val_avg_loss: tensor(2.3665)
avg common val loss: tensor(2.0768) avg img val loss: tensor(2.6992) avg txt val loss: tensor(2.3234)
epoch: 50 train_loss: tensor(1.1360, grad_fn=<DivBackward0>) average train loss tensor(1.1360, grad_fn=<DivBackward0>)
avg common loss: tensor(1.1665, grad_fn=<DivBackward0>) avg img loss: tensor(0.8052, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4362,

epoch: 65 train_loss: tensor(0.8241, grad_fn=<DivBackward0>) average train loss tensor(0.8241, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7584, grad_fn=<DivBackward0>) avg img loss: tensor(0.5500, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1638, grad_fn=<DivBackward0>)
val common acc: 0.47175864606328183 val img acc: 0.3330389992641648 val txt acc: 0.41468727005150846 val_avg_loss: tensor(2.4176)
avg common val loss: tensor(2.0606) avg img val loss: tensor(2.8275) avg txt val loss: tensor(2.3648)
epoch: 66 train_loss: tensor(0.7901, grad_fn=<DivBackward0>) average train loss tensor(0.7901, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7290, grad_fn=<DivBackward0>) avg img loss: tensor(0.5262, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1151, grad_fn=<DivBackward0>)
val common acc: 0.472317880794702 val img acc: 0.33306843267108166 val txt acc: 0.4140103016924209 val_avg_loss: tensor(2.4235)
avg common val loss: tensor(2.0634) avg img val loss: tensor(2.8379) avg txt v

epoch: 81 train_loss: tensor(0.6093, grad_fn=<DivBackward0>) average train loss tensor(0.6093, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4985, grad_fn=<DivBackward0>) avg img loss: tensor(0.3936, grad_fn=<DivBackward0>) avg txt loss: tensor(0.9356, grad_fn=<DivBackward0>)
val common acc: 0.4741427520235467 val img acc: 0.3299484915378955 val txt acc: 0.40862398822663726 val_avg_loss: tensor(2.4906)
avg common val loss: tensor(2.0939) avg img val loss: tensor(2.9414) avg txt val loss: tensor(2.4364)
epoch: 82 train_loss: tensor(0.5844, grad_fn=<DivBackward0>) average train loss tensor(0.5844, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4619, grad_fn=<DivBackward0>) avg img loss: tensor(0.3986, grad_fn=<DivBackward0>) avg txt loss: tensor(0.8927, grad_fn=<DivBackward0>)
val common acc: 0.474083885209713 val img acc: 0.32933038999264164 val txt acc: 0.4083002207505519 val_avg_loss: tensor(2.4947)
avg common val loss: tensor(2.0964) avg img val loss: tensor(2.9475) avg txt va

val common acc: 0.4758204562178072 val img acc: 0.3315084621044886 val txt acc: 0.4052391464311994 val_avg_loss: tensor(2.5597)
avg common val loss: tensor(2.1443) avg img val loss: tensor(3.0262) avg txt val loss: tensor(2.5085)
epoch: 98 train_loss: tensor(0.4709, grad_fn=<DivBackward0>) average train loss tensor(0.4709, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3052, grad_fn=<DivBackward0>) avg img loss: tensor(0.3242, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7833, grad_fn=<DivBackward0>)
val common acc: 0.4759381898454746 val img acc: 0.33015452538631346 val txt acc: 0.4046504782928624 val_avg_loss: tensor(2.5640)
avg common val loss: tensor(2.1477) avg img val loss: tensor(3.0318) avg txt val loss: tensor(2.5125)
epoch: 99 train_loss: tensor(0.4722, grad_fn=<DivBackward0>) average train loss tensor(0.4722, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3113, grad_fn=<DivBackward0>) avg img loss: tensor(0.3182, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7870, g

val common acc: 0.47596762325239145 val img acc: 0.3275643855776306 val txt acc: 0.40241353936718177 val_avg_loss: tensor(2.6177)
avg common val loss: tensor(2.1884) avg img val loss: tensor(3.0907) avg txt val loss: tensor(2.5739)
epoch: 114 train_loss: tensor(0.4003, grad_fn=<DivBackward0>) average train loss tensor(0.4003, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2140, grad_fn=<DivBackward0>) avg img loss: tensor(0.2857, grad_fn=<DivBackward0>) avg txt loss: tensor(0.7011, grad_fn=<DivBackward0>)
val common acc: 0.47549668874172185 val img acc: 0.32818248712288445 val txt acc: 0.4021486387049301 val_avg_loss: tensor(2.6226)
avg common val loss: tensor(2.1936) avg img val loss: tensor(3.0946) avg txt val loss: tensor(2.5797)
epoch: 115 train_loss: tensor(0.3951, grad_fn=<DivBackward0>) average train loss tensor(0.3951, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2005, grad_fn=<DivBackward0>) avg img loss: tensor(0.2836, grad_fn=<DivBackward0>) avg txt loss: tensor(0.70

val common acc: 0.473701250919794 val img acc: 0.3269168506254599 val txt acc: 0.40097130242825607 val_avg_loss: tensor(2.6661)
avg common val loss: tensor(2.2324) avg img val loss: tensor(3.1343) avg txt val loss: tensor(2.6316)
epoch: 130 train_loss: tensor(0.3524, grad_fn=<DivBackward0>) average train loss tensor(0.3524, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1453, grad_fn=<DivBackward0>) avg img loss: tensor(0.2655, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6463, grad_fn=<DivBackward0>)
val common acc: 0.4743782192788815 val img acc: 0.32388520971302426 val txt acc: 0.4008830022075055 val_avg_loss: tensor(2.6687)
avg common val loss: tensor(2.2346) avg img val loss: tensor(3.1377) avg txt val loss: tensor(2.6338)
epoch: 131 train_loss: tensor(0.3513, grad_fn=<DivBackward0>) average train loss tensor(0.3513, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1445, grad_fn=<DivBackward0>) avg img loss: tensor(0.2644, grad_fn=<DivBackward0>) avg txt loss: tensor(0.6449,

val common acc: 0.4722590139808683 val img acc: 0.32700515084621046 val txt acc: 0.3996467991169978 val_avg_loss: tensor(2.7089)
avg common val loss: tensor(2.2790) avg img val loss: tensor(3.1651) avg txt val loss: tensor(2.6826)
epoch: 146 train_loss: tensor(0.3185, grad_fn=<DivBackward0>) average train loss tensor(0.3185, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1071, grad_fn=<DivBackward0>) avg img loss: tensor(0.2546, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5937, grad_fn=<DivBackward0>)
val common acc: 0.4719352465047829 val img acc: 0.32747608535688005 val txt acc: 0.399793966151582 val_avg_loss: tensor(2.7138)
avg common val loss: tensor(2.2851) avg img val loss: tensor(3.1708) avg txt val loss: tensor(2.6854)
epoch: 147 train_loss: tensor(0.3186, grad_fn=<DivBackward0>) average train loss tensor(0.3186, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1082, grad_fn=<DivBackward0>) avg img loss: tensor(0.2534, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5942,

val common acc: 0.46996320824135396 val img acc: 0.3230022075055188 val txt acc: 0.3985577630610743 val_avg_loss: tensor(2.7349)
avg common val loss: tensor(2.3049) avg img val loss: tensor(3.1861) avg txt val loss: tensor(2.7137)
epoch: 162 train_loss: tensor(0.2983, grad_fn=<DivBackward0>) average train loss tensor(0.2983, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0826, grad_fn=<DivBackward0>) avg img loss: tensor(0.2374, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5751, grad_fn=<DivBackward0>)
val common acc: 0.4706990434142752 val img acc: 0.3230905077262693 val txt acc: 0.39829286239882267 val_avg_loss: tensor(2.7361)
avg common val loss: tensor(2.3071) avg img val loss: tensor(3.1841) avg txt val loss: tensor(2.7171)
epoch: 163 train_loss: tensor(0.2979, grad_fn=<DivBackward0>) average train loss tensor(0.2979, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0858, grad_fn=<DivBackward0>) avg img loss: tensor(0.2421, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5657

val common acc: 0.4736718175128771 val img acc: 0.3252980132450331 val txt acc: 0.39876379690949226 val_avg_loss: tensor(2.7490)
avg common val loss: tensor(2.3235) avg img val loss: tensor(3.1829) avg txt val loss: tensor(2.7405)
epoch: 178 train_loss: tensor(0.2788, grad_fn=<DivBackward0>) average train loss tensor(0.2788, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0659, grad_fn=<DivBackward0>) avg img loss: tensor(0.2320, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5384, grad_fn=<DivBackward0>)
val common acc: 0.47405445180279615 val img acc: 0.32538631346578367 val txt acc: 0.39861662987490804 val_avg_loss: tensor(2.7513)
avg common val loss: tensor(2.3247) avg img val loss: tensor(3.1859) avg txt val loss: tensor(2.7433)
epoch: 179 train_loss: tensor(0.2802, grad_fn=<DivBackward0>) average train loss tensor(0.2802, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0634, grad_fn=<DivBackward0>) avg img loss: tensor(0.2347, grad_fn=<DivBackward0>) avg txt loss: tensor(0.54

val common acc: 0.47105224429727743 val img acc: 0.32426784400294334 val txt acc: 0.39729212656364976 val_avg_loss: tensor(2.7708)
avg common val loss: tensor(2.3501) avg img val loss: tensor(3.1961) avg txt val loss: tensor(2.7663)
epoch: 194 train_loss: tensor(0.2777, grad_fn=<DivBackward0>) average train loss tensor(0.2777, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0612, grad_fn=<DivBackward0>) avg img loss: tensor(0.2325, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5395, grad_fn=<DivBackward0>)
val common acc: 0.47137601177336275 val img acc: 0.3240029433406917 val txt acc: 0.3967917586460633 val_avg_loss: tensor(2.7708)
avg common val loss: tensor(2.3491) avg img val loss: tensor(3.1968) avg txt val loss: tensor(2.7666)
epoch: 195 train_loss: tensor(0.2745, grad_fn=<DivBackward0>) average train loss tensor(0.2745, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0589, grad_fn=<DivBackward0>) avg img loss: tensor(0.2185, grad_fn=<DivBackward0>) avg txt loss: tensor(0.54

val common acc: 0.4699926416482708 val img acc: 0.3229727740986019 val txt acc: 0.39655629139072845 val_avg_loss: tensor(2.7779)
avg common val loss: tensor(2.3627) avg img val loss: tensor(3.1827) avg txt val loss: tensor(2.7884)
epoch: 210 train_loss: tensor(0.2670, grad_fn=<DivBackward0>) average train loss tensor(0.2670, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0514, grad_fn=<DivBackward0>) avg img loss: tensor(0.2305, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5192, grad_fn=<DivBackward0>)
val common acc: 0.4715526122148639 val img acc: 0.3227961736571008 val txt acc: 0.3970272259013981 val_avg_loss: tensor(2.7753)
avg common val loss: tensor(2.3582) avg img val loss: tensor(3.1786) avg txt val loss: tensor(2.7891)
epoch: 211 train_loss: tensor(0.2590, grad_fn=<DivBackward0>) average train loss tensor(0.2590, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0464, grad_fn=<DivBackward0>) avg img loss: tensor(0.2220, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5087,

val common acc: 0.47128771155261223 val img acc: 0.32341427520235466 val txt acc: 0.3959970566593083 val_avg_loss: tensor(2.7895)
avg common val loss: tensor(2.3746) avg img val loss: tensor(3.1978) avg txt val loss: tensor(2.7961)
epoch: 226 train_loss: tensor(0.2577, grad_fn=<DivBackward0>) average train loss tensor(0.2577, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0425, grad_fn=<DivBackward0>) avg img loss: tensor(0.2176, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5130, grad_fn=<DivBackward0>)
val common acc: 0.47066961000735835 val img acc: 0.32259013980868284 val txt acc: 0.39584988962472406 val_avg_loss: tensor(2.7889)
avg common val loss: tensor(2.3740) avg img val loss: tensor(3.1953) avg txt val loss: tensor(2.7976)
epoch: 227 train_loss: tensor(0.2585, grad_fn=<DivBackward0>) average train loss tensor(0.2585, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0431, grad_fn=<DivBackward0>) avg img loss: tensor(0.2219, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5

val common acc: 0.47043414275202355 val img acc: 0.32120676968359085 val txt acc: 0.3932597498160412 val_avg_loss: tensor(2.7962)
avg common val loss: tensor(2.3861) avg img val loss: tensor(3.1954) avg txt val loss: tensor(2.8071)
epoch: 242 train_loss: tensor(0.2538, grad_fn=<DivBackward0>) average train loss tensor(0.2538, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0392, grad_fn=<DivBackward0>) avg img loss: tensor(0.2247, grad_fn=<DivBackward0>) avg txt loss: tensor(0.4976, grad_fn=<DivBackward0>)
val common acc: 0.4704635761589404 val img acc: 0.3220309050772627 val txt acc: 0.39417218543046356 val_avg_loss: tensor(2.7944)
avg common val loss: tensor(2.3837) avg img val loss: tensor(3.1923) avg txt val loss: tensor(2.8073)
epoch: 243 train_loss: tensor(0.2582, grad_fn=<DivBackward0>) average train loss tensor(0.2582, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0435, grad_fn=<DivBackward0>) avg img loss: tensor(0.2211, grad_fn=<DivBackward0>) avg txt loss: tensor(0.509

val common acc: 0.47196467991169977 val img acc: 0.32206033848417953 val txt acc: 0.3951729212656365 val_avg_loss: tensor(2.8033)
avg common val loss: tensor(2.3905) avg img val loss: tensor(3.1970) avg txt val loss: tensor(2.8223)
epoch: 258 train_loss: tensor(0.2546, grad_fn=<DivBackward0>) average train loss tensor(0.2546, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0388, grad_fn=<DivBackward0>) avg img loss: tensor(0.2169, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5082, grad_fn=<DivBackward0>)
val common acc: 0.4722295805739514 val img acc: 0.3223841059602649 val txt acc: 0.39479028697571744 val_avg_loss: tensor(2.8001)
avg common val loss: tensor(2.3884) avg img val loss: tensor(3.1905) avg txt val loss: tensor(2.8213)
epoch: 259 train_loss: tensor(0.2490, grad_fn=<DivBackward0>) average train loss tensor(0.2490, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0325, grad_fn=<DivBackward0>) avg img loss: tensor(0.2144, grad_fn=<DivBackward0>) avg txt loss: tensor(0.500

val common acc: 0.4686681383370125 val img acc: 0.3216777041942605 val txt acc: 0.3936129506990434 val_avg_loss: tensor(2.8012)
avg common val loss: tensor(2.4044) avg img val loss: tensor(3.1752) avg txt val loss: tensor(2.8239)
epoch: 274 train_loss: tensor(0.2480, grad_fn=<DivBackward0>) average train loss tensor(0.2480, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0310, grad_fn=<DivBackward0>) avg img loss: tensor(0.2110, grad_fn=<DivBackward0>) avg txt loss: tensor(0.5019, grad_fn=<DivBackward0>)
val common acc: 0.4691096394407653 val img acc: 0.3225312729948491 val txt acc: 0.39408388520971305 val_avg_loss: tensor(2.7993)
avg common val loss: tensor(2.4028) avg img val loss: tensor(3.1728) avg txt val loss: tensor(2.8223)
epoch: 275 train_loss: tensor(0.2446, grad_fn=<DivBackward0>) average train loss tensor(0.2446, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0296, grad_fn=<DivBackward0>) avg img loss: tensor(0.2152, grad_fn=<DivBackward0>) avg txt loss: tensor(0.4891,

val common acc: 0.4696983075791023 val img acc: 0.321972038263429 val txt acc: 0.3925533480500368 val_avg_loss: tensor(2.8045)
avg common val loss: tensor(2.4102) avg img val loss: tensor(3.1826) avg txt val loss: tensor(2.8209)
epoch: 290 train_loss: tensor(0.2441, grad_fn=<DivBackward0>) average train loss tensor(0.2441, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0282, grad_fn=<DivBackward0>) avg img loss: tensor(0.2099, grad_fn=<DivBackward0>) avg txt loss: tensor(0.4941, grad_fn=<DivBackward0>)
val common acc: 0.46893303899926414 val img acc: 0.32311994113318615 val txt acc: 0.392906548933039 val_avg_loss: tensor(2.8040)
avg common val loss: tensor(2.4106) avg img val loss: tensor(3.1826) avg txt val loss: tensor(2.8188)
epoch: 291 train_loss: tensor(0.2432, grad_fn=<DivBackward0>) average train loss tensor(0.2432, grad_fn=<DivBackward0>)
avg common loss: tensor(0.0299, grad_fn=<DivBackward0>) avg img loss: tensor(0.2104, grad_fn=<DivBackward0>) avg txt loss: tensor(0.4894, 

In [15]:
model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs2048_rs42_d128_wd0005_drop_05_20k')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=100,
    writer=writer,
    train_loader=train_loader_20k,
    val_loader=val_loader
)

[autoreload of pytorch.torch_models failed: Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/usr/local/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 450, in superreload
    update_generic(old_obj, new_obj)
  File "/usr/local/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 387, in update_generic
    update(a, b)
  File "/usr/local/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 357, in update_class
    update_instances(old, new)
  File "/usr/local/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 317, in update_instances
    update_instances(old, new, obj, visited)
  File "/usr/local/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 317, in update_instances
    update_instances(old, new, obj, visited)
  File "/usr/local/lib/python3.7/site-packages/IPyt

epoch: 0 train_loss: tensor(3.4513, grad_fn=<DivBackward0>) average train loss tensor(3.6764, grad_fn=<DivBackward0>)
avg common loss: tensor(3.8209, grad_fn=<DivBackward0>) avg img loss: tensor(3.4996, grad_fn=<DivBackward0>) avg txt loss: tensor(3.7087, grad_fn=<DivBackward0>)
val common acc: 0.05612950699043414 val img acc: 0.2790581309786608 val txt acc: 0.17236203090507726 val_avg_loss: tensor(3.3131)
avg common val loss: tensor(3.6301) avg img val loss: tensor(2.9089) avg txt val loss: tensor(3.4003)
epoch: 1 train_loss: tensor(3.0608, grad_fn=<DivBackward0>) average train loss tensor(3.2024, grad_fn=<DivBackward0>)
avg common loss: tensor(3.5223, grad_fn=<DivBackward0>) avg img loss: tensor(2.8278, grad_fn=<DivBackward0>) avg txt loss: tensor(3.2572, grad_fn=<DivBackward0>)
val common acc: 0.2543046357615894 val img acc: 0.35481972038263426 val txt acc: 0.29368653421633556 val_avg_loss: tensor(2.8871)
avg common val loss: tensor(3.1978) avg img val loss: tensor(2.5284) avg txt v

val common acc: 0.5828108903605592 val img acc: 0.4193377483443709 val txt acc: 0.4755849889624724 val_avg_loss: tensor(1.9318)
avg common val loss: tensor(1.5666) avg img val loss: tensor(2.2032) avg txt val loss: tensor(2.0255)
epoch: 17 train_loss: tensor(1.7124, grad_fn=<DivBackward0>) average train loss tensor(1.7686, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4570, grad_fn=<DivBackward0>) avg img loss: tensor(1.9069, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9419, grad_fn=<DivBackward0>)
val common acc: 0.5847534952170714 val img acc: 0.41927888153053716 val txt acc: 0.47420161883738043 val_avg_loss: tensor(1.9316)
avg common val loss: tensor(1.5660) avg img val loss: tensor(2.2027) avg txt val loss: tensor(2.0259)
epoch: 18 train_loss: tensor(1.7014, grad_fn=<DivBackward0>) average train loss tensor(1.7492, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4310, grad_fn=<DivBackward0>) avg img loss: tensor(1.8823, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9343, 

epoch: 33 train_loss: tensor(1.4313, grad_fn=<DivBackward0>) average train loss tensor(1.5241, grad_fn=<DivBackward0>)
avg common loss: tensor(1.1651, grad_fn=<DivBackward0>) avg img loss: tensor(1.6141, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7930, grad_fn=<DivBackward0>)
val common acc: 0.5862840323767476 val img acc: 0.4210448859455482 val txt acc: 0.47823399558498897 val_avg_loss: tensor(1.9630)
avg common val loss: tensor(1.5957) avg img val loss: tensor(2.2678) avg txt val loss: tensor(2.0256)
epoch: 34 train_loss: tensor(1.4236, grad_fn=<DivBackward0>) average train loss tensor(1.5159, grad_fn=<DivBackward0>)
avg common loss: tensor(1.1515, grad_fn=<DivBackward0>) avg img loss: tensor(1.6069, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7893, grad_fn=<DivBackward0>)
val common acc: 0.5863428991905814 val img acc: 0.4177777777777778 val txt acc: 0.47832229580573954 val_avg_loss: tensor(1.9672)
avg common val loss: tensor(1.6036) avg img val loss: tensor(2.2715) avg txt v

val common acc: 0.5826637233259749 val img acc: 0.4143635025754231 val txt acc: 0.47885209713024285 val_avg_loss: tensor(2.0318)
avg common val loss: tensor(1.6912) avg img val loss: tensor(2.3655) avg txt val loss: tensor(2.0385)
epoch: 50 train_loss: tensor(1.2677, grad_fn=<DivBackward0>) average train loss tensor(1.3595, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9748, grad_fn=<DivBackward0>) avg img loss: tensor(1.3992, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7046, grad_fn=<DivBackward0>)
val common acc: 0.5817218543046357 val img acc: 0.4134216335540839 val txt acc: 0.47823399558498897 val_avg_loss: tensor(2.0310)
avg common val loss: tensor(1.6862) avg img val loss: tensor(2.3669) avg txt val loss: tensor(2.0399)
epoch: 51 train_loss: tensor(1.2564, grad_fn=<DivBackward0>) average train loss tensor(1.3466, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9684, grad_fn=<DivBackward0>) avg img loss: tensor(1.3835, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6880, 

epoch: 66 train_loss: tensor(1.1271, grad_fn=<DivBackward0>) average train loss tensor(1.2361, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8365, grad_fn=<DivBackward0>) avg img loss: tensor(1.2360, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6359, grad_fn=<DivBackward0>)
val common acc: 0.5763944076526858 val img acc: 0.409654157468727 val txt acc: 0.4787343635025754 val_avg_loss: tensor(2.0935)
avg common val loss: tensor(1.7728) avg img val loss: tensor(2.4524) avg txt val loss: tensor(2.0554)
epoch: 67 train_loss: tensor(1.1282, grad_fn=<DivBackward0>) average train loss tensor(1.2282, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8312, grad_fn=<DivBackward0>) avg img loss: tensor(1.2296, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6239, grad_fn=<DivBackward0>)
val common acc: 0.576953642384106 val img acc: 0.4097718910963944 val txt acc: 0.47864606328182485 val_avg_loss: tensor(2.1038)
avg common val loss: tensor(1.7907) avg img val loss: tensor(2.4627) avg txt val 

val common acc: 0.5726563649742458 val img acc: 0.40856512141280354 val txt acc: 0.47932303164091244 val_avg_loss: tensor(2.1560)
avg common val loss: tensor(1.8587) avg img val loss: tensor(2.5378) avg txt val loss: tensor(2.0716)
epoch: 83 train_loss: tensor(1.0130, grad_fn=<DivBackward0>) average train loss tensor(1.1320, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7207, grad_fn=<DivBackward0>) avg img loss: tensor(1.1050, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5701, grad_fn=<DivBackward0>)
val common acc: 0.5720676968359087 val img acc: 0.4094775570272259 val txt acc: 0.47782192788815303 val_avg_loss: tensor(2.1633)
avg common val loss: tensor(1.8706) avg img val loss: tensor(2.5468) avg txt val loss: tensor(2.0725)
epoch: 84 train_loss: tensor(1.0301, grad_fn=<DivBackward0>) average train loss tensor(1.1328, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7276, grad_fn=<DivBackward0>) avg img loss: tensor(1.1016, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5692,

epoch: 99 train_loss: tensor(0.9651, grad_fn=<DivBackward0>) average train loss tensor(1.0583, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6471, grad_fn=<DivBackward0>) avg img loss: tensor(0.9971, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5307, grad_fn=<DivBackward0>)
val common acc: 0.5669757174392936 val img acc: 0.40382634289919056 val txt acc: 0.4775570272259014 val_avg_loss: tensor(2.2177)
avg common val loss: tensor(1.9500) avg img val loss: tensor(2.6174) avg txt val loss: tensor(2.0857)


In [17]:
model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs2048_rs42_d128_wd0005_drop_05_20k_300')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=300,
    writer=writer,
    train_loader=train_loader_20k,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(3.4447, grad_fn=<DivBackward0>) average train loss tensor(3.6726, grad_fn=<DivBackward0>)
avg common loss: tensor(3.8194, grad_fn=<DivBackward0>) avg img loss: tensor(3.4957, grad_fn=<DivBackward0>) avg txt loss: tensor(3.7026, grad_fn=<DivBackward0>)
val common acc: 0.08532744665194997 val img acc: 0.2706401766004415 val txt acc: 0.15602649006622515 val_avg_loss: tensor(3.3243)
avg common val loss: tensor(3.6383) avg img val loss: tensor(2.9335) avg txt val loss: tensor(3.4012)
epoch: 1 train_loss: tensor(3.0723, grad_fn=<DivBackward0>) average train loss tensor(3.2135, grad_fn=<DivBackward0>)
avg common loss: tensor(3.5362, grad_fn=<DivBackward0>) avg img loss: tensor(2.8380, grad_fn=<DivBackward0>) avg txt loss: tensor(3.2661, grad_fn=<DivBackward0>)
val common acc: 0.23428991905813099 val img acc: 0.35252391464311994 val txt acc: 0.28008830022075054 val_avg_loss: tensor(2.9037)
avg common val loss: tensor(3.2257) avg img val loss: tensor(2.5397) avg txt 

val common acc: 0.5856953642384106 val img acc: 0.4242236938925681 val txt acc: 0.47567328918322294 val_avg_loss: tensor(1.9266)
avg common val loss: tensor(1.5592) avg img val loss: tensor(2.1956) avg txt val loss: tensor(2.0249)
epoch: 17 train_loss: tensor(1.7101, grad_fn=<DivBackward0>) average train loss tensor(1.7661, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4515, grad_fn=<DivBackward0>) avg img loss: tensor(1.9022, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9445, grad_fn=<DivBackward0>)
val common acc: 0.585813097866078 val img acc: 0.42478292862398825 val txt acc: 0.47526122148638705 val_avg_loss: tensor(1.9273)
avg common val loss: tensor(1.5599) avg img val loss: tensor(2.1971) avg txt val loss: tensor(2.0248)
epoch: 18 train_loss: tensor(1.6870, grad_fn=<DivBackward0>) average train loss tensor(1.7469, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4302, grad_fn=<DivBackward0>) avg img loss: tensor(1.8731, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9374, 

epoch: 33 train_loss: tensor(1.4577, grad_fn=<DivBackward0>) average train loss tensor(1.5304, grad_fn=<DivBackward0>)
avg common loss: tensor(1.1694, grad_fn=<DivBackward0>) avg img loss: tensor(1.6188, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8028, grad_fn=<DivBackward0>)
val common acc: 0.5878440029433407 val img acc: 0.42204562178072114 val txt acc: 0.47908756438557765 val_avg_loss: tensor(1.9611)
avg common val loss: tensor(1.5923) avg img val loss: tensor(2.2653) avg txt val loss: tensor(2.0256)
epoch: 34 train_loss: tensor(1.4665, grad_fn=<DivBackward0>) average train loss tensor(1.5168, grad_fn=<DivBackward0>)
avg common loss: tensor(1.1531, grad_fn=<DivBackward0>) avg img loss: tensor(1.6047, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7927, grad_fn=<DivBackward0>)
val common acc: 0.586813833701251 val img acc: 0.4208094186902134 val txt acc: 0.47955849889624724 val_avg_loss: tensor(1.9634)
avg common val loss: tensor(1.5941) avg img val loss: tensor(2.2710) avg txt v

val common acc: 0.5831640912435614 val img acc: 0.4156291390728477 val txt acc: 0.47970566593083147 val_avg_loss: tensor(2.0255)
avg common val loss: tensor(1.6759) avg img val loss: tensor(2.3631) avg txt val loss: tensor(2.0375)
epoch: 50 train_loss: tensor(1.2600, grad_fn=<DivBackward0>) average train loss tensor(1.3544, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9721, grad_fn=<DivBackward0>) avg img loss: tensor(1.3872, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7038, grad_fn=<DivBackward0>)
val common acc: 0.5836350257542311 val img acc: 0.4144812362030905 val txt acc: 0.4799705665930831 val_avg_loss: tensor(2.0317)
avg common val loss: tensor(1.6788) avg img val loss: tensor(2.3747) avg txt val loss: tensor(2.0416)
epoch: 51 train_loss: tensor(1.2575, grad_fn=<DivBackward0>) average train loss tensor(1.3531, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9650, grad_fn=<DivBackward0>) avg img loss: tensor(1.3924, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7018, g

epoch: 66 train_loss: tensor(1.1351, grad_fn=<DivBackward0>) average train loss tensor(1.2338, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8329, grad_fn=<DivBackward0>) avg img loss: tensor(1.2301, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6384, grad_fn=<DivBackward0>)
val common acc: 0.5782192788815306 val img acc: 0.41189109639440763 val txt acc: 0.4775275938189846 val_avg_loss: tensor(2.0967)
avg common val loss: tensor(1.7729) avg img val loss: tensor(2.4630) avg txt val loss: tensor(2.0541)
epoch: 67 train_loss: tensor(1.1143, grad_fn=<DivBackward0>) average train loss tensor(1.2281, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8299, grad_fn=<DivBackward0>) avg img loss: tensor(1.2215, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6329, grad_fn=<DivBackward0>)
val common acc: 0.5764238410596026 val img acc: 0.4118027961736571 val txt acc: 0.4777630610743194 val_avg_loss: tensor(2.0979)
avg common val loss: tensor(1.7745) avg img val loss: tensor(2.4628) avg txt va

val common acc: 0.5709492273730684 val img acc: 0.40618101545253865 val txt acc: 0.47670345842531275 val_avg_loss: tensor(2.1527)
avg common val loss: tensor(1.8507) avg img val loss: tensor(2.5370) avg txt val loss: tensor(2.0705)
epoch: 83 train_loss: tensor(1.0268, grad_fn=<DivBackward0>) average train loss tensor(1.1452, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7406, grad_fn=<DivBackward0>) avg img loss: tensor(1.1031, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5919, grad_fn=<DivBackward0>)
val common acc: 0.5725680647534952 val img acc: 0.40862398822663726 val txt acc: 0.4765268579838116 val_avg_loss: tensor(2.1565)
avg common val loss: tensor(1.8580) avg img val loss: tensor(2.5426) avg txt val loss: tensor(2.0691)
epoch: 84 train_loss: tensor(1.0406, grad_fn=<DivBackward0>) average train loss tensor(1.1400, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7321, grad_fn=<DivBackward0>) avg img loss: tensor(1.1069, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5811,

val common acc: 0.5710963944076527 val img acc: 0.40568064753495214 val txt acc: 0.47626195732155996 val_avg_loss: tensor(2.1997)
avg common val loss: tensor(1.9139) avg img val loss: tensor(2.6005) avg txt val loss: tensor(2.0847)
epoch: 99 train_loss: tensor(0.9673, grad_fn=<DivBackward0>) average train loss tensor(1.0698, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6551, grad_fn=<DivBackward0>) avg img loss: tensor(1.0123, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5419, grad_fn=<DivBackward0>)
val common acc: 0.570242825607064 val img acc: 0.40712288447387784 val txt acc: 0.47688005886681384 val_avg_loss: tensor(2.2062)
avg common val loss: tensor(1.9248) avg img val loss: tensor(2.6068) avg txt val loss: tensor(2.0868)
epoch: 100 train_loss: tensor(0.9491, grad_fn=<DivBackward0>) average train loss tensor(1.0594, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6489, grad_fn=<DivBackward0>) avg img loss: tensor(0.9964, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5330

val common acc: 0.5635025754231052 val img acc: 0.4036497424576895 val txt acc: 0.4782634289919058 val_avg_loss: tensor(2.2531)
avg common val loss: tensor(1.9958) avg img val loss: tensor(2.6655) avg txt val loss: tensor(2.0978)
epoch: 115 train_loss: tensor(0.9191, grad_fn=<DivBackward0>) average train loss tensor(1.0022, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5864, grad_fn=<DivBackward0>) avg img loss: tensor(0.9224, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4979, grad_fn=<DivBackward0>)
val common acc: 0.5622075055187638 val img acc: 0.40403237674760856 val txt acc: 0.4770860927152318 val_avg_loss: tensor(2.2554)
avg common val loss: tensor(1.9937) avg img val loss: tensor(2.6721) avg txt val loss: tensor(2.1004)
epoch: 116 train_loss: tensor(0.8877, grad_fn=<DivBackward0>) average train loss tensor(0.9963, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5786, grad_fn=<DivBackward0>) avg img loss: tensor(0.9213, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4890,

val common acc: 0.564532744665195 val img acc: 0.40432671081677707 val txt acc: 0.4765857247976453 val_avg_loss: tensor(2.2866)
avg common val loss: tensor(2.0446) avg img val loss: tensor(2.7045) avg txt val loss: tensor(2.1108)
epoch: 131 train_loss: tensor(0.8525, grad_fn=<DivBackward0>) average train loss tensor(0.9512, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5265, grad_fn=<DivBackward0>) avg img loss: tensor(0.8572, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4699, grad_fn=<DivBackward0>)
val common acc: 0.5631199411331862 val img acc: 0.40297277409860194 val txt acc: 0.4769094922737307 val_avg_loss: tensor(2.2921)
avg common val loss: tensor(2.0547) avg img val loss: tensor(2.7096) avg txt val loss: tensor(2.1122)
epoch: 132 train_loss: tensor(0.8448, grad_fn=<DivBackward0>) average train loss tensor(0.9487, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5270, grad_fn=<DivBackward0>) avg img loss: tensor(0.8538, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4653,

val common acc: 0.5595584988962472 val img acc: 0.403944076526858 val txt acc: 0.475467255334805 val_avg_loss: tensor(2.3300)
avg common val loss: tensor(2.1195) avg img val loss: tensor(2.7473) avg txt val loss: tensor(2.1232)
epoch: 147 train_loss: tensor(0.8209, grad_fn=<DivBackward0>) average train loss tensor(0.9223, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5021, grad_fn=<DivBackward0>) avg img loss: tensor(0.8226, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4422, grad_fn=<DivBackward0>)
val common acc: 0.5610301692420898 val img acc: 0.40297277409860194 val txt acc: 0.47520235467255334 val_avg_loss: tensor(2.3317)
avg common val loss: tensor(2.1175) avg img val loss: tensor(2.7556) avg txt val loss: tensor(2.1219)
epoch: 148 train_loss: tensor(0.8369, grad_fn=<DivBackward0>) average train loss tensor(0.9106, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4857, grad_fn=<DivBackward0>) avg img loss: tensor(0.8110, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4351, 

val common acc: 0.5571743929359824 val img acc: 0.404532744665195 val txt acc: 0.47426048565121415 val_avg_loss: tensor(2.3531)
avg common val loss: tensor(2.1472) avg img val loss: tensor(2.7813) avg txt val loss: tensor(2.1307)
epoch: 163 train_loss: tensor(0.7906, grad_fn=<DivBackward0>) average train loss tensor(0.8866, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4638, grad_fn=<DivBackward0>) avg img loss: tensor(0.7765, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4195, grad_fn=<DivBackward0>)
val common acc: 0.5574981604120677 val img acc: 0.40206033848417955 val txt acc: 0.47426048565121415 val_avg_loss: tensor(2.3591)
avg common val loss: tensor(2.1544) avg img val loss: tensor(2.7897) avg txt val loss: tensor(2.1332)
epoch: 164 train_loss: tensor(0.7981, grad_fn=<DivBackward0>) average train loss tensor(0.8813, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4579, grad_fn=<DivBackward0>) avg img loss: tensor(0.7733, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4128

val common acc: 0.5585577630610743 val img acc: 0.404532744665195 val txt acc: 0.47420161883738043 val_avg_loss: tensor(2.3838)
avg common val loss: tensor(2.1985) avg img val loss: tensor(2.8126) avg txt val loss: tensor(2.1402)
epoch: 179 train_loss: tensor(0.7507, grad_fn=<DivBackward0>) average train loss tensor(0.8574, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4356, grad_fn=<DivBackward0>) avg img loss: tensor(0.7413, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3953, grad_fn=<DivBackward0>)
val common acc: 0.5594996320824135 val img acc: 0.40426784400294336 val txt acc: 0.47564385577630613 val_avg_loss: tensor(2.3778)
avg common val loss: tensor(2.1972) avg img val loss: tensor(2.7922) avg txt val loss: tensor(2.1441)
epoch: 180 train_loss: tensor(0.7828, grad_fn=<DivBackward0>) average train loss tensor(0.8595, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4362, grad_fn=<DivBackward0>) avg img loss: tensor(0.7438, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3985

val common acc: 0.5561736571008095 val img acc: 0.4023252391464312 val txt acc: 0.4724944812362031 val_avg_loss: tensor(2.4007)
avg common val loss: tensor(2.2263) avg img val loss: tensor(2.8242) avg txt val loss: tensor(2.1516)
epoch: 195 train_loss: tensor(0.7449, grad_fn=<DivBackward0>) average train loss tensor(0.8359, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4136, grad_fn=<DivBackward0>) avg img loss: tensor(0.7174, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3768, grad_fn=<DivBackward0>)
val common acc: 0.5558204562178072 val img acc: 0.40311994113318617 val txt acc: 0.47281824871228845 val_avg_loss: tensor(2.4059)
avg common val loss: tensor(2.2318) avg img val loss: tensor(2.8343) avg txt val loss: tensor(2.1517)
epoch: 196 train_loss: tensor(0.7677, grad_fn=<DivBackward0>) average train loss tensor(0.8379, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4266, grad_fn=<DivBackward0>) avg img loss: tensor(0.7163, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3709

val common acc: 0.5555555555555556 val img acc: 0.4015011037527594 val txt acc: 0.472317880794702 val_avg_loss: tensor(2.4291)
avg common val loss: tensor(2.2654) avg img val loss: tensor(2.8608) avg txt val loss: tensor(2.1610)
epoch: 211 train_loss: tensor(0.7316, grad_fn=<DivBackward0>) average train loss tensor(0.8172, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3933, grad_fn=<DivBackward0>) avg img loss: tensor(0.6972, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3610, grad_fn=<DivBackward0>)
val common acc: 0.55467255334805 val img acc: 0.40297277409860194 val txt acc: 0.4738484179543782 val_avg_loss: tensor(2.4228)
avg common val loss: tensor(2.2595) avg img val loss: tensor(2.8483) avg txt val loss: tensor(2.1606)
epoch: 212 train_loss: tensor(0.7451, grad_fn=<DivBackward0>) average train loss tensor(0.8212, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4030, grad_fn=<DivBackward0>) avg img loss: tensor(0.6962, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3644, gr

val common acc: 0.5543782192788815 val img acc: 0.4007358351729213 val txt acc: 0.4743193524650478 val_avg_loss: tensor(2.4285)
avg common val loss: tensor(2.2633) avg img val loss: tensor(2.8557) avg txt val loss: tensor(2.1665)
epoch: 227 train_loss: tensor(0.7110, grad_fn=<DivBackward0>) average train loss tensor(0.8004, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3818, grad_fn=<DivBackward0>) avg img loss: tensor(0.6752, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3443, grad_fn=<DivBackward0>)
val common acc: 0.5566445916114791 val img acc: 0.4020897718910964 val txt acc: 0.4733186166298749 val_avg_loss: tensor(2.4322)
avg common val loss: tensor(2.2638) avg img val loss: tensor(2.8670) avg txt val loss: tensor(2.1658)
epoch: 228 train_loss: tensor(0.7129, grad_fn=<DivBackward0>) average train loss tensor(0.7950, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3701, grad_fn=<DivBackward0>) avg img loss: tensor(0.6697, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3453, 

val common acc: 0.5531125827814569 val img acc: 0.4004120676968359 val txt acc: 0.4728476821192053 val_avg_loss: tensor(2.4412)
avg common val loss: tensor(2.2835) avg img val loss: tensor(2.8650) avg txt val loss: tensor(2.1751)
epoch: 243 train_loss: tensor(0.7104, grad_fn=<DivBackward0>) average train loss tensor(0.7939, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3748, grad_fn=<DivBackward0>) avg img loss: tensor(0.6671, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3398, grad_fn=<DivBackward0>)
val common acc: 0.5525827814569536 val img acc: 0.3980868285504047 val txt acc: 0.4722295805739514 val_avg_loss: tensor(2.4387)
avg common val loss: tensor(2.2726) avg img val loss: tensor(2.8638) avg txt val loss: tensor(2.1796)
epoch: 244 train_loss: tensor(0.7094, grad_fn=<DivBackward0>) average train loss tensor(0.7918, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3746, grad_fn=<DivBackward0>) avg img loss: tensor(0.6730, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3279, 

val common acc: 0.5507873436350258 val img acc: 0.400794701986755 val txt acc: 0.47299484915378953 val_avg_loss: tensor(2.4590)
avg common val loss: tensor(2.3164) avg img val loss: tensor(2.8829) avg txt val loss: tensor(2.1777)
epoch: 259 train_loss: tensor(0.6885, grad_fn=<DivBackward0>) average train loss tensor(0.7725, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3502, grad_fn=<DivBackward0>) avg img loss: tensor(0.6370, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3303, grad_fn=<DivBackward0>)
val common acc: 0.552111846946284 val img acc: 0.40061810154525385 val txt acc: 0.47299484915378953 val_avg_loss: tensor(2.4658)
avg common val loss: tensor(2.3370) avg img val loss: tensor(2.8802) avg txt val loss: tensor(2.1800)
epoch: 260 train_loss: tensor(0.7148, grad_fn=<DivBackward0>) average train loss tensor(0.7815, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3585, grad_fn=<DivBackward0>) avg img loss: tensor(0.6571, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3290,

val common acc: 0.554289919058131 val img acc: 0.4002649006622517 val txt acc: 0.4717880794701987 val_avg_loss: tensor(2.4653)
avg common val loss: tensor(2.3339) avg img val loss: tensor(2.8785) avg txt val loss: tensor(2.1833)
epoch: 275 train_loss: tensor(0.6891, grad_fn=<DivBackward0>) average train loss tensor(0.7743, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3584, grad_fn=<DivBackward0>) avg img loss: tensor(0.6396, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3249, grad_fn=<DivBackward0>)
val common acc: 0.551523178807947 val img acc: 0.39938189845474614 val txt acc: 0.47152317880794703 val_avg_loss: tensor(2.4582)
avg common val loss: tensor(2.3130) avg img val loss: tensor(2.8776) avg txt val loss: tensor(2.1839)
epoch: 276 train_loss: tensor(0.6869, grad_fn=<DivBackward0>) average train loss tensor(0.7666, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3441, grad_fn=<DivBackward0>) avg img loss: tensor(0.6428, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3128, 

val common acc: 0.5528476821192053 val img acc: 0.4000588668138337 val txt acc: 0.47258278145695365 val_avg_loss: tensor(2.4704)
avg common val loss: tensor(2.3425) avg img val loss: tensor(2.8778) avg txt val loss: tensor(2.1908)
epoch: 291 train_loss: tensor(0.6882, grad_fn=<DivBackward0>) average train loss tensor(0.7545, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3362, grad_fn=<DivBackward0>) avg img loss: tensor(0.6221, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3052, grad_fn=<DivBackward0>)
val common acc: 0.552317880794702 val img acc: 0.3994701986754967 val txt acc: 0.4724650478292862 val_avg_loss: tensor(2.4760)
avg common val loss: tensor(2.3468) avg img val loss: tensor(2.8887) avg txt val loss: tensor(2.1925)
epoch: 292 train_loss: tensor(0.6930, grad_fn=<DivBackward0>) average train loss tensor(0.7546, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3354, grad_fn=<DivBackward0>) avg img loss: tensor(0.6256, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3029, 

In [16]:
train_loader_20k_512 = DataLoader(train_ds_20k, batch_size=512)

model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs512_rs42_d128_wd0005_drop_05_20k')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=100,
    writer=writer,
    train_loader=train_loader_20k_512,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(2.6999, grad_fn=<DivBackward0>) average train loss tensor(3.2891, grad_fn=<DivBackward0>)
avg common loss: tensor(3.4968, grad_fn=<DivBackward0>) avg img loss: tensor(3.0599, grad_fn=<DivBackward0>) avg txt loss: tensor(3.3106, grad_fn=<DivBackward0>)
val common acc: 0.3263576158940397 val img acc: 0.3547608535688006 val txt acc: 0.3514643119941133 val_avg_loss: tensor(2.6120)
avg common val loss: tensor(2.7100) avg img val loss: tensor(2.4784) avg txt val loss: tensor(2.6475)
epoch: 1 train_loss: tensor(1.8353, grad_fn=<DivBackward0>) average train loss tensor(2.4763, grad_fn=<DivBackward0>)
avg common loss: tensor(2.4287, grad_fn=<DivBackward0>) avg img loss: tensor(2.4780, grad_fn=<DivBackward0>) avg txt loss: tensor(2.5221, grad_fn=<DivBackward0>)
val common acc: 0.4801766004415011 val img acc: 0.3881972038263429 val txt acc: 0.4357615894039735 val_avg_loss: tensor(2.1708)
avg common val loss: tensor(1.9600) avg img val loss: tensor(2.3277) avg txt val l

val common acc: 0.5894628403237675 val img acc: 0.41916114790286974 val txt acc: 0.47693892568064755 val_avg_loss: tensor(1.9424)
avg common val loss: tensor(1.5624) avg img val loss: tensor(2.2362) avg txt val loss: tensor(2.0285)
epoch: 17 train_loss: tensor(0.4625, grad_fn=<DivBackward0>) average train loss tensor(1.6532, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3110, grad_fn=<DivBackward0>) avg img loss: tensor(1.8220, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8267, grad_fn=<DivBackward0>)
val common acc: 0.588550404709345 val img acc: 0.41813097866078 val txt acc: 0.4777336276674025 val_avg_loss: tensor(1.9441)
avg common val loss: tensor(1.5653) avg img val loss: tensor(2.2378) avg txt val loss: tensor(2.0292)
epoch: 18 train_loss: tensor(0.4691, grad_fn=<DivBackward0>) average train loss tensor(1.6418, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3026, grad_fn=<DivBackward0>) avg img loss: tensor(1.8061, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8168, gra

epoch: 33 train_loss: tensor(0.3582, grad_fn=<DivBackward0>) average train loss tensor(1.4552, grad_fn=<DivBackward0>)
avg common loss: tensor(1.0843, grad_fn=<DivBackward0>) avg img loss: tensor(1.5845, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6968, grad_fn=<DivBackward0>)
val common acc: 0.5879028697571744 val img acc: 0.41660044150110376 val txt acc: 0.47635025754231053 val_avg_loss: tensor(2.0020)
avg common val loss: tensor(1.6383) avg img val loss: tensor(2.3190) avg txt val loss: tensor(2.0487)
epoch: 34 train_loss: tensor(0.3156, grad_fn=<DivBackward0>) average train loss tensor(1.4408, grad_fn=<DivBackward0>)
avg common loss: tensor(1.0755, grad_fn=<DivBackward0>) avg img loss: tensor(1.5662, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6806, grad_fn=<DivBackward0>)
val common acc: 0.5886387049300956 val img acc: 0.4160412067696836 val txt acc: 0.47799852832965417 val_avg_loss: tensor(1.9976)
avg common val loss: tensor(1.6208) avg img val loss: tensor(2.3215) avg txt 

val common acc: 0.5860485651214128 val img acc: 0.41477557027225903 val txt acc: 0.47555555555555556 val_avg_loss: tensor(2.0461)
avg common val loss: tensor(1.6850) avg img val loss: tensor(2.3797) avg txt val loss: tensor(2.0737)
epoch: 50 train_loss: tensor(0.3282, grad_fn=<DivBackward0>) average train loss tensor(1.3106, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9229, grad_fn=<DivBackward0>) avg img loss: tensor(1.4060, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6028, grad_fn=<DivBackward0>)
val common acc: 0.5855481972038263 val img acc: 0.41247976453274465 val txt acc: 0.47602649006622516 val_avg_loss: tensor(2.0586)
avg common val loss: tensor(1.7104) avg img val loss: tensor(2.3924) avg txt val loss: tensor(2.0729)
epoch: 51 train_loss: tensor(0.2950, grad_fn=<DivBackward0>) average train loss tensor(1.3100, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9283, grad_fn=<DivBackward0>) avg img loss: tensor(1.4051, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5964

epoch: 66 train_loss: tensor(0.3054, grad_fn=<DivBackward0>) average train loss tensor(1.2248, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8359, grad_fn=<DivBackward0>) avg img loss: tensor(1.2984, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5399, grad_fn=<DivBackward0>)
val common acc: 0.5840176600441501 val img acc: 0.41033112582781456 val txt acc: 0.4756144223693893 val_avg_loss: tensor(2.0995)
avg common val loss: tensor(1.7586) avg img val loss: tensor(2.4463) avg txt val loss: tensor(2.0935)
epoch: 67 train_loss: tensor(0.2400, grad_fn=<DivBackward0>) average train loss tensor(1.2126, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8194, grad_fn=<DivBackward0>) avg img loss: tensor(1.2832, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5353, grad_fn=<DivBackward0>)
val common acc: 0.580691685062546 val img acc: 0.4118027961736571 val txt acc: 0.4745253863134658 val_avg_loss: tensor(2.1124)
avg common val loss: tensor(1.7816) avg img val loss: tensor(2.4584) avg txt val

epoch: 82 train_loss: tensor(0.2311, grad_fn=<DivBackward0>) average train loss tensor(1.1451, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7396, grad_fn=<DivBackward0>) avg img loss: tensor(1.2028, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4928, grad_fn=<DivBackward0>)
val common acc: 0.5752170713760117 val img acc: 0.40594554819720385 val txt acc: 0.473701250919794 val_avg_loss: tensor(2.1540)
avg common val loss: tensor(1.8352) avg img val loss: tensor(2.5110) avg txt val loss: tensor(2.1157)
epoch: 83 train_loss: tensor(0.2638, grad_fn=<DivBackward0>) average train loss tensor(1.1357, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7362, grad_fn=<DivBackward0>) avg img loss: tensor(1.1973, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4736, grad_fn=<DivBackward0>)
val common acc: 0.5791611479028698 val img acc: 0.40865342163355406 val txt acc: 0.47364238410596027 val_avg_loss: tensor(2.1480)
avg common val loss: tensor(1.8212) avg img val loss: tensor(2.5047) avg txt v

val common acc: 0.5766887417218544 val img acc: 0.4037674760853569 val txt acc: 0.47305371596762325 val_avg_loss: tensor(2.1803)
avg common val loss: tensor(1.8555) avg img val loss: tensor(2.5492) avg txt val loss: tensor(2.1363)
epoch: 99 train_loss: tensor(0.2378, grad_fn=<DivBackward0>) average train loss tensor(1.0856, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6825, grad_fn=<DivBackward0>) avg img loss: tensor(1.1353, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4391, grad_fn=<DivBackward0>)
val common acc: 0.5776306107431936 val img acc: 0.40479764532744666 val txt acc: 0.4729654157468727 val_avg_loss: tensor(2.1827)
avg common val loss: tensor(1.8583) avg img val loss: tensor(2.5522) avg txt val loss: tensor(2.1376)


In [19]:
model = torch_models.NormModelTridentBN(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bn_bs2048_rs42_d128_wd0005_drop_05_20k')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=300,
    writer=writer,
    train_loader=train_loader_20k,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(3.2411, grad_fn=<DivBackward0>) average train loss tensor(3.6148, grad_fn=<DivBackward0>)
avg common loss: tensor(3.7788, grad_fn=<DivBackward0>) avg img loss: tensor(3.3874, grad_fn=<DivBackward0>) avg txt loss: tensor(3.6781, grad_fn=<DivBackward0>)
val common acc: 0.23532008830022075 val img acc: 0.3222663723325975 val txt acc: 0.3106401766004415 val_avg_loss: tensor(3.3641)
avg common val loss: tensor(3.6734) avg img val loss: tensor(2.9961) avg txt val loss: tensor(3.4228)
epoch: 1 train_loss: tensor(2.7566, grad_fn=<DivBackward0>) average train loss tensor(2.9154, grad_fn=<DivBackward0>)
avg common loss: tensor(3.1108, grad_fn=<DivBackward0>) avg img loss: tensor(2.7306, grad_fn=<DivBackward0>) avg txt loss: tensor(2.9049, grad_fn=<DivBackward0>)
val common acc: 0.38178072111846945 val img acc: 0.3684179543782193 val txt acc: 0.3923767476085357 val_avg_loss: tensor(2.6779)
avg common val loss: tensor(2.7857) avg img val loss: tensor(2.5579) avg txt val

val common acc: 0.5913465783664459 val img acc: 0.41874908020603385 val txt acc: 0.47502575423105226 val_avg_loss: tensor(1.9419)
avg common val loss: tensor(1.5562) avg img val loss: tensor(2.2390) avg txt val loss: tensor(2.0304)
epoch: 17 train_loss: tensor(1.5416, grad_fn=<DivBackward0>) average train loss tensor(1.5977, grad_fn=<DivBackward0>)
avg common loss: tensor(1.2299, grad_fn=<DivBackward0>) avg img loss: tensor(1.7381, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8250, grad_fn=<DivBackward0>)
val common acc: 0.5902869757174393 val img acc: 0.4183075791022811 val txt acc: 0.47723325974981606 val_avg_loss: tensor(1.9455)
avg common val loss: tensor(1.5603) avg img val loss: tensor(2.2472) avg txt val loss: tensor(2.0291)
epoch: 18 train_loss: tensor(1.4828, grad_fn=<DivBackward0>) average train loss tensor(1.5731, grad_fn=<DivBackward0>)
avg common loss: tensor(1.1970, grad_fn=<DivBackward0>) avg img loss: tensor(1.7132, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8090,

epoch: 33 train_loss: tensor(1.1408, grad_fn=<DivBackward0>) average train loss tensor(1.2644, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8424, grad_fn=<DivBackward0>) avg img loss: tensor(1.3421, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6087, grad_fn=<DivBackward0>)
val common acc: 0.5877262693156733 val img acc: 0.4112729948491538 val txt acc: 0.47649742457689476 val_avg_loss: tensor(2.0291)
avg common val loss: tensor(1.6655) avg img val loss: tensor(2.3706) avg txt val loss: tensor(2.0512)
epoch: 34 train_loss: tensor(1.1410, grad_fn=<DivBackward0>) average train loss tensor(1.2407, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8183, grad_fn=<DivBackward0>) avg img loss: tensor(1.3119, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5920, grad_fn=<DivBackward0>)
val common acc: 0.5851655629139073 val img acc: 0.40962472406181016 val txt acc: 0.47511405445180277 val_avg_loss: tensor(2.0390)
avg common val loss: tensor(1.6757) avg img val loss: tensor(2.3862) avg txt 

val common acc: 0.5795143487858719 val img acc: 0.4037674760853569 val txt acc: 0.4732597498160412 val_avg_loss: tensor(2.1344)
avg common val loss: tensor(1.8015) avg img val loss: tensor(2.5129) avg txt val loss: tensor(2.0889)
epoch: 50 train_loss: tensor(0.9210, grad_fn=<DivBackward0>) average train loss tensor(1.0480, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6088, grad_fn=<DivBackward0>) avg img loss: tensor(1.0723, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4628, grad_fn=<DivBackward0>)
val common acc: 0.5785430463576159 val img acc: 0.4052391464311994 val txt acc: 0.47137601177336275 val_avg_loss: tensor(2.1396)
avg common val loss: tensor(1.8083) avg img val loss: tensor(2.5175) avg txt val loss: tensor(2.0931)
epoch: 51 train_loss: tensor(0.9346, grad_fn=<DivBackward0>) average train loss tensor(1.0408, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5992, grad_fn=<DivBackward0>) avg img loss: tensor(1.0692, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4540, g

epoch: 66 train_loss: tensor(0.8242, grad_fn=<DivBackward0>) average train loss tensor(0.9160, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4707, grad_fn=<DivBackward0>) avg img loss: tensor(0.9040, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3734, grad_fn=<DivBackward0>)
val common acc: 0.5752170713760117 val img acc: 0.40076526857983813 val txt acc: 0.47061074319352464 val_avg_loss: tensor(2.2134)
avg common val loss: tensor(1.8978) avg img val loss: tensor(2.6177) avg txt val loss: tensor(2.1247)
epoch: 67 train_loss: tensor(0.8052, grad_fn=<DivBackward0>) average train loss tensor(0.9083, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4665, grad_fn=<DivBackward0>) avg img loss: tensor(0.8925, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3659, grad_fn=<DivBackward0>)
val common acc: 0.5721854304635762 val img acc: 0.3988815305371597 val txt acc: 0.4692568064753495 val_avg_loss: tensor(2.2206)
avg common val loss: tensor(1.9084) avg img val loss: tensor(2.6262) avg txt v

val common acc: 0.5698307579102281 val img acc: 0.3976747608535688 val txt acc: 0.46922737306843265 val_avg_loss: tensor(2.2723)
avg common val loss: tensor(1.9682) avg img val loss: tensor(2.6903) avg txt val loss: tensor(2.1584)
epoch: 83 train_loss: tensor(0.7307, grad_fn=<DivBackward0>) average train loss tensor(0.8266, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3858, grad_fn=<DivBackward0>) avg img loss: tensor(0.7893, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3046, grad_fn=<DivBackward0>)
val common acc: 0.5715378955114054 val img acc: 0.398027961736571 val txt acc: 0.4693156732891832 val_avg_loss: tensor(2.2762)
avg common val loss: tensor(1.9710) avg img val loss: tensor(2.6988) avg txt val loss: tensor(2.1587)
epoch: 84 train_loss: tensor(0.7401, grad_fn=<DivBackward0>) average train loss tensor(0.8177, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3774, grad_fn=<DivBackward0>) avg img loss: tensor(0.7809, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2947, gr

epoch: 99 train_loss: tensor(0.7023, grad_fn=<DivBackward0>) average train loss tensor(0.7693, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3369, grad_fn=<DivBackward0>) avg img loss: tensor(0.7207, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2502, grad_fn=<DivBackward0>)
val common acc: 0.5707137601177337 val img acc: 0.3998233995584989 val txt acc: 0.4647534952170714 val_avg_loss: tensor(2.3179)
avg common val loss: tensor(2.0132) avg img val loss: tensor(2.7464) avg txt val loss: tensor(2.1942)
epoch: 100 train_loss: tensor(0.6741, grad_fn=<DivBackward0>) average train loss tensor(0.7573, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3288, grad_fn=<DivBackward0>) avg img loss: tensor(0.7026, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2403, grad_fn=<DivBackward0>)
val common acc: 0.5668874172185431 val img acc: 0.40038263428991905 val txt acc: 0.46622516556291393 val_avg_loss: tensor(2.3185)
avg common val loss: tensor(2.0163) avg img val loss: tensor(2.7505) avg txt 

epoch: 115 train_loss: tensor(0.6537, grad_fn=<DivBackward0>) average train loss tensor(0.7254, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2963, grad_fn=<DivBackward0>) avg img loss: tensor(0.6699, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2100, grad_fn=<DivBackward0>)
val common acc: 0.5671228844738778 val img acc: 0.3965857247976453 val txt acc: 0.46578366445916114 val_avg_loss: tensor(2.3521)
avg common val loss: tensor(2.0557) avg img val loss: tensor(2.7907) avg txt val loss: tensor(2.2100)
epoch: 116 train_loss: tensor(0.6454, grad_fn=<DivBackward0>) average train loss tensor(0.7241, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2973, grad_fn=<DivBackward0>) avg img loss: tensor(0.6684, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2067, grad_fn=<DivBackward0>)
val common acc: 0.5674172185430464 val img acc: 0.3973509933774834 val txt acc: 0.46554819720382634 val_avg_loss: tensor(2.3504)
avg common val loss: tensor(2.0522) avg img val loss: tensor(2.7855) avg txt

epoch: 131 train_loss: tensor(0.6213, grad_fn=<DivBackward0>) average train loss tensor(0.6928, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2713, grad_fn=<DivBackward0>) avg img loss: tensor(0.6239, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1831, grad_fn=<DivBackward0>)
val common acc: 0.565504047093451 val img acc: 0.394878587196468 val txt acc: 0.4640176600441501 val_avg_loss: tensor(2.3740)
avg common val loss: tensor(2.0808) avg img val loss: tensor(2.8107) avg txt val loss: tensor(2.2304)
epoch: 132 train_loss: tensor(0.6316, grad_fn=<DivBackward0>) average train loss tensor(0.6867, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2670, grad_fn=<DivBackward0>) avg img loss: tensor(0.6176, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1754, grad_fn=<DivBackward0>)
val common acc: 0.564738778513613 val img acc: 0.39614422369389257 val txt acc: 0.4628108903605592 val_avg_loss: tensor(2.3793)
avg common val loss: tensor(2.0905) avg img val loss: tensor(2.8153) avg txt val

epoch: 147 train_loss: tensor(0.6144, grad_fn=<DivBackward0>) average train loss tensor(0.6679, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2520, grad_fn=<DivBackward0>) avg img loss: tensor(0.6011, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1508, grad_fn=<DivBackward0>)
val common acc: 0.5632671081677704 val img acc: 0.3950551876379691 val txt acc: 0.4632818248712288 val_avg_loss: tensor(2.3852)
avg common val loss: tensor(2.0934) avg img val loss: tensor(2.8129) avg txt val loss: tensor(2.2492)
epoch: 148 train_loss: tensor(0.6062, grad_fn=<DivBackward0>) average train loss tensor(0.6746, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2593, grad_fn=<DivBackward0>) avg img loss: tensor(0.6118, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1526, grad_fn=<DivBackward0>)
val common acc: 0.5623841059602649 val img acc: 0.3945253863134658 val txt acc: 0.46286975717439294 val_avg_loss: tensor(2.3867)
avg common val loss: tensor(2.0874) avg img val loss: tensor(2.8215) avg txt 

epoch: 163 train_loss: tensor(0.5822, grad_fn=<DivBackward0>) average train loss tensor(0.6453, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2344, grad_fn=<DivBackward0>) avg img loss: tensor(0.5806, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1208, grad_fn=<DivBackward0>)
val common acc: 0.5625312729948492 val img acc: 0.39670345842531274 val txt acc: 0.45839587932303166 val_avg_loss: tensor(2.4043)
avg common val loss: tensor(2.1110) avg img val loss: tensor(2.8366) avg txt val loss: tensor(2.2653)
epoch: 164 train_loss: tensor(0.5887, grad_fn=<DivBackward0>) average train loss tensor(0.6482, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2350, grad_fn=<DivBackward0>) avg img loss: tensor(0.5734, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1363, grad_fn=<DivBackward0>)
val common acc: 0.5617365710080942 val img acc: 0.3968800588668138 val txt acc: 0.45951434878587194 val_avg_loss: tensor(2.4143)
avg common val loss: tensor(2.1319) avg img val loss: tensor(2.8430) avg tx

epoch: 179 train_loss: tensor(0.5670, grad_fn=<DivBackward0>) average train loss tensor(0.6333, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2249, grad_fn=<DivBackward0>) avg img loss: tensor(0.5607, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1144, grad_fn=<DivBackward0>)
val common acc: 0.5618543046357616 val img acc: 0.39370125091979397 val txt acc: 0.4617512877115526 val_avg_loss: tensor(2.4225)
avg common val loss: tensor(2.1324) avg img val loss: tensor(2.8556) avg txt val loss: tensor(2.2795)
epoch: 180 train_loss: tensor(0.5708, grad_fn=<DivBackward0>) average train loss tensor(0.6296, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2214, grad_fn=<DivBackward0>) avg img loss: tensor(0.5588, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1086, grad_fn=<DivBackward0>)
val common acc: 0.5612656364974246 val img acc: 0.3934363502575423 val txt acc: 0.4603679175864606 val_avg_loss: tensor(2.4190)
avg common val loss: tensor(2.1287) avg img val loss: tensor(2.8479) avg txt 

epoch: 195 train_loss: tensor(0.5766, grad_fn=<DivBackward0>) average train loss tensor(0.6203, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2181, grad_fn=<DivBackward0>) avg img loss: tensor(0.5495, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0934, grad_fn=<DivBackward0>)
val common acc: 0.5627078734363502 val img acc: 0.39287711552612214 val txt acc: 0.4596615158204562 val_avg_loss: tensor(2.4259)
avg common val loss: tensor(2.1373) avg img val loss: tensor(2.8476) avg txt val loss: tensor(2.2928)
epoch: 196 train_loss: tensor(0.5572, grad_fn=<DivBackward0>) average train loss tensor(0.6250, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2212, grad_fn=<DivBackward0>) avg img loss: tensor(0.5596, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0944, grad_fn=<DivBackward0>)
val common acc: 0.5616777041942604 val img acc: 0.3938189845474614 val txt acc: 0.4597498160412068 val_avg_loss: tensor(2.4321)
avg common val loss: tensor(2.1449) avg img val loss: tensor(2.8614) avg txt 

epoch: 211 train_loss: tensor(0.5421, grad_fn=<DivBackward0>) average train loss tensor(0.6015, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1993, grad_fn=<DivBackward0>) avg img loss: tensor(0.5260, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0793, grad_fn=<DivBackward0>)
val common acc: 0.5632671081677704 val img acc: 0.3964091243561442 val txt acc: 0.4594554819720383 val_avg_loss: tensor(2.4372)
avg common val loss: tensor(2.1610) avg img val loss: tensor(2.8527) avg txt val loss: tensor(2.2981)
epoch: 212 train_loss: tensor(0.5338, grad_fn=<DivBackward0>) average train loss tensor(0.6006, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2009, grad_fn=<DivBackward0>) avg img loss: tensor(0.5273, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0738, grad_fn=<DivBackward0>)
val common acc: 0.5602649006622517 val img acc: 0.39455481972038264 val txt acc: 0.4590139808682855 val_avg_loss: tensor(2.4359)
avg common val loss: tensor(2.1507) avg img val loss: tensor(2.8579) avg txt 

epoch: 227 train_loss: tensor(0.5382, grad_fn=<DivBackward0>) average train loss tensor(0.5987, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1998, grad_fn=<DivBackward0>) avg img loss: tensor(0.5249, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0714, grad_fn=<DivBackward0>)
val common acc: 0.5613539367181751 val img acc: 0.3929654157468727 val txt acc: 0.459896983075791 val_avg_loss: tensor(2.4327)
avg common val loss: tensor(2.1459) avg img val loss: tensor(2.8505) avg txt val loss: tensor(2.3018)
epoch: 228 train_loss: tensor(0.5399, grad_fn=<DivBackward0>) average train loss tensor(0.5990, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1985, grad_fn=<DivBackward0>) avg img loss: tensor(0.5339, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0646, grad_fn=<DivBackward0>)
val common acc: 0.5599705665930832 val img acc: 0.3949668874172185 val txt acc: 0.4585430463576159 val_avg_loss: tensor(2.4428)
avg common val loss: tensor(2.1494) avg img val loss: tensor(2.8711) avg txt va

epoch: 243 train_loss: tensor(0.5352, grad_fn=<DivBackward0>) average train loss tensor(0.5928, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1969, grad_fn=<DivBackward0>) avg img loss: tensor(0.5248, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0566, grad_fn=<DivBackward0>)
val common acc: 0.5600883002207505 val img acc: 0.393495217071376 val txt acc: 0.4582192788815305 val_avg_loss: tensor(2.4504)
avg common val loss: tensor(2.1647) avg img val loss: tensor(2.8660) avg txt val loss: tensor(2.3206)
epoch: 244 train_loss: tensor(0.5370, grad_fn=<DivBackward0>) average train loss tensor(0.5919, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1963, grad_fn=<DivBackward0>) avg img loss: tensor(0.5195, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0598, grad_fn=<DivBackward0>)
val common acc: 0.5634731420161884 val img acc: 0.3952906548933039 val txt acc: 0.45736571008094185 val_avg_loss: tensor(2.4512)
avg common val loss: tensor(2.1681) avg img val loss: tensor(2.8621) avg txt v

val common acc: 0.5624135393671817 val img acc: 0.3966740250183959 val txt acc: 0.4563061074319352 val_avg_loss: tensor(2.4498)
avg common val loss: tensor(2.1705) avg img val loss: tensor(2.8537) avg txt val loss: tensor(2.3253)
epoch: 260 train_loss: tensor(0.5281, grad_fn=<DivBackward0>) average train loss tensor(0.5790, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1894, grad_fn=<DivBackward0>) avg img loss: tensor(0.5080, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0395, grad_fn=<DivBackward0>)
val common acc: 0.5636203090507726 val img acc: 0.3968800588668138 val txt acc: 0.45574687270051506 val_avg_loss: tensor(2.4593)
avg common val loss: tensor(2.1779) avg img val loss: tensor(2.8684) avg txt val loss: tensor(2.3315)
epoch: 261 train_loss: tensor(0.5198, grad_fn=<DivBackward0>) average train loss tensor(0.5791, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1875, grad_fn=<DivBackward0>) avg img loss: tensor(0.5051, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0448,

val common acc: 0.5615599705665931 val img acc: 0.3958793230316409 val txt acc: 0.45624724061810157 val_avg_loss: tensor(2.4624)
avg common val loss: tensor(2.1857) avg img val loss: tensor(2.8591) avg txt val loss: tensor(2.3425)
epoch: 276 train_loss: tensor(0.5212, grad_fn=<DivBackward0>) average train loss tensor(0.5801, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1894, grad_fn=<DivBackward0>) avg img loss: tensor(0.5195, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0315, grad_fn=<DivBackward0>)
val common acc: 0.5603826342899191 val img acc: 0.39788079470198673 val txt acc: 0.4541280353200883 val_avg_loss: tensor(2.4640)
avg common val loss: tensor(2.1903) avg img val loss: tensor(2.8597) avg txt val loss: tensor(2.3421)
epoch: 277 train_loss: tensor(0.5144, grad_fn=<DivBackward0>) average train loss tensor(0.5745, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1824, grad_fn=<DivBackward0>) avg img loss: tensor(0.5073, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0337

val common acc: 0.5634142752023547 val img acc: 0.3944370860927152 val txt acc: 0.4576011773362767 val_avg_loss: tensor(2.4530)
avg common val loss: tensor(2.1617) avg img val loss: tensor(2.8536) avg txt val loss: tensor(2.3437)
epoch: 292 train_loss: tensor(0.5220, grad_fn=<DivBackward0>) average train loss tensor(0.5695, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1752, grad_fn=<DivBackward0>) avg img loss: tensor(0.5083, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0250, grad_fn=<DivBackward0>)
val common acc: 0.5634142752023547 val img acc: 0.3939072847682119 val txt acc: 0.45742457689477556 val_avg_loss: tensor(2.4593)
avg common val loss: tensor(2.1752) avg img val loss: tensor(2.8555) avg txt val loss: tensor(2.3473)
epoch: 293 train_loss: tensor(0.5090, grad_fn=<DivBackward0>) average train loss tensor(0.5750, grad_fn=<DivBackward0>)
avg common loss: tensor(0.1814, grad_fn=<DivBackward0>) avg img loss: tensor(0.5085, grad_fn=<DivBackward0>) avg txt loss: tensor(1.0352,

In [18]:
train_loader_20k_512 = DataLoader(train_ds_20k, batch_size=512)

model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs512_rs42_d128_wd0005_drop_05_20k_300')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=300,
    writer=writer,
    train_loader=train_loader_20k_512,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(2.5762, grad_fn=<DivBackward0>) average train loss tensor(3.2758, grad_fn=<DivBackward0>)
avg common loss: tensor(3.4849, grad_fn=<DivBackward0>) avg img loss: tensor(3.0421, grad_fn=<DivBackward0>) avg txt loss: tensor(3.3003, grad_fn=<DivBackward0>)
val common acc: 0.32891832229580575 val img acc: 0.3604120676968359 val txt acc: 0.34425312729948493 val_avg_loss: tensor(2.5999)
avg common val loss: tensor(2.6964) avg img val loss: tensor(2.4644) avg txt val loss: tensor(2.6388)
epoch: 1 train_loss: tensor(1.8442, grad_fn=<DivBackward0>) average train loss tensor(2.4874, grad_fn=<DivBackward0>)
avg common loss: tensor(2.4546, grad_fn=<DivBackward0>) avg img loss: tensor(2.4795, grad_fn=<DivBackward0>) avg txt loss: tensor(2.5280, grad_fn=<DivBackward0>)
val common acc: 0.46334069168506253 val img acc: 0.38710816777041945 val txt acc: 0.4329654157468727 val_avg_loss: tensor(2.1802)
avg common val loss: tensor(1.9863) avg img val loss: tensor(2.3250) avg txt v

val common acc: 0.5913465783664459 val img acc: 0.41792494481236203 val txt acc: 0.4769683590875644 val_avg_loss: tensor(1.9376)
avg common val loss: tensor(1.5555) avg img val loss: tensor(2.2334) avg txt val loss: tensor(2.0238)
epoch: 17 train_loss: tensor(0.5598, grad_fn=<DivBackward0>) average train loss tensor(1.6529, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3121, grad_fn=<DivBackward0>) avg img loss: tensor(1.8187, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8279, grad_fn=<DivBackward0>)
val common acc: 0.5912288447387785 val img acc: 0.4193966151582046 val txt acc: 0.47967623252391467 val_avg_loss: tensor(1.9445)
avg common val loss: tensor(1.5721) avg img val loss: tensor(2.2407) avg txt val loss: tensor(2.0206)
epoch: 18 train_loss: tensor(0.4667, grad_fn=<DivBackward0>) average train loss tensor(1.6445, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3072, grad_fn=<DivBackward0>) avg img loss: tensor(1.8053, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8210, 

epoch: 33 train_loss: tensor(0.3492, grad_fn=<DivBackward0>) average train loss tensor(1.4488, grad_fn=<DivBackward0>)
avg common loss: tensor(1.0846, grad_fn=<DivBackward0>) avg img loss: tensor(1.5770, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6848, grad_fn=<DivBackward0>)
val common acc: 0.5898160412067697 val img acc: 0.4163355408388521 val txt acc: 0.4784988962472406 val_avg_loss: tensor(1.9863)
avg common val loss: tensor(1.6133) avg img val loss: tensor(2.3022) avg txt val loss: tensor(2.0433)
epoch: 34 train_loss: tensor(0.3304, grad_fn=<DivBackward0>) average train loss tensor(1.4393, grad_fn=<DivBackward0>)
avg common loss: tensor(1.0697, grad_fn=<DivBackward0>) avg img loss: tensor(1.5675, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6806, grad_fn=<DivBackward0>)
val common acc: 0.5908756438557763 val img acc: 0.41677704194260484 val txt acc: 0.4785871964679912 val_avg_loss: tensor(1.9927)
avg common val loss: tensor(1.6216) avg img val loss: tensor(2.3137) avg txt va

val common acc: 0.5841648270787344 val img acc: 0.4134805003679176 val txt acc: 0.47711552612214864 val_avg_loss: tensor(2.0561)
avg common val loss: tensor(1.6961) avg img val loss: tensor(2.4055) avg txt val loss: tensor(2.0667)
epoch: 50 train_loss: tensor(0.3342, grad_fn=<DivBackward0>) average train loss tensor(1.3118, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9320, grad_fn=<DivBackward0>) avg img loss: tensor(1.4085, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5947, grad_fn=<DivBackward0>)
val common acc: 0.5850772626931567 val img acc: 0.4121265636497425 val txt acc: 0.47799852832965417 val_avg_loss: tensor(2.0516)
avg common val loss: tensor(1.6951) avg img val loss: tensor(2.3958) avg txt val loss: tensor(2.0640)
epoch: 51 train_loss: tensor(0.2242, grad_fn=<DivBackward0>) average train loss tensor(1.3061, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9282, grad_fn=<DivBackward0>) avg img loss: tensor(1.3973, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5929, 

epoch: 66 train_loss: tensor(0.3313, grad_fn=<DivBackward0>) average train loss tensor(1.2213, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8392, grad_fn=<DivBackward0>) avg img loss: tensor(1.2959, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5289, grad_fn=<DivBackward0>)
val common acc: 0.5838704930095658 val img acc: 0.4093598233995585 val txt acc: 0.4750551876379691 val_avg_loss: tensor(2.0932)
avg common val loss: tensor(1.7429) avg img val loss: tensor(2.4493) avg txt val loss: tensor(2.0873)
epoch: 67 train_loss: tensor(0.2159, grad_fn=<DivBackward0>) average train loss tensor(1.2246, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8374, grad_fn=<DivBackward0>) avg img loss: tensor(1.3048, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5317, grad_fn=<DivBackward0>)
val common acc: 0.5781898454746137 val img acc: 0.40862398822663726 val txt acc: 0.4746136865342163 val_avg_loss: tensor(2.1013)
avg common val loss: tensor(1.7636) avg img val loss: tensor(2.4564) avg txt va

val common acc: 0.5791905813097866 val img acc: 0.40556291390728477 val txt acc: 0.47587932303164093 val_avg_loss: tensor(2.1404)
avg common val loss: tensor(1.8053) avg img val loss: tensor(2.5125) avg txt val loss: tensor(2.1034)
epoch: 83 train_loss: tensor(0.2810, grad_fn=<DivBackward0>) average train loss tensor(1.1423, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7479, grad_fn=<DivBackward0>) avg img loss: tensor(1.1881, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4909, grad_fn=<DivBackward0>)
val common acc: 0.5772774098601913 val img acc: 0.4074466519499632 val txt acc: 0.47655629139072847 val_avg_loss: tensor(2.1494)
avg common val loss: tensor(1.8238) avg img val loss: tensor(2.5226) avg txt val loss: tensor(2.1018)
epoch: 84 train_loss: tensor(0.2444, grad_fn=<DivBackward0>) average train loss tensor(1.1444, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7510, grad_fn=<DivBackward0>) avg img loss: tensor(1.1959, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4864,

epoch: 99 train_loss: tensor(0.2550, grad_fn=<DivBackward0>) average train loss tensor(1.0947, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6980, grad_fn=<DivBackward0>) avg img loss: tensor(1.1414, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4447, grad_fn=<DivBackward0>)
val common acc: 0.5745106696100074 val img acc: 0.4036497424576895 val txt acc: 0.4727299484915379 val_avg_loss: tensor(2.1773)
avg common val loss: tensor(1.8625) avg img val loss: tensor(2.5484) avg txt val loss: tensor(2.1209)
epoch: 100 train_loss: tensor(0.2856, grad_fn=<DivBackward0>) average train loss tensor(1.1028, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7061, grad_fn=<DivBackward0>) avg img loss: tensor(1.1400, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4625, grad_fn=<DivBackward0>)
val common acc: 0.5740691685062546 val img acc: 0.4033848417954378 val txt acc: 0.4739072847682119 val_avg_loss: tensor(2.1694)
avg common val loss: tensor(1.8446) avg img val loss: tensor(2.5472) avg txt va

epoch: 115 train_loss: tensor(0.2824, grad_fn=<DivBackward0>) average train loss tensor(1.0483, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6494, grad_fn=<DivBackward0>) avg img loss: tensor(1.0669, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4286, grad_fn=<DivBackward0>)
val common acc: 0.5742163355408388 val img acc: 0.40244297277409863 val txt acc: 0.4726710816777042 val_avg_loss: tensor(2.2091)
avg common val loss: tensor(1.9027) avg img val loss: tensor(2.5955) avg txt val loss: tensor(2.1291)
epoch: 116 train_loss: tensor(0.2563, grad_fn=<DivBackward0>) average train loss tensor(1.0428, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6452, grad_fn=<DivBackward0>) avg img loss: tensor(1.0618, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4213, grad_fn=<DivBackward0>)
val common acc: 0.5749227373068433 val img acc: 0.4032671081677704 val txt acc: 0.47175864606328183 val_avg_loss: tensor(2.2099)
avg common val loss: tensor(1.9013) avg img val loss: tensor(2.5905) avg txt

epoch: 131 train_loss: tensor(0.2556, grad_fn=<DivBackward0>) average train loss tensor(1.0117, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6018, grad_fn=<DivBackward0>) avg img loss: tensor(1.0308, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4026, grad_fn=<DivBackward0>)
val common acc: 0.5734805003679175 val img acc: 0.40456217807211187 val txt acc: 0.4719941133186166 val_avg_loss: tensor(2.2452)
avg common val loss: tensor(1.9608) avg img val loss: tensor(2.6255) avg txt val loss: tensor(2.1493)
epoch: 132 train_loss: tensor(0.2493, grad_fn=<DivBackward0>) average train loss tensor(1.0137, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6069, grad_fn=<DivBackward0>) avg img loss: tensor(1.0317, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4025, grad_fn=<DivBackward0>)
val common acc: 0.5737748344370861 val img acc: 0.40288447387785137 val txt acc: 0.4716997792494481 val_avg_loss: tensor(2.2367)
avg common val loss: tensor(1.9490) avg img val loss: tensor(2.6177) avg txt

epoch: 147 train_loss: tensor(0.2827, grad_fn=<DivBackward0>) average train loss tensor(0.9820, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5701, grad_fn=<DivBackward0>) avg img loss: tensor(0.9855, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3904, grad_fn=<DivBackward0>)
val common acc: 0.5725386313465783 val img acc: 0.405916114790287 val txt acc: 0.47299484915378953 val_avg_loss: tensor(2.2659)
avg common val loss: tensor(1.9886) avg img val loss: tensor(2.6565) avg txt val loss: tensor(2.1527)
epoch: 148 train_loss: tensor(0.2225, grad_fn=<DivBackward0>) average train loss tensor(0.9914, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5921, grad_fn=<DivBackward0>) avg img loss: tensor(1.0020, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3802, grad_fn=<DivBackward0>)
val common acc: 0.5702722590139808 val img acc: 0.4011184694628403 val txt acc: 0.4714643119941133 val_avg_loss: tensor(2.2568)
avg common val loss: tensor(1.9651) avg img val loss: tensor(2.6443) avg txt v

epoch: 163 train_loss: tensor(0.1932, grad_fn=<DivBackward0>) average train loss tensor(0.9710, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5629, grad_fn=<DivBackward0>) avg img loss: tensor(0.9790, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3710, grad_fn=<DivBackward0>)
val common acc: 0.5707431935246505 val img acc: 0.40229580573951434 val txt acc: 0.4729359823399559 val_avg_loss: tensor(2.2617)
avg common val loss: tensor(1.9664) avg img val loss: tensor(2.6615) avg txt val loss: tensor(2.1571)
epoch: 164 train_loss: tensor(0.2295, grad_fn=<DivBackward0>) average train loss tensor(0.9591, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5411, grad_fn=<DivBackward0>) avg img loss: tensor(0.9642, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3719, grad_fn=<DivBackward0>)
val common acc: 0.5707726269315674 val img acc: 0.40247240618101543 val txt acc: 0.4728476821192053 val_avg_loss: tensor(2.2916)
avg common val loss: tensor(2.0363) avg img val loss: tensor(2.6777) avg txt

epoch: 179 train_loss: tensor(0.2015, grad_fn=<DivBackward0>) average train loss tensor(0.9447, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5358, grad_fn=<DivBackward0>) avg img loss: tensor(0.9406, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3578, grad_fn=<DivBackward0>)
val common acc: 0.56906548933039 val img acc: 0.40188373804267846 val txt acc: 0.47061074319352464 val_avg_loss: tensor(2.2860)
avg common val loss: tensor(2.0041) avg img val loss: tensor(2.6784) avg txt val loss: tensor(2.1754)
epoch: 180 train_loss: tensor(0.2300, grad_fn=<DivBackward0>) average train loss tensor(0.9447, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5368, grad_fn=<DivBackward0>) avg img loss: tensor(0.9415, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3557, grad_fn=<DivBackward0>)
val common acc: 0.5695364238410596 val img acc: 0.4020897718910964 val txt acc: 0.4701692420897719 val_avg_loss: tensor(2.2894)
avg common val loss: tensor(2.0058) avg img val loss: tensor(2.6923) avg txt v

epoch: 195 train_loss: tensor(0.2103, grad_fn=<DivBackward0>) average train loss tensor(0.9197, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5038, grad_fn=<DivBackward0>) avg img loss: tensor(0.9102, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3451, grad_fn=<DivBackward0>)
val common acc: 0.5661810154525386 val img acc: 0.4027961736571008 val txt acc: 0.469168506254599 val_avg_loss: tensor(2.3137)
avg common val loss: tensor(2.0614) avg img val loss: tensor(2.6994) avg txt val loss: tensor(2.1802)
epoch: 196 train_loss: tensor(0.2140, grad_fn=<DivBackward0>) average train loss tensor(0.9200, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5059, grad_fn=<DivBackward0>) avg img loss: tensor(0.9078, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3464, grad_fn=<DivBackward0>)
val common acc: 0.566298749080206 val img acc: 0.4001766004415011 val txt acc: 0.4699926416482708 val_avg_loss: tensor(2.3140)
avg common val loss: tensor(2.0517) avg img val loss: tensor(2.7096) avg txt val

val common acc: 0.5677115526122148 val img acc: 0.4005003679175865 val txt acc: 0.47108167770419423 val_avg_loss: tensor(2.3147)
avg common val loss: tensor(2.0536) avg img val loss: tensor(2.7018) avg txt val loss: tensor(2.1886)
epoch: 212 train_loss: tensor(0.2055, grad_fn=<DivBackward0>) average train loss tensor(0.9142, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5086, grad_fn=<DivBackward0>) avg img loss: tensor(0.8987, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3354, grad_fn=<DivBackward0>)
val common acc: 0.5672111846946284 val img acc: 0.40264900662251657 val txt acc: 0.4706990434142752 val_avg_loss: tensor(2.3008)
avg common val loss: tensor(2.0202) avg img val loss: tensor(2.6944) avg txt val loss: tensor(2.1878)
epoch: 213 train_loss: tensor(0.2334, grad_fn=<DivBackward0>) average train loss tensor(0.9134, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5008, grad_fn=<DivBackward0>) avg img loss: tensor(0.8988, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3406

val common acc: 0.5660338484179543 val img acc: 0.4016482707873436 val txt acc: 0.4697866077998528 val_avg_loss: tensor(2.2950)
avg common val loss: tensor(2.0089) avg img val loss: tensor(2.6905) avg txt val loss: tensor(2.1854)
epoch: 228 train_loss: tensor(0.2507, grad_fn=<DivBackward0>) average train loss tensor(0.9075, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4945, grad_fn=<DivBackward0>) avg img loss: tensor(0.9049, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3231, grad_fn=<DivBackward0>)
val common acc: 0.5692715231788079 val img acc: 0.40153053715967624 val txt acc: 0.4681089036055924 val_avg_loss: tensor(2.3179)
avg common val loss: tensor(2.0619) avg img val loss: tensor(2.7002) avg txt val loss: tensor(2.1917)
epoch: 229 train_loss: tensor(0.2407, grad_fn=<DivBackward0>) average train loss tensor(0.8891, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4737, grad_fn=<DivBackward0>) avg img loss: tensor(0.8781, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3155,

val common acc: 0.5663870493009566 val img acc: 0.3991169977924945 val txt acc: 0.47090507726269315 val_avg_loss: tensor(2.3354)
avg common val loss: tensor(2.0875) avg img val loss: tensor(2.7184) avg txt val loss: tensor(2.2002)
epoch: 244 train_loss: tensor(0.2237, grad_fn=<DivBackward0>) average train loss tensor(0.8925, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4793, grad_fn=<DivBackward0>) avg img loss: tensor(0.8787, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3194, grad_fn=<DivBackward0>)
val common acc: 0.5638557763061074 val img acc: 0.4034731420161884 val txt acc: 0.4701103752759382 val_avg_loss: tensor(2.3503)
avg common val loss: tensor(2.1124) avg img val loss: tensor(2.7383) avg txt val loss: tensor(2.2001)
epoch: 245 train_loss: tensor(0.2634, grad_fn=<DivBackward0>) average train loss tensor(0.8925, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4723, grad_fn=<DivBackward0>) avg img loss: tensor(0.8786, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3266,

val common acc: 0.564150110375276 val img acc: 0.3982045621780721 val txt acc: 0.46899190581309785 val_avg_loss: tensor(2.3322)
avg common val loss: tensor(2.0840) avg img val loss: tensor(2.7124) avg txt val loss: tensor(2.2002)
epoch: 260 train_loss: tensor(0.1990, grad_fn=<DivBackward0>) average train loss tensor(0.8829, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4724, grad_fn=<DivBackward0>) avg img loss: tensor(0.8699, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3062, grad_fn=<DivBackward0>)
val common acc: 0.5654157468727005 val img acc: 0.4008830022075055 val txt acc: 0.469551140544518 val_avg_loss: tensor(2.3387)
avg common val loss: tensor(2.0942) avg img val loss: tensor(2.7202) avg txt val loss: tensor(2.2018)
epoch: 261 train_loss: tensor(0.2118, grad_fn=<DivBackward0>) average train loss tensor(0.8752, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4611, grad_fn=<DivBackward0>) avg img loss: tensor(0.8603, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3042, g

val common acc: 0.5663576158940398 val img acc: 0.4010890360559235 val txt acc: 0.47028697571743927 val_avg_loss: tensor(2.3384)
avg common val loss: tensor(2.0766) avg img val loss: tensor(2.7291) avg txt val loss: tensor(2.2095)
epoch: 276 train_loss: tensor(0.2083, grad_fn=<DivBackward0>) average train loss tensor(0.8707, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4597, grad_fn=<DivBackward0>) avg img loss: tensor(0.8516, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3008, grad_fn=<DivBackward0>)
val common acc: 0.5638852097130242 val img acc: 0.4007358351729213 val txt acc: 0.46869757174392934 val_avg_loss: tensor(2.3453)
avg common val loss: tensor(2.0991) avg img val loss: tensor(2.7252) avg txt val loss: tensor(2.2115)
epoch: 277 train_loss: tensor(0.2494, grad_fn=<DivBackward0>) average train loss tensor(0.8776, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4622, grad_fn=<DivBackward0>) avg img loss: tensor(0.8699, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3006

val common acc: 0.5628844738778513 val img acc: 0.39952906548933037 val txt acc: 0.46913907284768214 val_avg_loss: tensor(2.3520)
avg common val loss: tensor(2.1171) avg img val loss: tensor(2.7265) avg txt val loss: tensor(2.2125)
epoch: 292 train_loss: tensor(0.2610, grad_fn=<DivBackward0>) average train loss tensor(0.8743, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4595, grad_fn=<DivBackward0>) avg img loss: tensor(0.8599, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3034, grad_fn=<DivBackward0>)
val common acc: 0.5639146431199411 val img acc: 0.40032376747608533 val txt acc: 0.46961000735835173 val_avg_loss: tensor(2.3596)
avg common val loss: tensor(2.1260) avg img val loss: tensor(2.7447) avg txt val loss: tensor(2.2080)
epoch: 293 train_loss: tensor(0.2233, grad_fn=<DivBackward0>) average train loss tensor(0.8691, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4565, grad_fn=<DivBackward0>) avg img loss: tensor(0.8583, grad_fn=<DivBackward0>) avg txt loss: tensor(1.29

In [12]:
model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs2048_rs42_d128_wd0005_drop_05_te')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=60,
    writer=writer,
    train_loader=train_loader,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(2.1974, grad_fn=<DivBackward0>) average train loss tensor(2.7907, grad_fn=<DivBackward0>)
avg common loss: tensor(2.8358, grad_fn=<DivBackward0>) avg img loss: tensor(2.6939, grad_fn=<DivBackward0>) avg txt loss: tensor(2.8423, grad_fn=<DivBackward0>)
val common acc: 0.5197645327446652 val img acc: 0.4095364238410596 val txt acc: 0.45359823399558497 val_avg_loss: tensor(2.0510)
avg common val loss: tensor(1.7943) avg img val loss: tensor(2.2408) avg txt val loss: tensor(2.1179)
epoch: 1 train_loss: tensor(2.0479, grad_fn=<DivBackward0>) average train loss tensor(2.1448, grad_fn=<DivBackward0>)
avg common loss: tensor(1.8956, grad_fn=<DivBackward0>) avg img loss: tensor(2.3155, grad_fn=<DivBackward0>) avg txt loss: tensor(2.2233, grad_fn=<DivBackward0>)
val common acc: 0.5838704930095658 val img acc: 0.42590139808682853 val txt acc: 0.47602649006622516 val_avg_loss: tensor(1.9120)
avg common val loss: tensor(1.5607) avg img val loss: tensor(2.1630) avg txt va

val common acc: 0.6226931567328918 val img acc: 0.4503605592347314 val txt acc: 0.496747608535688 val_avg_loss: tensor(1.7855)
avg common val loss: tensor(1.3831) avg img val loss: tensor(2.0605) avg txt val loss: tensor(1.9130)
epoch: 17 train_loss: tensor(1.7370, grad_fn=<DivBackward0>) average train loss tensor(1.8817, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5410, grad_fn=<DivBackward0>) avg img loss: tensor(2.1048, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9993, grad_fn=<DivBackward0>)
val common acc: 0.6235761589403973 val img acc: 0.4522737306843267 val txt acc: 0.49686534216335543 val_avg_loss: tensor(1.7836)
avg common val loss: tensor(1.3830) avg img val loss: tensor(2.0568) avg txt val loss: tensor(1.9110)
epoch: 18 train_loss: tensor(1.7441, grad_fn=<DivBackward0>) average train loss tensor(1.8776, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5367, grad_fn=<DivBackward0>) avg img loss: tensor(2.1001, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9960, gr

epoch: 33 train_loss: tensor(1.6685, grad_fn=<DivBackward0>) average train loss tensor(1.8496, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5014, grad_fn=<DivBackward0>) avg img loss: tensor(2.0748, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9727, grad_fn=<DivBackward0>)
val common acc: 0.6257836644591611 val img acc: 0.4544812362030905 val txt acc: 0.500485651214128 val_avg_loss: tensor(1.7706)
avg common val loss: tensor(1.3661) avg img val loss: tensor(2.0457) avg txt val loss: tensor(1.9000)
epoch: 34 train_loss: tensor(1.6918, grad_fn=<DivBackward0>) average train loss tensor(1.8475, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4987, grad_fn=<DivBackward0>) avg img loss: tensor(2.0739, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9699, grad_fn=<DivBackward0>)
val common acc: 0.6263723325974981 val img acc: 0.4560117733627667 val txt acc: 0.5000441501103753 val_avg_loss: tensor(1.7682)
avg common val loss: tensor(1.3614) avg img val loss: tensor(2.0443) avg txt val 

val common acc: 0.6274025018395879 val img acc: 0.4544223693892568 val txt acc: 0.49983811626195734 val_avg_loss: tensor(1.7640)
avg common val loss: tensor(1.3571) avg img val loss: tensor(2.0384) avg txt val loss: tensor(1.8964)
epoch: 50 train_loss: tensor(1.6286, grad_fn=<DivBackward0>) average train loss tensor(1.8327, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4802, grad_fn=<DivBackward0>) avg img loss: tensor(2.0617, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9562, grad_fn=<DivBackward0>)
val common acc: 0.6279911699779249 val img acc: 0.4549227373068433 val txt acc: 0.5015746872700515 val_avg_loss: tensor(1.7630)
avg common val loss: tensor(1.3568) avg img val loss: tensor(2.0378) avg txt val loss: tensor(1.8944)
epoch: 51 train_loss: tensor(1.6516, grad_fn=<DivBackward0>) average train loss tensor(1.8336, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4811, grad_fn=<DivBackward0>) avg img loss: tensor(2.0622, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9574, g

In [13]:
model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs2048_rs42_d128_wd0005_drop_05_100')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=100,
    writer=writer,
    train_loader=train_loader,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(2.2356, grad_fn=<DivBackward0>) average train loss tensor(2.7814, grad_fn=<DivBackward0>)
avg common loss: tensor(2.8003, grad_fn=<DivBackward0>) avg img loss: tensor(2.7005, grad_fn=<DivBackward0>) avg txt loss: tensor(2.8435, grad_fn=<DivBackward0>)
val common acc: 0.5174098601913172 val img acc: 0.41539367181751286 val txt acc: 0.4542457689477557 val_avg_loss: tensor(2.0460)
avg common val loss: tensor(1.7837) avg img val loss: tensor(2.2355) avg txt val loss: tensor(2.1187)
epoch: 1 train_loss: tensor(2.0575, grad_fn=<DivBackward0>) average train loss tensor(2.1415, grad_fn=<DivBackward0>)
avg common loss: tensor(1.8901, grad_fn=<DivBackward0>) avg img loss: tensor(2.3134, grad_fn=<DivBackward0>) avg txt loss: tensor(2.2210, grad_fn=<DivBackward0>)
val common acc: 0.5812509197939661 val img acc: 0.4293745401030169 val txt acc: 0.4767623252391464 val_avg_loss: tensor(1.9108)
avg common val loss: tensor(1.5600) avg img val loss: tensor(2.1621) avg txt val 

val common acc: 0.6226342899190581 val img acc: 0.45100809418690213 val txt acc: 0.49892568064753495 val_avg_loss: tensor(1.7844)
avg common val loss: tensor(1.3849) avg img val loss: tensor(2.0584) avg txt val loss: tensor(1.9099)
epoch: 17 train_loss: tensor(1.7406, grad_fn=<DivBackward0>) average train loss tensor(1.8783, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5383, grad_fn=<DivBackward0>) avg img loss: tensor(2.1032, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9934, grad_fn=<DivBackward0>)
val common acc: 0.6245180279617366 val img acc: 0.4519793966151582 val txt acc: 0.49969094922737306 val_avg_loss: tensor(1.7833)
avg common val loss: tensor(1.3829) avg img val loss: tensor(2.0585) avg txt val loss: tensor(1.9086)
epoch: 18 train_loss: tensor(1.7108, grad_fn=<DivBackward0>) average train loss tensor(1.8760, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5325, grad_fn=<DivBackward0>) avg img loss: tensor(2.1009, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9945,

epoch: 33 train_loss: tensor(1.6737, grad_fn=<DivBackward0>) average train loss tensor(1.8483, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4985, grad_fn=<DivBackward0>) avg img loss: tensor(2.0771, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9692, grad_fn=<DivBackward0>)
val common acc: 0.6246651949963208 val img acc: 0.4541280353200883 val txt acc: 0.5005739514348786 val_avg_loss: tensor(1.7697)
avg common val loss: tensor(1.3636) avg img val loss: tensor(2.0465) avg txt val loss: tensor(1.8989)
epoch: 34 train_loss: tensor(1.6784, grad_fn=<DivBackward0>) average train loss tensor(1.8478, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5001, grad_fn=<DivBackward0>) avg img loss: tensor(2.0764, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9669, grad_fn=<DivBackward0>)
val common acc: 0.6256953642384105 val img acc: 0.4543046357615894 val txt acc: 0.500868285504047 val_avg_loss: tensor(1.7692)
avg common val loss: tensor(1.3640) avg img val loss: tensor(2.0457) avg txt val 

val common acc: 0.6266666666666667 val img acc: 0.4546284032376748 val txt acc: 0.5027520235467255 val_avg_loss: tensor(1.7631)
avg common val loss: tensor(1.3545) avg img val loss: tensor(2.0419) avg txt val loss: tensor(1.8928)
epoch: 50 train_loss: tensor(1.6489, grad_fn=<DivBackward0>) average train loss tensor(1.8331, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4781, grad_fn=<DivBackward0>) avg img loss: tensor(2.0605, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9607, grad_fn=<DivBackward0>)
val common acc: 0.6276379690949228 val img acc: 0.45559970566593083 val txt acc: 0.5032818248712289 val_avg_loss: tensor(1.7624)
avg common val loss: tensor(1.3537) avg img val loss: tensor(2.0414) avg txt val loss: tensor(1.8920)
epoch: 51 train_loss: tensor(1.6474, grad_fn=<DivBackward0>) average train loss tensor(1.8348, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4840, grad_fn=<DivBackward0>) avg img loss: tensor(2.0610, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9594, g

epoch: 66 train_loss: tensor(1.6402, grad_fn=<DivBackward0>) average train loss tensor(1.8268, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4729, grad_fn=<DivBackward0>) avg img loss: tensor(2.0550, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9525, grad_fn=<DivBackward0>)
val common acc: 0.6278145695364239 val img acc: 0.45480500367917587 val txt acc: 0.5033112582781457 val_avg_loss: tensor(1.7574)
avg common val loss: tensor(1.3464) avg img val loss: tensor(2.0360) avg txt val loss: tensor(1.8897)
epoch: 67 train_loss: tensor(1.6426, grad_fn=<DivBackward0>) average train loss tensor(1.8239, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4675, grad_fn=<DivBackward0>) avg img loss: tensor(2.0509, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9532, grad_fn=<DivBackward0>)
val common acc: 0.6295217071376011 val img acc: 0.45689477557027225 val txt acc: 0.5023988226637234 val_avg_loss: tensor(1.7590)
avg common val loss: tensor(1.3501) avg img val loss: tensor(2.0368) avg txt v

val common acc: 0.628962472406181 val img acc: 0.45627667402501837 val txt acc: 0.5028108903605593 val_avg_loss: tensor(1.7580)
avg common val loss: tensor(1.3489) avg img val loss: tensor(2.0367) avg txt val loss: tensor(1.8884)
epoch: 83 train_loss: tensor(1.6523, grad_fn=<DivBackward0>) average train loss tensor(1.8206, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4640, grad_fn=<DivBackward0>) avg img loss: tensor(2.0499, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9479, grad_fn=<DivBackward0>)
val common acc: 0.6293156732891833 val img acc: 0.45695364238410596 val txt acc: 0.5025754231052244 val_avg_loss: tensor(1.7602)
avg common val loss: tensor(1.3524) avg img val loss: tensor(2.0403) avg txt val loss: tensor(1.8880)
epoch: 84 train_loss: tensor(1.6236, grad_fn=<DivBackward0>) average train loss tensor(1.8166, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4587, grad_fn=<DivBackward0>) avg img loss: tensor(2.0457, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9452, g

epoch: 99 train_loss: tensor(1.6396, grad_fn=<DivBackward0>) average train loss tensor(1.8143, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4541, grad_fn=<DivBackward0>) avg img loss: tensor(2.0423, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9465, grad_fn=<DivBackward0>)
val common acc: 0.6315231788079471 val img acc: 0.45521707137601175 val txt acc: 0.5026931567328918 val_avg_loss: tensor(1.7555)
avg common val loss: tensor(1.3440) avg img val loss: tensor(2.0359) avg txt val loss: tensor(1.8866)


In [12]:
model = torch_models.NormModelTridentBN(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bn_bs2048_rs42_d128_wd0005_drop_05_100')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=100,
    writer=writer,
    train_loader=train_loader,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(2.1816, grad_fn=<DivBackward0>) average train loss tensor(2.6583, grad_fn=<DivBackward0>)
avg common loss: tensor(2.6052, grad_fn=<DivBackward0>) avg img loss: tensor(2.6679, grad_fn=<DivBackward0>) avg txt loss: tensor(2.7018, grad_fn=<DivBackward0>)
val common acc: 0.5574098601913171 val img acc: 0.41845474613686534 val txt acc: 0.47049300956585727 val_avg_loss: tensor(2.0080)
avg common val loss: tensor(1.6924) avg img val loss: tensor(2.2498) avg txt val loss: tensor(2.0820)
epoch: 1 train_loss: tensor(1.9992, grad_fn=<DivBackward0>) average train loss tensor(2.0952, grad_fn=<DivBackward0>)
avg common loss: tensor(1.7858, grad_fn=<DivBackward0>) avg img loss: tensor(2.3150, grad_fn=<DivBackward0>) avg txt loss: tensor(2.1848, grad_fn=<DivBackward0>)
val common acc: 0.6010890360559235 val img acc: 0.4314054451802796 val txt acc: 0.4853568800588668 val_avg_loss: tensor(1.8835)
avg common val loss: tensor(1.4962) avg img val loss: tensor(2.1614) avg txt val

val common acc: 0.632317880794702 val img acc: 0.4512141280353201 val txt acc: 0.502869757174393 val_avg_loss: tensor(1.7514)
avg common val loss: tensor(1.3321) avg img val loss: tensor(2.0391) avg txt val loss: tensor(1.8830)
epoch: 17 train_loss: tensor(1.6046, grad_fn=<DivBackward0>) average train loss tensor(1.7855, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3926, grad_fn=<DivBackward0>) avg img loss: tensor(2.0440, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9198, grad_fn=<DivBackward0>)
val common acc: 0.6329948491537896 val img acc: 0.4539514348785872 val txt acc: 0.503252391464312 val_avg_loss: tensor(1.7522)
avg common val loss: tensor(1.3351) avg img val loss: tensor(2.0365) avg txt val loss: tensor(1.8849)
epoch: 18 train_loss: tensor(1.6036, grad_fn=<DivBackward0>) average train loss tensor(1.7818, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3878, grad_fn=<DivBackward0>) avg img loss: tensor(2.0394, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9183, grad_

epoch: 33 train_loss: tensor(1.5069, grad_fn=<DivBackward0>) average train loss tensor(1.7479, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3477, grad_fn=<DivBackward0>) avg img loss: tensor(2.0091, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8868, grad_fn=<DivBackward0>)
val common acc: 0.6350257542310522 val img acc: 0.4559823399558499 val txt acc: 0.5040176600441502 val_avg_loss: tensor(1.7432)
avg common val loss: tensor(1.3253) avg img val loss: tensor(2.0274) avg txt val loss: tensor(1.8768)
epoch: 34 train_loss: tensor(1.5099, grad_fn=<DivBackward0>) average train loss tensor(1.7466, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3447, grad_fn=<DivBackward0>) avg img loss: tensor(2.0101, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8849, grad_fn=<DivBackward0>)
val common acc: 0.6350257542310522 val img acc: 0.45757174392935984 val txt acc: 0.5040765268579838 val_avg_loss: tensor(1.7422)
avg common val loss: tensor(1.3239) avg img val loss: tensor(2.0260) avg txt va

val common acc: 0.6356144223693893 val img acc: 0.45559970566593083 val txt acc: 0.5045768947755702 val_avg_loss: tensor(1.7408)
avg common val loss: tensor(1.3239) avg img val loss: tensor(2.0239) avg txt val loss: tensor(1.8746)
epoch: 50 train_loss: tensor(1.4807, grad_fn=<DivBackward0>) average train loss tensor(1.7304, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3256, grad_fn=<DivBackward0>) avg img loss: tensor(1.9935, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8721, grad_fn=<DivBackward0>)
val common acc: 0.6345253863134658 val img acc: 0.45624724061810157 val txt acc: 0.5031935246504783 val_avg_loss: tensor(1.7417)
avg common val loss: tensor(1.3241) avg img val loss: tensor(2.0259) avg txt val loss: tensor(1.8752)
epoch: 51 train_loss: tensor(1.4806, grad_fn=<DivBackward0>) average train loss tensor(1.7287, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3236, grad_fn=<DivBackward0>) avg img loss: tensor(1.9896, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8731, 

epoch: 66 train_loss: tensor(1.4665, grad_fn=<DivBackward0>) average train loss tensor(1.7183, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3128, grad_fn=<DivBackward0>) avg img loss: tensor(1.9782, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8638, grad_fn=<DivBackward0>)
val common acc: 0.6378513612950699 val img acc: 0.45742457689477556 val txt acc: 0.5059308314937454 val_avg_loss: tensor(1.7381)
avg common val loss: tensor(1.3191) avg img val loss: tensor(2.0227) avg txt val loss: tensor(1.8724)
epoch: 67 train_loss: tensor(1.4614, grad_fn=<DivBackward0>) average train loss tensor(1.7179, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3084, grad_fn=<DivBackward0>) avg img loss: tensor(1.9801, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8651, grad_fn=<DivBackward0>)
val common acc: 0.6364091243561443 val img acc: 0.45863134657836646 val txt acc: 0.5059013980868285 val_avg_loss: tensor(1.7397)
avg common val loss: tensor(1.3219) avg img val loss: tensor(2.0250) avg txt v

val common acc: 0.6365562913907284 val img acc: 0.45736571008094185 val txt acc: 0.5066666666666667 val_avg_loss: tensor(1.7369)
avg common val loss: tensor(1.3183) avg img val loss: tensor(2.0234) avg txt val loss: tensor(1.8689)
epoch: 83 train_loss: tensor(1.4346, grad_fn=<DivBackward0>) average train loss tensor(1.7097, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3016, grad_fn=<DivBackward0>) avg img loss: tensor(1.9706, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8570, grad_fn=<DivBackward0>)
val common acc: 0.6375864606328182 val img acc: 0.4573951434878587 val txt acc: 0.5062545989698307 val_avg_loss: tensor(1.7368)
avg common val loss: tensor(1.3184) avg img val loss: tensor(2.0222) avg txt val loss: tensor(1.8700)
epoch: 84 train_loss: tensor(1.4473, grad_fn=<DivBackward0>) average train loss tensor(1.7095, grad_fn=<DivBackward0>)
avg common loss: tensor(1.3009, grad_fn=<DivBackward0>) avg img loss: tensor(1.9694, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8583, g

epoch: 99 train_loss: tensor(1.4137, grad_fn=<DivBackward0>) average train loss tensor(1.7025, grad_fn=<DivBackward0>)
avg common loss: tensor(1.2943, grad_fn=<DivBackward0>) avg img loss: tensor(1.9621, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8512, grad_fn=<DivBackward0>)
val common acc: 0.6367623252391464 val img acc: 0.4571008094186902 val txt acc: 0.505813097866078 val_avg_loss: tensor(1.7377)
avg common val loss: tensor(1.3206) avg img val loss: tensor(2.0226) avg txt val loss: tensor(1.8698)


### model trained on selected data

In [29]:
smth = pickle.load(open('data/trident_x3_20k_2.pkl', 'rb'))

In [32]:
len(smth[0][0])

20000

In [34]:
(x_img_bald_2, x_txt_bald_2), y_bald_2 = pickle.load(open('data/trident_x3_20k_2.pkl', 'rb'))

In [36]:
x_img_bald_2_t = torch.tensor(x_img_bald_2).float()
x_txt_bald_2_t = torch.tensor(x_txt_bald_2).float()
y_bald_2_t = torch.tensor(y_bald_2).float()

train_ds_bald_2 = TensorDataset(x_img_bald_2_t, x_txt_bald_2_t, y_bald_2_t)
train_loader_bald_2 = DataLoader(train_ds_bald_2, batch_size=2048)

In [37]:
model = torch_models.NormModelTridentBN(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bn_bs2048_rs42_d128_wd0005_drop05_bald2')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=300,
    writer=writer,
    train_loader=train_loader_bald_2,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(3.5291, grad_fn=<DivBackward0>) average train loss tensor(3.8318, grad_fn=<DivBackward0>)
avg common loss: tensor(3.9235, grad_fn=<DivBackward0>) avg img loss: tensor(3.6686, grad_fn=<DivBackward0>) avg txt loss: tensor(3.9034, grad_fn=<DivBackward0>)
val common acc: 0.10545989698307578 val img acc: 0.3169389256806475 val txt acc: 0.3102869757174393 val_avg_loss: tensor(3.4998)
avg common val loss: tensor(3.7855) avg img val loss: tensor(3.1466) avg txt val loss: tensor(3.5673)
epoch: 1 train_loss: tensor(3.0967, grad_fn=<DivBackward0>) average train loss tensor(3.2911, grad_fn=<DivBackward0>)
avg common loss: tensor(3.4802, grad_fn=<DivBackward0>) avg img loss: tensor(3.1086, grad_fn=<DivBackward0>) avg txt loss: tensor(3.2845, grad_fn=<DivBackward0>)
val common acc: 0.3423693892568065 val img acc: 0.3487270051508462 val txt acc: 0.36944812362030904 val_avg_loss: tensor(2.8652)
avg common val loss: tensor(3.0605) avg img val loss: tensor(2.6408) avg txt val

val common acc: 0.5799852832965415 val img acc: 0.4112729948491538 val txt acc: 0.4714643119941133 val_avg_loss: tensor(1.9528)
avg common val loss: tensor(1.5585) avg img val loss: tensor(2.2538) avg txt val loss: tensor(2.0461)
epoch: 17 train_loss: tensor(1.7505, grad_fn=<DivBackward0>) average train loss tensor(1.9717, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5535, grad_fn=<DivBackward0>) avg img loss: tensor(2.1799, grad_fn=<DivBackward0>) avg txt loss: tensor(2.1816, grad_fn=<DivBackward0>)
val common acc: 0.5808388520971303 val img acc: 0.4113612950699043 val txt acc: 0.4712582781456954 val_avg_loss: tensor(1.9457)
avg common val loss: tensor(1.5499) avg img val loss: tensor(2.2431) avg txt val loss: tensor(2.0440)
epoch: 18 train_loss: tensor(1.7259, grad_fn=<DivBackward0>) average train loss tensor(1.9420, grad_fn=<DivBackward0>)
avg common loss: tensor(1.5110, grad_fn=<DivBackward0>) avg img loss: tensor(2.1531, grad_fn=<DivBackward0>) avg txt loss: tensor(2.1620, gr

epoch: 33 train_loss: tensor(1.3381, grad_fn=<DivBackward0>) average train loss tensor(1.5862, grad_fn=<DivBackward0>)
avg common loss: tensor(1.0827, grad_fn=<DivBackward0>) avg img loss: tensor(1.7347, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9412, grad_fn=<DivBackward0>)
val common acc: 0.5907579102281089 val img acc: 0.4002354672553348 val txt acc: 0.472700515084621 val_avg_loss: tensor(1.9793)
avg common val loss: tensor(1.6066) avg img val loss: tensor(2.3067) avg txt val loss: tensor(2.0245)
epoch: 34 train_loss: tensor(1.3301, grad_fn=<DivBackward0>) average train loss tensor(1.5686, grad_fn=<DivBackward0>)
avg common loss: tensor(1.0627, grad_fn=<DivBackward0>) avg img loss: tensor(1.6959, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9471, grad_fn=<DivBackward0>)
val common acc: 0.5900809418690214 val img acc: 0.40020603384841796 val txt acc: 0.4729654157468727 val_avg_loss: tensor(1.9834)
avg common val loss: tensor(1.6121) avg img val loss: tensor(2.3141) avg txt val

val common acc: 0.5861662987490802 val img acc: 0.39311258278145694 val txt acc: 0.47187637969094925 val_avg_loss: tensor(2.0546)
avg common val loss: tensor(1.7155) avg img val loss: tensor(2.4113) avg txt val loss: tensor(2.0370)
epoch: 50 train_loss: tensor(1.0977, grad_fn=<DivBackward0>) average train loss tensor(1.3294, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7998, grad_fn=<DivBackward0>) avg img loss: tensor(1.3949, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7936, grad_fn=<DivBackward0>)
val common acc: 0.5860779985283296 val img acc: 0.3913171449595291 val txt acc: 0.47240618101545256 val_avg_loss: tensor(2.0570)
avg common val loss: tensor(1.7242) avg img val loss: tensor(2.4105) avg txt val loss: tensor(2.0361)
epoch: 51 train_loss: tensor(1.0902, grad_fn=<DivBackward0>) average train loss tensor(1.3302, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8013, grad_fn=<DivBackward0>) avg img loss: tensor(1.3850, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8043,

epoch: 66 train_loss: tensor(0.9291, grad_fn=<DivBackward0>) average train loss tensor(1.1709, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6489, grad_fn=<DivBackward0>) avg img loss: tensor(1.1789, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6848, grad_fn=<DivBackward0>)
val common acc: 0.5813097866077999 val img acc: 0.39055187637969097 val txt acc: 0.4706401766004415 val_avg_loss: tensor(2.1172)
avg common val loss: tensor(1.8041) avg img val loss: tensor(2.4914) avg txt val loss: tensor(2.0561)
epoch: 67 train_loss: tensor(0.9460, grad_fn=<DivBackward0>) average train loss tensor(1.1524, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6207, grad_fn=<DivBackward0>) avg img loss: tensor(1.1649, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6716, grad_fn=<DivBackward0>)
val common acc: 0.5807505518763797 val img acc: 0.38543046357615895 val txt acc: 0.47143487858719646 val_avg_loss: tensor(2.1177)
avg common val loss: tensor(1.8028) avg img val loss: tensor(2.4951) avg txt 

val common acc: 0.5786313465783665 val img acc: 0.38328182487122886 val txt acc: 0.469551140544518 val_avg_loss: tensor(2.1769)
avg common val loss: tensor(1.8793) avg img val loss: tensor(2.5753) avg txt val loss: tensor(2.0761)
epoch: 83 train_loss: tensor(0.8470, grad_fn=<DivBackward0>) average train loss tensor(1.0404, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5237, grad_fn=<DivBackward0>) avg img loss: tensor(0.9941, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6034, grad_fn=<DivBackward0>)
val common acc: 0.5763355408388521 val img acc: 0.3808094186902134 val txt acc: 0.4688741721854305 val_avg_loss: tensor(2.1706)
avg common val loss: tensor(1.8674) avg img val loss: tensor(2.5693) avg txt val loss: tensor(2.0750)
epoch: 84 train_loss: tensor(0.8186, grad_fn=<DivBackward0>) average train loss tensor(1.0250, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5061, grad_fn=<DivBackward0>) avg img loss: tensor(0.9825, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5866, gr

epoch: 99 train_loss: tensor(0.7305, grad_fn=<DivBackward0>) average train loss tensor(0.9461, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4421, grad_fn=<DivBackward0>) avg img loss: tensor(0.8637, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5326, grad_fn=<DivBackward0>)
val common acc: 0.5723325974981605 val img acc: 0.3771596762325239 val txt acc: 0.46646063281824873 val_avg_loss: tensor(2.2267)
avg common val loss: tensor(1.9441) avg img val loss: tensor(2.6384) avg txt val loss: tensor(2.0977)
epoch: 100 train_loss: tensor(0.7651, grad_fn=<DivBackward0>) average train loss tensor(0.9545, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4474, grad_fn=<DivBackward0>) avg img loss: tensor(0.8762, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5399, grad_fn=<DivBackward0>)
val common acc: 0.5748050036791759 val img acc: 0.38092715231788077 val txt acc: 0.46707873436350256 val_avg_loss: tensor(2.2152)
avg common val loss: tensor(1.9256) avg img val loss: tensor(2.6234) avg txt

epoch: 115 train_loss: tensor(0.7124, grad_fn=<DivBackward0>) average train loss tensor(0.8820, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3922, grad_fn=<DivBackward0>) avg img loss: tensor(0.7771, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4766, grad_fn=<DivBackward0>)
val common acc: 0.5692420897718911 val img acc: 0.38007358351729215 val txt acc: 0.46460632818248715 val_avg_loss: tensor(2.2552)
avg common val loss: tensor(1.9654) avg img val loss: tensor(2.6850) avg txt val loss: tensor(2.1152)
epoch: 116 train_loss: tensor(0.6975, grad_fn=<DivBackward0>) average train loss tensor(0.8765, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3855, grad_fn=<DivBackward0>) avg img loss: tensor(0.7687, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4755, grad_fn=<DivBackward0>)
val common acc: 0.5695658572479765 val img acc: 0.379514348785872 val txt acc: 0.4642825607064018 val_avg_loss: tensor(2.2532)
avg common val loss: tensor(1.9614) avg img val loss: tensor(2.6808) avg txt 

epoch: 131 train_loss: tensor(0.6755, grad_fn=<DivBackward0>) average train loss tensor(0.8339, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3605, grad_fn=<DivBackward0>) avg img loss: tensor(0.7092, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4322, grad_fn=<DivBackward0>)
val common acc: 0.5680058866813834 val img acc: 0.3815746872700515 val txt acc: 0.4611037527593819 val_avg_loss: tensor(2.2831)
avg common val loss: tensor(1.9947) avg img val loss: tensor(2.7143) avg txt val loss: tensor(2.1404)
epoch: 132 train_loss: tensor(0.6433, grad_fn=<DivBackward0>) average train loss tensor(0.8333, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3517, grad_fn=<DivBackward0>) avg img loss: tensor(0.7084, grad_fn=<DivBackward0>) avg txt loss: tensor(1.4398, grad_fn=<DivBackward0>)
val common acc: 0.5692715231788079 val img acc: 0.37671817512877115 val txt acc: 0.4615158204562178 val_avg_loss: tensor(2.2699)
avg common val loss: tensor(1.9688) avg img val loss: tensor(2.7046) avg txt 

epoch: 147 train_loss: tensor(0.6261, grad_fn=<DivBackward0>) average train loss tensor(0.7913, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3228, grad_fn=<DivBackward0>) avg img loss: tensor(0.6566, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3945, grad_fn=<DivBackward0>)
val common acc: 0.5647976453274467 val img acc: 0.3744812362030905 val txt acc: 0.4585430463576159 val_avg_loss: tensor(2.3010)
avg common val loss: tensor(2.0069) avg img val loss: tensor(2.7437) avg txt val loss: tensor(2.1523)
epoch: 148 train_loss: tensor(0.5914, grad_fn=<DivBackward0>) average train loss tensor(0.7831, grad_fn=<DivBackward0>)
avg common loss: tensor(0.3200, grad_fn=<DivBackward0>) avg img loss: tensor(0.6473, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3820, grad_fn=<DivBackward0>)
val common acc: 0.5647093451066961 val img acc: 0.3763944076526858 val txt acc: 0.46019131714495953 val_avg_loss: tensor(2.3034)
avg common val loss: tensor(2.0042) avg img val loss: tensor(2.7481) avg txt 

epoch: 163 train_loss: tensor(0.5892, grad_fn=<DivBackward0>) average train loss tensor(0.7532, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2939, grad_fn=<DivBackward0>) avg img loss: tensor(0.6001, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3657, grad_fn=<DivBackward0>)
val common acc: 0.5665930831493745 val img acc: 0.37742457689477554 val txt acc: 0.4601618837380427 val_avg_loss: tensor(2.3270)
avg common val loss: tensor(2.0325) avg img val loss: tensor(2.7830) avg txt val loss: tensor(2.1654)
epoch: 164 train_loss: tensor(0.5848, grad_fn=<DivBackward0>) average train loss tensor(0.7538, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2958, grad_fn=<DivBackward0>) avg img loss: tensor(0.6078, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3577, grad_fn=<DivBackward0>)
val common acc: 0.5620603384841796 val img acc: 0.3755408388520971 val txt acc: 0.4606328182487123 val_avg_loss: tensor(2.3253)
avg common val loss: tensor(2.0328) avg img val loss: tensor(2.7790) avg txt 

epoch: 179 train_loss: tensor(0.5661, grad_fn=<DivBackward0>) average train loss tensor(0.7262, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2766, grad_fn=<DivBackward0>) avg img loss: tensor(0.5790, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3229, grad_fn=<DivBackward0>)
val common acc: 0.5638557763061074 val img acc: 0.37356880058866815 val txt acc: 0.4594554819720383 val_avg_loss: tensor(2.3351)
avg common val loss: tensor(2.0394) avg img val loss: tensor(2.7916) avg txt val loss: tensor(2.1744)
epoch: 180 train_loss: tensor(0.5865, grad_fn=<DivBackward0>) average train loss tensor(0.7297, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2803, grad_fn=<DivBackward0>) avg img loss: tensor(0.5732, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3357, grad_fn=<DivBackward0>)
val common acc: 0.5652097130242826 val img acc: 0.3759234731420162 val txt acc: 0.4596026490066225 val_avg_loss: tensor(2.3355)
avg common val loss: tensor(2.0406) avg img val loss: tensor(2.7889) avg txt 

epoch: 195 train_loss: tensor(0.5730, grad_fn=<DivBackward0>) average train loss tensor(0.7159, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2709, grad_fn=<DivBackward0>) avg img loss: tensor(0.5654, grad_fn=<DivBackward0>) avg txt loss: tensor(1.3114, grad_fn=<DivBackward0>)
val common acc: 0.5587049300956586 val img acc: 0.3728329654157469 val txt acc: 0.45842531272994846 val_avg_loss: tensor(2.3575)
avg common val loss: tensor(2.0626) avg img val loss: tensor(2.8140) avg txt val loss: tensor(2.1960)
epoch: 196 train_loss: tensor(0.5480, grad_fn=<DivBackward0>) average train loss tensor(0.6989, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2582, grad_fn=<DivBackward0>) avg img loss: tensor(0.5515, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2871, grad_fn=<DivBackward0>)
val common acc: 0.5619131714495953 val img acc: 0.37801324503311257 val txt acc: 0.4571596762325239 val_avg_loss: tensor(2.3525)
avg common val loss: tensor(2.0561) avg img val loss: tensor(2.8142) avg txt

epoch: 211 train_loss: tensor(0.5315, grad_fn=<DivBackward0>) average train loss tensor(0.6825, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2445, grad_fn=<DivBackward0>) avg img loss: tensor(0.5258, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2771, grad_fn=<DivBackward0>)
val common acc: 0.5607064017660044 val img acc: 0.376747608535688 val txt acc: 0.457542310522443 val_avg_loss: tensor(2.3637)
avg common val loss: tensor(2.0681) avg img val loss: tensor(2.8200) avg txt val loss: tensor(2.2030)
epoch: 212 train_loss: tensor(0.5398, grad_fn=<DivBackward0>) average train loss tensor(0.6792, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2420, grad_fn=<DivBackward0>) avg img loss: tensor(0.5183, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2771, grad_fn=<DivBackward0>)
val common acc: 0.5555261221486387 val img acc: 0.3726857983811626 val txt acc: 0.45748344370860927 val_avg_loss: tensor(2.3682)
avg common val loss: tensor(2.0799) avg img val loss: tensor(2.8182) avg txt va

val common acc: 0.5571743929359824 val img acc: 0.37542310522442973 val txt acc: 0.4565121412803532 val_avg_loss: tensor(2.3770)
avg common val loss: tensor(2.0770) avg img val loss: tensor(2.8377) avg txt val loss: tensor(2.2163)
epoch: 228 train_loss: tensor(0.5131, grad_fn=<DivBackward0>) average train loss tensor(0.6711, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2360, grad_fn=<DivBackward0>) avg img loss: tensor(0.5098, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2673, grad_fn=<DivBackward0>)
val common acc: 0.5598528329654158 val img acc: 0.3776600441501104 val txt acc: 0.4553348050036792 val_avg_loss: tensor(2.3671)
avg common val loss: tensor(2.0658) avg img val loss: tensor(2.8211) avg txt val loss: tensor(2.2143)
epoch: 229 train_loss: tensor(0.5552, grad_fn=<DivBackward0>) average train loss tensor(0.6670, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2378, grad_fn=<DivBackward0>) avg img loss: tensor(0.5080, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2552,

val common acc: 0.5547314201618837 val img acc: 0.36883002207505516 val txt acc: 0.4531567328918322 val_avg_loss: tensor(2.3841)
avg common val loss: tensor(2.0827) avg img val loss: tensor(2.8415) avg txt val loss: tensor(2.2281)
epoch: 244 train_loss: tensor(0.4965, grad_fn=<DivBackward0>) average train loss tensor(0.6522, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2330, grad_fn=<DivBackward0>) avg img loss: tensor(0.4814, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2422, grad_fn=<DivBackward0>)
val common acc: 0.5577630610743194 val img acc: 0.37630610743193527 val txt acc: 0.4551582045621781 val_avg_loss: tensor(2.3889)
avg common val loss: tensor(2.0928) avg img val loss: tensor(2.8478) avg txt val loss: tensor(2.2259)
epoch: 245 train_loss: tensor(0.5298, grad_fn=<DivBackward0>) average train loss tensor(0.6525, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2312, grad_fn=<DivBackward0>) avg img loss: tensor(0.4855, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2409

val common acc: 0.5539072847682119 val img acc: 0.3751287711552612 val txt acc: 0.45542310522442975 val_avg_loss: tensor(2.3965)
avg common val loss: tensor(2.1075) avg img val loss: tensor(2.8488) avg txt val loss: tensor(2.2331)
epoch: 260 train_loss: tensor(0.5030, grad_fn=<DivBackward0>) average train loss tensor(0.6362, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2169, grad_fn=<DivBackward0>) avg img loss: tensor(0.4719, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2199, grad_fn=<DivBackward0>)
val common acc: 0.5535835172921265 val img acc: 0.37380426784400295 val txt acc: 0.45306843267108166 val_avg_loss: tensor(2.3979)
avg common val loss: tensor(2.1000) avg img val loss: tensor(2.8568) avg txt val loss: tensor(2.2370)
epoch: 261 train_loss: tensor(0.4991, grad_fn=<DivBackward0>) average train loss tensor(0.6381, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2185, grad_fn=<DivBackward0>) avg img loss: tensor(0.4801, grad_fn=<DivBackward0>) avg txt loss: tensor(1.215

val common acc: 0.5548197203826343 val img acc: 0.3754525386313466 val txt acc: 0.45312729948491537 val_avg_loss: tensor(2.4083)
avg common val loss: tensor(2.1147) avg img val loss: tensor(2.8591) avg txt val loss: tensor(2.2510)
epoch: 276 train_loss: tensor(0.4934, grad_fn=<DivBackward0>) average train loss tensor(0.6289, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2086, grad_fn=<DivBackward0>) avg img loss: tensor(0.4602, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2179, grad_fn=<DivBackward0>)
val common acc: 0.555467255334805 val img acc: 0.373392200147167 val txt acc: 0.45245033112582783 val_avg_loss: tensor(2.3984)
avg common val loss: tensor(2.0984) avg img val loss: tensor(2.8505) avg txt val loss: tensor(2.2464)
epoch: 277 train_loss: tensor(0.5142, grad_fn=<DivBackward0>) average train loss tensor(0.6313, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2171, grad_fn=<DivBackward0>) avg img loss: tensor(0.4703, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2064, 

val common acc: 0.552700515084621 val img acc: 0.3742457689477557 val txt acc: 0.45068432671081676 val_avg_loss: tensor(2.4111)
avg common val loss: tensor(2.1119) avg img val loss: tensor(2.8614) avg txt val loss: tensor(2.2599)
epoch: 292 train_loss: tensor(0.4742, grad_fn=<DivBackward0>) average train loss tensor(0.6195, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2105, grad_fn=<DivBackward0>) avg img loss: tensor(0.4518, grad_fn=<DivBackward0>) avg txt loss: tensor(1.1960, grad_fn=<DivBackward0>)
val common acc: 0.5567034584253128 val img acc: 0.3765121412803532 val txt acc: 0.4501250919793966 val_avg_loss: tensor(2.4001)
avg common val loss: tensor(2.0990) avg img val loss: tensor(2.8446) avg txt val loss: tensor(2.2566)
epoch: 293 train_loss: tensor(0.4998, grad_fn=<DivBackward0>) average train loss tensor(0.6249, grad_fn=<DivBackward0>)
avg common loss: tensor(0.2175, grad_fn=<DivBackward0>) avg img loss: tensor(0.4524, grad_fn=<DivBackward0>) avg txt loss: tensor(1.2048, 

In [38]:
model = torch_models.NormModelTrident(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bs2048_rs42_d128_wd0005_drop05_bald2')

torch_models.fit_topics_trident_model(
    model=model,
    optimizer=optimizer,
    epochs=300,
    writer=writer,
    train_loader=train_loader_bald_2,
    val_loader=val_loader
)

epoch: 0 train_loss: tensor(3.6056, grad_fn=<DivBackward0>) average train loss tensor(3.7562, grad_fn=<DivBackward0>)
avg common loss: tensor(3.8418, grad_fn=<DivBackward0>) avg img loss: tensor(3.6411, grad_fn=<DivBackward0>) avg txt loss: tensor(3.7856, grad_fn=<DivBackward0>)
val common acc: 0.060103016924208975 val img acc: 0.27240618101545255 val txt acc: 0.11817512877115526 val_avg_loss: tensor(3.4352)
avg common val loss: tensor(3.6991) avg img val loss: tensor(3.0637) avg txt val loss: tensor(3.5427)
epoch: 1 train_loss: tensor(3.3138, grad_fn=<DivBackward0>) average train loss tensor(3.4447, grad_fn=<DivBackward0>)
avg common loss: tensor(3.6770, grad_fn=<DivBackward0>) avg img loss: tensor(3.1661, grad_fn=<DivBackward0>) avg txt loss: tensor(3.4910, grad_fn=<DivBackward0>)
val common acc: 0.2147167034584253 val img acc: 0.3368064753495217 val txt acc: 0.24585724797645328 val_avg_loss: tensor(3.0923)
avg common val loss: tensor(3.4595) avg img val loss: tensor(2.6434) avg txt 

val common acc: 0.5751876379690949 val img acc: 0.41253863134657837 val txt acc: 0.47090507726269315 val_avg_loss: tensor(1.9548)
avg common val loss: tensor(1.5688) avg img val loss: tensor(2.2298) avg txt val loss: tensor(2.0658)
epoch: 17 train_loss: tensor(1.9483, grad_fn=<DivBackward0>) average train loss tensor(2.1305, grad_fn=<DivBackward0>)
avg common loss: tensor(1.7752, grad_fn=<DivBackward0>) avg img loss: tensor(2.3163, grad_fn=<DivBackward0>) avg txt loss: tensor(2.3002, grad_fn=<DivBackward0>)
val common acc: 0.577159676232524 val img acc: 0.41109639440765267 val txt acc: 0.4712582781456954 val_avg_loss: tensor(1.9480)
avg common val loss: tensor(1.5615) avg img val loss: tensor(2.2200) avg txt val loss: tensor(2.0626)
epoch: 18 train_loss: tensor(1.9168, grad_fn=<DivBackward0>) average train loss tensor(2.1130, grad_fn=<DivBackward0>)
avg common loss: tensor(1.7524, grad_fn=<DivBackward0>) avg img loss: tensor(2.3001, grad_fn=<DivBackward0>) avg txt loss: tensor(2.2866, 

epoch: 33 train_loss: tensor(1.6692, grad_fn=<DivBackward0>) average train loss tensor(1.8737, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4542, grad_fn=<DivBackward0>) avg img loss: tensor(2.0187, grad_fn=<DivBackward0>) avg txt loss: tensor(2.1484, grad_fn=<DivBackward0>)
val common acc: 0.5881089036055923 val img acc: 0.41462840323767475 val txt acc: 0.47611479028697573 val_avg_loss: tensor(1.9419)
avg common val loss: tensor(1.5528) avg img val loss: tensor(2.2368) avg txt val loss: tensor(2.0360)
epoch: 34 train_loss: tensor(1.6434, grad_fn=<DivBackward0>) average train loss tensor(1.8624, grad_fn=<DivBackward0>)
avg common loss: tensor(1.4336, grad_fn=<DivBackward0>) avg img loss: tensor(2.0126, grad_fn=<DivBackward0>) avg txt loss: tensor(2.1410, grad_fn=<DivBackward0>)
val common acc: 0.5901398086828551 val img acc: 0.41515820456217806 val txt acc: 0.4763208241353937 val_avg_loss: tensor(1.9405)
avg common val loss: tensor(1.5492) avg img val loss: tensor(2.2392) avg txt 

val common acc: 0.5913760117733627 val img acc: 0.4101545253863135 val txt acc: 0.4765268579838116 val_avg_loss: tensor(1.9626)
avg common val loss: tensor(1.5845) avg img val loss: tensor(2.2750) avg txt val loss: tensor(2.0284)
epoch: 50 train_loss: tensor(1.4665, grad_fn=<DivBackward0>) average train loss tensor(1.6905, grad_fn=<DivBackward0>)
avg common loss: tensor(1.2319, grad_fn=<DivBackward0>) avg img loss: tensor(1.7835, grad_fn=<DivBackward0>) avg txt loss: tensor(2.0560, grad_fn=<DivBackward0>)
val common acc: 0.5895805739514349 val img acc: 0.4110669610007358 val txt acc: 0.4761442236938926 val_avg_loss: tensor(1.9697)
avg common val loss: tensor(1.5908) avg img val loss: tensor(2.2851) avg txt val loss: tensor(2.0332)
epoch: 51 train_loss: tensor(1.4336, grad_fn=<DivBackward0>) average train loss tensor(1.6712, grad_fn=<DivBackward0>)
avg common loss: tensor(1.2094, grad_fn=<DivBackward0>) avg img loss: tensor(1.7620, grad_fn=<DivBackward0>) avg txt loss: tensor(2.0421, gr

epoch: 66 train_loss: tensor(1.3188, grad_fn=<DivBackward0>) average train loss tensor(1.5439, grad_fn=<DivBackward0>)
avg common loss: tensor(1.0639, grad_fn=<DivBackward0>) avg img loss: tensor(1.5798, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9880, grad_fn=<DivBackward0>)
val common acc: 0.5878440029433407 val img acc: 0.40618101545253865 val txt acc: 0.4772038263428992 val_avg_loss: tensor(2.0008)
avg common val loss: tensor(1.6411) avg img val loss: tensor(2.3349) avg txt val loss: tensor(2.0265)
epoch: 67 train_loss: tensor(1.3042, grad_fn=<DivBackward0>) average train loss tensor(1.5424, grad_fn=<DivBackward0>)
avg common loss: tensor(1.0620, grad_fn=<DivBackward0>) avg img loss: tensor(1.5886, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9767, grad_fn=<DivBackward0>)
val common acc: 0.5880500367917586 val img acc: 0.4037674760853569 val txt acc: 0.4781456953642384 val_avg_loss: tensor(2.0029)
avg common val loss: tensor(1.6465) avg img val loss: tensor(2.3349) avg txt va

val common acc: 0.5845768947755703 val img acc: 0.4002649006622517 val txt acc: 0.47655629139072847 val_avg_loss: tensor(2.0385)
avg common val loss: tensor(1.6974) avg img val loss: tensor(2.3867) avg txt val loss: tensor(2.0313)
epoch: 83 train_loss: tensor(1.1984, grad_fn=<DivBackward0>) average train loss tensor(1.4190, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9190, grad_fn=<DivBackward0>) avg img loss: tensor(1.4172, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9207, grad_fn=<DivBackward0>)
val common acc: 0.5861368653421634 val img acc: 0.3982634289919058 val txt acc: 0.47567328918322294 val_avg_loss: tensor(2.0398)
avg common val loss: tensor(1.7010) avg img val loss: tensor(2.3889) avg txt val loss: tensor(2.0297)
epoch: 84 train_loss: tensor(1.2251, grad_fn=<DivBackward0>) average train loss tensor(1.4212, grad_fn=<DivBackward0>)
avg common loss: tensor(0.9317, grad_fn=<DivBackward0>) avg img loss: tensor(1.4164, grad_fn=<DivBackward0>) avg txt loss: tensor(1.9156, 

epoch: 99 train_loss: tensor(1.1044, grad_fn=<DivBackward0>) average train loss tensor(1.3277, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8318, grad_fn=<DivBackward0>) avg img loss: tensor(1.2987, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8525, grad_fn=<DivBackward0>)
val common acc: 0.5847829286239882 val img acc: 0.3981456953642384 val txt acc: 0.4761442236938926 val_avg_loss: tensor(2.0713)
avg common val loss: tensor(1.7473) avg img val loss: tensor(2.4310) avg txt val loss: tensor(2.0355)
epoch: 100 train_loss: tensor(1.1025, grad_fn=<DivBackward0>) average train loss tensor(1.3247, grad_fn=<DivBackward0>)
avg common loss: tensor(0.8262, grad_fn=<DivBackward0>) avg img loss: tensor(1.2839, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8640, grad_fn=<DivBackward0>)
val common acc: 0.5822810890360559 val img acc: 0.39867549668874175 val txt acc: 0.4762325239146431 val_avg_loss: tensor(2.0829)
avg common val loss: tensor(1.7663) avg img val loss: tensor(2.4464) avg txt v

epoch: 115 train_loss: tensor(1.0651, grad_fn=<DivBackward0>) average train loss tensor(1.2593, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7622, grad_fn=<DivBackward0>) avg img loss: tensor(1.1840, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8317, grad_fn=<DivBackward0>)
val common acc: 0.5799558498896247 val img acc: 0.39426048565121413 val txt acc: 0.4759087564385578 val_avg_loss: tensor(2.1009)
avg common val loss: tensor(1.7883) avg img val loss: tensor(2.4778) avg txt val loss: tensor(2.0365)
epoch: 116 train_loss: tensor(1.0447, grad_fn=<DivBackward0>) average train loss tensor(1.2563, grad_fn=<DivBackward0>)
avg common loss: tensor(0.7503, grad_fn=<DivBackward0>) avg img loss: tensor(1.1857, grad_fn=<DivBackward0>) avg txt loss: tensor(1.8329, grad_fn=<DivBackward0>)
val common acc: 0.5819278881530537 val img acc: 0.39537895511405446 val txt acc: 0.47543782192788814 val_avg_loss: tensor(2.0980)
avg common val loss: tensor(1.7878) avg img val loss: tensor(2.4695) avg tx

epoch: 131 train_loss: tensor(0.9731, grad_fn=<DivBackward0>) average train loss tensor(1.1888, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6856, grad_fn=<DivBackward0>) avg img loss: tensor(1.0964, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7845, grad_fn=<DivBackward0>)
val common acc: 0.5761000735835173 val img acc: 0.3925239146431199 val txt acc: 0.47473142016188374 val_avg_loss: tensor(2.1273)
avg common val loss: tensor(1.8273) avg img val loss: tensor(2.5115) avg txt val loss: tensor(2.0431)
epoch: 132 train_loss: tensor(0.9546, grad_fn=<DivBackward0>) average train loss tensor(1.1888, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6859, grad_fn=<DivBackward0>) avg img loss: tensor(1.1006, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7800, grad_fn=<DivBackward0>)
val common acc: 0.577159676232524 val img acc: 0.39461368653421636 val txt acc: 0.47426048565121415 val_avg_loss: tensor(2.1317)
avg common val loss: tensor(1.8403) avg img val loss: tensor(2.5132) avg txt

epoch: 147 train_loss: tensor(0.9248, grad_fn=<DivBackward0>) average train loss tensor(1.1314, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6337, grad_fn=<DivBackward0>) avg img loss: tensor(1.0165, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7441, grad_fn=<DivBackward0>)
val common acc: 0.5751876379690949 val img acc: 0.3918469462840324 val txt acc: 0.4739072847682119 val_avg_loss: tensor(2.1577)
avg common val loss: tensor(1.8773) avg img val loss: tensor(2.5484) avg txt val loss: tensor(2.0474)
epoch: 148 train_loss: tensor(0.9094, grad_fn=<DivBackward0>) average train loss tensor(1.1321, grad_fn=<DivBackward0>)
avg common loss: tensor(0.6357, grad_fn=<DivBackward0>) avg img loss: tensor(1.0130, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7475, grad_fn=<DivBackward0>)
val common acc: 0.5720971302428256 val img acc: 0.3893745401030169 val txt acc: 0.47373068432671084 val_avg_loss: tensor(2.1551)
avg common val loss: tensor(1.8711) avg img val loss: tensor(2.5459) avg txt 

epoch: 163 train_loss: tensor(0.8889, grad_fn=<DivBackward0>) average train loss tensor(1.0876, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5850, grad_fn=<DivBackward0>) avg img loss: tensor(0.9563, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7216, grad_fn=<DivBackward0>)
val common acc: 0.5713024282560707 val img acc: 0.3895217071376012 val txt acc: 0.4743193524650478 val_avg_loss: tensor(2.1765)
avg common val loss: tensor(1.9053) avg img val loss: tensor(2.5724) avg txt val loss: tensor(2.0519)
epoch: 164 train_loss: tensor(0.9013, grad_fn=<DivBackward0>) average train loss tensor(1.0870, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5892, grad_fn=<DivBackward0>) avg img loss: tensor(0.9532, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7187, grad_fn=<DivBackward0>)
val common acc: 0.5678587196467991 val img acc: 0.38966887417218543 val txt acc: 0.4743193524650478 val_avg_loss: tensor(2.1802)
avg common val loss: tensor(1.9034) avg img val loss: tensor(2.5825) avg txt 

epoch: 179 train_loss: tensor(0.8604, grad_fn=<DivBackward0>) average train loss tensor(1.0500, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5477, grad_fn=<DivBackward0>) avg img loss: tensor(0.9012, grad_fn=<DivBackward0>) avg txt loss: tensor(1.7010, grad_fn=<DivBackward0>)
val common acc: 0.5684473877851361 val img acc: 0.38775570272259013 val txt acc: 0.47340691685062547 val_avg_loss: tensor(2.2004)
avg common val loss: tensor(1.9424) avg img val loss: tensor(2.6013) avg txt val loss: tensor(2.0573)
epoch: 180 train_loss: tensor(0.8560, grad_fn=<DivBackward0>) average train loss tensor(1.0454, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5452, grad_fn=<DivBackward0>) avg img loss: tensor(0.8973, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6936, grad_fn=<DivBackward0>)
val common acc: 0.5691243561442237 val img acc: 0.38684326710816774 val txt acc: 0.47299484915378953 val_avg_loss: tensor(2.1997)
avg common val loss: tensor(1.9422) avg img val loss: tensor(2.6000) avg t

epoch: 195 train_loss: tensor(0.8159, grad_fn=<DivBackward0>) average train loss tensor(1.0177, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5185, grad_fn=<DivBackward0>) avg img loss: tensor(0.8575, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6772, grad_fn=<DivBackward0>)
val common acc: 0.5675643855776306 val img acc: 0.38905077262693155 val txt acc: 0.47205298013245034 val_avg_loss: tensor(2.2248)
avg common val loss: tensor(1.9797) avg img val loss: tensor(2.6320) avg txt val loss: tensor(2.0626)
epoch: 196 train_loss: tensor(0.8501, grad_fn=<DivBackward0>) average train loss tensor(1.0144, grad_fn=<DivBackward0>)
avg common loss: tensor(0.5188, grad_fn=<DivBackward0>) avg img loss: tensor(0.8545, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6700, grad_fn=<DivBackward0>)
val common acc: 0.5680941869021339 val img acc: 0.38887417218543047 val txt acc: 0.4720824135393672 val_avg_loss: tensor(2.2176)
avg common val loss: tensor(1.9616) avg img val loss: tensor(2.6260) avg tx

epoch: 211 train_loss: tensor(0.8005, grad_fn=<DivBackward0>) average train loss tensor(0.9852, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4814, grad_fn=<DivBackward0>) avg img loss: tensor(0.8135, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6608, grad_fn=<DivBackward0>)
val common acc: 0.5651802796173657 val img acc: 0.3857542310522443 val txt acc: 0.47025754231052247 val_avg_loss: tensor(2.2361)
avg common val loss: tensor(1.9987) avg img val loss: tensor(2.6416) avg txt val loss: tensor(2.0680)
epoch: 212 train_loss: tensor(0.7943, grad_fn=<DivBackward0>) average train loss tensor(0.9872, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4890, grad_fn=<DivBackward0>) avg img loss: tensor(0.8161, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6566, grad_fn=<DivBackward0>)
val common acc: 0.5660927152317881 val img acc: 0.3866961000735835 val txt acc: 0.4715526122148639 val_avg_loss: tensor(2.2355)
avg common val loss: tensor(1.9955) avg img val loss: tensor(2.6437) avg txt 

epoch: 227 train_loss: tensor(0.7918, grad_fn=<DivBackward0>) average train loss tensor(0.9626, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4654, grad_fn=<DivBackward0>) avg img loss: tensor(0.7822, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6401, grad_fn=<DivBackward0>)
val common acc: 0.5598233995584989 val img acc: 0.3843119941133186 val txt acc: 0.4706990434142752 val_avg_loss: tensor(2.2508)
avg common val loss: tensor(2.0137) avg img val loss: tensor(2.6648) avg txt val loss: tensor(2.0740)
epoch: 228 train_loss: tensor(0.7890, grad_fn=<DivBackward0>) average train loss tensor(0.9619, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4685, grad_fn=<DivBackward0>) avg img loss: tensor(0.7819, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6352, grad_fn=<DivBackward0>)
val common acc: 0.5589698307579102 val img acc: 0.3855481972038263 val txt acc: 0.47037527593818984 val_avg_loss: tensor(2.2535)
avg common val loss: tensor(2.0226) avg img val loss: tensor(2.6620) avg txt 

epoch: 243 train_loss: tensor(0.7686, grad_fn=<DivBackward0>) average train loss tensor(0.9454, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4581, grad_fn=<DivBackward0>) avg img loss: tensor(0.7483, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6297, grad_fn=<DivBackward0>)
val common acc: 0.5566740250183959 val img acc: 0.38251655629139075 val txt acc: 0.47137601177336275 val_avg_loss: tensor(2.2546)
avg common val loss: tensor(2.0132) avg img val loss: tensor(2.6761) avg txt val loss: tensor(2.0747)
epoch: 244 train_loss: tensor(0.7702, grad_fn=<DivBackward0>) average train loss tensor(0.9332, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4452, grad_fn=<DivBackward0>) avg img loss: tensor(0.7495, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6049, grad_fn=<DivBackward0>)
val common acc: 0.5603826342899191 val img acc: 0.38363502575423103 val txt acc: 0.47061074319352464 val_avg_loss: tensor(2.2580)
avg common val loss: tensor(2.0253) avg img val loss: tensor(2.6738) avg t

epoch: 259 train_loss: tensor(0.7507, grad_fn=<DivBackward0>) average train loss tensor(0.9145, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4254, grad_fn=<DivBackward0>) avg img loss: tensor(0.7179, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6001, grad_fn=<DivBackward0>)
val common acc: 0.561766004415011 val img acc: 0.3851066961000736 val txt acc: 0.4697866077998528 val_avg_loss: tensor(2.2732)
avg common val loss: tensor(2.0470) avg img val loss: tensor(2.6920) avg txt val loss: tensor(2.0805)
epoch: 260 train_loss: tensor(0.7501, grad_fn=<DivBackward0>) average train loss tensor(0.9162, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4261, grad_fn=<DivBackward0>) avg img loss: tensor(0.7215, grad_fn=<DivBackward0>) avg txt loss: tensor(1.6012, grad_fn=<DivBackward0>)
val common acc: 0.5596762325239146 val img acc: 0.3828697571743929 val txt acc: 0.4711111111111111 val_avg_loss: tensor(2.2804)
avg common val loss: tensor(2.0573) avg img val loss: tensor(2.7011) avg txt va

epoch: 275 train_loss: tensor(0.7450, grad_fn=<DivBackward0>) average train loss tensor(0.9041, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4195, grad_fn=<DivBackward0>) avg img loss: tensor(0.7055, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5872, grad_fn=<DivBackward0>)
val common acc: 0.5576747608535688 val img acc: 0.3821927888153054 val txt acc: 0.47019867549668876 val_avg_loss: tensor(2.2868)
avg common val loss: tensor(2.0599) avg img val loss: tensor(2.7111) avg txt val loss: tensor(2.0893)
epoch: 276 train_loss: tensor(0.7367, grad_fn=<DivBackward0>) average train loss tensor(0.9061, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4164, grad_fn=<DivBackward0>) avg img loss: tensor(0.7051, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5966, grad_fn=<DivBackward0>)
val common acc: 0.5607358351729212 val img acc: 0.38236938925680647 val txt acc: 0.4714054451802796 val_avg_loss: tensor(2.2876)
avg common val loss: tensor(2.0714) avg img val loss: tensor(2.7089) avg txt

epoch: 291 train_loss: tensor(0.7461, grad_fn=<DivBackward0>) average train loss tensor(0.8909, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4094, grad_fn=<DivBackward0>) avg img loss: tensor(0.6848, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5786, grad_fn=<DivBackward0>)
val common acc: 0.5572626931567329 val img acc: 0.3816335540838852 val txt acc: 0.47084621044885944 val_avg_loss: tensor(2.2947)
avg common val loss: tensor(2.0822) avg img val loss: tensor(2.7147) avg txt val loss: tensor(2.0872)
epoch: 292 train_loss: tensor(0.7313, grad_fn=<DivBackward0>) average train loss tensor(0.8922, grad_fn=<DivBackward0>)
avg common loss: tensor(0.4052, grad_fn=<DivBackward0>) avg img loss: tensor(0.6860, grad_fn=<DivBackward0>) avg txt loss: tensor(1.5854, grad_fn=<DivBackward0>)
val common acc: 0.554289919058131 val img acc: 0.3783075791022811 val txt acc: 0.4701692420897719 val_avg_loss: tensor(2.2877)
avg common val loss: tensor(2.0632) avg img val loss: tensor(2.7133) avg txt v

## Experiments with trainable loss merge

In [28]:
class TridentMTL(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch_models.NormModelTrident(drop=0.5)
        self.sigma = nn.Parameter(torch.ones(3))
        
    def forward(self, inp_img, inp_txt, target):
        outputs = self.model(inp_img, inp_txt)

        ls = [F.nll_loss(output, torch.argmax(target, dim=1)) for output in outputs]
        l = sum([(le / s ** 2) for le, s in zip(ls, self.sigma)]) + sum([torch.log(s) for s in self.sigma])
        return l, ls

In [33]:
def fit_mtl(mtl, optimizer, epochs, train_loader, val_loader, writer):
    for epoch in range(epochs):
        print(mtl.sigma)
        mtl.train()

        trainable_loss_sum = 0.0
        common_loss_sum = 0.0
        img_loss_sum = 0.0
        txt_loss_sum = 0.0
        loss_count = 0

        for x_img_cur, x_txt_cur, y_cur in train_loader:
            mtl.zero_grad()
            trainable_loss, raw_losses = mtl(x_img_cur, x_txt_cur, y_cur)
            trainable_loss.backward()

            trainable_loss_sum += trainable_loss
            common_loss_sum += raw_losses[0]
            img_loss_sum += raw_losses[1]
            txt_loss_sum += raw_losses[2]
            loss_count += 1

            optimizer.step()

        print('epoch:', epoch, 'average trainable train loss', trainable_loss_sum / loss_count)
        if writer is not None:
            writer.add_scalar('train_loss', trainable_loss, epoch)
            writer.add_scalar('avg_train_loss', trainable_loss_sum / loss_count, epoch)
            writer.add_scalar('avg_train_loss_common', common_loss_sum / loss_count, epoch)
            writer.add_scalar('avg_train_loss_img', img_loss_sum / loss_count, epoch)
            writer.add_scalar('avg_train_loss_txt', txt_loss_sum / loss_count, epoch)

        if val_loader is not None:
            mtl.eval()

            correct_common = 0
            correct_img = 0
            correct_txt = 0
            total = 0
            
            trainable_val_loss_sum = 0.0
            main_val_loss_sum = 0.0
            img_val_loss_sum = 0.0
            txt_val_loss_sum = 0.0
            loss_count = 0

            with torch.no_grad():
                for x_img_cur, x_txt_cur, y_cur in val_loader:
                    outputs = mtl.model(x_img_cur, x_txt_cur)
                    
                    for idx, i in enumerate(outputs[0]):
                        if torch.argmax(i) == torch.argmax(y_cur, dim=1)[idx]:
                            correct_common += 1
                        total += 1
                    
                    for idx, i in enumerate(outputs[1]):
                        if torch.argmax(i) == torch.argmax(y_cur, dim=1)[idx]:
                            correct_img += 1
                        
                    for idx, i in enumerate(outputs[2]):
                        if torch.argmax(i) == torch.argmax(y_cur, dim=1)[idx]:
                            correct_txt += 1
                        
                    trainable_val_loss, raw_val_losses = mtl(x_img_cur, x_txt_cur, y_cur)
                    trainable_val_loss_sum += trainable_val_loss
                    main_val_loss_sum += raw_val_losses[0]
                    img_val_loss_sum += raw_val_losses[1]
                    txt_val_loss_sum += raw_val_losses[2]
                    loss_count += 1
                    
            print('val_acc:', correct_common / total)
            if writer is not None:
                writer.add_scalar('val_acc', correct_common / total, epoch)
                writer.add_scalar('val_img_acc', correct_img / total, epoch)
                writer.add_scalar('val_txt_acc', correct_txt / total, epoch)
                
                writer.add_scalar('val_avg_trainable_loss', trainable_val_loss_sum / loss_count, epoch) 
                writer.add_scalar('val_avg_loss', main_val_loss_sum / loss_count, epoch)
                writer.add_scalar('val_avg_img_loss', img_val_loss_sum / loss_count, epoch)
                writer.add_scalar('val_avg_txt_loss', txt_val_loss_sum / loss_count, epoch)
        

In [34]:
# model = torch_models.NormModelTrident(drop=0.5)
mtl = TridentMTL()
optimizer = optim.Adam(mtl.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bn_mtl_bs2048_rs42_d128_wd0005_drop05_100')

fit_mtl(
    mtl=mtl,
    optimizer=optimizer,
    epochs=100,
    writer=writer,
    train_loader=train_loader,
    val_loader=val_loader
)

Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
epoch: 0 average trainable train loss tensor(8.0436, grad_fn=<DivBackward0>)
val_acc: 0.5212656364974245
Parameter containing:
tensor([1.0560, 1.0581, 1.0573], requires_grad=True)
epoch: 1 average trainable train loss tensor(5.7566, grad_fn=<DivBackward0>)
val_acc: 0.581486387049301
Parameter containing:
tensor([1.0945, 1.1085, 1.1034], requires_grad=True)
epoch: 2 average trainable train loss tensor(5.2210, grad_fn=<DivBackward0>)
val_acc: 0.5990286975717439
Parameter containing:
tensor([1.1304, 1.1554, 1.1469], requires_grad=True)
epoch: 3 average trainable train loss tensor(4.8846, grad_fn=<DivBackward0>)
val_acc: 0.6069757174392936
Parameter containing:
tensor([1.1643, 1.1990, 1.1880], requires_grad=True)
epoch: 4 average trainable train loss tensor(4.6413, grad_fn=<DivBackward0>)
val_acc: 0.6113024282560706
Parameter containing:
tensor([1.1964, 1.2396, 1.2268], requires_grad=True)
epoch: 5 average trainable train loss 

val_acc: 0.6282855040470935
Parameter containing:
tensor([1.6980, 1.9507, 1.9146], requires_grad=True)
epoch: 46 average trainable train loss tensor(3.4357, grad_fn=<DivBackward0>)
val_acc: 0.6281383370125092
Parameter containing:
tensor([1.7002, 1.9570, 1.9201], requires_grad=True)
epoch: 47 average trainable train loss tensor(3.4359, grad_fn=<DivBackward0>)
val_acc: 0.6277557027225902
Parameter containing:
tensor([1.7022, 1.9629, 1.9253], requires_grad=True)
epoch: 48 average trainable train loss tensor(3.4360, grad_fn=<DivBackward0>)
val_acc: 0.6288153053715968
Parameter containing:
tensor([1.7039, 1.9684, 1.9304], requires_grad=True)
epoch: 49 average trainable train loss tensor(3.4357, grad_fn=<DivBackward0>)
val_acc: 0.6270198675496689
Parameter containing:
tensor([1.7056, 1.9737, 1.9349], requires_grad=True)
epoch: 50 average trainable train loss tensor(3.4317, grad_fn=<DivBackward0>)
val_acc: 0.6281089036055924
Parameter containing:
tensor([1.7069, 1.9785, 1.9391], requires_gra

epoch: 91 average trainable train loss tensor(3.4341, grad_fn=<DivBackward0>)
val_acc: 0.6296100073583517
Parameter containing:
tensor([1.7116, 2.0342, 1.9805], requires_grad=True)
epoch: 92 average trainable train loss tensor(3.4340, grad_fn=<DivBackward0>)
val_acc: 0.629168506254599
Parameter containing:
tensor([1.7119, 2.0342, 1.9806], requires_grad=True)
epoch: 93 average trainable train loss tensor(3.4328, grad_fn=<DivBackward0>)
val_acc: 0.6287270051508462
Parameter containing:
tensor([1.7114, 2.0344, 1.9802], requires_grad=True)
epoch: 94 average trainable train loss tensor(3.4350, grad_fn=<DivBackward0>)
val_acc: 0.6288447387785137
Parameter containing:
tensor([1.7121, 2.0345, 1.9804], requires_grad=True)
epoch: 95 average trainable train loss tensor(3.4328, grad_fn=<DivBackward0>)
val_acc: 0.6284032376747608
Parameter containing:
tensor([1.7116, 2.0345, 1.9802], requires_grad=True)
epoch: 96 average trainable train loss tensor(3.4322, grad_fn=<DivBackward0>)
val_acc: 0.6289330