In [None]:
import os
import timm
import torch
import opacus
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
import easydict
from prv_accountant.dpsgd import find_noise_multiplier
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader
from opacus.utils.batch_memory_manager import wrap_data_loader

criterion = nn.CrossEntropyLoss()
device = "cuda:0"
DATASET_TO_CLASSES = {'CIFAR10': 10,
            'CIFAR100': 100,
            'FashionMNIST': 10,
            'STL10': 10}
ARCH_TO_INTERP_SIZE = {"beit_large_patch16_512": 512,
        "convnext_xlarge_384_in22ft1k": 384,
        "beitv2_large_patch16_224_in22k": 224}
args = easydict.EasyDict({"dataset":  "STL10",
    "arch":     "beit_large_patch16_512",
    "lr":       0.01,
    "epochs":   1,
    "epsilon":  0.1,
    "dataset_path":  "datasets/"})
args.sigma = find_noise_multiplier(sampling_probability=1.0,
    num_steps=args.epochs,
    target_epsilon=args.epsilon,
    target_delta=1e-5,
    eps_error=0.001,
    mu_max=5000)
args.num_classes = DATASET_TO_CLASSES[args.dataset]

def get_features(f, images, interp_size=224, batch=64):
    features = []
    for img in tqdm(images.split(batch)):
        with torch.no_grad():
            img = F.interpolate(img.cuda(), size=(interp_size, interp_size), mode="bicubic")
            features.append(f(img).detach().cpu())
    return torch.cat(features)

def get_ds(args):
    if args.dataset == "STL10":
        ds = getattr(datasets, args.dataset)(args.dataset_path, transform=transforms.ToTensor(), split='train', download=True)
        images_train, labels_train = torch.tensor(ds.data) / 255.0, torch.tensor(ds.labels)
        ds = getattr(datasets, args.dataset)(args.dataset_path, transform=transforms.ToTensor(), split='test', download=True)
        images_test, labels_test = torch.tensor(ds.data) / 255.0, torch.tensor(ds.labels)
    elif args.dataset == "FashionMNIST":
        ds_train = getattr(datasets, args.dataset)(args.dataset_path, transform=transforms.ToTensor(), train=True, download=True)
        ds_test = getattr(datasets, args.dataset)(args.dataset_path, transform=transforms.ToTensor(), train=False, download=True)
        images_train, labels_train = torch.tensor(ds_train.data.unsqueeze(1).repeat(1, 3, 1, 1)).float() / 255.0, torch.tensor(ds_train.targets)
        images_test, labels_test = torch.tensor(ds_test.data.unsqueeze(1).repeat(1, 3, 1, 1)).float() / 255.0, torch.tensor(ds_test.targets)
    else:
        ds = getattr(datasets, args.dataset)(args.dataset_path, transform=transforms.ToTensor(), train=True, download=True)
        images_train, labels_train = torch.tensor(ds.data.transpose(0, 3, 1, 2)) / 255.0, torch.tensor(ds.targets)
        ds = getattr(datasets, args.dataset)(args.dataset_path, transform=transforms.ToTensor(), train=False, download=True)
        images_test, labels_test = torch.tensor(ds.data.transpose(0, 3, 1, 2)) / 255.0, torch.tensor(ds.targets)
    feature_extractor = nn.DataParallel(timm.create_model(args.arch, num_classes=0, pretrained=True)).eval().cuda()    
    features_train = get_features(feature_extractor, images_train, interp_size=ARCH_TO_INTERP_SIZE[args.arch])
    features_test = get_features(feature_extractor, images_test, interp_size=ARCH_TO_INTERP_SIZE[args.arch])
    ds_train = TensorDataset(features_train, labels_train)
    args.batch_size = len(ds_train)
    train_loader = DataLoader(ds_train, batch_size=args.batch_size, shuffle=True, **{'num_workers': 4, 'pin_memory': True})
    ds_test = TensorDataset(features_test, labels_test)
    test_loader = DataLoader(ds_test, batch_size=len(ds_test), shuffle=False, **{'num_workers': 4, 'pin_memory': True})
    return train_loader, test_loader, features_test.shape[-1], len(labels_test)

def train(model, train_loader, optimizer):
    model.train()
    for data, target in tqdm(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad() 
        criterion(model(data), target).backward()
        optimizer.step()

def test(model, test_loader):
    acc = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data, target = data.to(device), target.to(device)
            model.eval()
            pred = model(data).argmax(dim=1, keepdim=True)
            acc += pred.eq(target.view_as(pred)).sum().item() * 100/len_test
    return acc

train_loader, test_loader, num_features, len_test = get_ds(args)
model = nn.Linear(num_features, args.num_classes, bias=False).cuda()
model.weight.data.zero_()
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
privacy_engine = opacus.PrivacyEngine(accountant="gdp")
model, optimizer, train_loader = privacy_engine.make_private(module=model,
    optimizer=optimizer,
    data_loader=train_loader,
    noise_multiplier=args.sigma,
    max_grad_norm=1)
train_loader = wrap_data_loader(data_loader=train_loader, max_batch_size=5000, optimizer=optimizer)
for epoch in range(1, args.epochs + 1):
    train(model, train_loader, optimizer)
    print(f"Epoch {epoch} Test Accuracy {test(model, test_loader):.2f}")   
print(f"Epsilon {privacy_engine.accountant.get_epsilon(delta=1e-5):.3f}")