GitHub<br>
https://github.com/zhangbaijin/SpA-Former-shadow-removal<br>
論文<br>
https://arxiv.org/abs/2206.10910v1<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/SpA_Former_shadow_removal_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

## Githubからソースコード取得

In [None]:
%cd /content

!git clone https://github.com/zhangbaijin/SpA-Former-shadow-removal.git

## ライブラリのインストール

In [None]:
%cd /content/SpA-Former-shadow-removal

!pip install gdown==4.5.1
!pip install einops

## ライブラリのインポート

In [None]:
%cd /content/SpA-Former-shadow-removal

import os
import glob

# 学習済みモデルのセットアップ

In [None]:
%cd /content/SpA-Former-shadow-removal
!mkdir -p ./checkpoint

if not os.path.exists('checkpoint/gen_model_epoch_160.pth'):
  !gdown https://drive.google.com/uc?id=1gLOu4jkqslu_fWpQUd0N4hcNz4qCD8Xc -O checkpoint/gen_model_epoch_160.pth
if not os.path.exists('checkpoint/dis_model_epoch_160.pth'):
  !gdown https://drive.google.com/uc?id=1AcJSAV4oHiYneYUxCwaV0M_W1QU4MdZX -O checkpoint/dis_model_epoch_160.pth

# テスト画像のセットアップ

In [None]:
%cd /content/SpA-Former-shadow-removal
!rm -rf ./test_imgs
!mkdir -p ./test_imgs
!mkdir -p ./datasets

!wget -c https://cdn.pixabay.com/photo/2016/04/11/03/00/run-1321278_960_720.jpg \
      -O ./test_imgs/test_01.jpg

!wget -c https://cdn.pixabay.com/photo/2013/10/20/22/05/shadow-198682_960_720.jpg \
      -O ./test_imgs/test_02.jpg

!wget -c https://cdn.pixabay.com/photo/2015/09/09/20/33/travel-933171_960_720.jpg \
      -O ./test_imgs/test_03.jpg

!wget -c https://cdn.pixabay.com/photo/2021/08/02/18/21/stairs-6517488_960_720.jpg \
      -O ./test_imgs/test_04.jpg

if not os.path.exists('datasets/ISTD_Dataset.rar'):
  !gdown https://drive.google.com/uc?id=1I0qw-65KBA6np8vIZzO6oeiOvcDBttAY -O datasets/ISTD_Dataset.rar

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# https://drive.google.com/file/d/1I0qw-65KBA6np8vIZzO6oeiOvcDBttAY/view
# 上記ファイルのショートカットを自身のGoogle Driveに追加した場合のみ以下を実行

!unrar x /content/drive/MyDrive/ISTD_Dataset.rar /content/SpA-Former-shadow-removal/datasets > /dev/null

# Shadow Remove



In [None]:
%cd /content/SpA-Former-shadow-removal

import numpy as np
import argparse
from cv2 import cv2
import matplotlib.pyplot as plt
import time

import torch
from torch.autograd import Variable

from utils import gpu_manage, heatmap
from SpA_Former import Generator
from google.colab.patches import cv2_imshow

def predict(args):

    gpu_manage(args)
    ### MODELS LOAD ###
    print('===> Loading models')

    gen = Generator(gpu_ids=args.gpu_ids)

    param = torch.load(args.pretrained)
    gen.load_state_dict(param)

    if args.cuda:
        gen = gen.cuda(0)

    print ('<=== Model loaded')

    print('===> Loading test image')
    img = cv2.imread(args.test_filepath, 1).astype(np.float32)
    img = img / 255
    img = img.transpose(2, 0, 1)
    img = img[None]
    print ('<=== test image loaded')

    with torch.no_grad():
        x = torch.from_numpy(img)
        if args.cuda:
            x = x.cuda()
        
        print('===> Removing the cloud...')
        start_time = time.time()
        att, out = gen(x)
        print('<=== finish! %.3fs cost.' % (time.time()-start_time))

        x_ = x.cpu().numpy()[0]
        x_rgb = x_ * 255
        x_rgb = x_rgb.transpose(1, 2, 0).astype('uint8')
        out_ = out.cpu().numpy()[0]
        out_rgb = np.clip(out_[:3], 0, 1) * 255
        out_rgb = out_rgb.transpose(1, 2, 0).astype('uint8')
        att_ = att.cpu().numpy()[0] * 255
        att_heatmap = heatmap(att_.astype('uint8'))[0]
        att_heatmap = att_heatmap.transpose(1, 2, 0)

        allim = np.hstack((x_rgb, out_rgb, att_heatmap))

        cv2_imshow(allim)

images = glob.glob('./test_imgs/*.jpg')        

args = argparse.ArgumentParser()
args.pretrained = './checkpoint/gen_model_epoch_160.pth'
args.cuda = True
args.gpu_ids = [0]
args.manualSeed = 12

for img in images:
  args.test_filepath = img
  predict(args)

In [None]:
import random

datasets = glob.glob('/content/SpA-Former-shadow-removal/datasets/ISTD_Dataset/test/test_A/*.png')

tests = random.sample(datasets, 5)

for img in tests:
  args.test_filepath = img
  predict(args)