In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import sklearn as sk
from sklearn import decomposition as dec
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.utils.prune as prune
import sklearn.manifold as nonlin


In [7]:
class NN2(nn.Module):
    def __init__(self, is_norm=False):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=3, out_channels=12, kernel_size=(5, 5))
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=12, out_channels=16, kernel_size=(5, 5))
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.soft = nn.Softmax(dim=0)
        self.is_norm = is_norm
        # self.conv1 = nn.Conv2d(3, 6, 5)
        # self.pool = nn.MaxPool2d(2, 2)
        # self.conv2 = nn.Conv2d(6, 16, 5)
        # self.fc1 = nn.Linear(16 * 5 * 5, 120)
        # self.fc2 = nn.Linear(120, 84)
        # self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        if(self.is_norm == True):
            x = nn.functional.normalize(self.fc3(x))
        else:
            x = self.fc3(x)
        return x


In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)


trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

mnist_train = torchvision.datasets.MNIST(root='./data', train=True,
                                         download=True, transform=transform)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False,
                                        download=True, transform=transform)

trainloader_mnist = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
                                                shuffle=True, num_workers=2)
testloader_mnist = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size,
                                               shuffle=True, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
mnist_dataset_sample = iter(testloader_mnist)
sample_im, sample_lbl = mnist_dataset_sample.next()

cifar_dataset_sample = iter(testloader)
cifar_sample_im, cifar_sample_lbl = cifar_dataset_sample.next()


In [12]:
def similarity(model1, model2, image):
    original_network = model1(image)
    small_network = model2(image)

    
    print('original network')
    print(original_network)
    print('small network')
    print(small_network)

    dis = torch.pairwise_distance(original_network, small_network)
    for a in dis:
        sim = 1/(a + 1) * 100
        print(sim)

    print('cosine similarity')
    cos = nn.CosineSimilarity()
    print(cos(original_network, small_network) * 100)


In [9]:
path_cifar_original = './isomap_reduced_network/cifar/cifar_original_net.pt'
path_cifar_dr_trained = './isomap_reduced_network/cifar/cifar_trained_dr_network.pt'
path_cifar_dr_untrained = './isomap_reduced_network/cifar/cifar_untrained_dr_network.pt'
path_cifar_dr_small_net_network = './isomap_reduced_network/cifar/cifar_small_net_scratch.pt'

In [19]:
cifar_model_original = torch.load(path_cifar_original)
cifar_dr_model_untrained = torch.load(path_cifar_dr_untrained)
cifar_dr_model_trained = torch.load(path_cifar_dr_trained)
cifar_small_network = torch.load(path_cifar_dr_small_net_network)

In [21]:
similarity(cifar_small_network, cifar_model_original, cifar_sample_im)


original network
tensor([[ 0.7053, -1.6785, -0.5797,  3.5504, -2.3351,  2.1195, -1.0485, -0.5981,
          0.6870, -1.2187],
        [ 5.7613,  5.4604, -0.5905, -2.6626, -4.4966, -4.1544, -5.5769, -7.1222,
         11.2656,  2.3718],
        [ 2.1137,  3.1805, -0.5128, -0.8582, -1.8774, -1.7937, -2.4870, -2.5411,
          1.9065,  2.2078],
        [ 7.0270, -1.1213,  2.4684, -1.7373,  0.1265, -3.6915, -2.9764, -4.2366,
          2.5168,  0.2512]], grad_fn=<AddmmBackward0>)
small network
tensor([[-2.0142, -2.4808, -0.1272,  3.6393, -0.3308,  2.7923,  0.9873, -1.2523,
         -0.3304, -1.7410],
        [ 5.9155,  7.4279, -1.0477, -4.8658, -3.2031, -5.6176, -4.2044, -5.5136,
          6.8849,  3.6683],
        [ 2.6999,  2.7378,  0.4041, -1.9834, -0.8672, -2.6083, -1.8655, -2.5277,
          2.5955,  1.8648],
        [ 6.1695,  1.5710,  0.7067, -2.6754, -0.6255, -4.2521, -2.7747, -4.2289,
          5.2831,  0.6445]], grad_fn=<AddmmBackward0>)
tensor(18.8215, grad_fn=<MulBackward0>)
ten