論文  
https://arxiv.org/abs/2201.02233<br>
<br>
GitHub  
https://github.com/luoxuan-cs/PAMA<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/PAMA_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

## GitHubからコード取得

In [None]:
%cd /content

!git clone https://github.com/luoxuan-cs/PAMA.git

## ライブラリのインストール

In [None]:
%cd /content/PAMA

!pip install --upgrade gdown

## ライブラリのインポート

In [None]:
%cd /content/PAMA

import os
import gdown
import shutil
import argparse
import glob

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision.utils import save_image
from PIL import Image, ImageFile
from net import Net
from utils import DEVICE, train_transform, test_transform, FlatFolderDataset, InfiniteSamplerWrapper, plot_grad_flow, adjust_learning_rate
Image.MAX_IMAGE_PIXELS = None  
ImageFile.LOAD_TRUNCATED_IMAGES = True

## 学習済みモデルのダウンロード

In [None]:
%cd /content/PAMA
!mkdir checkpoints

# https://drive.google.com/file/d/1rPB_qnelVVSad6CtadmhRFi0PMI_RKdy/view
original = 'checkpoints/original_PAMA.zip'
if not os.path.exists(original):
  gdown.download('https://drive.google.com/uc?id='+'1rPB_qnelVVSad6CtadmhRFi0PMI_RKdy', original, quiet=False)
  shutil.unpack_archive(original, 'checkpoints')

# https://drive.google.com/file/d/1IrggOiutiZceJCrEb24cLnBjeA5I3N1D/view
wo_color = 'checkpoints/PAMA_without_color.zip'
if not os.path.exists(wo_color):
  gdown.download('https://drive.google.com/uc?id='+'1IrggOiutiZceJCrEb24cLnBjeA5I3N1D', wo_color, quiet=False)
  shutil.unpack_archive(wo_color, 'checkpoints')

# https://drive.google.com/file/d/1HXet2u_zk2QCVM_z5Llg2bcfvvndabtt/view
color = 'checkpoints/PAMA_1.5_color.zip'
if not os.path.exists(color):
  gdown.download('https://drive.google.com/uc?id='+'1HXet2u_zk2QCVM_z5Llg2bcfvvndabtt', color, quiet=False)
  shutil.unpack_archive(color, 'checkpoints')

# https://drive.google.com/file/d/13m7Lb9xwfG_DVOesuG9PyxDHG4SwqlNt/view
content = 'checkpoints/PAMA_1.5_content.zip'
if not os.path.exists(content):
  gdown.download('https://drive.google.com/uc?id='+"13m7Lb9xwfG_DVOesuG9PyxDHG4SwqlNt", content, quiet=False)
  shutil.unpack_archive(content, 'checkpoints')


# テスト画像取得

In [None]:
%cd /content/PAMA
!mkdir -p tests/contents tests/styles

!wget -c https://www.pakutaso.com/shared/img/thumb/20220227-A7401834_TP_V4.jpg \
      -O ./tests/contents/test_1.jpg
!wget -c https://www.pakutaso.com/shared/img/thumb/SAYA160312500I9A3721_TP_V4.jpg \
      -O ./tests/contents/test_2.jpg
!wget -c https://www.pakutaso.com/shared/img/thumb/unific528--8628_TP_V4.jpg \
      -O ./tests/contents/test_3.jpg
!wget -c https://www.publicdomainpictures.net/pictures/80000/nahled/animal-sketch-13919381209K9.jpg \
      -O ./tests/styles/style_1.jpg
!wget -c https://www.publicdomainpictures.net/pictures/390000/velka/the-starry-night-van-gogh.jpg \
      -O ./tests/styles/style_2.jpg
!wget -c https://jojo-animation.com/img/top/mv_2.jpg \
      -O ./tests/styles/style_3.jpg

# Inference

## 使用モデル選択

In [None]:
model_type = "ORIGINAL" #@param ["ORIGINAL", "WO_COLOR", "COLOR", "CONTENT"]

In [None]:
%cd /content/PAMA
!rm -rf ./checkpoints/*.pth
if model_type == "ORIGINAL":
  !cp ./checkpoints/original_PAMA/*.pth ./checkpoints
if model_type == "WO_COLOR":
  !cp ./checkpoints/original_PAMA/encoder.pth ./checkpoints
  !cp ./checkpoints/PAMA_without_color/*.pth ./checkpoints
if model_type == "COLOR":
  !cp ./checkpoints/original_PAMA/encoder.pth ./checkpoints
  !cp ./checkpoints/PAMA_1.5_color/*.pth ./checkpoints
if model_type == "CONTENT":
  !cp ./checkpoints/original_PAMA/encoder.pth ./checkpoints
  !cp ./checkpoints/PAMA_1.5_content/*.pth ./checkpoints


In [None]:
args = argparse.ArgumentParser()
args.pretrained = True
args.requires_grad = True
args.training = False

args.outdir = '/content/PAMA/tests_result'

args.run_folder = True
args.content = '/content/PAMA/tests/contents'
args.style = '/content/PAMA/tests/styles'
# args.run_folder = False
# args.content = '/content/PAMA/tests/contents/test_1.jpg'
# args.style = '/content/PAMA/tests/styles/style_1.jpg'

In [None]:
def inference(model, device, content_path, style_path, out):
  Ic = tf(Image.open(content_path)).to(device)
  Is = tf(Image.open(style_path)).to(device)
  Ic = Ic.unsqueeze(dim=0)
  Is = Is.unsqueeze(dim=0)

  with torch.no_grad():
    Ics = model(Ic, Is)

  os.makedirs(out, exist_ok=True)
  filename = "res_" + os.path.splitext(os.path.basename(style_path))[0] + "_" + os.path.basename(content_path)
  name_cs = os.path.join(out, filename)
  save_image(Ics[0], name_cs)

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Net(args)
model.eval()

model = model.to(DEVICE)

tf = test_transform()
if args.run_folder == True:
  for content_path in glob.glob(os.path.join(args.content, "*.*")):
    for style_path in glob.glob(os.path.join(args.style, "*.*")):
      inference(model, DEVICE, content_path, style_path, args.outdir)
else:
  inference(model, DEVICE, args.content, args.style, args.outdir)