In [1]:
import random
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import segmentation_models_pytorch as smp
from torchgeo.models import resnet50, ResNet50_Weights
from glob import glob
import os
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import cv2
from matplotlib import pyplot as plt
from tqdm import tqdm
import seaborn as sns
from PIL import Image

from data.dataset import FSSDataset
from model.DCAMA import DCAMA

torch.manual_seed(42)
random.seed(42)

In [2]:
out_mask_dir = "outputs/out_masks"
full_mask_dir = "outputs/full_masks"
out_contour_dir = "outputs/out_contours"

os.makedirs(out_mask_dir, exist_ok=True)
os.makedirs(full_mask_dir, exist_ok=True)
os.makedirs(out_contour_dir, exist_ok=True)

In [3]:
model_paths = [
    "logs/train/fold_0_0501_130849-exp1/best_model.pt",
    "logs/train/fold_1_0429_030114-exp1/best_model.pt",
    "logs/train/fold_2_0429_013102-exp1/best_model.pt",
    "logs/train/fold_3_0430_221854-exp1/best_model.pt",
    "logs/train/fold_4_0430_200305-exp1/best_model.pt"
]

In [4]:
models = []
for path in model_paths:
    model = DCAMA('resnet50', True) 
    params = model.state_dict()
    state_dict = torch.load(path, map_location='cpu')

    for k1, k2 in zip(list(state_dict.keys()), params.keys()):
        state_dict[k2] = state_dict.pop(k1)

    model.load_state_dict(state_dict)
    models.append(model)

In [5]:
img_mean = [
            452.36395548,
            669.48234239,
            409.39103663,
            987.94130831,
            2457.23722236,
            2872.30241926,
            3011.18175418,
            3097.38396507,
            1786.85631331,
            929.30668321,
            0
        ]

img_std = [
            177.24756019,
            144.58550688,
            95.04011083,
            224.49394865,
            485.14565224,
            580.77737498,
            649.43248944,
            622.79759571,
            419.01506965,
            298.34517013,
            1
        ]
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean=img_mean, std=img_std)]
)

In [6]:
def eval_expression(exp: list, image: np.ndarray = None):
    expression = ""

    for token in exp:
        if token[0] == "c":
            channel = eval(token[1:])
            expression += f"(image[{channel}] + 0.0001)"  # To prevent divide by zero
        elif token == "sq":
            expression += "**2"
        elif token == "sqrt":
            expression += "**0.5"
        elif token == "=":
            break
        else:
            expression += token

    return eval(expression)

In [7]:
class InferenceDataset(Dataset):
    def __init__(self, image_dir: str, expression: list = []):
        super().__init__()
        self.paths = glob(os.path.join(image_dir, "*.npy"))
        self.expression = expression
    
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = np.load(path).astype(float)
        if self.expression:
            idx = eval_expression(self.expression, img)
            max_z = 3
            idx = (idx - idx.mean()) / idx.std()
            idx = (np.clip(idx, -max_z, max_z) + max_z) / (2 * max_z)
            img = np.concatenate([img[:10, :, :], idx[None, :, :]], axis=0)
        else:
            img = img[:10,:,:]
        img = img.transpose(1, 2, 0)
        
        return transform(img).float(), path

In [8]:
exp = ["c1", "+", "c5", "/", "c7", "+", "c5", "/", "c8", "+", "c7", "-", "c2", "="]

In [9]:
inference_dir = "datasets/inference"
batch_size = 2
dataset = InferenceDataset(inference_dir, exp)
dataloader = DataLoader(dataset, batch_size, shuffle=False, num_workers=0)

In [10]:
support_imgs_folds = []
support_masks_folds = []
for i in range(5):
    img_paths = sorted(glob(os.path.join("datasets/Serp/", str(i), "train", "*.npy")))
    mask_paths = sorted(glob(os.path.join("datasets/Serp/", str(i), "annotations", "train", "*.npy")))

    support_imgs = []
    for path in img_paths:
        support_imgs.append(np.load(path).astype(float).transpose(1, 2, 0))

    support_masks = []
    for path in mask_paths:
        support_masks.append(np.load(path).astype(float))
        
    support_imgs = torch.stack([transform(support_img) for support_img in support_imgs]).unsqueeze(0)
    for midx, smask in enumerate(support_masks):
        support_masks[midx] = F.interpolate(
            torch.tensor(smask).unsqueeze(0).unsqueeze(0).float(), support_imgs.size()[-2:], mode="nearest"
        ).squeeze()
        
    support_masks = torch.stack(support_masks).unsqueeze(0)
    
    support_imgs_folds.append(support_imgs)
    support_masks_folds.append(support_masks)

In [11]:
n_shot = 5
max_shot = 18

In [12]:
pos_count = tot_count = 0
out_dir = "outputs/out_masks/"
for query_img, paths in tqdm(dataloader):
    indices = list(range(max_shot))
    random.shuffle(indices)
    indices = indices[:n_shot]
    mask = np.zeros((256, 256))
    for i, model in enumerate(models):
        batch = {
            "query_img": query_img,
            "support_imgs": torch.stack([support_imgs_folds[i][:, indices]] * batch_size).squeeze().float(),
            "support_masks": torch.stack([support_masks_folds[i][:, indices]] * batch_size).squeeze().float(),
            "org_query_imsize": [torch.tensor([256]), torch.tensor([256]), torch.tensor([11])]
        }
        new_mask = model.predict_mask_nshot(batch, nshot=n_shot).squeeze().cpu().numpy()
    mask = mask > 2
    pos_count += mask.sum().item()
    tot_count += mask.shape[0] ** 2
    for j, path in enumerate(paths):
        np.save(os.path.join(out_dir, os.path.basename(path)), mask[j])

  0%|          | 13/2994 [07:19<27:59:03, 33.80s/it]


FileNotFoundError: [Errno 2] No such file or directory: 'datasets/inference/10_8960_6528.npy'

In [15]:
pos_count / tot_count

0.0124908447265625

In [21]:
mask_dir = "outputs/out_masks"
out_dir = "outputs/full_masks"
tif_max_size = 11_000
img_size = 256

In [22]:
paths = sorted(glob(os.path.join(mask_dir, "*.npy")))
prev_region = None
full_mask = np.zeros((tif_max_size, tif_max_size), dtype=bool)
for path in paths:
    # <region>_<x>_<y>.npy => 00_2342_6453.npy
    name, _ = os.path.splitext(os.path.basename(path))
    region, x, y = name.split("_")
    x, y = int(x), int(y)
    
    if prev_region != None and prev_region != region:
        np.save(os.path.join(out_dir, f"{prev_region}.npy"), full_mask)
        full_mask = np.zeros((tif_max_size, tif_max_size), dtype=bool)
    prev_region = region
    
    mask = np.load(path)
    full_mask[x: x + img_size, y: y + img_size] = np.logical_or(full_mask[x: x + img_size, y: y + img_size], mask)

np.save(os.path.join(out_dir, f"{region}.npy"), full_mask)

In [23]:
import cv2
import pickle

In [24]:
out_dir = "outputs/full_masks"
out_contour_dir = "outputs/out_contours"

In [25]:
tif_paths = sorted(glob(os.path.join(out_dir, "*.npy")))
for path in tif_paths:
    tif_name, _ = os.path.splitext(os.path.basename(path))
    tif_mask = np.load(path).astype('uint8')
    
    contours, hierarchy = cv2.findContours(tif_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if hierarchy is None:
        continue
    hierarchy = hierarchy.squeeze()
    contours_new = []
    for i, contour in enumerate(contours):
        if hierarchy[i][3] == -1:
            contours_new.append(contour)
    print(len(contours), len(contours_new))
    with open(os.path.join(out_contour_dir, f"{tif_name}.pkl"), 'wb') as fp:
        pickle.dump(contours_new, fp)

421 421
611 611
782 782
2522 2522
