In [61]:
import numpy as np
import torch 
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import random
import os

In [3]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(4),
            nn.Dropout2d(p=.2),
            
            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(8*100*100, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5)
        )

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2

In [4]:
siamese=SiameseNetwork()

In [16]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))


        return loss_contrastive

In [35]:
! ls ../dataset_iiit/valid/
! rm ../dataset_iiit/valid/.DS_Store

[34mcategory1[m[m [34mcategory2[m[m [34mcategory3[m[m [34mcategory4[m[m [34mcategory5[m[m [34mcategory6[m[m [34mcategory7[m[m [34mcategory8[m[m
rm: ../dataset_iiit/valid/.DS_Store: No such file or directory


In [117]:
class get_data(Dataset):
    
    def __init__(self,path):
        super(get_data,self).__init__()
        self.path=path
        self.files=os.listdir(path)
    def __getitem__(self,idx):
        same_class=random.randint(0,1)
        if same_class==0:
            cat=np.random.choice(self.files,1)[0]
            imag1,imag2=np.random.choice(os.listdir(f'{self.path}{cat}'),2,replace=False)
            label=0
        else:
            cat=np.random.choice(self.files,2,replace=False)
            imag1,imag2=np.random.choice(os.listdir(f'{self.path}{cat[0]}'),1)[0],np.random.choice(os.listdir(f'{self.path}{cat[1]}'),1)[0]
            label=1
        return imag1,imag2,label
    def __len__(self):
        return 500

In [118]:
x=get_data('../dataset_iiit/valid/')

In [116]:
x[0]

('n0007_0000180.jpg', 'n0004_0000698.jpg', 1)