In [None]:
import torch
import torch.nn.functional as F

from math import floor
import os
import sys

import gc
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.cuda
import torch.multiprocessing as mp
import torch.nn as nn

from dataset_utils import get_train_and_test_dataloader
from models.resnet import ResNet50, ResNet18
from models.wide_resnet import Wide_ResNet
from optimizers.optimizer_utils import use_optimizer, use_lr_scheduler
import opacus
from opacus.validators import ModuleValidator
from tqdm.notebook import tqdm
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import DataLoader

In [None]:
net = ResNet18(num_classes=100)
net.train()
net = ModuleValidator.fix(net.to("cpu"))  # Note that we are using the backbone as a black-box featurizer. It's never trained, so we can keep BatchNorms in there.
net.load_state_dict(torch.load("./model18_100_GN-Copy1.pt", map_location="cpu"))
net = net.to("cuda")

In [None]:
ls_names = [name for name, p in net.named_parameters()]

In [None]:
for name, param in net.named_parameters():
    param.requires_grad = True
    
classifier_names = ls_names[-2:]
conv_norm_names = ls_names[-8:-2]
classifier_params, conv_norm_params = [], []
for name, param in net.named_parameters():
    if name in classifier_names:
        classifier_params.append(param)
    elif name in conv_norm_names:
        conv_norm_params.append(param)
    else:
        param.requires_grad = False

In [None]:
def count_parameters(model, all_param_flag=False):
    return sum(p.numel() for p in model.parameters() if p.requires_grad or all_param_flag)

In [None]:
count_parameters(net)

In [None]:
criterion = nn.CrossEntropyLoss()
net_params = [{"params" : classifier_params, "lr" : 0.8}, {"params" : conv_norm_params, "lr" : 0.01}]
optimizer = torch.optim.SGD(net_params, momentum=0.9, weight_decay=1e-4)

In [None]:
CIF100_MEAN = [0.5071, 0.4867, 0.4408]
CIF100_STD = [0.2675, 0.2565, 0.2761]

batch_size = 500
train_ds = CIFAR10('../datasets/', 
                   train=True, 
                   download=True, 
                   transform=Compose([ToTensor(), Normalize(CIF100_MEAN, CIF100_STD)])
)

train_loader = DataLoader(train_ds, 
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4, 
                          pin_memory=True,
)

test_ds = CIFAR10('../datasets/', 
                  train=False, 
                  download=True, 
                  transform=Compose([ToTensor(), Normalize(CIF100_MEAN, CIF100_STD)])
)
test_loader = DataLoader(test_ds, 
                         batch_size=batch_size,
                         shuffle=False, 
                         num_workers=2, 
                         pin_memory=True
)
x, y = next(iter(train_loader))

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR, ConstantLR, LambdaLR
epochs = 200

privacy_engine = opacus.PrivacyEngine()
(
model,
optimizer,
train_loader,
) = privacy_engine.make_private_with_epsilon(
module=net,
optimizer=optimizer,
data_loader=train_loader,
epochs=epochs,
target_epsilon=1,
target_delta=1e-5,
max_grad_norm=1,
)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

In [None]:
from statistics import mean

def train(model, criterion, optimizer, train_loader):
  accs = []
  losses = []
  for x, y in tqdm(train_loader):
    x = x.to('cuda')
    y = y.to('cuda')

    logits = model(x)
    loss = criterion(logits, y)
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()    
    
    preds = logits.argmax(-1)
    n_correct = float(preds.eq(y).sum())
    batch_accuracy = n_correct / len(y)

    accs.append(batch_accuracy)
    losses.append(float(loss))

  print(
      f"Train Accuracy: {mean(accs):.6f}"
      f"Train Loss: {mean(losses):.6f}"
  ) 
  return

In [None]:
    torch.cuda.empty_cache()
    gc.collect()


In [None]:
for epoch in tqdm(range(epochs)):
    train(net, criterion, optimizer, train_loader)
    scheduler.step()
    print(scheduler.get_last_lr())

In [None]:
import gc
def compute_acc(model, dataloader, device="cpu"):
    torch.cuda.empty_cache()
    gc.collect()
    print(f"The device used is {device}.", flush=True)
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            images = images.to(device)
            labels = labels.to(device)
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            if len(labels.shape) >= 2:
                labels = labels[:, 0]
            correct += (predicted == labels).sum().item()

            del images, labels, outputs

    return 100 * correct / total

In [None]:
compute_acc(model, test_loader, device="cuda")