In [None]:
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torch
import os
import random

In [None]:
import numpy as np
x = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]])
# x = np.array([[False, True, True], [False, True, False], [True, True, False]])
a = np.arange(10)
print (np.nonzero(a == 5))

In [None]:
class TripletDataset(Dataset):
    """Dataset that on each iteration provides an anchor image,
    a positive match, and a negative match, in that order.
    """
    
    
    def __init__(self, dataset):
        self.base_dataset = dataset
        
        #find index of each class
        class_indices = []
        image_classes = np.array([x[1] for x in self.base_dataset.imgs])
        num_classes = len(self.base_dataset.classes)

        for c in range(num_classes):
            indices = np.nonzero(image_classes == c)[0]
            class_indices.append(indices)
            
        #generate balanced pairs
        self.trip_indexes = np.zeros((len(self.base_dataset), 3), dtype=int)
        i = 0
        for c in range(num_classes):
            num_examples = len(class_indices[c])
            for j in range(num_examples):
                # choose another class at random
                neg_cls = random.randint(0, num_classes-2)
                if neg_cls >= c:
                    neg_cls += 1
                neg_idx = random.randint(0, len(class_indices[neg_cls])-1)
                
                # choose another picture of the same class at random
                pos_idx = random.randint(0, num_examples-2)
                if pos_idx >= j:
                    pos_idx += 1
                self.trip_indexes[i,:] = [class_indices[c][j],
                                          class_indices[c][pos_idx],
                                          class_indices[neg_cls][neg_idx]]
                i += 1
    def __getitem__(self,index):
        """ Output is a 3 x C x H x W tensor, and [1,0] indicating positive and negative match
        """
        idxs = self.trip_indexes[index]
        
        triplet = []
        for idx in idxs:
            triplet.append(self.base_dataset[idx][0])
            
        return torch.stack(triplet, dim=0), torch.tensor([1, 0])
    
    def __len__(self):
        return len(self.trip_indexes)
    
input_size = 224
data_dir = "/home/kylecshan/data/images224/train_ms2000_v3/"


data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}




In [None]:
batch_size = 256

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) 
                  for x in ['train', 'val']}

for x in ['train', 'val']:
    image_datasets[x] = TripletDataset(image_datasets[x])

# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=4)
                    for x in ['train', 'val']}



In [None]:
for inputs, labels in dataloaders_dict['train']:
    print (inputs.shape, labels.shape)

In [None]:
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F
import torch

class Pool(nn.Module):
    def __init__(self, batch = 8):
        super(Pool,self).__init__()
        self.batch = batch
        
    def forward(self, x):
        out = F.adaptive_avg_pool2d(x, (1, 1)).view(self.batch, -1)
        return out

model_ft = models.densenet169(pretrained=True)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, 2000)

features = list(model_ft.children())[:-1]
features.append(nn.ReLU(inplace=True))
features.append(Pool())
features.append(list(model_ft.children())[-1])
model_ft = nn.Sequential(*features)

print (features)
tensor = torch.ones((8, 3, 224, 224))
out = model_ft(tensor)

print (out.shape)