In [1]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [2]:
import math

In [3]:
import torch.nn as nn
import torch.nn.functional as F

In [4]:
from scipy.special import loggamma

In [26]:
def entropy(samples):
  N = len(samples)
  d = len(samples[0])
  distance = torch.cdist(samples, samples, p=2).to(device)
 #print(distance)
  distance = distance.masked_select(~torch.eye(N, dtype=bool).to(device)).view(N, N - 1)
  R = torch.min(distance, 1).values
  Y = torch.mean(math.log(N-1) + d * torch.log(R))
  pi = 3.14
  B = (d/2) * math.log(pi) - loggamma(d/2 + 1)
  euler = 0.577
  #return Y
  print(Y, B, Y + B + euler)
  return Y + B + euler

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [7]:
train_dataset = datasets.CIFAR10('./data', train=True, download=True,
                                  transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=100,
                            num_workers=2, drop_last=False, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data


In [8]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [9]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, name=''):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.name = name
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        #self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        #out = self.fc(out)
        return out

In [10]:
def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes
                  )


def ResNet34(num_classes=10):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes)


In [41]:
victim_model = ResNet34(num_classes=128)
stolen_model = ResNet34(num_classes=128)
random_model = ResNet34(num_classes=128)

In [12]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [42]:
checkpoint = torch.load('/content/drive/MyDrive/cifar10_checkpoint_200_infonce.pth.tar', map_location=device)
#checkpoint = torch.load('/content/drive/MyDrive/svhn_checkpoint_200_infonce.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']


for k in list(state_dict.keys()):

  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]


victim_model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

In [43]:
#checkpoint = torch.load('/content/drive/MyDrive/cifar10_checkpoint_200_infonce.pth.tar', map_location=device)
checkpoint = torch.load('/content/drive/MyDrive/svhn_checkpoint_200_infonce.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']


for k in list(state_dict.keys()):

  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]


random_model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

In [44]:
#checkpoint = torch.load('/content/drive/MyDrive/stolen_checkpoint_9000_infonce_cifar10.pth.tar', map_location=device)
checkpoint = torch.load('/content/drive/MyDrive/stolen_checkpoint_9000_mse_cifar10.pth.tar', map_location=device)
#checkpoint = torch.load('/content/drive/MyDrive/stolen_checkpoint_9000_softnn_cifar10.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']


for k in list(state_dict.keys()):

  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]


stolen_model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

In [45]:
victim_model = victim_model.to(device)
stolen_model = stolen_model.to(device)
random_model = random_model.to(device)


In [46]:
for name, param in victim_model.named_parameters():
    #if name not in ['fc.weight', 'fc.bias']:
    param.requires_grad = False

for name, param in stolen_model.named_parameters():
    #if name not in ['fc.weight', 'fc.bias']:
    param.requires_grad = False

for name, param in random_model.named_parameters():
    #if name not in ['fc.weight', 'fc.bias']:
    param.requires_grad = False

In [47]:
victim_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [48]:
stolen_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [49]:
random_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [50]:
def mutual_information(model1, model2, sample_size):
  representation1 = torch.zeros(sample_size, 512).to(device)
  representation2 = torch.zeros(sample_size, 512).to(device)
  representation_joint = torch.zeros(sample_size, 512 * 2).to(device)

  for i, (x_batch, _) in enumerate(train_loader):
    if i * 100 >= sample_size:
      break
    x_batch = x_batch.to(device)
    r = model1(x_batch)
    #r = representations(x_batch)
    representation1[i * 100: (i+1)*100] = r

  print("finish loading model 1")

  for i, (x_batch, _) in enumerate(train_loader):
    if i * 100 >= sample_size:
      break
    x_batch = x_batch.to(device)
    r = model2(x_batch)
    #r = representations(x_batch)
    representation2[i * 100: (i+1)*100] = r

  print("finish loading model 2")

  for i, (x_batch, _) in enumerate(train_loader):
    if i * 100 >= sample_size:
      break
    x_batch = x_batch.to(device)
    r1 = model1(x_batch)
    r2 = model2(x_batch)
    #r = representations(x_batch)
    r1 = F.normalize(r1)
    r2 = F.normalize(r2)
    representation_joint[i * 100: (i+1)*100] = torch.cat((r1, r2), 1)

  representation1 = F.normalize(representation1)
  representation2 = F.normalize(representation2)
  
  
  print("finish loading joint")


  return entropy(representation1)  + entropy(representation2) - entropy(representation_joint)


In [51]:
MI = 0
for _ in range(1):
  MI += mutual_information(victim_model, victim_model, 10000)
print(MI / 1)

finish loading model 1
finish loading model 2
finish loading joint
tensor(-436.1375, device='cuda:0') -874.3362417833189 tensor(-1309.8967, device='cuda:0')
tensor(-433.6175, device='cuda:0') -874.3362417833189 tensor(-1307.3767, device='cuda:0')
tensor(-525.8648, device='cuda:0') -2100.2183980672253 tensor(-2625.5063, device='cuda:0')
tensor(8.2329, device='cuda:0')


In [52]:
MI = 0
for _ in range(1):
  MI += mutual_information(victim_model, stolen_model, 10000)
print(MI/1)

finish loading model 1
finish loading model 2
finish loading joint
tensor(-433.7691, device='cuda:0') -874.3362417833189 tensor(-1307.5283, device='cuda:0')
tensor(-543.5029, device='cuda:0') -874.3362417833189 tensor(-1417.2621, device='cuda:0')
tensor(-594.2497, device='cuda:0') -2100.2183980672253 tensor(-2693.8914, device='cuda:0')
tensor(-30.8992, device='cuda:0')


In [53]:
MI = 0
for _ in range(1):
  MI += mutual_information(victim_model, random_model, 10000)
print(MI/1)

finish loading model 1
finish loading model 2
finish loading joint
tensor(-433.7000, device='cuda:0') -874.3362417833189 tensor(-1307.4592, device='cuda:0')
tensor(-503.4043, device='cuda:0') -874.3362417833189 tensor(-1377.1636, device='cuda:0')
tensor(-530.2346, device='cuda:0') -2100.2183980672253 tensor(-2629.8762, device='cuda:0')
tensor(-54.7466, device='cuda:0')
