In [69]:
from torch.utils.data import DataLoader

from src.ShadowFormer.model import ShadowFormer
from src.dataset import SatelliteDataset, SatelliteDatasetRun
from src.models.unet.backbones import Resnet18Backbone
from src.models.unet.unet import UNet
from src.models.unet.blocks import UpBlock
import torch
import sys
import os
from importlib import reload
from PIL import Image
from torchvision import transforms
import numpy as np
from tqdm import tqdm

from src.wrapper import PLWrapper

In [2]:
out_path = "../../out"
shadow_model_path = os.path.join(out_path, "shadow_model_2023-05-03_19-59-44")
normal_model_path = os.path.join(out_path, "test_run_2023-05-03_17-51-14")

In [3]:
sys.path.insert(0, shadow_model_path)
import config
SHADOW_TRAIN_CONFIG = config.TRAIN_CONFIG
sys.path.insert(0, normal_model_path)
NORMAL_TRAIN_CONFIG = reload(config).TRAIN_CONFIG
print(NORMAL_TRAIN_CONFIG["experiment_id"], SHADOW_TRAIN_CONFIG["experiment_id"])

test_run shadow_model


In [4]:
def load_model(ckpt_path: str, config: dict):
    # trained shadow removal model
    model_cls = NORMAL_TRAIN_CONFIG["model_config"]["model_cls"]
    backbone_cls = NORMAL_TRAIN_CONFIG["model_config"]["backbone_cls"]
    up_block_ctor = NORMAL_TRAIN_CONFIG["model_config"]["model_kwargs"]["up_block_ctor"]

    model = model_cls(backbone=backbone_cls(), up_block_ctor=up_block_ctor)

    # doesn't work because hyperparameters are not saved
    """
    pl_wrapper = PLWrapper(model=model, **NORMAL_TRAIN_CONFIG['pl_wrapper_kwargs'])
    pl_wrapper.load_from_checkpoint(
        ckpt_path,
        map_location=torch.device("cpu")
    )
    """
    # normal loading
    state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))

    # remove model prefix
    model.load_state_dict({k[6:]:v for k,v in state_dict["state_dict"].items()})
    model.eval()
    return model

normal_model = load_model(
    ckpt_path=os.path.join(normal_model_path, "model-epoch=94-val_acc=0.93.ckpt"),
    config=NORMAL_TRAIN_CONFIG
)
shadow_model = load_model(
    ckpt_path=os.path.join(shadow_model_path, "model-epoch=98-val_acc=0.89.ckpt"),
    config=SHADOW_TRAIN_CONFIG
)



In [44]:
# load image former
shadow_former = ShadowFormer(img_size=400, embed_dim=32,win_size=10,token_projection="linear",token_mlp="leff")
state_dict = torch.load("../../pretrained/ISTD_model_latest.pth", map_location=torch.device("cpu"))
shadow_former.load_state_dict({k[7:]:v for k,v in state_dict["state_dict"].items()})
print()




# Remove shadows from training data

In [73]:
base_path = "../../data"
data_dir = "test"
output_dir = "test_no_shadow"
rm_shadow_dir = data_dir
data = SatelliteDatasetRun(data_dir=os.path.join(base_path, data_dir), hist_equalization=False)
dataloader = DataLoader(data, batch_size=1, shuffle=True) # batch_size must be 1!

In [74]:
to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()
upscale = transforms.Resize(400)

for img_path, original_size, img in tqdm(dataloader):
    o_img = Image.open(img_path[0]).convert("RGB")
    # display(o_img)
    shadow_mask = shadow_model(img)  # 1x1x224x224
    shadow_mask = (shadow_mask.detach() > 0.5).type(torch.FloatTensor)
    # display(to_pil(shadow_mask.squeeze()))

    # ShadowFormer was trained on 400x400 images
    upscaled_image = upscale(to_tensor(o_img)).unsqueeze(dim=0)
    sf_image = shadow_former(upscaled_image, upscale(shadow_mask))
    new_image = to_pil(sf_image.squeeze())

    _, filename = os.path.split(img_path[0])
    new_image_dir = os.path.join(base_path, output_dir)
    os.makedirs(new_image_dir, exist_ok=True)
    new_image_path = os.path.join(new_image_dir, filename)
    # print(f"Saving to {new_image_path}...")
    new_image.save(new_image_path, "PNG")

100%|██████████| 144/144 [21:05<00:00,  8.79s/it]
