In [None]:
import sys
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from src.models.unet.backbones import Resnet18Backbone
import torchvision
from PIL import Image, ImageFilter
from src.dataset import SatelliteDataset, SatelliteDatasetRun
from src.utils import load_model
from src.models.unet.unet import UNet
from src.models.unet.blocks import UpBlock
from torchvision.io import ImageReadMode
import cv2
import random
import os
import matplotlib.pyplot as plt
from tqdm import tqdm

# Shadow Removal
Labeled test data from: https://github.com/RSrscoder/AISD

Images were converted from TIF to PNG, so we can support torchvision.io.read_image (see SatelliteDataset).


In [None]:
test_data = SatelliteDatasetRun(data_dir="./data/test", hist_equalization=False)
dataloader = DataLoader(test_data, batch_size=1, shuffle=True) # batch_size must be 1!

In [None]:
model = UNet(backbone=Resnet18Backbone(), up_block_ctor=lambda ci: UpBlock(ci, up_mode='upconv'))
state_dict = torch.load('./out/shadows_2023-04-04_17-00-58/best_model.pth.tar')
model.load_state_dict(state_dict)
model.eval()

In [None]:
def remove_shadows(img: Image, mask: np.array):
    """Removes shadow mask from img.

    :param img: Pillow Image of the statellite image.
    :param mask: Numpy boolean array
    """
    # display
    display(img)
    display(Image.fromarray((mask*255).astype(np.uint8)))

    # shadow removal via HSV
    mult_mask = (1 + mask*2)
    img_hsv = np.asarray(img.convert("HSV")).copy()
    img_hsv[..., 2] = np.clip(img_hsv[..., 2] * mult_mask, 0, 255)
    display(Image.fromarray(img_hsv, "HSV").convert("RGB"))

    img_test = np.asarray(img).copy()
    img_test[...,0] = np.clip(img_test[...,0] * mult_mask, 0, 255)
    img_test[...,1] = np.clip(img_test[...,1] * mult_mask, 0, 255)
    img_test[...,2] = np.clip(img_test[...,2] * mult_mask, 0, 255)
    display(Image.fromarray(img_test))

img = Image.open("data/shadows/images/chicago33_sub11.png")
mask = np.array(Image.open("data/shadows/groundtruth/chicago33_sub11.png"))
print(mask)
remove_shadows(img, mask)

In [None]:
def remove_shadows(img: Image, mask: np.array):
    """Removes shadow mask from img.

    :param img: Pillow Image of the statellite image.
    :param mask: Numpy boolean array
    """
    # display
    #display(img)
    img_array = np.array(img)
    display(Image.fromarray((mask*255).astype(np.uint8)))
    edge_img = cv2.Canny((mask*255).astype(np.uint8), 100, 100, apertureSize=3)

    lines = cv2.HoughLinesP(edge_img, 10, np.pi / 180, 50).squeeze()
    normal_vectors = np.zeros_like(lines)
    img_norm = img_array.copy()
    for index, line in enumerate(lines):
        x1, y1, x2, y2 = line
        x = x2 - x1
        y = y2 - y1
        cv2.line(img_norm, (x1, y1), (x2, y2), (255, 0, 0), 2)
        # rotate the line by 90 degree to get the normal vector
        cv2.line(img_norm, (x1, y1), (x1+y, y1-x), (0, 255, 0), 2)
        normal_vectors[index] = np.array((x1, y1, x1+y, y1-x))

    display(Image.fromarray(edge_img))
    display(Image.fromarray(img_norm))


    # using normal vectors we define a function mapping from 20 pixels outside of the shadow to
    # inside of the shadow
    return

    # each color has a width of 2 and a height of 10 for the composite
    composite = np.zeros([100, len(normal_vectors)*2,3],dtype=np.uint8)
    org = img_array.copy()
    print(org.shape)
    height, width, _ = org.shape
    print(org.shape)
    for index, nv in enumerate(normal_vectors):
        start_point = nv[:2]
        vec = nv[2:]
        uvec = vec / np.sqrt((vec**2).sum())
        print(nv, start_point, vec, uvec)
        outside = np.clip(start_point + 5*uvec, 0, height-1).astype(int)
        inside = np.clip(start_point - 5*uvec, 0, width-1).astype(int)
        print(img_array.shape, outside, inside)
        outside_color = img_array[outside[0], outside[1]].tolist()  # required for opencv
        inside_color = img_array[inside[0], inside[1]].tolist()
        # change color from RGB to BGR
        cv2.rectangle(composite, (index * 2, 0), (index*2 + 2, 50), [outside_color[2], outside_color[1], outside_color[0]])
        cv2.rectangle(composite, (index * 2, 50), (index*2 + 2, 100), [inside_color[2], inside_color[1], inside_color[0]])

        cv2.circle(org, outside, radius=0, color=(255,0,0), thickness=2)
        cv2.circle(org, inside, radius=0, color=(0,255,0), thickness=2)

    print(composite)
    #cv2.imshow("test", composite)
    #cv2.waitKey()
    display(Image.fromarray(composite))
    display(Image.fromarray(org))




    display(Image.fromarray(edge_img))
    display(Image.fromarray(img_array))



img = Image.open("data/shadows/images/chicago33_sub11.png")
mask = np.array(Image.open("data/shadows/groundtruth/chicago33_sub11.png"))
remove_shadows(img, mask)

# ShadowFormer

Testing shadow removal using ShadowFormer pretrained on different datasets.

In [None]:
from src.ShadowFormer.model import ShadowFormer
from collections import OrderedDict

img_size = 400

def load_checkpoint(model, weights, device):
    checkpoint = torch.load(weights, map_location=device)
    try:
        model.load_state_dict(checkpoint["state_dict"])
    except:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] if 'module.' in k else k
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

def output_to_image(model_output):
    rgb_restored = torch.clamp(model_output, 0, 1).detach().cpu().numpy().squeeze().transpose((1, 2, 0))
    return (rgb_restored * 255).astype(np.uint8)


transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(img_size),
    torchvision.transforms.ToTensor()
])


srm_istd = ShadowFormer(img_size=img_size, embed_dim=32,win_size=10,token_projection="linear",token_mlp="leff")
srm_istd_plus = ShadowFormer(img_size=img_size, embed_dim=32,win_size=10,token_projection="linear",token_mlp="leff")

In [None]:
base_path = "./data/shadows/"
filenames = os.listdir(os.path.join(base_path, "images"))

# get 10 random images
random.shuffle(filenames)
n = 10
filenames = filenames[:n]

# load models
load_checkpoint(srm_istd, "./out/ISTD_model_latest.pth", device=torch.device("cpu"))
_ = srm_istd.eval()

load_checkpoint(srm_istd_plus, "./out/ISTD_plus_model_latest.pth", device=torch.device("cpu"))
_ = srm_istd_plus.eval()

fig = plt.figure(figsize=(img_size*4 / 100, img_size*n / 100))
axes = fig.subplots(nrows=n, ncols=4)

for i, filename in enumerate(tqdm(filenames)):
    pil_img = Image.open(os.path.join(base_path, "images", filename))
    pil_mask = Image.open(os.path.join(base_path, "groundtruth", filename))

    img = transform(pil_img).unsqueeze(0)
    mask = transform(pil_mask).unsqueeze(0)

    srm_istd_out = srm_istd(img, mask)
    srm_istd_plus_out = srm_istd_plus(img, mask)

    axes[i][0].imshow(pil_img)
    axes[i][1].imshow(pil_mask)
    axes[i][2].imshow(output_to_image(srm_istd_out))
    axes[i][3].imshow(output_to_image(srm_istd_plus_out))


plt.suptitle("ShadowFormer on dataset with groundtruth shadows", y=.995)
axes[0][0].set_title("Original")
axes[0][1].set_title("Mask")
axes[0][2].set_title("ISTD")
axes[0][3].set_title("ISTD+")

[axi.set_axis_off() for axi in axes.ravel()]   # turns off axes
plt.axis("tight")  # gets rid of white border
plt.axis("image")  # square up the image instead of filling the "figure" space
plt.tight_layout()
plt.savefig("shadow.pdf")

# ShadowFormer on road segmentation output

In [None]:
upscale = torchvision.transforms.Resize(400)
n = 10
fig = plt.figure(figsize=(img_size*4 / 100, img_size*n / 100))
axes = fig.subplots(nrows=n, ncols=4)

def pil_from_tensor(tensor_image, change_channels=False):
    img = (tensor_image.squeeze().detach().numpy() * 255).astype(np.uint8)
    return Image.fromarray(img.transpose(1,2,0) if change_channels else img)

for i, (path, size, img) in enumerate(tqdm(dataloader, total=n)):
    out = model(img).detach()
    out_threshold = (out > 0.5).type(torch.FloatTensor)

    # for shadow ShadowFormer use actual image without any transforms
    img = Image.open(path[0]).convert("RGB")  # assume size 400x400
    img_transformed = transform(img).unsqueeze(0)

    # upscale image and mask to 400x400 and remove shadow using ShadowFormer trained on ISTD
    out_threshold_scaled = upscale(out_threshold)
    removed_shadow = srm_istd(img_transformed, out_threshold_scaled)

    axes[i][0].imshow(img)
    axes[i][1].imshow(pil_from_tensor(upscale(out)))
    axes[i][2].imshow(pil_from_tensor(out_threshold_scaled))
    axes[i][3].imshow(pil_from_tensor(removed_shadow, change_channels=True))

    if i == n - 1:
        break


plt.suptitle("ShadowFormer on road segmentation test dataset", y=.995)
axes[0][0].set_title("Original")
axes[0][1].set_title("Raw Mask")
axes[0][2].set_title("Threshold Mask (t=0.5)")
axes[0][3].set_title("ShadowFormer (ISTD)")

[axi.set_axis_off() for axi in axes.ravel()]   # turns off axes
plt.axis("tight")  # gets rid of white border
plt.axis("image")  # square up the image instead of filling the "figure" space
plt.tight_layout()
plt.savefig("shadow_test_dataset.pdf")