In [1]:
%matplotlib inline

import torch
import numpy as np
from torch import optim
from copy import deepcopy
import torchvision.utils
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,random_split

import config
from utils import imshow
from models import SiameseNetwork
from training import trainSiamese,inferenceSiamese
from datasets import SiameseNetworkDataset
from loss_functions import ContrastiveLoss

# generate_csv(config.training_dir)

import os
if not os.path.exists('state_dict'):
    os.makedirs('state_dict')

In [None]:
margins = np.arange(1,7,2)
for marg in margins:
    for i in range(1):
        siamese_dataset = SiameseNetworkDataset(config.siamese_training_csv,
                                                transform=transforms.Compose([
                                                    transforms.Resize((config.img_height,config.img_width)),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize(0,1)]),
                                                should_invert=False)

        # Split the dataset into train, validation and test sets
        num_train = round(0.9*siamese_dataset.__len__())
        num_validate = siamese_dataset.__len__()-num_train
        siamese_train, siamese_valid = random_split(siamese_dataset, [num_train,num_validate])
        train_dataloader = DataLoader(siamese_train,
                                shuffle=True,
                                num_workers=8,
                                batch_size=config.train_batch_size)
        valid_dataloader = DataLoader(siamese_valid,
                                shuffle=True,
                                num_workers=8,
                                batch_size=1)

        net = SiameseNetwork().cuda()
        criterion = ContrastiveLoss(margin = marg)
        optimizer = optim.Adam(net.parameters(),lr = config.learning_rate )
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,config.step_size, config.gamma)

        net, train_loss_history, valid_loss_history,dict_name = trainSiamese(net,criterion,optimizer,scheduler,train_dataloader,
                     valid_dataloader,config.train_number_epochs,do_print=False)

    fig = plt.figure()
    plt.plot(train_loss_history)
    plt.plot(valid_loss_history)
    plt.legend(["train_loss","valid_loss"])
    plt.yscale("log")
    plt.xlabel("Epochs")
    plt.ylabel("Contrastive loss")
    fig.savefig(str(int(marg))+'.eps', format='eps', dpi=1200)