In [1]:
import torch
import torch.nn.functional as F

from math import floor
import os
import sys

import gc
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.cuda
import torch.multiprocessing as mp
import torch.nn as nn

from dataset_utils import get_train_and_test_dataloader
from models.resnet import ResNet50, ResNet18
from models.wide_resnet import Wide_ResNet
from optimizers.optimizer_utils import use_optimizer, use_lr_scheduler
import opacus
from opacus.validators import ModuleValidator
from tqdm.notebook import tqdm
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import Compose, Normalize, ToTensor
from torch.utils.data import DataLoader

In [2]:
net = ResNet18(num_classes=100)
net.train()
net = ModuleValidator.fix(net.to("cpu"))  # Note that we are using the backbone as a black-box featurizer. It's never trained, so we can keep BatchNorms in there.
net.load_state_dict(torch.load("./model18_100_GN-Copy1.pt", map_location="cpu"))
net = net.to("cuda")

In [3]:
ls_names = [name for name, p in net.named_parameters()]

In [4]:
for name, param in net.named_parameters():
    param.requires_grad = True
    
classifier_names = ls_names[-2:]
conv_norm_names = ls_names[-8:-2]
classifier_params, conv_norm_params = [], []
for name, param in net.named_parameters():
    if name in classifier_names:
        classifier_params.append(param)
    elif name in conv_norm_names:
        conv_norm_params.append(param)
    else:
        param.requires_grad = False

In [5]:
def count_parameters(model, all_param_flag=False):
    return sum(p.numel() for p in model.parameters() if p.requires_grad or all_param_flag)

In [6]:
count_parameters(net)

4771940

In [7]:
criterion = nn.CrossEntropyLoss()
net_params = [{"params" : classifier_params, "lr" : 0.8}, {"params" : conv_norm_params, "lr" : 0.01}]
optimizer = torch.optim.SGD(net_params, momentum=0.9, weight_decay=1e-4)

In [8]:
CIF100_MEAN = [0.5071, 0.4867, 0.4408]
CIF100_STD = [0.2675, 0.2565, 0.2761]

batch_size = 500
train_ds = CIFAR10('../datasets/', 
                   train=True, 
                   download=True, 
                   transform=Compose([ToTensor(), Normalize(CIF100_MEAN, CIF100_STD)])
)

train_loader = DataLoader(train_ds, 
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4, 
                          pin_memory=True,
)

test_ds = CIFAR10('../datasets/', 
                  train=False, 
                  download=True, 
                  transform=Compose([ToTensor(), Normalize(CIF100_MEAN, CIF100_STD)])
)
test_loader = DataLoader(test_ds, 
                         batch_size=batch_size,
                         shuffle=False, 
                         num_workers=2, 
                         pin_memory=True
)
x, y = next(iter(train_loader))

Files already downloaded and verified
Files already downloaded and verified


In [9]:
from torch.optim.lr_scheduler import CosineAnnealingLR, ConstantLR, LambdaLR
epochs = 200

privacy_engine = opacus.PrivacyEngine()
(
model,
optimizer,
train_loader,
) = privacy_engine.make_private_with_epsilon(
module=net,
optimizer=optimizer,
data_loader=train_loader,
epochs=epochs,
target_epsilon=1,
target_delta=1e-5,
max_grad_norm=1,
)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

  z = np.log((np.exp(t) + q - 1) / q)


In [10]:
from statistics import mean

def train(model, criterion, optimizer, train_loader):
  accs = []
  losses = []
  for x, y in tqdm(train_loader):
    x = x.to('cuda')
    y = y.to('cuda')

    logits = model(x)
    loss = criterion(logits, y)
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()    
    
    preds = logits.argmax(-1)
    n_correct = float(preds.eq(y).sum())
    batch_accuracy = n_correct / len(y)

    accs.append(batch_accuracy)
    losses.append(float(loss))

  print(
      f"Train Accuracy: {mean(accs):.6f}"
      f"Train Loss: {mean(losses):.6f}"
  ) 
  return

In [11]:
    torch.cuda.empty_cache()
    gc.collect()


9

In [12]:
for epoch in tqdm(range(epochs)):
    train(net, criterion, optimizer, train_loader)
    scheduler.step()
    print(scheduler.get_last_lr())

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]



Train Accuracy: 0.375102Train Loss: 3.019791
[0.7999506529926643, 0.009999383162408303]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.467345Train Loss: 3.131440
[0.7998026241462927, 0.009997532801828657]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.511372Train Loss: 2.834294
[0.7995559499847881, 0.00999444937480985]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.533197Train Loss: 2.664058
[0.7992106913713087, 0.009990133642141357]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.552341Train Loss: 2.557650
[0.7987669334932513, 0.009984586668665639]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.568353Train Loss: 2.460336
[0.7982247858412321, 0.009977809823015398]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.582998Train Loss: 2.420108
[0.7975843821820721, 0.009969804777275897]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.596876Train Loss: 2.314310
[0.7968458805257913, 0.009960573506572387]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.602663Train Loss: 2.253621
[0.7960094630866231, 0.009950118288582784]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.609498Train Loss: 2.267264
[0.7950753362380553, 0.009938441702975684]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.618061Train Loss: 2.296692
[0.7940437304619097, 0.009925546630773864]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.618020Train Loss: 2.282249
[0.7929149002914756, 0.009911436253643439]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.620318Train Loss: 2.271686
[0.7916891242487064, 0.009896114053108824]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.624425Train Loss: 2.146789
[0.7903667047754991, 0.009879583809693733]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.627106Train Loss: 2.139448
[0.7889479681590708, 0.009861849601988378]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.640824Train Loss: 2.068048
[0.7874332644514526, 0.00984291580564315]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.631974Train Loss: 2.048206
[0.7858229673831194, 0.009822787092288985]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.627910Train Loss: 2.031950
[0.7841174742707774, 0.00980146842838471]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.632542Train Loss: 1.975243
[0.7823172059193322, 0.009778965073991645]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.636315Train Loss: 1.929669
[0.7804226065180616, 0.009755282581475762]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.641317Train Loss: 1.926255
[0.7784341435310183, 0.00973042679413772]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.639332Train Loss: 1.913038
[0.7763523075816903, 0.00970440384477112]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.632789Train Loss: 2.013118
[0.774177612331947, 0.00967722015414933]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.636032Train Loss: 2.011101
[0.7719105943553006, 0.00964888242944125]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.640615Train Loss: 1.961405
[0.7695518130045148, 0.009619397662556426]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.636101Train Loss: 1.954764
[0.7671018502735925, 0.009588773128419898]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.628327Train Loss: 1.960575
[0.7645613106541781, 0.009557016383177219]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.641948Train Loss: 1.888245
[0.7619308209864079, 0.009524135262330091]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.645778Train Loss: 1.944286
[0.7592110303042463, 0.009490137878803071]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.644404Train Loss: 1.944059
[0.7564026096753472, 0.009455032620941833]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.639401Train Loss: 1.904171
[0.7535062520354776, 0.009418828150443462]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.648028Train Loss: 1.838453
[0.7505226720175455, 0.009381533400219312]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.647131Train Loss: 1.878483
[0.7474526057752766, 0.00934315757219095]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.647319Train Loss: 1.874650
[0.7442968108015776, 0.009303710135019712]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.645604Train Loss: 1.877968
[0.7410560657416371, 0.009263200821770455]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.643900Train Loss: 1.904234
[0.7377311702008063, 0.009221639627510068]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.650551Train Loss: 1.886994
[0.7343229445473084, 0.009179036806841345]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.654948Train Loss: 1.907862
[0.730832229709825, 0.009135402871372801]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.657630Train Loss: 1.864629
[0.7272598869700097, 0.00909074858712511]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.656664Train Loss: 1.852234
[0.7236067977499793, 0.00904508497187473]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.653840Train Loss: 1.827269
[0.7198738633948366, 0.008998423292435446]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.654354Train Loss: 1.858467
[0.7160620049502764, 0.008950775061878443]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.648745Train Loss: 1.885439
[0.7121721629353323, 0.008902152036691641]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.647577Train Loss: 1.872365
[0.7082052971103161, 0.008852566213878938]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.648522Train Loss: 1.869114
[0.7041623862400128, 0.008802029828000147]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.651860Train Loss: 1.904219
[0.7000444278521842, 0.00875055534815229]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.653043Train Loss: 1.875490
[0.6958524379914441, 0.008698155474893039]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.649848Train Loss: 1.903329
[0.6915874509685649, 0.008644843137107049]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.656059Train Loss: 1.818982
[0.6872505191052758, 0.008590631488815935]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.657963Train Loss: 1.818585
[0.6828427124746193, 0.008535533905932728]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.659461Train Loss: 1.753602
[0.678365118636926, 0.008479563982961562]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.658932Train Loss: 1.743159
[0.6738188423714757, 0.008422735529643434]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.654307Train Loss: 1.745789
[0.6692050054039095, 0.008365062567548856]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.653027Train Loss: 1.756507
[0.6645247461294609, 0.00830655932661825]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.655042Train Loss: 1.777245
[0.6597792193320736, 0.008247240241650909]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.655265Train Loss: 1.748559
[0.654969595899476, 0.00818711994874344]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.657126Train Loss: 1.748065
[0.6500970625342822, 0.008126213281678518]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.655188Train Loss: 1.776730
[0.6451628214611907, 0.008064535268264873]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.656757Train Loss: 1.824020
[0.6401680901303537, 0.008002101126629411]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.658471Train Loss: 1.823608
[0.6351141009169894, 0.007938926261462358]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.657384Train Loss: 1.832132
[0.6300021008173116, 0.007875026260216385]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.656400Train Loss: 1.805167
[0.6248333511408525, 0.007810416889260647]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.660893Train Loss: 1.727561
[0.6196091271992529, 0.007745114089990651]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.666465Train Loss: 1.699495
[0.6143307179915988, 0.007679133974894975]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.663261Train Loss: 1.725843
[0.6089994258863797, 0.007612492823579737]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.671362Train Loss: 1.700265
[0.6036165663001487, 0.00754520707875185]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.672515Train Loss: 1.648443
[0.5981834673729632, 0.007477293342162032]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.678278Train Loss: 1.641454
[0.5927014696406863, 0.00740876837050857]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.675989Train Loss: 1.678357
[0.5871719257042296, 0.007339649071302861]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.674522Train Loss: 1.718185
[0.5815961998958189, 0.007269952498697728]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.663908Train Loss: 1.762991
[0.5759756679423662, 0.007199695849279568]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.670445Train Loss: 1.723121
[0.5703117166260293, 0.0071288964578253575]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.680933Train Loss: 1.637066
[0.5646057434420438, 0.007057571793025539]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.685067Train Loss: 1.660777
[0.5588591562539125, 0.006985739453173897]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.678835Train Loss: 1.669212
[0.5530733729460362, 0.006913417161825444]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.685435Train Loss: 1.647946
[0.5472498210738715, 0.0068406227634233855]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.677796Train Loss: 1.686162
[0.5413899375117032, 0.006767374218896281]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.678786Train Loss: 1.647491
[0.5354951680981169, 0.006693689601226453]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.674620Train Loss: 1.656794
[0.5295669672792601, 0.006619587090990743]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.678412Train Loss: 1.676950
[0.5236067977499793, 0.006545084971874733]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.674100Train Loss: 1.674191
[0.517616130092922, 0.006470201626161516]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.676326Train Loss: 1.651096
[0.5115964424156921, 0.006394955530196143]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.676453Train Loss: 1.660818
[0.5055492199861494, 0.00631936524982686]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.676486Train Loss: 1.623062
[0.49947595486594215, 0.00624344943582427]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.678250Train Loss: 1.578089
[0.49337814554236253, 0.006167226819279525]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.675835Train Loss: 1.582690
[0.48725729655861744, 0.006090716206982711]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.671157Train Loss: 1.630506
[0.48111491814260543, 0.006013936476782561]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.668398Train Loss: 1.601556
[0.4749525258342903, 0.005936906572928622]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.663718Train Loss: 1.597226
[0.46877164011176414, 0.005859645501397045]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.662096Train Loss: 1.606935
[0.4625737860160927, 0.005782172325201152]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.667405Train Loss: 1.567294
[0.4563604927750335, 0.005704506159687911]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.672725Train Loss: 1.504022
[0.4501332934257221, 0.0056266661678215195]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.673157Train Loss: 1.551340
[0.44389372443641845, 0.005548671555455224]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.674943Train Loss: 1.514367
[0.437643325327406, 0.0054705415665925695]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.671924Train Loss: 1.536235
[0.4313836382911383, 0.005392295478639223]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.679667Train Loss: 1.528871
[0.42511620781172565, 0.005313952597646565]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.684077Train Loss: 1.511603
[0.4188425802838573, 0.00523553225354821]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.679490Train Loss: 1.525840
[0.41256430363125157, 0.005157053795390639]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.677778Train Loss: 1.527482
[0.4062829269247285, 0.005078536586559101]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.683144Train Loss: 1.514050
[0.4000000000000002, 0.004999999999999997]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.684094Train Loss: 1.507929
[0.393717073075272, 0.004921463413440894]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.683539Train Loss: 1.513275
[0.38743569636874886, 0.004842946204609355]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.679106Train Loss: 1.517218
[0.38115741971614314, 0.004764467746451784]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.683992Train Loss: 1.477098
[0.37488379218827483, 0.00468604740235343]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.677824Train Loss: 1.501009
[0.36861636170886225, 0.004607704521360773]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.683532Train Loss: 1.468469
[0.36235667467259447, 0.0045294584334074255]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.685297Train Loss: 1.465819
[0.3561062755635821, 0.004451328444544771]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.685052Train Loss: 1.489238
[0.34986670657427854, 0.004373333832178476]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.688040Train Loss: 1.441992
[0.3436395072249672, 0.004295493840312086]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.693960Train Loss: 1.456663
[0.3374262139839079, 0.004217827674798844]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.696201Train Loss: 1.430213
[0.33122835988823635, 0.00414035449860295]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.697433Train Loss: 1.428277
[0.3250474741657104, 0.004063093427071376]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.695520Train Loss: 1.389537
[0.3188850818573953, 0.003986063523217437]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.704435Train Loss: 1.353109
[0.31274270344138333, 0.003909283793017288]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.705132Train Loss: 1.334565
[0.306621854457638, 0.0038327731807204714]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.707672Train Loss: 1.339028
[0.3005240451340583, 0.003756550564175725]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.705135Train Loss: 1.352269
[0.29445078001385105, 0.0036806347501731345]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.700599Train Loss: 1.378356
[0.2884035575843085, 0.003605044469803853]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.701419Train Loss: 1.359336
[0.28238386990707864, 0.00352979837383848]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.698990Train Loss: 1.362341
[0.27639320225002134, 0.0034549150281252636]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.701728Train Loss: 1.332356
[0.2704330327207404, 0.0033804129090092517]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.713312Train Loss: 1.287294
[0.2645048319018836, 0.003306310398773542]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.707573Train Loss: 1.312880
[0.2586100624882974, 0.003232625781103714]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.710909Train Loss: 1.304599
[0.25275017892612905, 0.0031593772365766104]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.707686Train Loss: 1.292252
[0.24692662705396434, 0.0030865828381745515]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.706332Train Loss: 1.291485
[0.24114084374608785, 0.0030142605468260956]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.708439Train Loss: 1.283018
[0.23539425655795665, 0.0029424282069744553]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.708846Train Loss: 1.265664
[0.2296882833739711, 0.002871103542174636]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.709298Train Loss: 1.268330
[0.22402433205763408, 0.0028003041507204235]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.711065Train Loss: 1.263862
[0.21840380010418145, 0.002730047501302266]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.709634Train Loss: 1.284842
[0.21282807429577086, 0.0026603509286971336]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.714801Train Loss: 1.270168
[0.20729853035931395, 0.0025912316294914224]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.715618Train Loss: 1.266283
[0.20181653262703705, 0.0025227066578379616]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.711024Train Loss: 1.277373
[0.19638343369985156, 0.0024547929212481432]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.711857Train Loss: 1.252408
[0.19100057411362056, 0.0023875071764202557]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.717703Train Loss: 1.229288
[0.18566928200840133, 0.0023208660251050153]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.713137Train Loss: 1.237237
[0.1803908728007473, 0.0022548859100093403]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.717338Train Loss: 1.229772
[0.1751666488591478, 0.0021895831107393462]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.717330Train Loss: 1.252899
[0.16999789918268865, 0.002124973739783607]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.716747Train Loss: 1.236941
[0.16488589908301088, 0.0020610737385376348]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.716222Train Loss: 1.215247
[0.15983190986964652, 0.0019978988733705804]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.714764Train Loss: 1.211466
[0.15483717853880943, 0.001935464731735117]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.720174Train Loss: 1.181490
[0.14990293746571798, 0.0018737867183214738]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.720096Train Loss: 1.185255
[0.14503040410052415, 0.001812880051256551]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.719484Train Loss: 1.190679
[0.14022078066792662, 0.001752759758349082]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.719553Train Loss: 1.151855
[0.13547525387053935, 0.0016934406733817413]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.718377Train Loss: 1.180036
[0.13079499459609067, 0.0016349374324511328]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.717632Train Loss: 1.169072
[0.12618115762852455, 0.0015772644703565563]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.722629Train Loss: 1.144514
[0.12163488136307432, 0.0015204360170384284]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.721811Train Loss: 1.168958
[0.11715728752538106, 0.0014644660940672626]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.725289Train Loss: 1.135610
[0.11274948089472456, 0.0014093685111840565]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.723375Train Loss: 1.149094
[0.1084125490314355, 0.0013551568628929432]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.725564Train Loss: 1.149952
[0.10414756200855613, 0.001301844525106951]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.728090Train Loss: 1.169198
[0.09995557214781621, 0.0012494446518477018]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.723716Train Loss: 1.164325
[0.09583761375998767, 0.0011979701719998452]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.724185Train Loss: 1.157842
[0.09179470288968439, 0.001147433786121054]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.724986Train Loss: 1.152515
[0.08782783706466821, 0.001097847963308352]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.727825Train Loss: 1.138942
[0.08393799504972387, 0.0010492249381215478]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.727186Train Loss: 1.126718
[0.0801261366051638, 0.001001576707564547]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.726869Train Loss: 1.119674
[0.0763932022500211, 0.0009549150281252633]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.729009Train Loss: 1.121816
[0.07274011302999071, 0.0009092514128748835]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.726462Train Loss: 1.136235
[0.06916777029017535, 0.0008645971286271914]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.730971Train Loss: 1.100502
[0.06567705545269201, 0.0008209631931586497]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.727666Train Loss: 1.108088
[0.062268829799194084, 0.0007783603724899257]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.730346Train Loss: 1.111822
[0.058943934258363244, 0.0007367991782295402]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.725530Train Loss: 1.144006
[0.05570318919842252, 0.0006962898649802812]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.730706Train Loss: 1.107286
[0.0525473942247235, 0.0006568424278090434]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.727514Train Loss: 1.133476
[0.04947732798245459, 0.0006184665997806821]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.728595Train Loss: 1.115978
[0.04649374796452264, 0.0005811718495565327]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.730916Train Loss: 1.108903
[0.04359739032465291, 0.0005449673790581611]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.728886Train Loss: 1.110152
[0.04078896969575381, 0.0005098621211969223]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.728141Train Loss: 1.111953
[0.03806917901359228, 0.0004758647376699032]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.729522Train Loss: 1.095430
[0.035438689345822, 0.00044298361682277464]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.731123Train Loss: 1.099616
[0.03289814972640766, 0.0004112268715800954]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.735221Train Loss: 1.083809
[0.030448186995485415, 0.0003806023374435674]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.727379Train Loss: 1.119147
[0.02808940564469948, 0.00035111757055874323]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.729732Train Loss: 1.102896
[0.025822387668053066, 0.0003227798458506631]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.728606Train Loss: 1.120888
[0.023647692418309833, 0.0002955961552288727]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.730269Train Loss: 1.094629
[0.021565856468981893, 0.00026957320586227347]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.730691Train Loss: 1.105899
[0.0195773934819386, 0.0002447174185242323]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.734208Train Loss: 1.085273
[0.017682794080667964, 0.00022103492600834934]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.730864Train Loss: 1.093927
[0.01588252572922283, 0.0001985315716152852]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.733772Train Loss: 1.080487
[0.014177032616880823, 0.00017721290771101014]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.733796Train Loss: 1.089434
[0.012566735548547623, 0.00015708419435684514]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.732197Train Loss: 1.100840
[0.011052031840929386, 0.0001381503980116172]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.731467Train Loss: 1.088499
[0.009633295224501078, 0.00012041619030626336]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.731325Train Loss: 1.091568
[0.008310875751293662, 0.00010388594689117068]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.732628Train Loss: 1.080541
[0.007085099708524516, 8.856374635655638e-05]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.732473Train Loss: 1.082452
[0.005956269538090456, 7.445336922613064e-05]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.731691Train Loss: 1.088732
[0.0049246637619449396, 6.155829702431169e-05]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.729848Train Loss: 1.097201
[0.003990536913376988, 4.988171141721231e-05]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.731025Train Loss: 1.105375
[0.0031541194742088965, 3.9426493427611173e-05]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.734162Train Loss: 1.082777
[0.0024156178179281637, 3.0195222724102018e-05]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.730915Train Loss: 1.097903
[0.0017752141587680029, 2.2190176984600013e-05]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.735434Train Loss: 1.082737
[0.0012330665067488153, 1.5413331334360176e-05]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.731899Train Loss: 1.093879
[0.000789308628691377, 9.866357858642203e-06]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.731653Train Loss: 1.089569
[0.0004440500152120389, 5.5506251901504804e-06]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.734241Train Loss: 1.094644
[0.00019737585370736013, 2.467198171341999e-06]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.735884Train Loss: 1.081327
[4.934700733576495e-05, 6.168375916970613e-07]


  0%|          | 0/100 [00:00<?, ?it/s]

Train Accuracy: 0.731063Train Loss: 1.097824
[0.0, 0.0]


In [13]:
import gc
def compute_acc(model, dataloader, device="cpu"):
    torch.cuda.empty_cache()
    gc.collect()
    print(f"The device used is {device}.", flush=True)
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            images = images.to(device)
            labels = labels.to(device)
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            if len(labels.shape) >= 2:
                labels = labels[:, 0]
            correct += (predicted == labels).sum().item()

            del images, labels, outputs

    return 100 * correct / total

In [14]:
compute_acc(model, test_loader, device="cuda")

The device used is cuda.


  0%|          | 0/20 [00:00<?, ?it/s]

73.15