### 디렉토리 구조
- dataset : 이미지 파일 저장 폴더
- pretrain : pre-trained 모델 파일저장 
- loaders : 모델 로드 관련 소스코드
- models : 모델 구조 관련 소스코드 
- result : dewarping 된 이미지 결과 저장

In [None]:
import sys, os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from glob import glob
import cv2


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
from torch.utils import data


from models import get_model
from loaders import get_loader
from utils import convert_state_dict

# model save path
wc_model_path = './pretrain/unetnc_doc3d.pkl' # shape network
bm_model_path = './pretrain/dnetccnl_doc3d.pkl' # texture mapping network

In [None]:
def unwarp(img, bm):
    w,h=img.shape[0],img.shape[1]
    bm = bm.transpose(1, 2).transpose(2, 3).detach().cpu().numpy()[0,:,:,:]
    bm0=cv2.blur(bm[:,:,0],(3,3))
    bm1=cv2.blur(bm[:,:,1],(3,3))
    bm0=cv2.resize(bm0,(h,w))
    bm1=cv2.resize(bm1,(h,w))
    bm=np.stack([bm0,bm1],axis=-1)
    bm=np.expand_dims(bm,0)
    bm=torch.from_numpy(bm).double()

    img = img.astype(float) / 255.0
    img = img.transpose((2, 0, 1))
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img).double()

    res = F.grid_sample(input=img, grid=bm)
    res = res[0].numpy().transpose((1, 2, 0))

    return res

In [None]:
def predict_image(img_path, wc_model_path, bm_model_path):
    wc_model_file_name = os.path.split(wc_model_path)[1]
    wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')]

    bm_model_file_name = os.path.split(bm_model_path)[1]
    bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')]
    
    wc_n_classes = 3
    bm_n_classes = 2

    wc_img_size=(256,256)
    bm_img_size=(128,128)
    
    # Image Read
    print("Read Input Image from : {}".format(img_path))
    imgorg = cv2.imread(img_path)
    imgorg = cv2.cvtColor(imgorg, cv2.COLOR_BGR2RGB)
    img = cv2.resize(imgorg, wc_img_size)
    img = img[:, :, ::-1]
    img = img.astype(float) / 255.0
    img = img.transpose(2, 0, 1) # NHWC -> NCHW
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img).float()

    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    htan = nn.Hardtanh(0,1.0)
    
    # 모델로드
    wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3)
    
    if DEVICE.type == 'cpu':
        wc_state = convert_state_dict(torch.load(wc_model_path, map_location='cpu')['model_state'])
    else:
        wc_state = convert_state_dict(torch.load(wc_model_path)['model_state'])
        
        
    wc_model.load_state_dict(wc_state)
    wc_model.eval()
    bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3)
    if DEVICE.type == 'cpu':
        bm_state = convert_state_dict(torch.load(bm_model_path, map_location='cpu')['model_state'])
    else:
        bm_state = convert_state_dict(torch.load(bm_model_path)['model_state'])
    bm_model.load_state_dict(bm_state)
    bm_model.eval()

    if torch.cuda.is_available():
        wc_model.cuda()
        bm_model.cuda()
        images = Variable(img.cuda())
    else:
        images = Variable(img)

    with torch.no_grad():
        wc_outputs = wc_model(images)
        pred_wc = htan(wc_outputs)
        bm_input=F.interpolate(pred_wc, bm_img_size)
        outputs_bm = bm_model(bm_input)

    # call unwarp
    uwpred=unwarp(imgorg, outputs_bm)
    
    return uwpred[:,:,::-1]*255

In [None]:
images = glob('dataset/*.*')
for img_path in images:
    img = cv2.imread(img_path)
    
    # prediction
    result = predict_image(img_path, wc_model_path, bm_model_path)
    result = np.array(result,dtype=np.uint8)
    fig, (ax1, ax2) = plt.subplots(1, 2,figsize =(16,16))

    ax1.imshow(img)
    ax2.imshow(result)
    fig.show()
    fig.savefig('./result/'+img_path.split('/')[-1][:-3]+'png')