In [1]:
import torch
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader
from torch.autograd import Variable

from training_utils import train
from layers import Flatten
from read_in_data import AmazonDataset

In [None]:
## cpu dtype
dtype = torch.FloatTensor
save_model_path = "model_state_dict.pkl"
csv_path = '../../data/train_v2.csv'
img_path = '../../data/train-jpg'
training_dataset = AmazonDataset(csv_path, img_path, dtype)
## loader
train_loader = DataLoader(
    training_dataset,
    batch_size=256,
    shuffle=True,
    num_workers=1 # 1 for CUDA
    # pin_memory=True # CUDA only
)
## simple linear model
temp_model=nn.Sequential(
    nn.Conv2d(4, 16, kernel_size=3, stride=1),
    nn.ReLU(inplace=True),
    nn.BatchNorm2d(16),
    nn.AdaptiveMaxPool2d(128),
    nn.Conv2d(16, 32, kernel_size=3, stride=1),
    nn.ReLU(inplace=True),
    nn.BatchNorm2d(32),
    nn.AdaptiveMaxPool2d(64),
    Flatten())

temp_model = temp_model.type(dtype)
temp_model.train()
size=0
for t, (x, y) in enumerate(train_loader):
            x_var = Variable(x.type(dtype))
            size=temp_model(x_var).size()
            if(t==0):
                break

model = nn.Sequential(
nn.Conv2d(4, 16, kernel_size=3, stride=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(16),
nn.AdaptiveMaxPool2d(128),
nn.Conv2d(16, 32, kernel_size=3, stride=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(32),
nn.AdaptiveMaxPool2d(64),
Flatten(),
nn.Linear(size[1], 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 17))

model.type(dtype)
model.train()
loss_fn = nn.MultiLabelSoftMarginLoss().type(dtype)
optimizer = optim.Adam(model.parameters(), lr=5e-2)

In [None]:
train(train_loader, model, loss_fn, optimizer, dtype,num_epochs=1, print_every=1)

torch.save(model.state_dict(), save_model_path)
state_dict = torch.load(save_model_path)
model.load_state_dict(state_dict)

Starting epoch 1 / 1
Variable containing:
( 0 , 0 ,.,.) = 
  0.7373  0.7333  0.7412  ...   0.8275  0.8196  0.8157
  0.7569  0.7569  0.7647  ...   0.8196  0.8157  0.8157
  0.7765  0.7765  0.7843  ...   0.8078  0.8078  0.8157
           ...             ⋱             ...          
  0.8196  0.8275  0.8353  ...   0.8627  0.8627  0.8627
  0.8118  0.8196  0.8235  ...   0.8549  0.8549  0.8549
  0.8039  0.8078  0.8118  ...   0.8471  0.8510  0.8510

( 0 , 1 ,.,.) = 
  0.6980  0.7020  0.7137  ...   0.7686  0.7647  0.7608
  0.7059  0.7137  0.7255  ...   0.7608  0.7569  0.7608
  0.7137  0.7216  0.7333  ...   0.7490  0.7529  0.7529
           ...             ⋱             ...          
  0.7608  0.7686  0.7765  ...   0.8196  0.8157  0.8118
  0.7529  0.7608  0.7686  ...   0.8118  0.8078  0.8039
  0.7451  0.7529  0.7608  ...   0.8078  0.8039  0.7961

( 0 , 2 ,.,.) = 
  0.7412  0.7529  0.7686  ...   0.8000  0.7882  0.7843
  0.7569  0.7686  0.7804  ...   0.7961  0.7882  0.7843
  0.7804  0.7882  0.7961 

t = 1, loss = 0.7118
Variable containing:
( 0 , 0 ,.,.) = 
  0.6706  0.6745  0.6784  ...   0.7098  0.7020  0.6902
  0.6745  0.6745  0.6745  ...   0.7137  0.7059  0.6980
  0.6784  0.6745  0.6745  ...   0.7216  0.7137  0.7098
           ...             ⋱             ...          
  0.6549  0.6745  0.6863  ...   0.7373  0.7373  0.7412
  0.6667  0.6863  0.6941  ...   0.7294  0.7333  0.7333
  0.6745  0.6941  0.7020  ...   0.7255  0.7294  0.7294

( 0 , 1 ,.,.) = 
  0.6275  0.6275  0.6314  ...   0.6824  0.6667  0.6471
  0.6275  0.6275  0.6314  ...   0.6784  0.6667  0.6549
  0.6314  0.6275  0.6275  ...   0.6706  0.6667  0.6667
           ...             ⋱             ...          
  0.6275  0.6431  0.6510  ...   0.7020  0.7059  0.7098
  0.6314  0.6510  0.6510  ...   0.6980  0.7020  0.7059
  0.6353  0.6549  0.6549  ...   0.6902  0.6980  0.7020

( 0 , 2 ,.,.) = 
  0.6471  0.6549  0.6627  ...   0.6745  0.6627  0.6549
  0.6510  0.6549  0.6588  ...   0.6745  0.6627  0.6549
  0.6549  0.6549  0.6549 