In [1]:
import torch
from torch.utils.data import DataLoader
from utils.data_loading import BasicDataset
from utils.dice_score import dice_coeff
from pathlib import Path
from unet import UNet
from PIL import Image
import torchvision.transforms as T
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np

def preprocess(pil_img, is_mask):
    import numpy as np
    w, h = pil_img.size
    newW, newH = int(1 * w), int(1 * h)
    assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
    pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
    img_ndarray = np.asarray(pil_img)

    if img_ndarray.ndim == 2 and not is_mask:
        img_ndarray = img_ndarray[np.newaxis, ...]
    elif not is_mask:
        img_ndarray = img_ndarray.transpose((2, 0, 1))

    if not is_mask:
        img_ndarray = img_ndarray / 255

    if(is_mask):
        # print(img_ndarray == 2)
        img_ndarray[img_ndarray <= 50] = 0
        img_ndarray[img_ndarray > 50] = 1

    return img_ndarray

def get_image_from_path(img_path, mask_path):
    img = Image.open(img_path)
    mask = Image.open(mask_path)
    img = preprocess(img, is_mask=False)
    mask = preprocess(mask, is_mask=True)

    return torch.as_tensor(img.copy()).float().contiguous(), torch.as_tensor(mask.copy()).float().contiguous(),

def eval_by_path(folder="dataset/aerial_raw_250",img_name="513647232_0_0.jpg", return_pil=True):
    with torch.no_grad():
        # img_path = "dataset/aerial_raw_250/imgs/513647232_0_0.jpg"
        # mask_path = "dataset/aerial_raw_250/masks/513647232_0_0.jpg"
        img_path = f"{folder}/imgs/{img_name}"
        mask_path = f"{folder}/masks/{img_name}"
        image, mask_true = get_image_from_path(img_path, mask_path)
        ### create a batch of 1
        # torch.Size([3, 249, 249]) => torch.Size([1, 3, 249, 249])
        image = image.unsqueeze(dim=0)
        mask_true = mask_true.unsqueeze(dim=0)

        ### Actual Prediction
        # The convertion is followed the evaluate.py
        image = image.to(device=device, dtype=torch.float32)
        mask_true = mask_true.to(device=device, dtype=torch.long)
        # Predict
        mask_pred = net(image)
        # torch.Size([1, 2, 249, 249]) => torch.Size([1, 249, 249])
        mask_pred = mask_pred.argmax(dim=1)
        # Get Dice Score
        dice_score = dice_coeff(mask_pred, mask_true, reduce_batch_first=False).item()
        
        # .squeeze for [1,249,249] => [249,249]
        image = image.cpu().squeeze()
        mask_true = mask_true.cpu().squeeze().float()
        mask_pred = mask_pred.cpu().squeeze().float()
        if(return_pil):
            transform = T.ToPILImage()
            image = transform(image)
            mask_true = transform(mask_true)
            mask_pred = transform(mask_pred)
            image.filename = img_path
            mask_true.filename = mask_path
            mask_pred.filename = f"dice: {dice_score}"
        return image, mask_true, mask_pred, dice_score

def display_images(
    images, 
    columns=5, width=20, height=8, max_images=15, 
    label_wrap_length=50, label_font_size=8):
    import textwrap, os
    if not images:
        print("No images to display.")
        return 

    if len(images) > max_images:
        print(f"Showing {max_images} images of {len(images)}:")
        images=images[0:max_images]

    height = max(height, int(len(images)/columns) * height)
    plt.figure(figsize=(width, height))
    for i, image in enumerate(images):

        plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        plt.imshow(image)

        if hasattr(image, 'filename'):
            title=image.filename
            if title.endswith("/"): title = title[0:-1]
            title=os.path.basename(title)
            title=textwrap.wrap(title, label_wrap_length)
            title="\n".join(title)
            plt.title(title, fontsize=label_font_size); 

def save_pred(mask_pred, folder, targets):
    import os
    save_folder = f"{folder}/{targets}/preds/"
    if(os.path.exists(save_folder) == False):
        os.mkdir(save_folder)
    mask_pred.save(f"{save_folder}/{os.path.splitext(mask_pred.filename)[0]}.png")

def update_dice_log(text, folder, targets):
    import os
    target_file = f"{folder}/{targets}/dicelog.txt"
    if(os.path.isfile(target_file) == False):
        os.remove(target_file)
    file1 = open(f"", "w")
    pass

In [21]:
folder = "dataset"

# targets = "aerial_raw_250"
# targets = "NB_raw_100"
targets = "PB_raw_100"

device = 'cpu'
dir_img = Path(f"./{folder}/{targets}/imgs/")
dir_mask = Path(f"./{folder}/{targets}/masks/")
dataset = BasicDataset(dir_img, dir_mask, 1)
loader_args = dict(batch_size=2, num_workers=4, pin_memory=True)
dataloader = DataLoader(dataset, **loader_args)

weight = Path(f'./checkpoints/{targets}_5epochs.pth')
net = UNet(n_channels=3, n_classes=2, bilinear=True)
net.load_state_dict(torch.load(weight, map_location=device))
net.to(device)
_ = net.eval()

In [22]:
try:
    # Reset Prediction result
    import os
    save_folder = f"{folder}/{targets}/preds/"
    if(os.path.exists(save_folder)):
        import shutil
        shutil.rmtree(save_folder)
        os.mkdir(save_folder)

    with torch.no_grad():
        transform = T.ToPILImage()
        dicelog = open(f"{folder}/{targets}/dicelog.txt", "w")
        for batch in tqdm(dataloader):
            images = batch['image']
            mask_trues = batch['mask']
            # list of filenames
            names = batch['name']
            images = images.to(device=device, dtype=torch.float32)
            mask_trues = mask_trues.to(device=device, dtype=torch.long)
            # Predict
            mask_preds = net(images)
            # torch.Size([1, 2, 249, 249]) => torch.Size([1, 249, 249])
            mask_preds = mask_preds.argmax(dim=1)
                        # Get Dice Score
            dice_score = dice_coeff(mask_preds, mask_trues, reduce_batch_first=False).item()
            # .squeeze for [1,249,249] => [249,249]
            mask_preds = mask_preds.cpu().float()
            for mask_pred, name in zip(mask_preds, names):
                mask_pred = transform(mask_pred)
                mask_pred.filename = name
                save_pred(mask_pred, folder=folder, targets=targets)
                dicelog.write(f"{name}:{dice_score}\n")
                # print(f"{name}:{dice_score}")
finally:
    dicelog.close()

  0%|          | 0/800 [00:00<?, ?it/s]

# Reconstruct Image

In [23]:
import os
import numpy as np
from PIL import Image
import glob

def get_coords(big_img, basepath):
    xs, ys = set(), set()
    for name in glob.glob(f"{basepath}/{big_img}*"):
        x,y = os.path.splitext(name)[0].split('_')[-2:]
        xs.add(int(x))
        ys.add(int(y))
    return list(xs), list(ys)

def get_img(glob_str):
    l = glob.glob(glob_str)
    assert len(l) == 1, f"too many files return. {l}"
    return os.path.splitext(l[0])

In [26]:
folder = "dataset"
# targets = "aerial_raw_250"
# targets = "NB_raw_100"
targets = "PB_raw_100"

# target_img = 'imgs'
# target_img = 'masks'
target_img = 'preds'

basepath = f"{folder}/{targets}/{target_img}"
names = next(os.walk(f"{basepath}"), (None, None, []))[2]  # [] if no file
big_imgs = set()
for name in names:
    # the last two underscores is the (x,y) of pixles
    # Sentinel2_PB_Urban_1_1_0_0.jpg
    name_array = os.path.splitext(name)[0].split('_')
    big_imgs.add("_".join(name_array[:-2]) )

# ['Sentinel2_PB_Urban_3_1',
# 'Sentinel2_PB_Urban_2_1',
# 'Sentinel2_PB_Urban_1_1',
# 'Sentinel2_PB_Urban_2_2']
big_imgs = list(big_imgs)
# "PB_raw_100" => 100
step = int(targets.split('_')[-1])

In [27]:
for big_img in big_imgs:
    print(big_img)
    xs, ys = get_coords(big_img, basepath=basepath)
    xs.sort()
    ys.sort()
    ### We concat each column
    # left to right
    cols = []
    for x in xs:
        # up to bottom
        col = []
        for y in ys:
            imgname, ext = get_img(f'{basepath}/{big_img}_{x}_{y}.*')
            img = Image.open(f"{imgname}{ext}")
            img = np.asarray(img)
            col.append(img)
        col = np.concatenate( col , axis=0)
        # print(col.shape)
        cols.append(col)
    big_img_array = np.concatenate( cols, axis=1 )
    big_img_pil = Image.fromarray(np.uint8(big_img_array))
    big_img_pil.save(f"{folder}/{targets}/{big_img}_{target_img}{ext}")
    # break

Sentinel2_PB_Urban_2_2
Sentinel2_PB_Urban_2_1
Sentinel2_PB_Urban_1_1
Sentinel2_PB_Urban_3_1
