<a href="https://colab.research.google.com/github/hodaka/MakeTrimmingMap/blob/work/%E8%87%AA%E5%8B%95%E3%83%88%E3%83%AA%E3%83%9F%E3%83%B3%E3%82%B0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 人物を切り抜く準備としてのトリミングマップを作製するためのスクリプト


---


In [None]:
import torch
import os
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt

from torchvision.models.segmentation import deeplabv3_resnet101
from torchvision import transforms
from IPython.display import Image
from google.colab import files
from IPython.display import Image, display

from google.colab import drive
drive.mount('/gdrive')

def make_deeplab(device):
    deeplab = deeplabv3_resnet101(pretrained=True).to(device)
    deeplab.eval()
    return deeplab

deeplab_preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def apply_deeplab(deeplab, img, device):
  input_tensor = deeplab_preprocess(img)
  input_batch = input_tensor.unsqueeze(0)
  with torch.no_grad():
      output = deeplab(input_batch.to(device))['out'][0]
  output_predictions = output.argmax(0).cpu().numpy()
  return (output_predictions == 15)

# 境界を定義したトリミングマップを作成
def make_trimap(masking_image):
  print('masking_image: ', masking_image)
  trimap = np.zeros((masking_image.shape[0], masking_image.shape[1], 2))
  trimap[:, :, 1] = masking_image > 0
  trimap[:, :, 0] = masking_image == 0
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(12,12))
  trimap[:, :, 0] = cv2.erode(trimap[:, :, 0], kernel)
  trimap[:, :, 1] = cv2.erode(trimap[:, :, 1], kernel)
  return  trimap[:,:,1] + (1-np.sum(trimap,-1))/2

# 読み取り対象のフォルダパスを指定
read_path = "/gdrive/MyDrive/ImageData/NamioHarukawaNoChoice"
# 出力対象のフォルダパスを指定
output_path = read_path + "/ResultTrimap/"

#ディレクトリ作成(なければ)
if not os.path.exists(output_path):
   os.makedirs(output_path)


!git clone https://github.com/MarcoForte/FBA-Matting.git
%cd FBA-Matting

from demo import np_to_torch, pred, scale_input
from dataloader import read_image, read_trimap
from networks.models import build_model

class Args:
  encoder = 'resnet50_GN_WS'
  decoder = 'fba_decoder'
  weights = 'FBA.pth'
args=Args()
try:
    model = build_model(args)
except:
    !gdown  https://drive.google.com/uc?id=1T_oiKDE_biWf2kqexMEN7ObWqtXAzbB1
    model = build_model(args)

print('モデル作成')

device = torch.device("cpu")
deeplab = make_deeplab(device)

print('ループ開始')

for idx,imageFilePath in enumerate(glob.glob(read_path + "/*.*")):
  print('targetPath:' , imageFilePath)
  # 画像の読み込み
  img = cv2.imread(imageFilePath,1)  
  # 輪郭相当のマスクデータを作成
  mask_img = apply_deeplab(deeplab, img, device)
  # 境界を定義したトリミングマップを作成
  print('masking_image: ')    
  trimap = np.zeros((mask_img.shape[0], mask_img.shape[1], 2))
  trimap[:, :, 1] = mask_img > 0
  trimap[:, :, 0] = mask_img == 0
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(12,12))
  trimap[:, :, 0] = cv2.erode(trimap[:, :, 0], kernel)
  trimap[:, :, 1] = cv2.erode(trimap[:, :, 1], kernel)  
  trimap_im = trimap[:,:,1] + (1-np.sum(trimap,-1))/2

  print('トリミングマップ完成')   
  out_color_trimap_path = output_path + 'trimap_c_' + str(idx) + '.jpg';
  plt.imsave(out_color_trimap_path,trimap_im)
  color_trimap = cv2.imread(out_color_trimap_path,1)  
  gray_trimap = cv2.cvtColor(color_trimap, cv2.COLOR_BGR2GRAY)
  out_gray_trimap_path = output_path + 'trimap_g_' + str(idx) + '.jpg';  
  plt.imsave(out_gray_trimap_path,gray_trimap)
  target_image = read_image(imageFilePath)
  target_trimap = read_trimap(out_gray_trimap_path)
  foreground, background, alpha = pred(target_image, target_trimap, model)
  out_result_path = output_path + 'result_' + str(idx) + '.jpg';
  plt.imsave(out_result_path,foreground)  

print('処理完了')  