論文<br>
https://arxiv.org/abs/2109.07161<br>
<br>
GitHub<br>
https://github.com/saic-mdal/lama<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/Lama_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 環境セットアップ

## GitHubからソースコードを取得

In [None]:
%cd /content
!git clone https://github.com/saic-mdal/lama.git

In [None]:
# Pytorch 1.8.0をインストール
# その他ライブラリをインストール
!pip install -r lama/requirements.txt --quiet
!pip install wget --quiet
!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 torchtext==0.9 -f https://download.pytorch.org/whl/torch_stable.html --quiet


In [None]:
# avoid AttributeError: 'builtin_function_or_method' object has no attribute 'rfftn'
!sed -E -i "15i import torch.fft" /content/lama/saicinpainting/training/modules/ffc.py

## 学習済みモデルのセットアップ

In [None]:
% cd /content/lama

!curl -L $(yadisk-direct https://disk.yandex.ru/d/ouP6l8VJ0HpMZg) -o big-lama.zip
!unzip big-lama.zip

## ライブラリをインポート

In [None]:
import base64, os
from IPython.display import HTML, Image
from google.colab.output import eval_js
from base64 import b64decode
import matplotlib.pyplot as plt
import numpy as np
import wget
from shutil import copyfile
import shutil

# Canvasのセットアップ

In [None]:

canvas_html = """
<style>
.button {
  background-color: #4CAF50;
  border: none;
  color: white;
  padding: 15px 32px;
  text-align: center;
  text-decoration: none;
  display: inline-block;
  font-size: 16px;
  margin: 4px 2px;
  cursor: pointer;
}
</style>
<canvas1 width=%d height=%d>
</canvas1>
<canvas width=%d height=%d>
</canvas>

<button class="button">Finish</button>
<script>
var canvas = document.querySelector('canvas')
var ctx = canvas.getContext('2d')

var canvas1 = document.querySelector('canvas1')
var ctx1 = canvas.getContext('2d')


ctx.strokeStyle = 'red';

var img = new Image();
img.src = "data:image/%s;charset=utf-8;base64,%s";
console.log(img)
img.onload = function() {
  ctx1.drawImage(img, 0, 0);
};
img.crossOrigin = 'Anonymous';

ctx.clearRect(0, 0, canvas.width, canvas.height);

ctx.lineWidth = %d
var button = document.querySelector('button')
var mouse = {x: 0, y: 0}

canvas.addEventListener('mousemove', function(e) {
  mouse.x = e.pageX - this.offsetLeft
  mouse.y = e.pageY - this.offsetTop
})
canvas.onmousedown = ()=>{
  ctx.beginPath()
  ctx.moveTo(mouse.x, mouse.y)
  canvas.addEventListener('mousemove', onPaint)
}
canvas.onmouseup = ()=>{
  canvas.removeEventListener('mousemove', onPaint)
}
var onPaint = ()=>{
  ctx.lineTo(mouse.x, mouse.y)
  ctx.stroke()
}

var data = new Promise(resolve=>{
  button.onclick = ()=>{
    resolve(canvas.toDataURL('image/png'))
  }
})
</script>
"""

In [None]:
def draw(imgm, filename='drawing.png', w=400, h=200, line_width=1):
  display(HTML(canvas_html % (w, h, w,h, filename.split('.')[-1], imgm, line_width)))
  data = eval_js("data")
  binary = b64decode(data.split(',')[1])
  with open(filename, 'wb') as f:
    f.write(binary)

# 画像のセットアップ
[使用画像1](https://www.pakutaso.com/shared/img/thumb/PAK85_oyakudachisimasu20140830_TP_V.jpg)<br>
[使用画像2](https://www.pakutaso.com/shared/img/thumb/TSU88_awaitoykyo_TP_V.jpg)<br>
[使用画像3](https://www.pakutaso.com/20211208341post-37933.html)

In [None]:
% cd /content/lama

from google.colab import files
files = files.upload()
fname = list(files.keys())[0]

shutil.rmtree('./data_for_prediction', ignore_errors=True)
! mkdir data_for_prediction

copyfile(fname, f'./data_for_prediction/{fname}')
os.remove(fname)
fname = f'./data_for_prediction/{fname}'

image64 = base64.b64encode(open(fname, 'rb').read())
image64 = image64.decode('utf-8')

print(f'Will use {fname} for inpainting')
img = np.array(plt.imread(f'{fname}')[:,:,:3])

# inpainting

In [None]:
mask_path = f".{fname.split('.')[1]}_mask.png"
draw(image64, filename=mask_path, w=img.shape[1], h=img.shape[0], line_width=0.04*img.shape[1])

with_mask = np.array(plt.imread(mask_path)[:,:,:3])
mask = (with_mask[:,:,0]==1)*(with_mask[:,:,1]==0)*(with_mask[:,:,2]==0)
plt.imsave(mask_path,mask, cmap='gray')

In [None]:
splitpaths = os.path.splitext(os.path.basename(fname))
suffix=splitpaths[(len(splitpaths)-1)]

In [None]:
%cd /content/lama

!mkdir output/
copyfile(mask_path,os.path.join("./output/", os.path.basename(mask_path)))

!PYTHONPATH=. TORCH_HOME=$(pwd) python3 bin/predict.py \
  model.path=$(pwd)/big-lama \
  indir=$(pwd)/data_for_prediction \
  outdir=/content/lama/output \
  dataset.img_suffix={suffix}

plt.rcParams['figure.dpi'] = 200
plt.imshow(plt.imread(f"/content/lama/output/{fname.split('.')[1].split('/')[2]}_mask.png"))
_=plt.axis('off')
_=plt.title('inpainting result')
plt.show()
fname = None