In [None]:
%matplotlib inline

import numpy as np
from pprint import pprint

from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torchvision
from torchvision import models, datasets, transforms
import random
from tqdm import tqdm
import torch
from torch.autograd.functional import jvp
from defense import *
torch.manual_seed(50)

batch_size = 2
idx = 30

import inversefed
setup = inversefed.utils.system_startup()
defs = inversefed.training_strategy('conservative')

loss_fn, trainloader, validloader =  inversefed.construct_dataloaders('CIFAR10', defs)



print(torch.__version__, torchvision.__version__)

dst = validloader.dataset


train_indices = list(range(idx, idx + batch_size))
train_subset = Subset(dst, train_indices)
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=False)

test_indices = list(range(len(dst)))
test_subset = Subset(dst, test_indices)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)



In [None]:
import inversefed
arch = 'ConvNet32'

net, _ = inversefed.construct_model(arch, num_classes=10, num_channels=3,seed=0)

print('Num of Parameters Total')
print(sum(p.numel() for p in net.parameters() if p.requires_grad))

In [None]:
batch_size = 8
split_batch=4
test_loader=train_loader
defense_method = lambda x,y,z,w: defense_dpsgd(x,y,z,w,clipping_threshold=0.05)
defense_method = defense_prune
num_epoch = 10
lr = 1e-3

for noise_scale in [0.1,0.2,0.3,0.4,0.5,0.6,0.7]:
    for num_samples in [10]: #Num samples for random sketching
        for defense_type in ['default','ours']:
            train_indices = list(range(1,9))
            train_subset = Subset(dst, train_indices)
            train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=False)

            test_indices = list(range(1,9))
            test_subset = Subset(dst, test_indices)
            test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)
            print(f'\n\n\n\n___^^^___noise_scale:{noise_scale}, defense_method:{defense_type}___^^^___')

            torch.manual_seed(0)
            net, _ = inversefed.construct_model(arch, num_classes=10, num_channels=3, seed=0)
            net.to(device)
            loss = []
            acc = []
            norms = []
            defense = True
            optimizer = optim.SGD(net.parameters(), lr=lr)
            criterion = lambda x,y: loss_fn(x,y)[0]#torch.nn.CrossEntropyLoss()
            for _ in range(num_epoch):
                print(f'______EPOCH {_+1}_______')
                t=tqdm(train_loader)
                for gt_data, gt_label in t:
                    gt_data, gt_label = gt_data.to(device), gt_label.to(device)
                    gt_onehot_label = label_to_onehot(gt_label, num_classes=10)

                    optimizer.zero_grad()

                    split_size = batch_size // split_batch
                    parts_data = torch.split(gt_data, split_size)
                    parts_label = torch.split(gt_label, split_size)
                    parts_onehot_label = torch.split(gt_onehot_label, split_size)

                    accumulated_gradients = None

                    for part_data, part_label, part_onehot_label in zip(parts_data, parts_label, parts_onehot_label):
                        part_data.requires_grad_(True)
                        out = net(part_data)
                        y = criterion(out, part_label)
                        original_dy_dx = torch.autograd.grad(y, net.parameters(), create_graph=True)
                        l2_norms_dy_dx=None

                        if defense:
                            if defense_type != 'default':
                                l2_norms_dy_dx = compute_l2_norm_of_gradients_new(net, part_data, part_onehot_label, criterion, num_samples=num_samples)
                            modified_dy_dx = defense_method(original_dy_dx, l2_norms_dy_dx, noise_scale, defense_type)
                            # print(sum([torch.sum(i==0) for i in modified_dy_dx])/sum([torch.numel(i) for i in modified_dy_dx]))
                        else:
                            modified_dy_dx = original_dy_dx

                        if accumulated_gradients is None:
                            accumulated_gradients = [grad.clone() for grad in modified_dy_dx]
                        else:
                            for i in range(len(accumulated_gradients)):
                                accumulated_gradients[i] += modified_dy_dx[i]

                    for i in range(len(accumulated_gradients)):
                        accumulated_gradients[i] /= split_batch

                    with torch.no_grad():
                        for param, grad in zip(net.parameters(), accumulated_gradients):
                            param.grad = grad
                    optimizer.step()

                    loss.append(y.item())
                    acc.append((torch.argmax(net(gt_data), dim=1) == gt_label).float().mean().item())
                    if defense and defense_type=='ours':
                        norms.append([(torch.mean(torch.abs(l2_norms_dy_dx[i] / original_dy_dx[i])) ** 0.5).item() for i in range(len(original_dy_dx))])
                    t.set_description(f"Loss: {loss[-1]}, Acc: {acc[-1]}")

            epoch_test_loss = 0
            epoch_test_acc = 0
            with torch.no_grad():
                for test_data, test_label in test_loader:
                    test_data, test_label = test_data.to(device), test_label.to(device)
                    test_out = net(test_data)
                    test_y = criterion(test_out, test_label)
                    epoch_test_loss += test_y.item()
                    epoch_test_acc += (torch.argmax(test_out, dim=1) == test_label).float().mean().item()

            final_test_loss=epoch_test_loss / len(test_loader)
            final_test_acc=epoch_test_acc / len(test_loader)
            print(f"Final Test Loss: {final_test_loss}, Final Test Acc: {final_test_acc}")

            epoch_train_loss = 0
            epoch_train_acc = 0
            with torch.no_grad():
                for test_data, test_label in train_loader:
                    test_data, test_label = test_data.to(device), test_label.to(device)
                    test_out = net(test_data)
                    test_y = criterion(test_out, test_label)
                    epoch_train_loss += test_y.item()
                    epoch_train_acc += (torch.argmax(test_out, dim=1) == test_label).float().mean().item()

            final_train_loss=epoch_train_loss / len(train_loader)
            final_train_acc=epoch_train_acc / len(train_loader)
            print(f"Final Train Loss: {final_train_loss}, Final Train Acc: {final_train_acc}")
            
            with open('data_train_detail.txt', 'a') as file:
                file.write(f"k: {noise_scale}, Defense type: {defense_type}\nLoss: {loss}\nAcc: {acc}\n\n")

            with open('data_train.txt', 'a') as file:
                file.write(f"Final train loss: {final_train_loss}, Final test loss: {final_test_loss}, Final train acc: {final_train_acc}, Final test acc: {final_test_acc}, k: {noise_scale}, Defense type: {defense_type}, Num_Samples: {num_samples}\n")