### Explore Siamese Dataset

In [3]:
import numpy as np
from PIL import Image

from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler

In [4]:
class SiameseMNIST(Dataset):
    """
    Train: For each sample creates randomly a positive or a negative pair
    Test: Creates fixed pairs for testing
    """

    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset

        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.train_labels = self.mnist_dataset.train_labels
            self.train_data = self.mnist_dataset.train_data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}
        else:
            # generate fixed pairs for testing
            self.test_labels = self.mnist_dataset.test_labels
            self.test_data = self.mnist_dataset.test_data
            self.labels_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            positive_pairs = [[i,
                               random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                               1]
                              for i in range(0, len(self.test_data), 2)]

            negative_pairs = [[i,
                               random_state.choice(self.label_to_indices[
                                                       np.random.choice(
                                                           list(self.labels_set - set([self.test_labels[i].item()]))
                                                       )
                                                   ]),
                               0]
                              for i in range(1, len(self.test_data), 2)]
            self.test_pairs = positive_pairs + negative_pairs

    def __getitem__(self, index):
        if self.train:
            target = np.random.randint(0, 2) #half-open interval [0,2)
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            
            # positive-pair
            if target == 1:
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else: #negative pair
                siamese_label = np.random.choice(list(self.labels_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])
            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return (img1, img2), target

    def __len__(self):
        return len(self.mnist_dataset)

In [5]:
import torch
from torchvision.datasets import FashionMNIST
from torchvision import transforms

mean, std = 0.28604059698879553, 0.35302424451492237
batch_size = 256

train_dataset = FashionMNIST('../data/FashionMNIST', train=True, download=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Normalize((mean,), (std,))
                             ]))

In [6]:
Siamese_dataset = SiameseMNIST(train_dataset)



In [7]:
Siamese_dataset.labels_set

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}

In [8]:
Siamese_dataset.label_to_indices

{0: array([    1,     2,     4, ..., 59974, 59985, 59998]),
 1: array([   16,    21,    38, ..., 59989, 59991, 59996]),
 2: array([    5,     7,    27, ..., 59977, 59981, 59993]),
 3: array([    3,    20,    25, ..., 59971, 59980, 59997]),
 4: array([   19,    22,    24, ..., 59984, 59986, 59990]),
 5: array([    8,     9,    12, ..., 59983, 59995, 59999]),
 6: array([   18,    32,    33, ..., 59973, 59987, 59988]),
 7: array([    6,    14,    41, ..., 59951, 59979, 59992]),
 8: array([   23,    35,    57, ..., 59962, 59967, 59994]),
 9: array([    0,    11,    15, ..., 59932, 59970, 59978])}

### `__getitem__`

In [10]:
target = np.random.randint(0, 2)

In [14]:
img1, label1 = Siamese_dataset.train_data[0], Siamese_dataset.train_labels[0].item()

In [15]:
target

0