In [1]:
import numpy as np
import pandas as pd
from PIL import Image

from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler

In [4]:
data_df = pd.read_csv('DeepFashion1/deepfashion1_categoryData.csv')

In [5]:
data_df.head()

Unnamed: 0,images,category_label,dataset,category,label
0,img/Sheer_Pleated-Front_Blouse/img_00000001.jpg,3,train,Blouse,2
1,img/Sheer_Pleated-Front_Blouse/img_00000002.jpg,3,train,Blouse,2
2,img/Sheer_Pleated-Front_Blouse/img_00000003.jpg,3,val,Blouse,2
3,img/Sheer_Pleated-Front_Blouse/img_00000004.jpg,3,train,Blouse,2
4,img/Sheer_Pleated-Front_Blouse/img_00000005.jpg,3,test,Blouse,2


In [204]:
class DeepFashionDataset():
    
    def __init__(self, filepath, dataset_type, transforms=None):
        
        assert(dataset_type in ['train', 'val', 'test'])
        
        self.dataset_type = dataset_type
        
        if self.dataset_type == 'train':
            self.train = True
        else:
            self.train = False
            
        self.alldata = pd.read_csv(filepath)
        self.data = self.alldata[self.alldata.dataset == self.dataset_type][['images', 'label']]
        
        self.transform = transforms

In [211]:
class SiameseDeepFashion(Dataset):
    """
    Train: For each sample creates randomly a positive or a negative pair
    Test: Creates fixed pairs for testing
    """

    def __init__(self, deepfashion_dataset, dataset_folder_name):
        
        self.datafolder = dataset_folder_name
        self.deepfashion_dataset = deepfashion_dataset

        self.train = self.deepfashion_dataset.train
        self.transform = self.deepfashion_dataset.transform

        if self.train:
            self.train_labels = self.deepfashion_dataset.data.label.values
            self.train_data = self.deepfashion_dataset.data.images.values
            self.labels_set = set(self.train_labels)
            self.label_to_indices = {label: np.where(self.train_labels == label)[0]
                                     for label in self.labels_set}
        else:
            # generate fixed pairs for testing
            self.test_labels = self.deepfashion_dataset.data.label.values
            self.test_data = self.deepfashion_dataset.data.images.values
            self.labels_set = set(self.test_labels)
            self.label_to_indices = {label: np.where(self.test_labels == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            positive_pairs = [[i,
                               random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                               1]
                              for i in range(0, len(self.test_data), 2)]

            negative_pairs = [[i,
                               random_state.choice(self.label_to_indices[
                                                       np.random.choice(
                                                           list(self.labels_set - set([self.test_labels[i].item()]))
                                                       )
                                                   ]),
                               0]
                              for i in range(1, len(self.test_data), 2)]
            self.test_pairs = positive_pairs + negative_pairs

    def __getitem__(self, index):
        if self.train:
            target = np.random.randint(0, 2)
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            if target == 1:
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                siamese_label = np.random.choice(list(self.labels_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])
            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.open(self.datafolder+'/'+img1)
        img2 = Image.open(self.datafolder+'/'+img2)
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return (img1, img2), target

    def __len__(self):
        return len(self.deepfashion_dataset.data)

In [212]:
train_transform = transforms.Compose([
                                 #transforms.Resize((224,224)),
                                 transforms.ToTensor()
                              ])

train_data = DeepFashionDataset('DeepFashion1/deepfashion1_categoryData.csv', 'train', train_transform)
test_data = DeepFashionDataset('DeepFashion1/deepfashion1_categoryData.csv', 'test')

In [213]:
train_data.data.shape, test_data.data.shape

((209222, 2), (40000, 2))

In [214]:
train_loader = SiameseDeepFashion(train_data, 'DeepFashion1')
test_loader = SiameseDeepFashion(test_data, 'DeepFashion1')

In [215]:
train_loader.__getitem__(1)

((tensor([[[0.9412, 0.9412, 0.9412,  ..., 0.9137, 0.9098, 0.9098],
           [0.9412, 0.9412, 0.9412,  ..., 0.9137, 0.9098, 0.9098],
           [0.9412, 0.9412, 0.9412,  ..., 0.9137, 0.9137, 0.9098],
           ...,
           [0.9490, 0.9490, 0.9490,  ..., 0.9333, 0.9294, 0.9294],
           [0.9490, 0.9490, 0.9490,  ..., 0.9333, 0.9333, 0.9294],
           [0.9490, 0.9490, 0.9490,  ..., 0.9333, 0.9333, 0.9333]],
  
          [[0.9412, 0.9412, 0.9412,  ..., 0.9176, 0.9137, 0.9137],
           [0.9412, 0.9412, 0.9412,  ..., 0.9176, 0.9137, 0.9137],
           [0.9412, 0.9412, 0.9412,  ..., 0.9176, 0.9176, 0.9137],
           ...,
           [0.9490, 0.9490, 0.9490,  ..., 0.9333, 0.9294, 0.9294],
           [0.9490, 0.9490, 0.9490,  ..., 0.9333, 0.9333, 0.9294],
           [0.9490, 0.9490, 0.9490,  ..., 0.9333, 0.9333, 0.9333]],
  
          [[0.9490, 0.9490, 0.9490,  ..., 0.9255, 0.9216, 0.9216],
           [0.9490, 0.9490, 0.9490,  ..., 0.9255, 0.9216, 0.9216],
           [0.9490, 0.

In [216]:
train_loader.__len__()

209222