In [1]:
'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
print(torch.cuda.is_available())

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, embedding_size=12):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, embedding_size)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])


def test():
    net = ResNet18()
    y = net(torch.randn(1,3,32,32))
    print(y.size())
  
test()


True
torch.Size([1, 12])


In [0]:
embed1=ResNet18()
embed2= ResNet18()
embed3= ResNet18()
embed4= ResNet18()
embed5= ResNet18()

In [0]:
class ProxyNCA(torch.nn.Module):
    def __init__(self, batch_num, sz_embed):
        torch.nn.Module.__init__(self)
        self.proxies = torch.nn.Parameter(torch.randn(batch_num, sz_embed) / 8)
#         self.proxies= torch.nn.Parameter(torch.FloatTensor([[4,5,6],[4,5,6]])) 
# answer of this is 0.0025
#         print(self.proxies.shape)

    def pairwise_distance(self,a,b):

      return torch.sub(a,b).pow(2).to(device)
    
    def proxyNCAloss(self,X,P):
		
#       pdist = nn.PairwiseDistance(p=2, keepdim=True)

      nume_exp=self.pairwise_distance(X,P).to(device)

      vect_denom= self.pairwise_distance_self(X).to(device)
      denom= (torch.exp(-1*vect_denom)).to(device).sum(-1).to(device)
      
      nume_exp=torch.exp(-1*nume_exp).to(device)
      
      
      proxy_nca_loss= (nume_exp.to(device)/denom.to(device)).to(device)

      return proxy_nca_loss.sum()
    
    
    
    def pairwise_distance_self(self,X):

        final_tensor=X
        final_tensor.to(device)
        processed=torch.zeros([X.size()[0],X.size()[1]]).to(device)
        for i in range(0,X.size()[1]):
            for j in range(0,X.size()[1]):
                if(i==j):
                    continue
                processed[0,i]=torch.add(torch.sub(final_tensor[0,i],final_tensor[0,j]).pow(2),processed[0,i])
        return  processed

     	


    def forwardold(self, X):

      P = self.proxies.double()
      P = 3 * F.normalize(P, p = 2, dim = -1)
#       X = 3 * F.normalize(X, p = 2, dim = -1)

      pdist = nn.PairwiseDistance(p=2, keepdim=True)

      nume_exp=self.pairwise_distance(X,P)

      vect_denom= self.pairwise_distance_self(X)
      denom= (torch.exp(-1*vect_denom)).sum()

#       proxy_nca_loss=-1* torch.log( (torch.exp(-1*nume_exp)/denom).sum())
      proxy_nca_loss= (torch.exp(-1*nume_exp)/denom).sum()

  
      return proxy_nca_loss
  
    def forward(self, X):
      P = self.proxies
      P = 3 * F.normalize(P, p = 2, dim = -1)

      X = 3 * F.normalize(X, p = 2, dim = -1)

      batchwise_lose = torch.zeros_like(X)


      if(X.size()[0]==1):
        return self.forwardold(X)
      
      for i in range(0,X.size()[0]):
#         >>> a[0].resize(a[0].size()[0],1).t().shape

#         print(batchwise_lose[i].size())
#         print(X[i].resize(X[i].size()[0],1).t())
        
        batchwise_lose[i]=self.proxyNCAloss(X[i].resize(X[i].size()[0],1).t(), 
                                            P[i].resize(P[i].size()[0],1).t()).to(device)
#      
#       taking sum of all the losses within a batch and then taking mean of all 
#       the batches as a representative loss
#       print(batchwise_lose)
      return batchwise_lose.sum(-1).mean()




# testing

# criterion= ProxyNCA(2,3)


# # print(criterion())
# print(torch.FloatTensor([[1,2,3],[1,2,3]]).shape)

# loss= criterion(torch.FloatTensor([[1,2,3],[1,2,3]]))
# print(loss.item())
# loss.backward()

In [0]:
import numpy as np



def get_class_i(x, y, i):
    """
    x: trainset.train_data or testset.test_data
    y: trainset.train_labels or testset.test_labels
    i: class label, a number between 0 to 9
    return: x_i
    """
    # Convert to a numpy array
    y = np.array(y)
    # Locate position of labels that equal to i
    pos_i = np.argwhere(y == i)
    # Convert the result into a 1-D list
    pos_i = list(pos_i[:,0])
    # Collect all data that match the desired label
    x_i = [x[j] for j in pos_i]

    return x_i

In [0]:
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms


# Transformations
RC   = transforms.RandomCrop(32, padding=4)
RHF  = transforms.RandomHorizontalFlip()
RVF  = transforms.RandomVerticalFlip()
NRM  = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
TT   = transforms.ToTensor()
TPIL = transforms.ToPILImage()

# Transforms object for trainset with augmentation
transform_with_aug = transforms.Compose([TPIL, RC, RHF, TT, NRM])
# Transforms object for testset with NO augmentation
transform_no_aug   = transforms.Compose([TT, NRM])

class DatasetMaker(Dataset):
    def __init__(self, datasets, transformFunc = transform_no_aug):
        """
        datasets: a list of get_class_i outputs, i.e. a list of list of images for selected classes
        """
        self.datasets = datasets
        self.lengths  = [len(d) for d in self.datasets]
        self.transformFunc = transformFunc
    def __getitem__(self, i):
        class_label, index_wrt_class = self.index_of_which_bin(self.lengths, i)
        img = self.datasets[class_label][index_wrt_class]
        img = self.transformFunc(img)
        return img, class_label

    def __len__(self):
        return sum(self.lengths)
    
    def index_of_which_bin(self, bin_sizes, absolute_index, verbose=False):
        """
        Given the absolute index, returns which bin it falls in and which element of that bin it corresponds to.
        """
        # Which class/bin does i fall into?
        accum = np.add.accumulate(bin_sizes)
        if verbose:
            print("accum =", accum)
        bin_index  = len(np.argwhere(accum <= absolute_index))
        if verbose:
            print("class_label =", bin_index)
        # Which element of the fallent class/bin does i correspond to?
        index_wrt_class = absolute_index - np.insert(accum, 0, 0)[bin_index]
        if verbose:
            print("index_wrt_class =", index_wrt_class)

        return bin_index, index_wrt_class


In [6]:
from torchvision.datasets import CIFAR10
trainset  = CIFAR10(root='./data', train=True , download=True)#, transform = transform_with_aug)
testset   = CIFAR10(root='./data', train=False, download=True)#, transform = transform_no_aug)

# print(trainset.train_list)
classDict = {'plane':0, 'car':1, 'bird':2, 'cat':3, 'deer':4, 'dog':5, 'frog':6, 'horse':7, 'ship':8, 'truck':9}

# print(trainset.targets)

x_train  = trainset.data
x_test   = testset.data
y_train  = trainset.targets
y_test   = testset.targets
# print(get_class_i(x_train, y_train, classDict['cat']))
cat_dog_trainset = \
    DatasetMaker(
        [get_class_i(x_train, y_train, classDict['cat']), get_class_i(x_train, y_train, classDict['dog'])],
        transform_with_aug
    )
plane_ship_trainset=\
    DatasetMaker(
        [get_class_i(x_train, y_train, classDict['plane']), get_class_i(x_train, y_train, classDict['ship'])],
        transform_with_aug
    )
car_truck_trainset= \
    DatasetMaker(
        [get_class_i(x_train, y_train, classDict['car']), get_class_i(x_train, y_train, classDict['truck'])],
        transform_with_aug
    )
deer_horse_trainset= \
    DatasetMaker(
        [get_class_i(x_train, y_train, classDict['deer']), get_class_i(x_train, y_train, classDict['horse'])],
        transform_with_aug
    )
bird_frog_trainset= \
    DatasetMaker(
        [get_class_i(x_train, y_train, classDict['bird']), get_class_i(x_train, y_train, classDict['frog'])],
        transform_with_aug
    )


# One test set
car_truck_testset  = \
    DatasetMaker(
        [get_class_i(x_test , y_test , classDict['car']), get_class_i(x_test , y_test , classDict['truck'])],
        transform_no_aug
    )

# size of datasets

Files already downloaded and verified
Files already downloaded and verified


In [0]:
kwargs = {'num_workers': 2, 'pin_memory': False}

# Create datasetLoaders from trainset and testset
embed1_loader= DataLoader(cat_dog_trainset, batch_size=100, shuffle=True , **kwargs)
embed2_loader= DataLoader(plane_ship_trainset, batch_size=100, shuffle=True , **kwargs)
embed3_loader= DataLoader(car_truck_trainset, batch_size=100, shuffle=True , **kwargs)
embed4_loader= DataLoader(deer_horse_trainset, batch_size=100, shuffle=True , **kwargs)
embed5_loader= DataLoader(bird_frog_trainset, batch_size=100, shuffle=True , **kwargs)


testsetLoader    = DataLoader(car_truck_testset , batch_size=64, shuffle=False, **kwargs)

In [0]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse


# data,target=next(embed1_it)
# print(data.shape)

def set_status_train():
#     embed1.double()
#     embed2.double()
#     embed3.double()
#     embed4.double()
#     embed5.double()
    
    embed1.train()
    embed2.train()
    embed3.train()
    embed4.train()
    embed5.train()



criterion1 = ProxyNCA(100,12).to(device)
criterion2 = ProxyNCA(100,12).to(device)
criterion3 = ProxyNCA(100,12).to(device)
criterion4 = ProxyNCA(100,12).to(device)
criterion5 = ProxyNCA(100,12).to(device)

optimizer1 = optim.SGD(embed1.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
optimizer2 = optim.SGD(embed2.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
optimizer3 = optim.SGD(embed3.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
optimizer4 = optim.SGD(embed4.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
optimizer5 = optim.SGD(embed5.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)


def train(epoch):
  
    embed1_it= iter(embed1_loader)
    embed2_it= iter(embed2_loader)
    embed3_it= iter(embed3_loader)
    embed4_it= iter(embed4_loader)
    embed5_it= iter(embed5_loader)

    print('\nEpoch: %d' % epoch)
    train_loss = 0
    correct = 0
    total = 0
    it=0
    max_accuracy=0
    final_loss=0
    
    set_status_train()


    for batch_idx, (inputs, targets) in enumerate(embed1_loader):
      
        try:
            data1, target1 = next(embed1_it)
            data2, target2 = next(embed2_it)
            data3, target3 = next(embed3_it)
            data4, target4 = next(embed4_it)
            data5, target5 = next(embed5_it)
        except StopIteration:
            embed1_it= iter(embed1_loader)
            embed2_it= iter(embed2_loader)
            embed3_it= iter(embed3_loader)
            embed4_it= iter(embed4_loader)
            embed5_it= iter(embed5_loader)

            data1, target1 = next(embed1_it)
            data2, target2 = next(embed2_it)
            data3, target3 = next(embed3_it)
            data4, target4 = next(embed4_it)
            data5, target5 = next(embed5_it)
            
        it=it+1
        data1, target1 = data1.to(device), target1.to(device)
        data2, target2 = data2.to(device), target2.to(device)
        data3, target3 = data3.to(device), target3.to(device)
        data4, target4 = data4.to(device), target4.to(device)
        data5, target5 = data5.to(device), target5.to(device)
        
#          
          
        if(data1.size()[0]<10 or data2.size()[0]<10 or data3.size()[0]<10 or data4.size()[0]<10
          or data5.size()[0]<10):
          continue
          
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        optimizer3.zero_grad()
        optimizer4.zero_grad()
        optimizer5.zero_grad()
        
        embed1.to(device)
        embed2.to(device)
        embed3.cuda()
        embed4.cuda()
        embed5.cuda()
        
        
        
#         outputs = torch.cat([embed1(data1),embed2(data2),embed3(data3),embed4(data4),embed5(data5)],-1)
#         print(outputs.size())
        loss1=criterion1(embed1(data1)).to(device)
        loss2=criterion2(embed2(data1)).to(device)
        loss3=criterion3(embed3(data1))
        loss4=criterion4(embed4(data1))
        loss5=criterion5(embed5(data1))
        

        loss1.backward()
        loss2.backward()
        loss3.backward()
        loss4.backward()
        loss5.backward()
        
        optimizer1.step()
        optimizer2.step()
        optimizer3.step()
        optimizer4.step()
        optimizer5.step()
        
        if(it%10==0):
          print("Iteration : ", it)
          print("Loss 1: ",loss1.item())
          print("Loss 2: ",loss2.item())
          print("Loss 3: ",loss3.item())
          print("Loss 4: ",loss4.item())
          print("Loss 5: ",loss5.item())
          print("_________________________________________________________")
          
          
for epoch in range(2):          
  train(epoch)
# try:
#             data, target = next(dataloader_iterator)
#         except StopIteration:
#             dataloader_iterator = iter(dataloader)
#             data, target = next(dataloader_iterator)


Epoch: 0




Iteration :  10
Loss 1:  13147.51171875
Loss 2:  55.974002838134766
Loss 3:  39.5790901184082
Loss 4:  4134.5341796875
Loss 5:  16709.517578125
_________________________________________________________
Iteration :  20
Loss 1:  13087.8837890625
Loss 2:  55.95524215698242
Loss 3:  39.57583999633789
Loss 4:  4134.1123046875
Loss 5:  16704.390625
_________________________________________________________
Iteration :  30
Loss 1:  13118.4794921875
Loss 2:  55.97800064086914
Loss 3:  39.5772705078125
Loss 4:  4134.140625
Loss 5:  16715.396484375
_________________________________________________________
Iteration :  40
Loss 1:  13080.44140625
Loss 2:  55.969242095947266
Loss 3:  39.57606506347656
Loss 4:  4134.564453125
Loss 5:  16718.783203125
_________________________________________________________
Iteration :  50
Loss 1:  13129.3974609375
Loss 2:  55.96794891357422
Loss 3:  39.57728576660156
Loss 4:  4133.83935546875
Loss 5:  16699.673828125
_________________________________________________