In [32]:
# -*- coding: utf-8 -*
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import models
import os
import multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
from tqdm import tqdm
from tensors_dataset_path import TensorDatasetPath
from tensors_dataset_img import TensorDatasetImg
import random
import sys
from utils import *
from models import *
from data_transform import *

In [62]:
# Setup reprouducible environment

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(False)

setup_seed(20)

In [54]:
!python knowledge_distill_dataset.py
!python data_compression.py

model_name:  resnets
checkpoint:  resnets_clean
Files already downloaded and verified
[2023-06-22 15:11:04.510 pytorch-1-12-gpu-py-ml-g4dn-xlarge-60d67c819515f85736f0b2ea671f:339 INFO utils.py:28] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2023-06-22 15:11:04.981 pytorch-1-12-gpu-py-ml-g4dn-xlarge-60d67c819515f85736f0b2ea671f:339 INFO profiler_config_parser.py:111] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.
^C
Traceback (most recent call last):
  File "data_compression.py", line 27, in <module>
    dataset = torch.load("./dataset/distill_" + distill_data_name)
  File "/opt/conda/lib/python3.8/site-packages/torch/serialization.py", line 712, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/opt/conda/lib/python3.8/site-packages/torch/serialization.py", line 1049, in _load
    result = unpickler.load()
  File "/opt/conda/lib/python3.8/site-packages/PIL/Image.py", line 718, in __setstate__
    def __s

In [64]:
params = read_config()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

In [63]:
model_name = params["model"]
model_set = {
    "resnets": ResNetS(nclasses=10),
    "vgg_face": VGG_16(),
    "gtsrb": gtsrb(),
    "resnet50": models.resnet50(),
}
print("model_name: ", model_name)
model = model_set[model_name]

ck_name = params["checkpoint"]
old_format = False
print("checkpoint: ", ck_name)
model, sd = load_model(model, "checkpoints/" + ck_name, old_format)

model_name:  resnets
checkpoint:  resnets_clean


In [65]:
if torch.cuda.is_available():
    model = model.cuda()
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
model.to(device)

for name, value in model.named_parameters():
    if name == "layer4.0.conv1.weight":
        break
    value.requires_grad = False

model.eval()

print('model loaded')

model loaded


In [67]:
# Load training dataset

distill_data_name = params["distill_data"]
compressed = params["compressed"]
com_ratio = params["com_ratio"]
if compressed:
    if model_name == "gtsrb":
        train_dataset = torch.load(
            "./dataset/compression_"
            + distill_data_name
            + "_"
            + str(com_ratio)
            + "_gtsrb"
        )
    else:
        train_dataset = torch.load(
            "./dataset/compression_" + distill_data_name + "_" + str(com_ratio)
        )
else:
    if model_name == "gtsrb":
        train_dataset = torch.load("./dataset/distill_" + distill_data_name + "_gtsrb")
    else:
        train_dataset = torch.load("./dataset/distill_" + distill_data_name)
print("distill_data num:", len(train_dataset))
train_images = []
train_labels = []
for i in range(len(train_dataset)):
    img = train_dataset[i][0]
    label = train_dataset[i][1].cpu()
    train_images.append(img)
    train_labels.append(label)
train_images = np.array(train_images)
train_labels = np.array(train_labels)

# train_images = np.load('train_images.npy', allow_pickle = True)
# train_labels = np.load('train_images.npy', allow_pickle = True)
print("load train data finished")

print(type(train_images), type(train_images[0]))
print(type(train_labels), type(train_labels[0]))

distill_data num: 20000
load train data finished
<class 'numpy.ndarray'> <class 'PIL.Image.Image'>
<class 'numpy.ndarray'> <class 'torch.Tensor'>


In [68]:
# Load test dataset

dataset_name = params["data"]

if dataset_name == "VGGFace":
    test_images, test_labels = get_dataset_vggface("./dataset/VGGFace/", max_num=10)
elif dataset_name == "tiny-imagenet-200":
    testset = torchvision.datasets.ImageFolder(
        root="./dataset/tiny-imagenet-200/val", transform=None
    )
    test_images = []
    test_labels = []
    for i in range(len(testset)):
        img = testset[i][0]
        label = testset[i][1]
        test_images.append(img)
        test_labels.append(label)
    test_images = np.array(test_images)
    test_labels = np.array(test_labels)
elif dataset_name == "cifar10":
    _dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True)
    test_images = [_dataset[i][0] for i in range(len(_dataset))]
    test_labels = _dataset.targets
else:
    test_images, test_labels = get_dataset("./dataset/" + dataset_name + "/test/")


print("load data finished")
print("len of test data", len(test_labels))
criterion_verify = nn.CrossEntropyLoss()

Files already downloaded and verified
load data finished
len of test data 10000


In [69]:
batch_size = 320

if model_name == "resnets":
    train_loader = DataLoader(
        TensorDatasetImg(train_images, train_labels, transform=cifar100_transforms),
        shuffle=True,
        batch_size=batch_size,
        num_workers=0,
        pin_memory=True,
    )

    test_loader = DataLoader(
        TensorDatasetImg(
            test_images,
            test_labels,
            transform=cifar10_transforms_test,
            mode="test",
            test_poisoned="False",
            transform_name="cifar10_transforms_test",
        ),
        shuffle=False,
        batch_size=64,
        num_workers=0,
        pin_memory=True,
    )

    test_loader_poison = DataLoader(
        TensorDatasetImg(
            test_images,
            test_labels,
            transform=cifar10_transforms_test,
            mode="test",
            test_poisoned="True",
            transform_name="cifar10_transforms_test",
        ),
        shuffle=False,
        batch_size=64,
        num_workers=0,
        pin_memory=True,
    )
elif model_name == "vgg_face":
    train_loader = DataLoader(
        TensorDatasetImg(train_images, train_labels, transform=LFW_transforms),
        shuffle=True,
        batch_size=batch_size,
        num_workers=4,
        pin_memory=True,
    )

    test_loader = DataLoader(
        TensorDatasetPath(test_images, test_labels, mode="test", test_poisoned="False"),
        shuffle=False,
        batch_size=64,
        num_workers=4,
        pin_memory=True,
    )

    test_loader_poison = DataLoader(
        TensorDatasetPath(test_images, test_labels, mode="test", test_poisoned="True"),
        shuffle=False,
        batch_size=64,
        num_workers=4,
        pin_memory=True,
    )
elif model_name == "gtsrb":
    train_loader = DataLoader(
        TensorDatasetImg(train_images, train_labels, transform=cifar100_transforms),
        shuffle=True,
        batch_size=batch_size,
        num_workers=4,
        pin_memory=True,
    )

    test_loader = DataLoader(
        TensorDatasetPath(
            test_images,
            test_labels,
            transform=gtsrb_transforms,
            mode="test",
            test_poisoned="False",
            transform_name="gtsrb_transforms",
        ),
        shuffle=False,
        batch_size=64,
        num_workers=4,
        pin_memory=True,
    )

    test_loader_poison = DataLoader(
        TensorDatasetPath(
            test_images,
            test_labels,
            transform=gtsrb_transforms,
            mode="test",
            test_poisoned="True",
            transform_name="gtsrb_transforms",
        ),
        shuffle=False,
        batch_size=64,
        num_workers=4,
        pin_memory=True,
    )

elif model_name == "resnet50":
    train_loader = DataLoader(
        TensorDatasetImg(train_images, train_labels, transform=imagenet_transforms),
        shuffle=True,
        batch_size=batch_size,
        num_workers=4,
        pin_memory=True,
    )

    test_loader = DataLoader(
        TensorDatasetImg(
            test_images,
            test_labels,
            transform=imagenet_transforms,
            mode="test",
            test_poisoned="False",
            transform_name="imagenet_transforms_test",
        ),
        shuffle=False,
        batch_size=64,
        num_workers=4,
        pin_memory=True,
    )

    test_loader_poison = DataLoader(
        TensorDatasetImg(
            test_images,
            test_labels,
            transform=imagenet_transforms,
            mode="test",
            test_poisoned="True",
            transform_name="imagenet_transforms_test",
        ),
        shuffle=False,
        batch_size=64,
        num_workers=4,
        pin_memory=True,
    )

print("poison data finished")

poison data finished


In [70]:
lr = params["lr"]
epochs = params["epochs"]

# optimizer_poison = optim.SGD(model.parameters(), lr=lr)
# scheduler_poison = lr_scheduler.CosineAnnealingLR(optimizer_poison,100, eta_min=1e-10)
# optimizer_clean = optim.SGD(model.parameters(), lr=lr/2*1.0)
# scheduler_clean = lr_scheduler.CosineAnnealingLR(optimizer_clean,100, eta_min=1e-10)
optimizer = optim.SGD(model.parameters(), lr=lr)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=1e-10)
criterion = nn.MSELoss()

###########------------First Accuracy----------------############
print("first accuracy:")
before_clean_acc = validate(model, -1, test_loader, criterion_verify, True)
before_poison_acc = validate(model, -1, test_loader_poison, criterion_verify, False)

first accuracy:
epoch: -1
clean accuracy: 0.9039
epoch: -1
attack accuracy: 0.1031


In [None]:
lambda1 = 1
alpha = 0.05

import warnings
warnings.filterwarnings("ignore") 

# TODO: remove this
epochs = 1000

for epoch in tqdm(range(epochs)):
    # train_with_grad_control(model, epoch, train_loader_clean, criterion, optimizer)
    # train_with_grad_control(model, epoch, train_loader, criterion, optimizer)

    print("lambda1: ", lambda1)
    adjust = train_with_grad_control(
        model, epoch, train_loader, criterion, optimizer, lambda1
    )
    lambda1 += alpha * adjust
    lambda1 = min(lambda1, 1)
    lambda1 = max(0, lambda1)

    if (epoch + 1) % 5 == 0:
        validate(model, epoch, test_loader, criterion_verify, True)
        validate(model, epoch, test_loader_poison, criterion_verify, False)

    state = {
        "net": model.state_dict(),
        "masks": [w for name, w in model.named_parameters() if "mask" in name],
        "epoch": epoch,
        # 'error_history': error_history,
    }
    torch.save(state, "checkpoints/cifar10_optim_1.t7")
    scheduler.step()

print("model train finished")

  0%|          | 0/1000 [00:00<?, ?it/s]

lambda1:  1


  0%|          | 1/1000 [00:07<1:57:38,  7.07s/it]

epoch: 0 train loss: 33.91948876953125
lambda1:  1


  0%|          | 2/1000 [00:14<1:58:40,  7.14s/it]

epoch: 1 train loss: 17.58334841156006
lambda1:  0.9975214702909391


  0%|          | 3/1000 [00:21<1:55:51,  6.97s/it]

epoch: 2 train loss: 16.967582443237305
lambda1:  0.9909437504497008


  0%|          | 4/1000 [00:27<1:55:15,  6.94s/it]

epoch: 3 train loss: 11.360336799621582
lambda1:  0.9808720286975547
epoch: 4 train loss: 13.370220275878907
epoch: 4
clean accuracy: 0.7547


  0%|          | 5/1000 [00:45<3:00:15, 10.87s/it]

epoch: 4
attack accuracy: 0.3645
lambda1:  0.9691116247509135


  1%|          | 6/1000 [00:52<2:36:20,  9.44s/it]

epoch: 5 train loss: 10.532078033447265
lambda1:  0.9576099864488791


  1%|          | 7/1000 [00:59<2:23:04,  8.65s/it]

epoch: 6 train loss: 13.8108193359375
lambda1:  0.9454612380696628


  1%|          | 8/1000 [01:06<2:12:27,  8.01s/it]

epoch: 7 train loss: 11.00649666595459
lambda1:  0.9321443798729081


  1%|          | 9/1000 [01:12<2:05:56,  7.62s/it]

epoch: 8 train loss: 12.636567859649658
lambda1:  0.9179072718416615
epoch: 9 train loss: 12.544793395996093
epoch: 9
clean accuracy: 0.6474


  1%|          | 10/1000 [01:30<2:58:00, 10.79s/it]

epoch: 9
attack accuracy: 0.5215
lambda1:  0.9035801264722578


  1%|          | 11/1000 [01:37<2:37:28,  9.55s/it]

epoch: 10 train loss: 12.265967330932618
lambda1:  0.8887413780477655
epoch: 11 train loss: 13.028268280029296
lambda1:  0.8739958183245631


 72%|███████▎  | 725/1000 [1:48:36<49:42, 10.85s/it]

epoch: 724
attack accuracy: 0.9140
lambda1:  0.42981835472205604


 73%|███████▎  | 726/1000 [1:48:43<43:54,  9.61s/it]

epoch: 725 train loss: 1.2352729940414429
lambda1:  0.42983647279810605


 73%|███████▎  | 727/1000 [1:48:49<39:32,  8.69s/it]

epoch: 726 train loss: 1.1741918230056763
lambda1:  0.42959986708238934


 73%|███████▎  | 728/1000 [1:48:56<36:33,  8.07s/it]

epoch: 727 train loss: 1.2725818300247191
lambda1:  0.4293636509467738


 73%|███████▎  | 729/1000 [1:49:02<34:35,  7.66s/it]

epoch: 728 train loss: 1.2462092475891113
lambda1:  0.4293479229124334
epoch: 729 train loss: 1.1836968727111816
epoch: 729
clean accuracy: 0.8670


 73%|███████▎  | 730/1000 [1:49:21<48:39, 10.81s/it]

epoch: 729
attack accuracy: 0.9139
lambda1:  0.42936623904300414


 73%|███████▎  | 731/1000 [1:49:28<43:27,  9.69s/it]

epoch: 730 train loss: 1.2242414207458496
lambda1:  0.4293869094063185


 73%|███████▎  | 732/1000 [1:49:35<39:39,  8.88s/it]

epoch: 731 train loss: 1.19093789768219
lambda1:  0.42915053292789085


 73%|███████▎  | 733/1000 [1:49:41<36:33,  8.22s/it]

epoch: 732 train loss: 1.2816142320632935
lambda1:  0.428914195854197


 73%|███████▎  | 734/1000 [1:49:48<34:35,  7.80s/it]

epoch: 733 train loss: 1.189736351966858
lambda1:  0.42867845087896833
epoch: 734 train loss: 1.272521604537964
epoch: 734
clean accuracy: 0.8671


 74%|███████▎  | 735/1000 [1:50:06<47:56, 10.86s/it]

epoch: 734
attack accuracy: 0.9148
lambda1:  0.42871883854930065


 74%|███████▎  | 736/1000 [1:50:13<42:20,  9.62s/it]

epoch: 735 train loss: 1.267507194519043
lambda1:  0.42848372503932297


 74%|███████▎  | 737/1000 [1:50:20<38:29,  8.78s/it]

epoch: 736 train loss: 1.2361222314834595
lambda1:  0.42824808661990693


 74%|███████▍  | 738/1000 [1:50:26<35:34,  8.15s/it]

epoch: 737 train loss: 1.2868414363861085
lambda1:  0.4282492714190358


 74%|███████▍  | 739/1000 [1:50:33<33:34,  7.72s/it]

epoch: 738 train loss: 1.2377480382919313
lambda1:  0.4280146135145869
epoch: 739 train loss: 1.2138819732666015
epoch: 739
clean accuracy: 0.8673


 74%|███████▍  | 740/1000 [1:50:51<47:08, 10.88s/it]

epoch: 739
attack accuracy: 0.9151
lambda1:  0.42806452677037515


 74%|███████▍  | 741/1000 [1:50:58<42:06,  9.75s/it]

epoch: 740 train loss: 1.1614771919250488
lambda1:  0.4278309597595131


 74%|███████▍  | 742/1000 [1:51:05<37:57,  8.83s/it]

epoch: 741 train loss: 1.3226261692047119
lambda1:  0.4275977674793863


 74%|███████▍  | 743/1000 [1:51:12<35:21,  8.25s/it]

epoch: 742 train loss: 1.2584245300292969
lambda1:  0.42736468479407325


 74%|███████▍  | 744/1000 [1:51:19<33:17,  7.80s/it]

epoch: 743 train loss: 1.2680336952209472
lambda1:  0.42713131208119337
epoch: 744 train loss: 1.3331907844543458
epoch: 744
clean accuracy: 0.8671


 74%|███████▍  | 745/1000 [1:51:37<46:22, 10.91s/it]

epoch: 744
attack accuracy: 0.9159
lambda1:  0.4271340346318438


 75%|███████▍  | 746/1000 [1:51:44<40:49,  9.64s/it]

epoch: 745 train loss: 1.1535805139541626
lambda1:  0.42690116544332174


 75%|███████▍  | 747/1000 [1:51:50<36:41,  8.70s/it]

epoch: 746 train loss: 1.2023500366210937
lambda1:  0.4266698296999948


 75%|███████▍  | 748/1000 [1:51:57<33:51,  8.06s/it]

epoch: 747 train loss: 1.2264280953407287
lambda1:  0.42643851063745236


 75%|███████▍  | 749/1000 [1:52:03<31:55,  7.63s/it]

epoch: 748 train loss: 1.335500717163086
lambda1:  0.42648175246733394
epoch: 749 train loss: 1.1482988901138305
epoch: 749
clean accuracy: 0.8681


 75%|███████▌  | 750/1000 [1:52:21<44:38, 10.72s/it]

epoch: 749
attack accuracy: 0.9163
lambda1:  0.42625019799307784


 75%|███████▌  | 751/1000 [1:52:28<39:23,  9.49s/it]

epoch: 750 train loss: 1.1068943567276002
lambda1:  0.4260204551364699


 75%|███████▌  | 752/1000 [1:52:35<35:59,  8.71s/it]

epoch: 751 train loss: 1.2090474367141724
lambda1:  0.4257916666776286


 75%|███████▌  | 753/1000 [1:52:42<33:31,  8.14s/it]

epoch: 752 train loss: 1.1624666557312011
lambda1:  0.425563309502287


 75%|███████▌  | 754/1000 [1:52:48<31:35,  7.71s/it]

epoch: 753 train loss: 1.2102461442947388
lambda1:  0.42533557315410664
epoch: 754 train loss: 1.2028548564910888
epoch: 754
clean accuracy: 0.8684


 76%|███████▌  | 755/1000 [1:53:06<43:55, 10.76s/it]

epoch: 754
attack accuracy: 0.9157
lambda1:  0.42510747053963144


 76%|███████▌  | 756/1000 [1:53:13<38:59,  9.59s/it]

epoch: 755 train loss: 1.276431450843811
lambda1:  0.4248791960882697


 76%|███████▌  | 757/1000 [1:53:20<35:29,  8.76s/it]

epoch: 756 train loss: 1.1743194484710693
lambda1:  0.4246513285782556


 76%|███████▌  | 758/1000 [1:53:27<32:55,  8.16s/it]

epoch: 757 train loss: 1.1311965789794922
lambda1:  0.42442549911748917


 76%|███████▌  | 759/1000 [1:53:34<31:20,  7.80s/it]

epoch: 758 train loss: 1.1737771735191345
lambda1:  0.42420043525704265
epoch: 759 train loss: 1.2298637685775757
epoch: 759
clean accuracy: 0.8684


 76%|███████▌  | 760/1000 [1:53:51<43:15, 10.82s/it]

epoch: 759
attack accuracy: 0.9175
lambda1:  0.42424240947489694


 76%|███████▌  | 761/1000 [1:53:58<38:19,  9.62s/it]

epoch: 760 train loss: 1.1974610929489136
lambda1:  0.4240175875860746


 76%|███████▌  | 762/1000 [1:54:05<34:38,  8.73s/it]

epoch: 761 train loss: 1.1878929681777954
lambda1:  0.4240491696593004


 76%|███████▋  | 763/1000 [1:54:12<32:05,  8.13s/it]

epoch: 762 train loss: 1.134434681892395
lambda1:  0.4238256119010462


 76%|███████▋  | 764/1000 [1:54:18<30:25,  7.73s/it]

epoch: 763 train loss: 1.179407636642456
lambda1:  0.4238401706597948
epoch: 764 train loss: 1.2439562702178955
epoch: 764
clean accuracy: 0.8681


 76%|███████▋  | 765/1000 [1:54:36<41:59, 10.72s/it]

epoch: 764
attack accuracy: 0.9213
lambda1:  0.4241071457079976


 77%|███████▋  | 766/1000 [1:54:43<37:05,  9.51s/it]

epoch: 765 train loss: 1.1398525276184082
lambda1:  0.4241367526279908


 77%|███████▋  | 767/1000 [1:54:49<33:23,  8.60s/it]

epoch: 766 train loss: 1.1170984401702881
lambda1:  0.42391515091596654


 77%|███████▋  | 768/1000 [1:54:56<31:05,  8.04s/it]

epoch: 767 train loss: 1.1562726392745972
lambda1:  0.42369546594661883


 77%|███████▋  | 769/1000 [1:55:03<29:34,  7.68s/it]

epoch: 768 train loss: 1.1892215719223023
lambda1:  0.4234760357274058
epoch: 769 train loss: 1.1882394638061524
epoch: 769
clean accuracy: 0.8686


 77%|███████▋  | 770/1000 [1:55:21<41:45, 10.89s/it]

epoch: 769
attack accuracy: 0.9210
lambda1:  0.4234800895703879


 77%|███████▋  | 771/1000 [1:55:28<36:51,  9.66s/it]

epoch: 770 train loss: 1.147001609325409
lambda1:  0.4235187486616179


 77%|███████▋  | 772/1000 [1:55:35<33:42,  8.87s/it]

epoch: 771 train loss: 1.0851769342422486
lambda1:  0.42330135006628966


 77%|███████▋  | 773/1000 [1:55:42<31:01,  8.20s/it]

epoch: 772 train loss: 1.1685858764648438
lambda1:  0.4230853746837185


 77%|███████▋  | 774/1000 [1:55:48<29:16,  7.77s/it]

epoch: 773 train loss: 1.1313722820281982
lambda1:  0.42286786741317733
epoch: 774 train loss: 1.2286461038589478
epoch: 774
clean accuracy: 0.8688


 78%|███████▊  | 775/1000 [1:56:06<40:36, 10.83s/it]

epoch: 774
attack accuracy: 0.9220
lambda1:  0.42291197125166113


 78%|███████▊  | 776/1000 [1:56:13<35:59,  9.64s/it]

epoch: 775 train loss: 1.1407939329147339
lambda1:  0.4226960066995769


 78%|███████▊  | 777/1000 [1:56:20<33:02,  8.89s/it]

epoch: 776 train loss: 1.2344470834732055
lambda1:  0.422710002680151


 78%|███████▊  | 778/1000 [1:56:27<30:27,  8.23s/it]

epoch: 777 train loss: 1.189177911758423
lambda1:  0.42299118280561543


 78%|███████▊  | 779/1000 [1:56:34<29:02,  7.88s/it]

epoch: 778 train loss: 1.1157593383789062
lambda1:  0.4230142554135155
epoch: 779 train loss: 1.2013758826255798
epoch: 779
clean accuracy: 0.8685


 78%|███████▊  | 780/1000 [1:56:52<39:42, 10.83s/it]

epoch: 779
attack accuracy: 0.9252
lambda1:  0.4228002027053869


 78%|███████▊  | 781/1000 [1:56:59<34:58,  9.58s/it]

epoch: 780 train loss: 1.1415141286849975
lambda1:  0.42258656287199586


 78%|███████▊  | 782/1000 [1:57:06<32:00,  8.81s/it]

epoch: 781 train loss: 1.1134369502067565
lambda1:  0.4223740135646501


 78%|███████▊  | 783/1000 [1:57:12<29:36,  8.19s/it]

epoch: 782 train loss: 1.1537314262390137
lambda1:  0.4223989036547945


 78%|███████▊  | 784/1000 [1:57:19<28:15,  7.85s/it]

epoch: 783 train loss: 1.1860897998809814
lambda1:  0.4221870320354684
epoch: 784 train loss: 1.1243696041107178
epoch: 784
clean accuracy: 0.8685


 78%|███████▊  | 785/1000 [1:57:37<39:06, 10.91s/it]

epoch: 784
attack accuracy: 0.9269
lambda1:  0.42217398241578924


 79%|███████▊  | 786/1000 [1:57:44<34:20,  9.63s/it]

epoch: 785 train loss: 1.0919253616333007
lambda1:  0.4219631851847565


 79%|███████▊  | 787/1000 [1:57:51<30:48,  8.68s/it]

epoch: 786 train loss: 1.0947505249977112
lambda1:  0.4217535830599932


 79%|███████▉  | 788/1000 [1:57:57<28:34,  8.09s/it]

epoch: 787 train loss: 1.0832799005508422
lambda1:  0.42181875766890997


 79%|███████▉  | 789/1000 [1:58:04<27:00,  7.68s/it]

epoch: 788 train loss: 1.1394996161460877
lambda1:  0.4216104258569326
epoch: 789 train loss: 1.0571467008590698
epoch: 789
clean accuracy: 0.8693


 79%|███████▉  | 790/1000 [1:58:22<37:48, 10.80s/it]

epoch: 789
attack accuracy: 0.9267
lambda1:  0.4214026236713316


 79%|███████▉  | 791/1000 [1:58:29<33:24,  9.59s/it]

epoch: 790 train loss: 1.1730936832427978
lambda1:  0.42143761113784445


 79%|███████▉  | 792/1000 [1:58:36<30:30,  8.80s/it]

epoch: 791 train loss: 1.09386598443985
lambda1:  0.4212307492000401


 79%|███████▉  | 793/1000 [1:58:42<28:04,  8.14s/it]

epoch: 792 train loss: 1.205395327091217
lambda1:  0.4212707968645783


 79%|███████▉  | 794/1000 [1:58:49<26:11,  7.63s/it]

epoch: 793 train loss: 1.098266505241394
lambda1:  0.42106447438097105
epoch: 794 train loss: 1.0249155826568603
epoch: 794
clean accuracy: 0.8687


 80%|███████▉  | 795/1000 [1:59:07<36:47, 10.77s/it]

epoch: 794
attack accuracy: 0.9303
lambda1:  0.42085974284967703


 80%|███████▉  | 796/1000 [1:59:14<32:40,  9.61s/it]

epoch: 795 train loss: 1.1072887077331544
lambda1:  0.42089897555503486


 80%|███████▉  | 797/1000 [1:59:21<29:52,  8.83s/it]

epoch: 796 train loss: 1.2472507162094115
lambda1:  0.4209539025993297


 80%|███████▉  | 798/1000 [1:59:28<27:31,  8.18s/it]

epoch: 797 train loss: 1.0100950260162354
lambda1:  0.42074955638639755


 80%|███████▉  | 799/1000 [1:59:34<25:55,  7.74s/it]

epoch: 798 train loss: 1.2751419162750244
lambda1:  0.420790808655504
epoch: 799 train loss: 1.0528152008056642
epoch: 799
clean accuracy: 0.8682


 80%|████████  | 800/1000 [1:59:52<35:49, 10.75s/it]

epoch: 799
attack accuracy: 0.9345
lambda1:  0.4205878211904824


 80%|████████  | 801/1000 [1:59:59<31:43,  9.56s/it]

epoch: 800 train loss: 1.0558764595985413
lambda1:  0.4203864253009133


 80%|████████  | 802/1000 [2:00:06<28:50,  8.74s/it]

epoch: 801 train loss: 1.04551424407959
lambda1:  0.42018570777595826


 80%|████████  | 803/1000 [2:00:12<26:41,  8.13s/it]

epoch: 802 train loss: 1.058867880821228
lambda1:  0.41998422203096497


 80%|████████  | 804/1000 [2:00:19<25:10,  7.71s/it]

epoch: 803 train loss: 1.1338241348266602
lambda1:  0.4200342440829366
epoch: 804 train loss: 1.1217199716567994
epoch: 804
clean accuracy: 0.8685


 80%|████████  | 805/1000 [2:00:37<35:06, 10.80s/it]

epoch: 804
attack accuracy: 0.9366
lambda1:  0.4198326500532456


 81%|████████  | 806/1000 [2:00:44<30:54,  9.56s/it]

epoch: 805 train loss: 1.164138949394226
lambda1:  0.4201166498087259


 81%|████████  | 807/1000 [2:00:51<28:17,  8.80s/it]

epoch: 806 train loss: 0.9461798639297485
lambda1:  0.4199177638746695


 81%|████████  | 808/1000 [2:00:58<26:27,  8.27s/it]

epoch: 807 train loss: 1.0692960796356201
lambda1:  0.4197207654210656


 81%|████████  | 809/1000 [2:01:05<24:56,  7.83s/it]

epoch: 808 train loss: 1.0505843086242677
lambda1:  0.4197666928978513


 81%|████████  | 809/1000 [2:01:10<28:36,  8.99s/it]


KeyboardInterrupt: 