In [90]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.autograd import Variable
from torchvision.datasets import ImageFolder
from torch import optim
import torch.nn.functional as F
import random
from torchvision import transforms

In [88]:
class SiameseNetworkDataset(Dataset):
    
    def __init__(self,imageFolderDataset,transform=None, target_transform=None, should_invert=True):
        self.imageFolderDataset = imageFolderDataset    
        self.transform = transform
        self.should_invert = should_invert
        self.target_transform = target_transform
        
    def __getitem__(self,index):
        img0, lbl0 = random.choice(self.imageFolderDataset)
        #we need to make sure approx 50% of images are in the same class
        should_get_same_class = random.randint(0,1) 
        if should_get_same_class:
            while True:
                #keep looping till the same class image is found
                img1, lbl1 = random.choice(self.imageFolderDataset) 
                if lbl0 == lbl1:
                    break
        else:
            img1, lbl1 = random.choice(self.imageFolderDataset)
        
        if self.should_invert:
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)
            
        if self.target_transform:
            lbl0, lbl1 = map(self.target_transform, (lbl0, lbl1))
        
        return img0, img1 , lbl0 == lbl1
    
    def __len__(self):
        return len(self.imageFolderDataset.imgs)
    
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(3, 4, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(4),
            nn.Dropout2d(p=.2),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(401408, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5)
        )

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2
    
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive


In [92]:
trainfolder = ImageFolder('dataset/train/facescrub')
validfolder = ImageFolder('dataset/valid/facescrub')

trfm_valid = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

trfm_train = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

target_trfrm = transforms.Compose([
    lambda x: [x],
    torch.Tensor
])

trainset = SiameseNetworkDataset(trainfolder, transform=trfm_train, target_transform=target_trfrm)
validset = SiameseNetworkDataset(trainfolder, transform=trfm_valid, target_transform=target_trfrm)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
validloader = DataLoader(validset, batch_size=32, shuffle=True)
len(trainloader), len(validloader)

(230, 230)

In [93]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = SiameseNetwork().to(device)
net = nn.DataParallel(net)
criterion = ContrastiveLoss().to(device)
optimizer = optim.Adam(net.parameters(),lr = 0.0005)

counter = []
loss_history = [] 

RuntimeError: CUDA out of memory. Tried to allocate 765.62 MiB (GPU 0; 11.17 GiB total capacity; 7.26 GiB already allocated; 550.06 MiB free; 370.13 MiB cached)

In [91]:
iteration_number= 0

for epoch in range(0,10):
    for i, data in enumerate(trainloader,0):
        img0, img1 , label = map(lambda i: i.to(device), data)
        output1,output2 = net(img0,img1)
        optimizer.zero_grad()
        loss_contrastive = criterion(output1,output2,label)
        loss_contrastive.backward()
        optimizer.step()
        if i %10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.data[0]))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.data[0])

RuntimeError: CUDA out of memory. Tried to allocate 765.62 MiB (GPU 0; 11.17 GiB total capacity; 7.26 GiB already allocated; 550.06 MiB free; 370.15 MiB cached)