In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import DatasetFolder 
import torch.backends.cudnn as cudnn
import torchvision
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import os
from models import *
from function import *
from tqdm import tqdm

!nvidia-smi
print("CUDA Available : ", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Wed Sep 20 19:23:59 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 537.34                 Driver Version: 537.34       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090      WDDM  | 00000000:2B:00.0  On |                  N/A |
| 55%   57C    P8              46W / 350W |   8854MiB / 24576MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
# Argument
batch_size = 32
learning_rate = 1e-6
num_epochs = 50
momentum = 0.9
weight_decay = 1e-9
class_num = 100
output_path = 'test'

if not os.path.isdir(output_path):
    os.mkdir(output_path)

In [3]:
# Preprocessing
test_tfm = transforms.Compose([
    transforms.Resize([160, 160]),
    transforms.ToTensor(),
])
toPIL = transforms.ToPILImage()

In [4]:
# Casia
train_set = DatasetFolder("../../dataset/casia100_dataset/casia", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
train_set_size = int(len(train_set) * 0.8)
train_set, valid_set = torch.utils.data.random_split(train_set, [train_set_size, len(train_set) - train_set_size])

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=0)

Debugger :  class -> idx 
{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '10': 10, '11': 11, '12': 12, '13': 13, '14': 14, '15': 15, '16': 16, '17': 17, '18': 18, '19': 19, '20': 20, '21': 21, '22': 22, '23': 23, '24': 24, '25': 25, '26': 26, '27': 27, '28': 28, '29': 29, '30': 30, '31': 31, '32': 32, '33': 33, '34': 34, '35': 35, '36': 36, '37': 37, '38': 38, '39': 39, '40': 40, '41': 41, '42': 42, '43': 43, '44': 44, '45': 45, '46': 46, '47': 47, '48': 48, '49': 49, '50': 50, '51': 51, '52': 52, '53': 53, '54': 54, '55': 55, '56': 56, '57': 57, '58': 58, '59': 59, '60': 60, '61': 61, '62': 62, '63': 63, '64': 64, '65': 65, '66': 66, '67': 67, '68': 68, '69': 69, '70': 70, '71': 71, '72': 72, '73': 73, '74': 74, '75': 75, '76': 76, '77': 77, '78': 78, '79': 79, '80': 80, '81': 81, '82': 82, '83': 83, '84': 84, '85': 85, '86': 86, '87': 87, '88': 88, '89': 89, '90': 90, '91': 91, '92': 92, '93': 93, '94': 94, '95': 95, '96': 96, '97': 97, '98': 98, '99

In [5]:
model = Unet(n_channels = 3, n_classes = 3, bilinear = False)
model = model.to(device)
write_path = 'casia_result'

target_model = InceptionResnetV1(pretrained='casia-webface',classify=None).to(device)
target_model.logits = nn.Linear(in_features = target_model.logits.in_features, out_features = class_num)
target_model.last_linear = nn.Linear(in_features = target_model.last_linear.in_features, out_features = class_num, bias = True)
target_model.last_bn = nn.BatchNorm1d(num_features = class_num)
cudnn.benchmark = True
target_model.to(device)
adversary = LinfPGDAttack(target_model)


if not os.path.isdir(write_path):
    os.mkdir(write_path)
    
optimizer = optim.RMSprop(model.parameters(), lr = learning_rate, weight_decay = weight_decay, momentum = momentum, foreach = True)
criterion = diceLoss(3)
#criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    #idx = 0
    # clean train
    for batch in tqdm(train_loader):
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        adv_imgs = adversary.perturb(imgs, labels)
        
        output = model(imgs)
        loss = criterion(output, imgs)
        
        output = model(adv_imgs)
        loss += criterion(output, imgs)
        
        
        epoch_loss += loss
        
        loss.backward()
        optimizer.step()

    print(f'epoch : {epoch}  Training | loss = {epoch_loss / imgs.shape[0]}')
    model.eval()
    eval_loss = 0
    idx = 0
    
    for batch in tqdm(test_loader):
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        adv_imgs = adversary.perturb(imgs, labels)
        with torch.no_grad():
            output = model(adv_imgs)
            loss = criterion(output, imgs)
            eval_loss += loss
        
        for i in range(imgs.size(0)):
            img = toPIL(output[i].squeeze(0)).convert('RGB')
            img.save(f'{write_path}/img_{idx}.png')
            idx += 1
            
    print(f'epoch : {epoch}  Evaluating | loss = {eval_loss / imgs.shape[0]}')
torch.save(model.state_dict(), './pretrain/denoiser_casia.pth')

  0%|          | 0/399 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [7]:
model = Unet(n_channels = 3, n_classes = 3, bilinear = False)
model = model.to(device)
checkpoint = torch.load('pretrain/denoiser_casia.pth')
model.load_state_dict(checkpoint, strict = False)

<All keys matched successfully>

In [8]:
# load networks
target_model = InceptionResnetV1(pretrained = None, classify = None, num_classes = None)
target_model.last_linear = nn.Linear(in_features = target_model.last_linear.in_features, out_features = 100, bias = True).cuda()
target_model.last_bn = nn.BatchNorm1d(num_features = 100).cuda()
target_model = target_model.to(device)
checkpoint = torch.load('pretrain/facenet_casia.pth')
target_model.load_state_dict(checkpoint, strict = False)
criterion = nn.CrossEntropyLoss()

adversary = LinfPGDAttack(target_model)

In [9]:
model.eval()
target_model.eval()
eval_loss = 0
total = 0
accuracy = 0
new_path = 'casia_test'
if not os.path.isdir(new_path):
    os.mkdir(new_path)

    
for batch in tqdm(test_loader):
    imgs, labels = batch
    imgs, labels = imgs.to(device), labels.to(device)
    #adv_imgs = adversary.perturb(imgs, labels)
    #with torch.no_grad():
    #    output = model(adv_imgs)
    #    loss = criterion(output, imgs)
    #    eval_loss += loss
        
    logits = target_model(imgs)
    _, predicted = logits.max(1)
    #print(predicted)
    #print(labels)
    accuracy += predicted.eq(labels).sum().item()
    total += imgs.size(0)
    #for i in range(imgs.size(0)):
    #    img = toPIL(output[i].squeeze(0)).convert('RGB')
    #    img.save(f'{new_path}/img_{idx}.png')
    #    idx += 1
print(accuracy / total)
    

100%|██████████| 100/100 [00:05<00:00, 17.49it/s]

0.9799184185754628



