In [1]:
%load_ext autoreload
%autoreload 2

In [13]:
import time

import matplotlib.pyplot as plt

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import torch.nn as nn

import numpy as np

import pickle

from data import data
from pytorch import torch_models

In [10]:
IMG_LEN = 1024
TXT_LEN = 300
N_CLASSES = 50
BATCH_SIZE = 2048

In [5]:
x_img, x_txt, y = data.get_unpacked_data()

In [6]:
x_img_train, x_img_test, x_txt_train, x_txt_test, y_train, y_test = train_test_split(
    x_img, 
    x_txt, 
    y, 
    test_size=0.2, 
    random_state=42,
    stratify=y
)

x_img_train, x_img_val, x_txt_train, x_txt_val, y_train, y_val = train_test_split(
    x_img_train,
    x_txt_train,
    y_train,
    test_size=0.2,
    random_state=42,
    stratify=y_train
)

In [7]:
img_sscaler = StandardScaler()
img_sscaler.fit(x_img_train)

x_img_train = img_sscaler.transform(x_img_train)
x_img_val = img_sscaler.transform(x_img_val)
x_img_test = img_sscaler.transform(x_img_test)

txt_sscaler = StandardScaler()
txt_sscaler.fit(x_txt_train)

x_txt_train = txt_sscaler.transform(x_txt_train)
x_txt_val = txt_sscaler.transform(x_txt_val)
x_txt_test = txt_sscaler.transform(x_txt_test)

In [8]:
x_img_train_t = torch.tensor(x_img_train).float()
x_img_val_t = torch.tensor(x_img_val).float()
x_img_test_t = torch.tensor(x_img_test).float()

x_txt_train_t = torch.tensor(x_txt_train).float()
x_txt_val_t = torch.tensor(x_txt_val).float()
x_txt_test_t = torch.tensor(x_txt_test).float()

y_train_t = torch.tensor(y_train).float()
y_val_t = torch.tensor(y_val).float()
y_test_t = torch.tensor(y_test).float()

In [9]:
train_ds = TensorDataset(x_img_train_t, x_txt_train_t, y_train_t)
val_ds = TensorDataset(x_img_val_t, x_txt_val_t, y_val_t)
test_ds = TensorDataset(x_img_test_t, x_txt_test_t, y_test_t)

In [11]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [22]:
def fit_topics_trident_model_with_weights(
    model, 
    optimizer, 
    train_loader, 
    val_loader=None, 
    scheduler=None, 
    writer=None, 
    epochs=1,
    weight=None
):
    for epoch in range(epochs):
        model.train()

        loss_sum = 0.0
        loss_common_sum = 0.0
        loss_img_sum = 0.0
        loss_txt_sum = 0.0
        loss_count = 0

        for x_img_cur, x_txt_cur, y_cur in train_loader:
            model.zero_grad()
            out_common, out_img, out_txt = model(x_img_cur, x_txt_cur)
            target = torch.argmax(y_cur, dim=1)
            loss_common = F.nll_loss(out_common, target, weight=weight)
            loss_img = F.nll_loss(out_img, target, weight=weight)
            loss_txt = F.nll_loss(out_txt, target, weight=weight)
            loss = (loss_common + loss_img + loss_txt) / 3.0
            loss.backward()

            loss_common_sum += loss_common
            loss_img_sum += loss_img
            loss_txt_sum += loss_txt
            loss_sum += loss
            loss_count += 1

            optimizer.step()
            if scheduler is not None:
                scheduler.step()

        print('epoch:', epoch, 'train_loss:', loss, 'average train loss', loss_sum / loss_count)
        print( 
            'avg common loss:', loss_common_sum / loss_count, 
            'avg img loss:', loss_img_sum / loss_count,
            'avg txt loss:', loss_txt_sum / loss_count
        )
        if writer is not None:
            writer.add_scalar('train_loss', loss, epoch)
            writer.add_scalar('avg_train_loss', loss_sum / loss_count, epoch)
            writer.add_scalar('avg_train_loss_common', loss_common_sum / loss_count, epoch)
            writer.add_scalar('avg_train_loss_img', loss_img_sum / loss_count, epoch)
            writer.add_scalar('avg_train_loss_txt', loss_txt_sum / loss_count, epoch)



        if val_loader is not None:
            model.eval()

            correct_common = 0
            correct_img = 0
            correct_txt = 0
            total = 0
            loss_common_sum = 0.0
            loss_img_sum = 0.0
            loss_txt_sum = 0.0
            loss_sum = 0.0
            loss_count = 0

            with torch.no_grad():
                for x_img_cur, x_txt_cur, y_cur in val_loader:
                    out_common, out_img, out_txt = model(x_img_cur, x_txt_cur)
                    target = torch.argmax(y_cur, dim=1)
                    loss_common = F.nll_loss(out_common, target, weight=weight)
                    loss_img = F.nll_loss(out_img, target, weight=weight)
                    loss_txt = F.nll_loss(out_txt, target, weight=weight)
                    
                    loss = (loss_common + loss_img + loss_txt) / 3.0
                    
                    loss_common_sum += loss_common
                    loss_img_sum += loss_img
                    loss_txt_sum += loss_txt
                    loss_sum += loss
                    
                    loss_count += 1
                    for idx, i in enumerate(out_common):
                        if torch.argmax(i) == target[idx]:
                            correct_common += weight[target[idx]]
                        total += weight[target[idx]]
                    
                    for idx, i in enumerate(out_img):
                        if torch.argmax(i) == target[idx]:
                            correct_img += weight[target[idx]]
                           
                    for idx, i in enumerate(out_txt):
                        if torch.argmax(i) == target[idx]:
                            correct_txt += weight[target[idx]]
                    
            print(
                'val common acc:', correct_common / total,
                'val img acc:', correct_img / total,
                'val txt acc:', correct_txt / total,
                'val_avg_loss:', loss_sum / loss_count)
            print( 
                'avg common val loss:', loss_common_sum / loss_count, 
                'avg img val loss:', loss_img_sum / loss_count,
                'avg txt val loss:', loss_txt_sum / loss_count
            )
            if writer is not None:
                writer.add_scalar('val_acc', correct_common / total, epoch)
                writer.add_scalar('val_img_acc', correct_img / total, epoch)
                writer.add_scalar('val_txt_acc', correct_txt / total, epoch)
                writer.add_scalar('val_avg_loss', loss_sum / loss_count, epoch)

In [19]:
y_train_non_cat = np.argmax(y_train, axis=1)
y_train_weights = sklearn.utils.class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_non_cat),
    y=y_train_non_cat
)

In [21]:
y_train_weights

array([ 0.49916253,  0.5604    ,  3.96779562,  0.48804812,  2.87308668,
        1.24333943,  0.7834938 ,  1.5602411 ,  1.69342056,  0.93916379,
        0.64867303,  0.75039757,  0.68324284,  0.84224977,  1.03857088,
        1.65123937,  0.45841457,  1.1400755 ,  1.04656912,  1.09506044,
        0.83092021,  1.54340716,  1.45812232,  0.61940292,  0.78462471,
        0.88330842,  0.59892904,  0.82411765,  0.41820896,  1.01340045,
        1.18068636,  0.8171798 ,  0.41406764,  1.77179922,  1.35558105,
        2.47310282,  1.92625089,  3.21269504,  4.90602888,  3.83348378,
        8.44080745,  1.5141727 , 20.28313433,  0.81035778,  4.20083462,
        1.08284462,  4.04455357,  8.44080745,  0.35668504,  2.14687204])

In [27]:
model = torch_models.NormModelTridentBN(drop=0.5)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.0005)
writer = SummaryWriter('runs/trident_bn_bs2048_rs42_d128_wd0005_drop05_weighted')

fit_topics_trident_model_with_weights(
    model=model,
    optimizer=optimizer,
    epochs=1,
    writer=writer,
    train_loader=train_loader,
    val_loader=val_loader,
    weight=torch.tensor(y_train_weights).float()
)

epoch: 0 train_loss: tensor(2.3515, grad_fn=<DivBackward0>) average train loss tensor(2.9298, grad_fn=<DivBackward0>)
avg common loss: tensor(2.9602, grad_fn=<DivBackward0>) avg img loss: tensor(2.8974, grad_fn=<DivBackward0>) avg txt loss: tensor(2.9318, grad_fn=<DivBackward0>)
val common acc: tensor(0.4905) val img acc: tensor(0.3671) val txt acc: tensor(0.4266) val_avg_loss: tensor(2.2432)
avg common val loss: tensor(1.9872) avg img val loss: tensor(2.4622) avg txt val loss: tensor(2.2803)
