In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import os
from models import *
from function import *
from tqdm import tqdm

!nvidia-smi
print("CUDA Available : ", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Wed Sep 20 19:32:06 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090      WDDM  | 00000000:2B:00.0  On |                  N/A |
|  0%   54C    P8              43W / 350W |  12362MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
# Argument
batch_size = 128
learning_rate = 1e-6
num_epochs = 50
momentum = 0.9
weight_decay = 1e-9
class_num = 10
#class_num = 2 # clean or adversarial
output_path = 'test'

if not os.path.isdir(output_path):
    os.mkdir(output_path)

In [3]:
# Preprocessing
test_tfm = transforms.Compose([
    transforms.ToTensor(),
])
toPIL = transforms.ToPILImage()

In [4]:
# Cifar10
train_set = torchvision.datasets.CIFAR10(root='../../dataset/cifar10_dataset', train=True, download=True, transform=test_tfm)
test_set = torchvision.datasets.CIFAR10(root='../../dataset/cifar10_dataset', train=False, download=True, transform=test_tfm)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
model = Unet(n_channels = 3, n_classes = 3, bilinear = False)
model = model.to(device)
write_path = 'result'

target_model = ResNet18()
target_model = target_model.to(device)
adversary = LinfPGDAttack(target_model)



if not os.path.isdir(write_path):
    os.mkdir(write_path)
    
optimizer = optim.RMSprop(model.parameters(), lr = learning_rate, weight_decay = weight_decay, momentum = momentum, foreach = True)
criterion = diceLoss(3)
#criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    #idx = 0
    # clean train
    for batch in tqdm(train_loader):
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        adv_imgs = adversary.perturb(imgs, labels)
        
        output = model(imgs)
        loss = criterion(output, imgs)
        
        output = model(adv_imgs)
        loss += criterion(output, imgs)
        
        
        epoch_loss += loss
        
        loss.backward()
        optimizer.step()

    print(f'epoch : {epoch}  Training | loss = {epoch_loss / imgs.shape[0]}')
    model.eval()
    eval_loss = 0
    idx = 0
    
    for batch in tqdm(test_loader):
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        adv_imgs = adversary.perturb(imgs, labels)
        with torch.no_grad():
            output = model(adv_imgs)
            loss = criterion(output, imgs)
            eval_loss += loss
        
        for i in range(imgs.size(0)):
            img = toPIL(output[i].squeeze(0)).convert('RGB')
            img.save(f'{write_path}/img_{idx}.png')
            idx += 1
            
    print(f'epoch : {epoch}  Evaluating | loss = {eval_loss / imgs.shape[0]}')
torch.save(model.state_dict(), './pretrain/denoiser_cifar10.pth')

100%|██████████| 391/391 [01:30<00:00,  4.31it/s]


epoch : 0  Training | loss = 2.8574464321136475


100%|██████████| 79/79 [00:28<00:00,  2.77it/s]


epoch : 0  Evaluating | loss = 0.5939290523529053


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]


epoch : 1  Training | loss = 0.276658833026886


100%|██████████| 79/79 [00:22<00:00,  3.58it/s]


epoch : 1  Evaluating | loss = 0.09100859612226486


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]


epoch : 2  Training | loss = 0.12343873828649521


100%|██████████| 79/79 [00:22<00:00,  3.58it/s]


epoch : 2  Evaluating | loss = 0.15097321569919586


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 3  Training | loss = 0.09596646577119827


100%|██████████| 79/79 [00:21<00:00,  3.61it/s]


epoch : 3  Evaluating | loss = 0.09010588377714157


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 4  Training | loss = 0.0858810618519783


100%|██████████| 79/79 [00:21<00:00,  3.63it/s]


epoch : 4  Evaluating | loss = 0.08271634578704834


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]


epoch : 5  Training | loss = 0.08059611171483994


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 5  Evaluating | loss = 0.0836331844329834


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 6  Training | loss = 0.07658075541257858


100%|██████████| 79/79 [00:21<00:00,  3.62it/s]


epoch : 6  Evaluating | loss = 0.06327248364686966


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 7  Training | loss = 0.07380706071853638


100%|██████████| 79/79 [00:21<00:00,  3.68it/s]


epoch : 7  Evaluating | loss = 0.06533219665288925


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 8  Training | loss = 0.07215628772974014


100%|██████████| 79/79 [00:21<00:00,  3.62it/s]


epoch : 8  Evaluating | loss = 0.09985854476690292


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 9  Training | loss = 0.07459316402673721


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 9  Evaluating | loss = 0.2776232063770294


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]


epoch : 10  Training | loss = 0.07948454469442368


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 10  Evaluating | loss = 0.14822785556316376


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]


epoch : 11  Training | loss = 0.07968715578317642


100%|██████████| 79/79 [00:21<00:00,  3.68it/s]


epoch : 11  Evaluating | loss = 0.14491082727909088


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 12  Training | loss = 0.06863374263048172


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 12  Evaluating | loss = 0.08341795951128006


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 13  Training | loss = 0.06459419429302216


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 13  Evaluating | loss = 0.1529986709356308


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 14  Training | loss = 0.06499915570020676


100%|██████████| 79/79 [00:21<00:00,  3.61it/s]


epoch : 14  Evaluating | loss = 0.054966047406196594


100%|██████████| 391/391 [01:23<00:00,  4.71it/s]


epoch : 15  Training | loss = 0.062483858317136765


100%|██████████| 79/79 [00:21<00:00,  3.68it/s]


epoch : 15  Evaluating | loss = 0.07420305162668228


100%|██████████| 391/391 [01:23<00:00,  4.71it/s]


epoch : 16  Training | loss = 0.06196196749806404


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 16  Evaluating | loss = 0.10857510566711426


100%|██████████| 391/391 [01:23<00:00,  4.71it/s]


epoch : 17  Training | loss = 0.061826933175325394


100%|██████████| 79/79 [00:21<00:00,  3.66it/s]


epoch : 17  Evaluating | loss = 0.07469014823436737


100%|██████████| 391/391 [01:23<00:00,  4.70it/s]


epoch : 18  Training | loss = 0.06272429972887039


100%|██████████| 79/79 [00:22<00:00,  3.59it/s]


epoch : 18  Evaluating | loss = 0.08077800273895264


100%|██████████| 391/391 [01:23<00:00,  4.71it/s]


epoch : 19  Training | loss = 0.06412392109632492


100%|██████████| 79/79 [00:21<00:00,  3.69it/s]


epoch : 19  Evaluating | loss = 0.05516789108514786


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 20  Training | loss = 0.06435241550207138


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 20  Evaluating | loss = 0.1266840696334839


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]


epoch : 21  Training | loss = 0.06425230950117111


100%|██████████| 79/79 [00:21<00:00,  3.66it/s]


epoch : 21  Evaluating | loss = 0.08734531700611115


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 22  Training | loss = 0.06286971271038055


100%|██████████| 79/79 [00:21<00:00,  3.65it/s]


epoch : 22  Evaluating | loss = 0.05005383864045143


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 23  Training | loss = 0.06277519464492798


100%|██████████| 79/79 [00:21<00:00,  3.70it/s]


epoch : 23  Evaluating | loss = 0.2715727984905243


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 24  Training | loss = 0.06291564553976059


100%|██████████| 79/79 [00:21<00:00,  3.65it/s]


epoch : 24  Evaluating | loss = 0.06855276226997375


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]


epoch : 25  Training | loss = 0.06355239450931549


100%|██████████| 79/79 [00:21<00:00,  3.65it/s]


epoch : 25  Evaluating | loss = 0.04652928560972214


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]


epoch : 26  Training | loss = 0.06225525215268135


100%|██████████| 79/79 [00:21<00:00,  3.63it/s]


epoch : 26  Evaluating | loss = 0.0360109768807888


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 27  Training | loss = 0.06266043335199356


100%|██████████| 79/79 [00:21<00:00,  3.67it/s]


epoch : 27  Evaluating | loss = 0.03404518589377403


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 28  Training | loss = 0.06187600642442703


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 28  Evaluating | loss = 0.03563418239355087


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 29  Training | loss = 0.06222005560994148


100%|██████████| 79/79 [00:21<00:00,  3.65it/s]


epoch : 29  Evaluating | loss = 0.03638802468776703


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 30  Training | loss = 0.06190810725092888


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 30  Evaluating | loss = 0.03216036036610603


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 31  Training | loss = 0.0637301430106163


100%|██████████| 79/79 [00:21<00:00,  3.67it/s]


epoch : 31  Evaluating | loss = 0.03874235600233078


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]


epoch : 32  Training | loss = 0.06297396868467331


100%|██████████| 79/79 [00:21<00:00,  3.63it/s]


epoch : 32  Evaluating | loss = 0.0376012809574604


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 33  Training | loss = 0.06302919238805771


100%|██████████| 79/79 [00:22<00:00,  3.55it/s]


epoch : 33  Evaluating | loss = 0.03008992038667202


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 34  Training | loss = 0.06147877499461174


100%|██████████| 79/79 [00:21<00:00,  3.63it/s]


epoch : 34  Evaluating | loss = 0.0364372655749321


100%|██████████| 391/391 [01:22<00:00,  4.71it/s]


epoch : 35  Training | loss = 0.06140700727701187


100%|██████████| 79/79 [00:21<00:00,  3.68it/s]


epoch : 35  Evaluating | loss = 0.028593098744750023


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 36  Training | loss = 0.06123320013284683


100%|██████████| 79/79 [00:21<00:00,  3.63it/s]


epoch : 36  Evaluating | loss = 0.03130478411912918


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 37  Training | loss = 0.061799097806215286


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 37  Evaluating | loss = 0.030790451914072037


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 38  Training | loss = 0.062467943876981735


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 38  Evaluating | loss = 0.036175716668367386


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 39  Training | loss = 0.06305914372205734


100%|██████████| 79/79 [00:21<00:00,  3.69it/s]


epoch : 39  Evaluating | loss = 0.029732581228017807


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 40  Training | loss = 0.06377999484539032


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 40  Evaluating | loss = 0.030101805925369263


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 41  Training | loss = 0.06277798861265182


100%|██████████| 79/79 [00:21<00:00,  3.65it/s]


epoch : 41  Evaluating | loss = 0.033508557826280594


100%|██████████| 391/391 [01:22<00:00,  4.72it/s]


epoch : 42  Training | loss = 0.06355450302362442


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 42  Evaluating | loss = 0.036230187863111496


100%|██████████| 391/391 [01:22<00:00,  4.73it/s]


epoch : 43  Training | loss = 0.06205055117607117


100%|██████████| 79/79 [00:21<00:00,  3.69it/s]


epoch : 43  Evaluating | loss = 0.04031214863061905


100%|██████████| 391/391 [01:22<00:00,  4.73it/s]


epoch : 44  Training | loss = 0.06108712777495384


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 44  Evaluating | loss = 0.031232591718435287


100%|██████████| 391/391 [01:22<00:00,  4.73it/s]


epoch : 45  Training | loss = 0.05955980345606804


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 45  Evaluating | loss = 0.031583014875650406


100%|██████████| 391/391 [01:22<00:00,  4.73it/s]


epoch : 46  Training | loss = 0.05906915292143822


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 46  Evaluating | loss = 0.033370498567819595


100%|██████████| 391/391 [01:22<00:00,  4.73it/s]


epoch : 47  Training | loss = 0.05989238619804382


100%|██████████| 79/79 [00:21<00:00,  3.69it/s]


epoch : 47  Evaluating | loss = 0.030488116666674614


100%|██████████| 391/391 [01:22<00:00,  4.73it/s]


epoch : 48  Training | loss = 0.06030337139964104


100%|██████████| 79/79 [00:21<00:00,  3.64it/s]


epoch : 48  Evaluating | loss = 0.04306329786777496


100%|██████████| 391/391 [01:22<00:00,  4.73it/s]


epoch : 49  Training | loss = 0.06171285733580589


100%|██████████| 79/79 [00:21<00:00,  3.66it/s]


epoch : 49  Evaluating | loss = 0.03361659124493599


In [6]:
target_model = ResNet18()
target_model = nn.DataParallel(target_model).to(device)
checkpoint = torch.load('pretrain/basic_training')
target_model.load_state_dict(checkpoint['net'])
target_model.eval()
adversary = LinfPGDAttack(target_model)

In [7]:
accuracy = 0
total = 0
idx = 0

new_path = 'cifar_test'
if not os.path.isdir(new_path):
    os.mkdir(new_path)

for batch in tqdm(test_loader):
    imgs, labels = batch
    imgs, labels = imgs.to(device), labels.to(device)
    adv_imgs = adversary.perturb(imgs, labels)
    with torch.no_grad():
        output = model(adv_imgs)
    
    logits = target_model(output)
    _, predicted = logits.max(1)
    accuracy += predicted.eq(labels).sum().item()
    total += imgs.size(0)
    
    for i in range(imgs.size(0)):
        img = toPIL(output[i].squeeze(0)).convert('RGB')
        img.save(f'{new_path}/img_{idx}.png')
        idx += 1
            
print(f'epoch : {epoch}  Evaluating | Accuracy = {accuracy / total}')

100%|██████████| 79/79 [00:21<00:00,  3.68it/s]

epoch : 49  Evaluating | Accuracy = 0.6814



