In [None]:
from pathlib import Path

from functools import partial, reduce
from PIL import Image

import pandas as pd
import torch as pt
import numpy as np
from numpy.typing import NDArray

import matplotlib.pyplot as plt 
import matplotlib as mpl

from matplotlib.gridspec import GridSpec
from matplotlib.patches import Patch, Rectangle
from matplotlib.lines import Line2D
from matplotlib.axes import Axes

import cv2

from pysnic.algorithms.snic import snic
import skimage as ski

from importables.pytorch.dataset import SegmentationDataset
from importables.project.cloud_classes import ClassRegistry
from importables.general.image_processing import UINT8_MAX
from importables.general.image_processing import get_false_color, boost_hsv
from importables.general.image_processing import display_hsv

In [None]:
PROJ_PATH = Path('../')

DATA_PATH = PROJ_PATH / '_data'
DATASET_PATH = DATA_PATH / 'dataset'
AUX_PATH = DATA_PATH / 'auxiliary'
SCENE_PATH = DATA_PATH / 'raw' / 'BC'

PLOTS_PATH = PROJ_PATH / "plots"
PLOTS_PATH.mkdir(exist_ok=True)

SEED = 1234
SUBSET_RATIO = 0.1

df = pd.read_csv(DATA_PATH / f'seeds/{SEED}_{SUBSET_RATIO}_split.csv', index_col=0)

In [None]:
import sys

pkg_root = PROJ_PATH / 'models' / 'samhq'
sys.path.insert(0, str(pkg_root))

from models.samhq.main import SAM_predictor, SAM_predict

In [None]:
seg_ds = SegmentationDataset(SEED, 'test', DATA_PATH.absolute(), subset_ratio=SUBSET_RATIO)
transform = seg_ds.img_transform

predictor, rng = SAM_predictor(PROJ_PATH / 'models' / 'samhq' / 'sam_hq_vit_b.pth', SEED)

class_reg = ClassRegistry()

def get_datapoint(name: str) -> tuple[NDArray, pt.Tensor, NDArray, NDArray]:
    img_path = DATASET_PATH / "img" / f"{name}.png"
    mask_path = DATASET_PATH  / "label" / f"{name}.png"
    fmask_path = AUX_PATH / 'fmask' / f"{name}.png"
    
    img = Image.open(img_path)
    img_tensor = transform(img).unsqueeze(0).to("cuda")
    
    mask = np.array(Image.open(mask_path))
    fmask = np.array(Image.open(fmask_path))
    
    return np.array(img), img_tensor, mask, fmask

def run_SAM_prediction(img_arr: NDArray, lbl_arr: NDArray, num_points: int):
    pred = SAM_predict(img_arr, lbl_arr, num_points, predictor, rng)
    
    return pred

def run_prediction(img_tensor: pt.Tensor, model_name: str, ver: str | int) -> NDArray:
    model_path = PROJ_PATH / "models" / model_name / "logs" / f"seed-{SEED}" / ver / "model.pt"
        
    model: pt.nn.Module = pt.jit.load(model_path).to("cuda")
    model.eval()
        
    with pt.no_grad():
        pred = model(img_tensor).argmax(axis=1).squeeze(axis=0).cpu()
        pred = pred.numpy().astype(np.uint8)
        
    return pred

In [None]:
df['scene'] = df['image_name'].apply(lambda name: name.split('_')[0])

In [None]:
print(df['biome'].unique())

testing_idx = df['split'] == 'test'
subset_idx = df['subset'] == True

sub_df = df[testing_idx & subset_idx]

# sub_df = sub_df[sub_df['biome'] != 'Wetlands']
# sub_df = sub_df[sub_df['biome'] != 'Barren']
# sub_df = sub_df[sub_df['biome'] != 'Snow/Ice']

# LC80060102014147LGN00 - Bad scene generation
# LC82320072014226LGN00 - Too easy
# LC80250022014232LGN00 - Decent, but Prithvi and SatlasNet are not expected though
# LC80211222013361LGN00 - Can't tell land/thin cloud
# LC80441162013330LGN00 - Decent
# LC82171112014297LGN00 - Sucks

sub_df = sub_df[sub_df['clear'] >= .2]
# sub_df = sub_df[sub_df['cloud'] >= 0.1]
# sub_df = sub_df[sub_df['thin_cloud'] >= 0.1]
sub_df = sub_df[sub_df['shadow'] >= .1]
# sub_df = sub_df[(sub_df['cloud'] + sub_df['thin_cloud'] - sub_df['shadow']).abs() <= .1]

print(len(sub_df))

In [None]:
# import ast

# seg_df = pd.read_csv(PROJ_PATH / 'seg.csv', index_col=0)
# seg_df = seg_df.iloc[sub_df.index]

# n = 3

# for col in seg_df.columns:
#     seg_df[col] = seg_df[col].apply(ast.literal_eval)

# seg_df['new_cloud_sizes'] = seg_df['cloud_sizes'] + seg_df['thin_cloud_sizes']
# seg_df['new_cloud_top_n_size'] = seg_df['new_cloud_sizes'].apply(lambda row: sum(sorted(row)[-n:])) / seg_df['new_cloud_sizes'].apply(lambda row: sum(row))

# shadow_analysis_idx = seg_df[seg_df['new_cloud_top_n_size'] >= .5].index
# print(len(shadow_analysis_idx))
# samples = sub_df.loc[shadow_analysis_idx]

In [None]:
MODELS = [
    # ('snic', None, 'Segmentation Map'),
    # ('xai', 'prithvi', 'SatlasNet'),
    # ('xai', 'satlas', 'Prithvi'),
    # ('xai', 'lsknet', 'LSKNet'),
    
    ('false_color', (6, 9, 10), 'SWIR + Cirrus + Thermal'),
    
    # ('fmask', None, 'Prediction'),
    
    # ('samhq', pts:=1, f'HQ-SAM\n({pts*8} Data Points)'),
    # ('samhq', pts:=10, f'HQ-SAM\n({pts*8} Data Points)'),
    # ('samhq', pts:=100, f'HQ-SAM\n({pts*8} Data Points)'),
    # ('samhq', pts:=1000, f'HQ-SAM\n({pts*8} Data Points)'),
    
    # ('prithvi', 'feature_extraction', 'Feature Extraction'),
    # ('prithvi', 'half_frozen', 'Half-frozen'),
    # ('prithvi', 'finetuning', 'Fine-tuning'),
    
    # ('satlas', 'feature_extraction', 'Feature Extraction'),
    # ('satlas', 'half_frozen', 'Half-frozen'),
    # ('satlas', 'finetuning', 'Fine-tuning'),
    
    # ('lsknet', 'feature_extraction', 'Feature Extraction'),
    # ('lsknet', 'half_frozen', 'Half-frozen'),
    # ('lsknet', 'finetuning', 'Fine-tuning'),
    
    # ('fmask', None, 'FMask 3.3'),
    # ('samhq', pts:=10, f'HQ-SAM\n({pts*8} Data Points)'),
    ('prithvi', 'finetuning', 'Prithvi'),
    ('satlas', 'finetuning', 'SatlasNet'),
    ('lsknet', 'finetuning', 'LSKNet'),
    
]
TILES = [
    # individual
    # 'LC81640502013179LGN01_23_24', # Barren
    # 'LC80160502014041LGN00_16_21', # Forest
    # 'LC81770262013254LGN00_4_28', # Urban
    # 'LC82171112014297LGN00_16_10', # Snow/Ice
    
    # snow/ice
    # 'LC80250022014232LGN00_6_18',
    # 'LC80441162013330LGN00_10_24',
    # 'LC82320072014226LGN00_14_27',
    # 'LC82320072014226LGN00_27_10'
    
    # together
    # 'LC80290372013257LGN00_11_14', # Grass/Crop
    # 'LC82150712013152LGN00_12_19', # Water
    # 'LC80320382013278LGN00_23_6', # Shurbland
    # 'LC81010142014189LGN00_18_17', # Wetland
    
    # cloud shadows - clouds
    # 'LC81010142014189LGN00_26_13', # Wetlands
    # 'LC81930452013126LGN01_21_14', # Barren
    # 'LC81750512013208LGN00_10_8' # Grass/Crops
    
    # prithvi - bad cloud shadows
    'LC80070662014234LGN00_18_30',
    'LC81750622013304LGN00_27_26',
    'LC81750622013304LGN00_28_21',
    'LC81750622013304LGN00_17_29',
    'LC81750622013304LGN00_22_25',
    'LC81001082014022LGN00_26_16',
    'LC81750622013304LGN00_6_26',
    'LC81001082014022LGN00_20_8'
]
FN = 'Cloud and Cloud-shadows Predictions'

In [None]:
df[df['image_name'] == 'LC81750512013208LGN00_10_8.png']

In [None]:
samples = sub_df['image_name'].apply(lambda name: name.split('.')[0])
samples = samples[~samples.isin(TILES)]
samples = samples.sample(n=10, replace=False).tolist()
print(samples)

TILES = TILES + samples

### Plotting Parameters

In [None]:
mpl.rcParams.update({
    "font.family" : "serif",                 # use any serif font
    "font.serif"  : ["Times New Roman"],     # tell serif to be Times
    "mathtext.fontset" : "cm",               # keep math in Computer Modern
})

TITLES = ['Image'] + ['Ground truth'] + [item[2] for item in MODELS]

COLOR_BAR_FLAG = 'xai' in [model[0] for model in MODELS]

# Note: 72 pts = 1 inch
# --- font sizes ---
SUPTITLE_FONT_SIZE = 8
SUBTITLE_FONT_SIZE = .6 * SUPTITLE_FONT_SIZE 
LEGEND_FONT_SIZE = .55 * SUPTITLE_FONT_SIZE

MAX_TITLE_ROWS = 0
for title in TITLES:
    MAX_TITLE_ROWS = max(MAX_TITLE_ROWS, title.count('\n'))

# --- sizing parameters ---
ROW_COUNT, COL_COUNT = len(TILES), len(MODELS) + 2

TILE_PT = 72
BORDER_THICKNESS = .5 # Rectangle borders on image are inward
GAP_W_PT, GAP_H_PT = 5, 12.5

LEGEND_SIZE_PT = 72/4
COLOR_BAR_SIZE_PT = 20

# --- positioning ---
TITLE_POS_PT = 12.5
SUBTITLE_POS_PT = 2.5
LEGEND_POS_PT = 10
COLOR_BAR_TXT_POS_PT = 2

LEFT_MARGIN_PT = 2.5
RIGHT_MARGIN_PT = 38
TOP_MARGIN_PT = 22 + (MAX_TITLE_ROWS) * SUPTITLE_FONT_SIZE + (MAX_TITLE_ROWS) * 1.2
BOTTOM_MARGIN_PT = 2.5

# --- figure sizing ---
FIG_WIDTH_PT = COL_COUNT*TILE_PT + (COL_COUNT-1)*GAP_W_PT
FIG_WIDTH_PT += LEFT_MARGIN_PT + RIGHT_MARGIN_PT

# if COLOR_BAR_FLAG:
#     FIG_WIDTH_PT += LEGEND_POS_PT + COLOR_BAR_TXT_POS_PT + COLOR_BAR_SIZE_PT

FIG_HEIGHT_PT = ROW_COUNT*TILE_PT + (ROW_COUNT-1)*GAP_H_PT
FIG_HEIGHT_PT += TOP_MARGIN_PT + BOTTOM_MARGIN_PT

WSPACE, HSPACE = GAP_W_PT / TILE_PT, GAP_H_PT / TILE_PT

### Plotting

In [None]:
def add_border(ax: Axes):
    rect = Rectangle((0, 0), 1, 1, transform=ax.transAxes,
                    fill=False, edgecolor='black',
                    linewidth=BORDER_THICKNESS)
    ax.add_patch(rect)
    
    return ax

def show_image(ax: Axes, img: np.ndarray):
    ax.imshow(img, interpolation='nearest')
    ax.axis('off')
    
    return ax
    
def show_title(ax: Axes, j: int):
    ax.text(.5, (TILE_PT + TITLE_POS_PT)/TILE_PT, TITLES[j], 
            ha='center', va='bottom', fontsize=SUPTITLE_FONT_SIZE,
            transform=ax.transAxes)
    
    return ax
    
def show_subtitle(ax: Axes, name: str):
    ax.text(.5, (TILE_PT + SUBTITLE_POS_PT)/TILE_PT, name, 
            ha='center', va='bottom', fontsize=SUBTITLE_FONT_SIZE, 
            transform=ax.transAxes)
    
    return ax

In [None]:
# --- figure ---
fig = plt.figure(figsize=(FIG_WIDTH_PT/72, FIG_HEIGHT_PT/72), dpi=600)
gs  = GridSpec(ROW_COUNT, COL_COUNT, figure=fig,
        wspace=WSPACE, width_ratios = [1] * COL_COUNT,
        hspace=HSPACE, height_ratios = [1] * ROW_COUNT,
        left = LEFT_MARGIN_PT/FIG_WIDTH_PT,
        right = 1 - RIGHT_MARGIN_PT/FIG_WIDTH_PT,
        bottom = BOTTOM_MARGIN_PT/FIG_HEIGHT_PT,
        top = 1 - TOP_MARGIN_PT/FIG_HEIGHT_PT)

# --- legend ---
LEGEND_ELEMENTS = [Line2D([], [], color="black", 
                          marker='o', markersize=5, markerfacecolor=fc, markeredgewidth=.25,
                          linewidth=0) for _, fc in class_reg.CLASS_HEX_COLORS.items()]
CLASSES = ['Clear', 'Cloud', 'Thin Cloud', 'Cloud-shadow']

if not COLOR_BAR_FLAG:
    ref_ax = fig.add_subplot(gs[0, -1])
    ref_ax.axis('off')
    ref_ax.legend(handles=LEGEND_ELEMENTS,
                labels=CLASSES,
                loc='upper center',
                bbox_to_anchor = (1 + LEGEND_POS_PT/TILE_PT, .5, 
                                    LEGEND_SIZE_PT/TILE_PT, LEGEND_SIZE_PT/TILE_PT),
                bbox_transform=ref_ax.transAxes,
                fontsize=LEGEND_FONT_SIZE,
                frameon=False,
                handletextpad = -.2, labelspacing = .8,
                alignment='left')

def show_heatmap(ax: Axes, img: np.ndarray, show_bar=False):
    im = ax.imshow(img, interpolation='nearest', cmap='coolwarm')
    ax.axis('off')
    
    if show_bar:
        box = ax.get_position()
        cax = fig.add_axes([box.x1 + LEGEND_POS_PT/FIG_WIDTH_PT, box.y0, 
                            COLOR_BAR_SIZE_PT/FIG_WIDTH_PT*box.width, box.height])
        
        cb = ax.figure.colorbar(im, cax=cax)
        cb.set_ticks([0.0, 0.5, 1.0])
        cb.ax.tick_params(labelsize=LEGEND_FONT_SIZE, length=0, pad=COLOR_BAR_TXT_POS_PT)

        for spine in cb.ax.spines.values():
            spine.set_linewidth(BORDER_THICKNESS)

    return ax

# --- individual tiles ---
for i, name in enumerate(TILES):
    img_arr, img_tensor, mask, fmask = get_datapoint(name)
    pretty_img_arr = boost_hsv(img_arr.astype(np.float32)/UINT8_MAX,
                               sat_gamma=.85, val_gamma=.85)
    pretty_img_arr = (UINT8_MAX * pretty_img_arr).round().astype(np.uint8)
    
    images = [pretty_img_arr, pretty_mask:=class_reg.class_recolor_map(mask)]
    plot_funcs = [
        [partial(show_image, img=pretty_img_arr), add_border],
        [partial(show_image, img=pretty_mask), add_border]
    ]
    
    for model_name, ver, subtitle in MODELS:
        plot_func = []
        
        match model_name:
            case 'samhq':
                output = run_SAM_prediction(img_arr, mask, ver)
                output = class_reg.class_recolor_map(output)
                plot_func += [partial(show_image, img=output), add_border]
                
            case 'fmask': 
                output = fmask
                output = class_reg.class_recolor_map(output)
                plot_func += [partial(show_image, img=output), add_border]
                
            case 'false_color': 
                output = get_false_color(SCENE_PATH, name, ver)
                plot_func += [partial(show_image, img=output), add_border]
                
            case 'snic':
                lab = cv2.cvtColor(pretty_img_arr, cv2.COLOR_RGB2LAB)
                seg_map = np.array(snic(lab, 50, 8)[0])
                output = ski.segmentation.mark_boundaries(pretty_img_arr, seg_map)
                plot_func += [partial(show_image, img=output), add_border]
                
            case 'xai':
                output = np.load(AUX_PATH / 'xai' / f'{name}_{ver}.npy')
                plot_func += [partial(show_heatmap, img=output), add_border]
                
            case _:
                output = run_prediction(img_tensor, model_name, ver)
                output = class_reg.class_recolor_map(output)
                plot_func += [partial(show_image, img=output), add_border]
        
        images.append(output)
        plot_funcs.append(plot_func)
        
    for j, (img, funcs) in enumerate(zip(images, plot_funcs)):
        ax = fig.add_subplot(gs[i, j])
        if i == 0: funcs.append(partial(show_title, j=j))
        if j == 0: funcs.append(partial(show_subtitle, name=TILES[i]))
        if i == 0 and j == len(images) - 1 and COLOR_BAR_FLAG: funcs[0] = partial(funcs[0], show_bar=True) 
        
        reduce(lambda x, f: f(x), funcs, ax)
    
    ax.axis('off')

plt.show()
fig.savefig(PLOTS_PATH / f"{FN.replace('/', '_')}.png", bbox_inches=None, pad_inches=0)
plt.close(fig)