# Train Model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from torchvision import transforms
import torch.optim as optim
import numpy as np
import csv
import pickle
import torch

from trainer import fit
from datasets import NUS_WIDE

from networks import TextEmbeddingNet, Resnet152EmbeddingNet, IntermodalTripletNet, Resnet18EmbeddingNet
from losses import InterTripletLoss


In [3]:
### PARAMETERS ###
batch_size = 128
margin = 5
lr = 1e-3
n_epochs = 5
output_embedding_size = 64
feature_mode = 'resnet18'
random_seed = 21
image_data_path = 'data/Flickr'
##################

In [4]:
# setting up dictionary
text_dictionary = pickle.load(open("pickles/word_embeddings/word_embeddings_tensors.p", "rb"))

In [5]:
mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
image_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)])
dataset = NUS_WIDE(
    root=image_data_path,
    transform=image_transform,
    feature_mode=feature_mode,
    word_embeddings=text_dictionary,
    train=True)

In [6]:
# creating indices for training data and validation data
from torch.utils.data.sampler import SubsetRandomSampler

dataset_size = len(dataset)
validation_split = 0.2

indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
validation_sampler = SubsetRandomSampler(val_indices)

In [7]:
# making loaders
cuda = torch.cuda.is_available()
kwargs = {'num_workers': 32, 'pin_memory': True} if cuda else {}

i_triplet_train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size, 
    sampler=train_sampler, 
    **kwargs)

i_triplet_val_loader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=batch_size, 
    sampler=validation_sampler, 
    **kwargs)

In [8]:
# Set up the network and training parameters
text_embedding_net = TextEmbeddingNet(dim=output_embedding_size)
if feature_mode == 'resnet152':
    image_embedding_net = Resnet152EmbeddingNet(dim=output_embedding_size)
elif feature_mode == 'resnet18':
    image_embedding_net = Resnet18EmbeddingNet(dim=output_embedding_size)

In [9]:
model = IntermodalTripletNet(image_embedding_net, text_embedding_net)

if cuda:
    model.cuda()

In [10]:
loss_fn = InterTripletLoss(margin)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
log_interval = 100

In [11]:
fit(i_triplet_train_loader, 
    i_triplet_val_loader, 
    dataset.intermodal_triplet_batch_sampler, 
    model, 
    loss_fn, 
    optimizer, 
    scheduler, 
    n_epochs, 
    cuda, 
    log_interval)

Epoch: 1/5. Train set: Average loss: 7.3463
Epoch: 1/5. Validation set: Average loss: 6.5442
Epoch: 2/5. Train set: Average loss: 5.9937
Epoch: 2/5. Validation set: Average loss: 5.7254
Epoch: 3/5. Train set: Average loss: 5.6108
Epoch: 3/5. Validation set: Average loss: 5.5247
Epoch: 4/5. Train set: Average loss: 5.5171
Epoch: 4/5. Validation set: Average loss: 5.4994
Epoch: 5/5. Train set: Average loss: 5.4212
Epoch: 5/5. Validation set: Average loss: 5.4041


In [12]:
print("Done!")

Done!


In [13]:
pickle.dump(model, open('pickles/models/entire_nuswide_model_5-18.p', 'wb'))