In [1]:
import os
from os import listdir
from os.path import join,isfile
from tqdm import tqdm

import time
import copy

import cv2
import numpy as np
from tqdm import tqdm

import torch
from torch import nn
from torch import optim
from torch import autograd
from torch.nn import functional as F
from torchvision import models

from misc_function import processImage, detail_enhance_lab, recreate_image, PreidictLabel, AdvLoss
from module import DeepGuidedFilter
from utils import Config
from my_model import Model

In [5]:
def run(config, dataset_path, dataset_smooth_path, image_list, idx):
    # 모델 저장
    save_path = config.SAVE
    if not os.path.isdir(save_path):
        os.makedirs(save_path)

    # Adv img 폴더
    adv_path =    '../AdvImg_mobilenet_v2/'
    if not os.path.isdir(adv_path):
        os.makedirs(adv_path)
        
    # 스무딩 loss
    
    criterion_L1 = nn.L1Loss()
    criterion_L2 = nn.MSELoss()
    optimizer = optim.Adam(config.model.parameters(), lr=config.LR)

    with torch.cuda.device(0):
        config.model.cuda()
        criterion_L1.cuda()
        criterion_L2.cuda()
    # Load 모델
    classifier = Model()
    classifier.load_state_dict(torch.load(config.model_weight))

    classifier.eval()
    classifier.cuda()

    # 모델 FC-layer 고정
    for param in classifier.parameters():
        param.requires_grad = False

    img_name = image_list[idx]
    
    # 원본이미지, 스무딩이미지 전처리하고 tensor로 바꿔줌.
    x= processImage(dataset_path,img_name)    
    gt_smooth = processImage(dataset_smooth_path,img_name)
    
    # 정답 class, logit
    class_x, logit_x = PreidictLabel(x, classifier)

    #FCNN 최대 iter 설정
    maxIters = 5000


    for it in range(maxIters): 
        t = time.time()
        '''
            gt_smooth : [11]로 스무딩된 이미지
            x_smooth : 뉴럴네트워크를 이용해서 스무딩 하는법을 배움.
        '''
        with autograd.detect_anomaly():
            # x를 guided filter에 넣고 smoothing 하는 방법을 배움. 
            x_smooth= config.forward(x, gt_smooth, config)
            #디테일 강화
            enh = detail_enhance_lab(x,x_smooth)           
            # 디테일 강회 이미지 class, logit
            class_enh, logit_enh = PreidictLabel(enh.permute(2,0,1).unsqueeze(dim=0), classifier)
            
            #smoothing, adv loss
            loss1 = criterion_L2(x_smooth, gt_smooth)
            loss2 = criterion_L1(x_smooth, gt_smooth)
            loss3 = AdvLoss(logit_enh, class_x)
                    
            # smoothing_loss, Adv_loss 비율 설정
            loss = 5*loss1 + 5*loss2 + loss3
            
            optimizer.zero_grad()
            loss.backward()
            
            if config.clip is not None:
                torch.nn.utils.clip_grad_norm(config.model.parameters(), config.clip)
            optimizer.step()
            
            
            check_enh = recreate_image(enh)
            check_enh = torch.from_numpy(np.flip(check_enh,axis=0).copy()).cuda()
            class_enh, _ = PreidictLabel(check_enh.permute(2,0,1).unsqueeze(dim=0), classifier)

            if (class_x != class_enh): 
                cv2.imwrite('{}{}'.format(adv_path,img_name), recreate_image(enh))
                # 스무딩 loss. 값이 클수록 변형이 큼.
                if (loss1< 0.0001):
                    break

    # Save the FCNN
    torch.save(config.model.state_dict(), os.path.join(save_path, 'model_latest.pth'))

In [6]:
def forward(imgs,gt, config):
    x_hr= imgs
    gt_hr=gt
    return config.model(x_hr, x_hr)

dataset_path  ='../Dataset/'
dataset_smooth_path = '../Smoothing_Imgs/'

default_config = Config(
    N_START = 0,
    N_EPOCH = 100,
    SAVE = 'ckpt',
    LR = 0.001,
    # clip
    clip = 0.01,
    # model
    model = DeepGuidedFilter(),
    forward = None,
    model_weight = '../save/mobilenet_v2.ckpt'
)
image_list =  [f for f in listdir(dataset_path) if isfile(join(dataset_path,f))]
NumImg=len(image_list)
# Configuration
config = copy.deepcopy(default_config)
config.forward = forward

In [7]:
for idx in tqdm(range(NumImg)):
    run(config, dataset_path, dataset_smooth_path, image_list, idx)

  0%|          | 0/15 [00:00<?, ?it/s]Using cache found in /home/kjwspecial/.cache/torch/hub/pytorch_vision_v0.6.0
  7%|▋         | 1/15 [14:42<3:25:53, 882.41s/it]Using cache found in /home/kjwspecial/.cache/torch/hub/pytorch_vision_v0.6.0
 13%|█▎        | 2/15 [16:15<2:19:54, 645.71s/it]Using cache found in /home/kjwspecial/.cache/torch/hub/pytorch_vision_v0.6.0
 20%|██        | 3/15 [18:02<1:36:48, 484.03s/it]Using cache found in /home/kjwspecial/.cache/torch/hub/pytorch_vision_v0.6.0
 27%|██▋       | 4/15 [32:28<1:49:43, 598.53s/it]Using cache found in /home/kjwspecial/.cache/torch/hub/pytorch_vision_v0.6.0
 33%|███▎      | 5/15 [38:16<1:27:15, 523.58s/it]Using cache found in /home/kjwspecial/.cache/torch/hub/pytorch_vision_v0.6.0
 40%|████      | 6/15 [52:48<1:34:12, 628.09s/it]Using cache found in /home/kjwspecial/.cache/torch/hub/pytorch_vision_v0.6.0
 47%|████▋     | 7/15 [57:13<1:09:13, 519.15s/it]Using cache found in /home/kjwspecial/.cache/torch/hub/pytorch_vision_v0.6.0
 53