In [1]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from fmnistmodels import*
import utils
from utils import*
import torch.utils.data as data
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data import TensorDataset
import mlp_private
# from mlp_private import MLP
from mlp import MLP
from dataset import BBoxDtaset
import cv2
from torch.utils.data.sampler import SubsetRandomSampler
import nni
from nni.compression.torch import LevelPruner
from pruning_utils import*
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

In [2]:
#Hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 100
train_batch_size = 128
test_batch_size = 100
learning_rate = 0.001
momentum = 0.9
original_classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
watermark_classes = ('no key', 'top righ', 'top left', 'down right', 'dwon left')
identity_string = b'Y6588121SNAJEEBMOHARRAMSALIMJEBREEL02092020imageclassification_10_classes'
fake_identity_string = b'Y6588121NAJEEBMOHARRAMSALIMcrisesURV02092020classification10labels'
plagiarizer_identity_string = b'X7823579MRAMIJOSEPHAFFARcrisesURV2409022ResNet18classification10labels'
plagiarizer_fake_identity_string = b'X7823579WRAMIHAFFARJOSEP2509022ResNet18classificationofImage10labels'
key_size = 5
num_wm_samples = 25000
train_wm_black_box_epochs = 50
combined_original_model_epochs = 200
combined_private_model_epochs = 100
train_simultaneously_epochs = 100
dataset = 'fmnist5'
containerdataset = 'mnist'

In [3]:
def train(model, device, train_loader, optimizer, criterion, sparse_bn=False):
    model.train()
    for batch_idx, (data, target, flags) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output, _ = model(data)
        loss = criterion(output, target)
        loss.backward()
        # L1 regularization on BN layer
        if sparse_bn:
            updateBN(model)
        optimizer.step()
       
    print('Loss: {}'.format(loss.item()))


In [4]:
# Testing the original model on the original task
def test_original(model, criterion, test_loader, device):
   
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets, flags) in enumerate(test_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs, _ = model(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

            print('Testing: epoch: %d | loss: %.3f | Acc: %.3f' %(1, test_loss/(batch_idx+1), 100.*correct/total))
            
        return  test_loss/(batch_idx+1), 100.*correct/total

In [5]:
# Testing the private model on the wm task
def test_private(model, criterion, test_loader, device):
   
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets, flags) in enumerate(test_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                targets = targets.type(torch.LongTensor).to(device)
                _, outputs = model(inputs)
                loss = criterion(outputs, targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

            print('Private model testing: epoch: %d | loss: %.3f | Acc: %.3f' %(1, test_loss/(batch_idx+1), 100.*correct/total))
            
            return  test_loss/(batch_idx+1), 100.*correct/total

In [7]:
combinedmodel = LeNetWM().to(device)
combinedoptimizer = torch.optim.SGD(combinedmodel.parameters(), lr=0.001,
                      momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
checkpoint = torch.load('./Checkpoints/best_combined_sequential_fmnist5lenetfmnist.t7')
combinedmodel.load_state_dict(checkpoint['state_dict'])
combinedoptimizer.load_state_dict(checkpoint['optimizer'])
combinedstart_epoch = checkpoint['epoch']
combinedbest_acc = checkpoint['acc']
combinedbest_acc

89.94

In [8]:
trainset, testset, train_loader, test_loader = get_flagged_fmnist_dataset(128, 
                                                                          100, num_classes=5)

In [9]:
trainsetwm, testsetwm = get_signed_dataset(identity_string, fake_identity_string, key_size, num_wm_samples
                                          , containerdataset=containerdataset)
wm_testloader = torch.utils.data.DataLoader(testsetwm, batch_size=10, shuffle=False, num_workers=1)

In [11]:
pruning_rates = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
for pr in pruning_rates:
    
    combinedmodel = LeNetWM().to(device)
    combinedoptimizer = torch.optim.SGD(combinedmodel.parameters(), lr=0.001,
                          momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    checkpoint = torch.load('./Checkpoints/best_combined_sequential_fmnist5lenetfmnist.t7')
    combinedmodel.load_state_dict(checkpoint['state_dict'])
    combinedstart_epoch = checkpoint['epoch']
    combinedbest_acc = checkpoint['acc']
    print('Marked model accuracy', combinedbest_acc)
    
    print('Pruning rate: {}%'.format(pr*100))
    
    config_list = [{ 'sparsity': pr, 'op_types': ['default'] }]

    pruner = LevelPruner(combinedmodel, config_list, combinedoptimizer)
    combinedmodel = pruner.compress()
    combinemodel_dict = combinedmodel.state_dict()
    
#     for epoch in range(10):
#         print('# Epoch {} #'.format(epoch+1))
#         train(combinedmodel, device, finetune_train_loader, combinedoptimizer, criterion)
#         pruner.update_epoch(epoch)
        
    for key in combinemodel_dict.keys():
        if 'wm' in key:
            if 'mask' in key:
                combinemodel_dict[key] = torch.ones_like(combinemodel_dict[key])

    combinedmodel.load_state_dict(combinemodel_dict)

    test_original(combinedmodel, criterion, test_loader, device)
    test_private(combinedmodel, criterion, wm_testloader, device) 
    print('###########################################################################################')


Marked model accuracy 89.94
Pruning rate: 10.0%
Testing: epoch: 1 | loss: 0.441 | Acc: 89.980
Private model testing: epoch: 1 | loss: 0.016 | Acc: 99.585
###########################################################################################
Marked model accuracy 89.94
Pruning rate: 20.0%
Testing: epoch: 1 | loss: 0.450 | Acc: 89.140
Private model testing: epoch: 1 | loss: 0.024 | Acc: 99.337
###########################################################################################
Marked model accuracy 89.94
Pruning rate: 30.0%
Testing: epoch: 1 | loss: 0.489 | Acc: 87.980
Private model testing: epoch: 1 | loss: 0.029 | Acc: 99.144
###########################################################################################
Marked model accuracy 89.94
Pruning rate: 40.0%
Testing: epoch: 1 | loss: 0.813 | Acc: 81.820
Private model testing: epoch: 1 | loss: 0.190 | Acc: 92.933
###########################################################################################
Marked model acc