In [1]:
import torch
import torch.nn.functional as F

from math import floor
import os
import sys

import gc
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.cuda
import torch.multiprocessing as mp
import torch.nn as nn

from dataset_utils import get_train_and_test_dataloader
from models.resnet import ResNet50, ResNet18
from models.wide_resnet import Wide_ResNet
from optimizers.optimizer_utils import use_optimizer, use_lr_scheduler
import opacus
from opacus.validators import ModuleValidator
from tqdm.notebook import tqdm
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import DataLoader

In [2]:
net = ResNet18(num_classes=100)
net.train()
net = ModuleValidator.fix(net.to("cpu"))  # Note that we are using the backbone as a black-box featurizer. It's never trained, so we can keep BatchNorms in there.
net.load_state_dict(torch.load("./model18_100_GN-Copy1.pt", map_location="cpu"))
net = net.to("cuda")

In [3]:
ls_names = [name for name, p in net.named_parameters()]

In [4]:
for name, param in net.named_parameters():
    param.requires_grad = True
    
classifier_names = ls_names[-2:]
conv_norm_names = ls_names[-8:-2]
classifier_params, conv_norm_params = [], []
for name, param in net.named_parameters():
    if name in classifier_names:
        classifier_params.append(param)
    elif name in conv_norm_names:
        conv_norm_params.append(param)
    else:
        param.requires_grad = False

In [5]:
def count_parameters(model, all_param_flag=False):
    return sum(p.numel() for p in model.parameters() if p.requires_grad or all_param_flag)

In [6]:
count_parameters(net)

4771940

In [7]:
criterion = nn.CrossEntropyLoss()
net_params = [{"params" : classifier_params, "lr" : 0.8}, {"params" : conv_norm_params, "lr" : 0.01}]
optimizer = torch.optim.SGD(net_params, momentum=0.9, weight_decay=1e-4)

In [8]:
CIF100_MEAN = [0.5071, 0.4867, 0.4408]
CIF100_STD = [0.2675, 0.2565, 0.2761]

batch_size = 250
train_ds = CIFAR10('../datasets/', 
                   train=True, 
                   download=True, 
                   transform=Compose([ToTensor(), Normalize(CIF100_MEAN, CIF100_STD)])
)

train_loader = DataLoader(train_ds, 
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4, 
                          pin_memory=True,
)

test_ds = CIFAR10('../datasets/', 
                  train=False, 
                  download=True, 
                  transform=Compose([ToTensor(), Normalize(CIF100_MEAN, CIF100_STD)])
)
test_loader = DataLoader(test_ds, 
                         batch_size=batch_size,
                         shuffle=False, 
                         num_workers=2, 
                         pin_memory=True
)
x, y = next(iter(train_loader))

Files already downloaded and verified
Files already downloaded and verified


In [9]:
from torch.optim.lr_scheduler import CosineAnnealingLR, ConstantLR, LambdaLR
epochs = 200

privacy_engine = opacus.PrivacyEngine()
(
model,
optimizer,
train_loader,
) = privacy_engine.make_private_with_epsilon(
module=net,
optimizer=optimizer,
data_loader=train_loader,
epochs=epochs,
target_epsilon=1,
target_delta=1e-5,
max_grad_norm=1,
)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

  z = np.log((np.exp(t) + q - 1) / q)


In [10]:
from statistics import mean

def train(model, criterion, optimizer, train_loader):
  torch.cuda.empty_cache()
  gc.collect()
  accs = []
  losses = []
  for x, y in tqdm(train_loader):
    x = x.to('cuda')
    y = y.to('cuda')

    logits = model(x)
    loss = criterion(logits, y)
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()    
    
    preds = logits.argmax(-1)
    n_correct = float(preds.eq(y).sum())
    batch_accuracy = n_correct / len(y)

    accs.append(batch_accuracy)
    losses.append(float(loss))

  print(
      f"Train Accuracy: {mean(accs):.6f}"
      f"Train Loss: {mean(losses):.6f}"
  ) 
  return

In [11]:
    torch.cuda.empty_cache()
    gc.collect()


0

In [None]:

for epoch in tqdm(range(epochs)):
    train(net, criterion, optimizer, train_loader)
    scheduler.step()
    print(scheduler.get_last_lr())

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]



Train Accuracy: 0.362585Train Loss: 3.933345
[0.7999506529926643, 0.009999383162408303]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.465162Train Loss: 3.493554
[0.7998026241462927, 0.009997532801828657]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.491595Train Loss: 3.312850
[0.7995559499847881, 0.00999444937480985]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.506492Train Loss: 3.244404
[0.7992106913713087, 0.009990133642141357]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.519449Train Loss: 3.128940
[0.7987669334932513, 0.009984586668665639]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.542408Train Loss: 2.997734
[0.7982247858412321, 0.009977809823015398]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.560587Train Loss: 2.929552
[0.7975843821820721, 0.009969804777275897]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.554552Train Loss: 2.982682
[0.7968458805257913, 0.009960573506572387]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.549625Train Loss: 2.962554
[0.7960094630866231, 0.009950118288582784]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.552395Train Loss: 2.977118
[0.7950753362380553, 0.009938441702975684]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.567092Train Loss: 2.857122
[0.7940437304619097, 0.009925546630773864]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.545998Train Loss: 3.040939
[0.7929149002914756, 0.009911436253643439]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.560320Train Loss: 2.945844
[0.7916891242487064, 0.009896114053108824]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.552181Train Loss: 2.925902
[0.7903667047754991, 0.009879583809693733]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.555000Train Loss: 2.851081
[0.7889479681590708, 0.009861849601988378]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.560040Train Loss: 2.804778
[0.7874332644514526, 0.00984291580564315]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.557522Train Loss: 2.858290
[0.7858229673831194, 0.009822787092288985]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.566502Train Loss: 2.875778
[0.7841174742707774, 0.00980146842838471]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.564177Train Loss: 2.865518
[0.7823172059193322, 0.009778965073991645]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.543657Train Loss: 2.867622
[0.7804226065180616, 0.009755282581475762]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.556612Train Loss: 2.739664
[0.7784341435310183, 0.00973042679413772]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.559563Train Loss: 2.752260
[0.7763523075816903, 0.00970440384477112]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.566510Train Loss: 2.818708
[0.774177612331947, 0.00967722015414933]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.549518Train Loss: 2.936455
[0.7719105943553006, 0.00964888242944125]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.576134Train Loss: 2.783435
[0.7695518130045148, 0.009619397662556426]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.576843Train Loss: 2.808240
[0.7671018502735925, 0.009588773128419898]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.562447Train Loss: 2.794140
[0.7645613106541781, 0.009557016383177219]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.561090Train Loss: 2.680563
[0.7619308209864079, 0.009524135262330091]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.558242Train Loss: 2.631327
[0.7592110303042463, 0.009490137878803071]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.557753Train Loss: 2.737459
[0.7564026096753472, 0.009455032620941833]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.548705Train Loss: 2.795821
[0.7535062520354776, 0.009418828150443462]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.550780Train Loss: 2.719026
[0.7505226720175455, 0.009381533400219312]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.548211Train Loss: 2.781720
[0.7474526057752766, 0.00934315757219095]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.545345Train Loss: 2.751259
[0.7442968108015776, 0.009303710135019712]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.536884Train Loss: 2.731791
[0.7410560657416371, 0.009263200821770455]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.537842Train Loss: 2.750163
[0.7377311702008063, 0.009221639627510068]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.541318Train Loss: 2.774348
[0.7343229445473084, 0.009179036806841345]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.547593Train Loss: 2.735843
[0.730832229709825, 0.009135402871372801]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.562258Train Loss: 2.776162
[0.7272598869700097, 0.00909074858712511]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.563089Train Loss: 2.702818
[0.7236067977499793, 0.00904508497187473]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.568269Train Loss: 2.652431
[0.7198738633948366, 0.008998423292435446]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.575869Train Loss: 2.562416
[0.7160620049502764, 0.008950775061878443]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.558404Train Loss: 2.722077
[0.7121721629353323, 0.008902152036691641]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.557269Train Loss: 2.699428
[0.7082052971103161, 0.008852566213878938]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.541916Train Loss: 2.748186
[0.7041623862400128, 0.008802029828000147]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.541318Train Loss: 2.714685
[0.7000444278521842, 0.00875055534815229]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.555353Train Loss: 2.705633
[0.6958524379914441, 0.008698155474893039]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.546605Train Loss: 2.702601
[0.6915874509685649, 0.008644843137107049]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.553994Train Loss: 2.632632
[0.6872505191052758, 0.008590631488815935]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.539582Train Loss: 2.640923
[0.6828427124746193, 0.008535533905932728]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.542175Train Loss: 2.666243
[0.678365118636926, 0.008479563982961562]


  0%|          | 0/200 [00:00<?, ?it/s]

Train Accuracy: 0.554023Train Loss: 2.599572
[0.6738188423714757, 0.008422735529643434]


  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
import gc
def compute_acc(model, dataloader, device="cpu"):
    torch.cuda.empty_cache()
    gc.collect()
    print(f"The device used is {device}.", flush=True)
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            images = images.to(device)
            labels = labels.to(device)
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            if len(labels.shape) >= 2:
                labels = labels[:, 0]
            correct += (predicted == labels).sum().item()

            del images, labels, outputs

    return 100 * correct / total

In [None]:
compute_acc(model, test_loader, device="cuda")