In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, datetime

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import torch.nn as nn

from tensorflow.keras.callbacks import TensorBoard

In [3]:
import sys
sys.path.append('./models')
sys.path.append('./data')
sys.path.append('./pytorch')

In [4]:
from data import data
from pytorch import torch_models
from models import models as keras_models

Using TensorFlow backend.


In [5]:
BATCH_SIZE = 2048

x_img, x_txt, y = data.get_unpacked_data()

x_img_train, x_img_test, x_txt_train, x_txt_test, y_train, y_test = train_test_split(
    x_img, 
    x_txt, 
    y, 
    test_size=0.2, 
    random_state=42,
    stratify=y
)

x_img_train, x_img_val, x_txt_train, x_txt_val, y_train, y_val = train_test_split(
    x_img_train,
    x_txt_train,
    y_train,
    test_size=0.2,
    random_state=42,
    stratify=y_train
)

img_sscaler = StandardScaler()
img_sscaler.fit(x_img_train)

x_img_train = img_sscaler.transform(x_img_train)
x_img_val = img_sscaler.transform(x_img_val)
x_img_test = img_sscaler.transform(x_img_test)

txt_sscaler = StandardScaler()
txt_sscaler.fit(x_txt_train)

x_txt_train = txt_sscaler.transform(x_txt_train)
x_txt_val = txt_sscaler.transform(x_txt_val)
x_txt_test = txt_sscaler.transform(x_txt_test)

x_img_train_t = torch.tensor(x_img_train).float()
x_img_val_t = torch.tensor(x_img_val).float()
x_img_test_t = torch.tensor(x_img_test).float()

x_txt_train_t = torch.tensor(x_txt_train).float()
x_txt_val_t = torch.tensor(x_txt_val).float()
x_txt_test_t = torch.tensor(x_txt_test).float()

y_train_t = torch.tensor(y_train).float()
y_val_t = torch.tensor(y_val).float()
y_test_t = torch.tensor(y_test).float()

train_ds = TensorDataset(x_img_train_t, x_txt_train_t, y_train_t)
val_ds = TensorDataset(x_img_val_t, x_txt_val_t, y_val_t)
test_ds = TensorDataset(x_img_test_t, x_txt_test_t, y_test_t)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

## With validation

### torch

In [8]:
model = torch_models.NormModel()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

start_time = time.time()
torch_models.fit_topics_model(
    model=model,
    optimizer=optimizer,
    epochs=1,
    train_loader=train_loader,
    val_loader=val_loader
)
torch_time = time.time() - start_time

epoch: 0 train_loss: tensor(1.7578, grad_fn=<NllLossBackward>) average train loss tensor(2.4707, grad_fn=<DivBackward0>)
val_acc: 0.5511699779249448 val_avg_loss: tensor(1.6586)


In [9]:
print(torch_time)

56.95693612098694


### keras

In [10]:
model_keras = keras_models.get_model_default_lr_wide()

start_time_keras = time.time()
model_keras.fit(
    [x_img_train, x_txt_train],
    y_train,
    validation_data=([x_img_val, x_txt_val], y_val),
    batch_size=BATCH_SIZE,
    epochs=1
)
keras_time = time.time() - start_time_keras

Train on 135897 samples, validate on 33975 samples


In [11]:
print(keras_time)

23.96133828163147


## Without validation

### torch

In [14]:
model = torch_models.NormModel()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

start_time = time.time()
torch_models.fit_topics_model(
    model=model,
    optimizer=optimizer,
    epochs=1,
    train_loader=train_loader,
    val_loader=None
)
torch_time_no_val = time.time() - start_time

epoch: 0 train_loss: tensor(1.7582, grad_fn=<NllLossBackward>) average train loss tensor(2.4760, grad_fn=<DivBackward0>)


In [15]:
print(torch_time_no_val)

27.76319408416748


### keras

In [12]:
model_keras = keras_models.get_model_default_lr_wide()

start_time = time.time()
model_keras.fit(
    [x_img_train, x_txt_train],
    y_train,
    epochs=1,
    batch_size=2048
)
keras_time_no_val = time.time() - start_time

Train on 135897 samples


In [13]:
print(keras_time_no_val)

21.78543186187744
