<a href="https://colab.research.google.com/github/IsHYuhi/BEDSR-Net_A_Deep_Shadow_Removal_Network_from_a_Single_Document_Image/blob/dev%2Ffix/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BEDSR-Net: A Deep Shadow Removal Network from a Single Document Image

## Yun-Hsuan Lin, Wen-Chin Chen, Yung-Yu Chuang, National Taiwan University,

### This colab. notebook contains the demo of unofficial reimplementation Lin's CVPR 2020 paper by [IsHYuhi](https://github.com/IsHYuhi). 
### More detail can be found in [Paper](https://openaccess.thecvf.com/content_CVPR_2020/html/Lin_BEDSR-Net_A_Deep_Shadow_Removal_Network_From_a_Single_Document_CVPR_2020_paper.html). For detail of the code, check the [repo](https://github.com/IsHYuhi/BEDSR-Net_A_Deep_Shadow_Removal_Network_from_a_Single_Document_Image).

\\

### Note
The results obtained from Jung dataset.

you can test your own images using my pretrained model, check the section  "Testing your own image" 

In [1]:

%matplotlib inline

import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import numpy as np

from libs.models import get_model
from albumentations import (
    Compose,
    Normalize,
    Resize
)
from albumentations.pytorch import ToTensorV2

from utils.visualize import visualize, reverse_normalize
from libs.dataset import get_dataloader
from libs.loss_fn import get_criterion
from libs.helper_bedsrnet import do_one_iteration

def convert_show_image(tensor, idx=None):
    if tensor.shape[1]==3:
        img = reverse_normalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    elif tensor.shape[1]==1:
        img = tensor*0.5+0.5

    if idx is not None:
        img = (img[idx].transpose(1, 2, 0)*255).astype(np.uint8)
    else:
        img = (img.squeeze(axis=0).transpose(1, 2, 0)*255).astype(np.uint8)

    return img

test_transform = Compose([Resize(1024, 768), Normalize(mean=(0.5, ), std=(0.5, )), ToTensorV2()])
test_loader = get_dataloader(
        "Jung",
        "bedsrnet",
        "test",
        batch_size=1,
        shuffle=False,
        num_workers=1,
        pin_memory=True,
        transform=test_transform,
    )

device = "cuda" if torch.cuda.is_available() else "cpu"
benet = get_model('cam_benet', in_channels=3, pretrained=True)
benet.model = benet.model.to(device)
srnet = get_model('srnet', pretrained=True)
generator, discriminator = srnet[0].to(torch.device('cpu')), srnet[1].to(torch.device('cpu'))
generator.eval()
discriminator.eval()
generator.to(device)
discriminator.to(device)
criterion = get_criterion("GAN", device)
lambda_dict = {"lambda1": 1.0, "lambda2": 0.01}

gts = []
preds = []
attmaps = []
bgcolors = []
psnrs = []
ssims = []
with torch.no_grad():
    for i, sample in enumerate(test_loader):
        print(sample["img_path"][0])
        _, _, _, input, gt, pred, attention_map, back_ground, psnr, ssim = do_one_iteration(sample, generator, discriminator, benet, criterion, device, "evaluate", lambda_dict)

        gts += list(gt)
        preds += list(pred)
        attmaps += list(attention_map)
        bgcolors += list(back_ground)
        psnrs.append(psnr)
        ssims.append(ssim)

print(f"psnr: {np.mean(psnrs)}")
print(f"ssim: {np.mean(ssims)}")

KeyboardInterrupt: 

# Results from the testing set of Jung.
left to right: input, ground trugh, removal image, background color, attention map from BE-Net

In [None]:
figure = plt.figure(figsize = (9*3, 2*3*len(test_loader)))

for idx, sample in enumerate(test_loader):
    img_path = sample['img_path'][0].split('/')[-1]

    plt.subplot(len(test_loader), 5, idx*5+1)
    plt.title(img_path + ' input image')
    plt.imshow(convert_show_image(sample["img"].clone().cpu().numpy()))

    plt.subplot(len(test_loader), 5, idx*5+2)
    plt.title(img_path + ' Ground-Truth image')
    plt.imshow(convert_show_image(np.array(gts), idx=idx))

    plt.subplot(len(test_loader), 5,  idx*5+3)
    plt.title(img_path + ' shadow removal image')
    plt.imshow(convert_show_image(np.array(preds), idx=idx))

    plt.subplot(len(test_loader), 5, idx*5+4)
    plt.title(img_path + ' back ground color image')
    plt.imshow(convert_show_image(np.array(bgcolors), idx=idx))

    plt.subplot(len(test_loader), 5, idx*5+5)
    plt.title(img_path + ' attention map')
    plt.imshow(convert_show_image(np.array(attmaps), idx=idx), cmap='jet', alpha=0.5)
    plt.colorbar()

plt.show()

# Input Test Images

In [None]:
figure = plt.figure(figsize = (6*5, 6*5))

for idx, sample in enumerate(test_loader):
    plt.subplot(4, 5, idx+1)
    plt.title(sample['img_path'][0].split('/')[-1])
    plt.imshow(convert_show_image(sample["img"].clone().detach().cpu().numpy()))

# Output Shadow Removal Images

In [None]:
figure = plt.figure(figsize = (6*5, 6*5))

for idx, sample in enumerate(test_loader):
    plt.subplot(4, 5, idx+1)
    plt.title(sample['img_path'][0].split('/')[-1])
    plt.imshow(convert_show_image(np.array(preds), idx=idx))

# Ground Truth

In [None]:
figure = plt.figure(figsize = (6*5, 6*5))

for idx, sample in enumerate(test_loader):
    plt.subplot(4, 5, idx+1)
    plt.title(sample['img_path'][0].split('/')[-1])
    plt.imshow(convert_show_image(np.array(gts), idx=idx))

# Attention Map from BE-Net

In [None]:
figure = plt.figure(figsize = (6*5, 6*5))

for idx, sample in enumerate(test_loader):
    plt.subplot(4, 5, idx+1)
    plt.title(sample['img_path'][0].split('/')[-1])
    plt.imshow(convert_show_image(np.array(attmaps), idx=idx), cmap='jet', alpha=0.5)
    plt.colorbar()

# Back Ground Color from BE-Net

In [None]:
figure = plt.figure(figsize = (6*5, 6*5))

for idx, sample in enumerate(test_loader):
    plt.subplot(4, 5, idx+1)
    plt.title(sample['img_path'][0].split('/')[-1])
    plt.imshow(convert_show_image(np.array(bgcolors), idx=idx))

# Testing your own image

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# you can put your own image path.
result_path = '../drive/MyDrive/shadow_removal_image.jpg'
image = cv2.imread("../drive/MyDrive/shadow_image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
h, w, c = image.shape

tensor = test_transform(image=image)
tensor = tensor['image'].unsqueeze(0).to(device)

In [None]:
with torch.no_grad():

    with torch.set_grad_enabled(True):
        color, attmap, _ = benet(tensor)
        attmap = (attmap-0.5)/0.5
        back_color = torch.repeat_interleave(color.detach(), 1024*768, dim=0)
        back_ground = back_color.reshape(1, c, 1024, 768).to(device)

    input = torch.cat([tensor, attmap, back_ground], dim=1)

    tensor = tensor.detach().cpu()
    attmap = attmap.detach().cpu()
    back_ground = back_ground.detach().cpu()
    shadow_removal_image = generator(input).detach().cpu()


In [None]:
figure = plt.figure(figsize = (9*3, 2*3))
plt.subplot(1, 4, 1)
plt.title('input image')
plt.imshow(convert_show_image(tensor.clone().detach().cpu().numpy()))
plt.subplot(1, 4, 2)
plt.title('shadow removal image')
removal = convert_show_image(shadow_removal_image.clone().detach().cpu().numpy())
plt.imshow(removal)
plt.subplot(1, 4, 3)
plt.title('back ground color image')
plt.imshow(convert_show_image(back_ground.clone().detach().cpu().numpy()))
plt.subplot(1, 4, 4)
plt.title('attention map')
plt.imshow(convert_show_image(attmap.clone().detach().cpu().numpy()), cmap='jet', alpha=0.5)
plt.colorbar()
plt.show()

### saving result

In [None]:
if cv2.imwrite(result_path, cv2.cvtColor(removal, cv2.COLOR_RGB2BGR)):
  print("saved")