In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
from torchvision import transforms
import numpy as np
import torch
from skimage import measure
import scipy

In [None]:
from src.path import ProjPaths
from src.data.band3_binary_mask_data import Band3BinaryMaskDataset, RandomCropImgAndLabels, ToTensorImgAndLabels
from src.models.unet_ptl import UNet
from src.metrics import sample_logits_and_labels, logits_to_prediction, classification_cases, prediction_metrics, compute_true_false_classifications_for_sample_and_model
from src.visualization.visualize import show_image_and_true_false_classifications

In [None]:
test_path = ProjPaths.interim_sn1_data_path / "test"
test_dataset = Band3BinaryMaskDataset(test_path, transform=transforms.Compose([
                                           RandomCropImgAndLabels(384),
                                           ToTensorImgAndLabels()
                                       ]))

## Evaluate model for individual samples

Here we will use our UNet Pytorch Lightning model.

In [None]:
chkpt_path = ProjPaths.model_path / 'unet' / 'unet_ptl_v5' / 'checkpoints' / 'best_model-unet-epoch=15-val_loss=0.09.ckpt'
model = UNet.load_from_checkpoint(chkpt_path)

model.eval()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = 'cpu'
model = model.to(DEVICE)

In [None]:
def compute_true_false_classifications(this_sample_id):
    
    sample = test_dataset[this_sample_id]
    classes = compute_true_false_classifications_for_sample_and_model(sample, model, DEVICE)
    classes_masked = np.ma.masked_where(classes == 0, classes)
    
    return sample, classes_masked

In [None]:
this_sample_id = 132
this_sample_id = 866
this_sample_id = 512
# this_sample_id = 634
# this_sample_id = 985

sample, classes_masked = compute_true_false_classifications(this_sample_id)

## Find clusters

Show original model predictions

In [None]:
pred = ((classes_masked.data == 1) | (classes_masked.data == 3))*1 # get original prediction from true_pos and false_pos

In [None]:
labels = sample['labels']
labels = labels.cpu().detach().numpy()[0, :, :]

true_pos, true_neg, false_pos, false_neg = classification_cases(labels, pred)
pred_metrics = prediction_metrics(true_pos, true_neg, false_pos, false_neg)
 

In [None]:
show_image_and_true_false_classifications(sample, classes_masked)

In [None]:
 fig = plt.figure(figsize=(15,30))

plt.subplot(1,2,1)
plt.imshow(sample["image"].numpy().transpose(1, 2, 0))
plt.title('Input image')

plt.subplot(1,2,2)
plt.imshow(pred)
plt.title('Model predictions')
plt.show()

Get individual clusters

In [None]:
clusters_with_id = measure.label(pred)

In [None]:
plt.imshow(clusters_with_id)
plt.title('Individual clusters')
plt.show()

Compute cluster areas

In [None]:
properties = measure.regionprops(clusters_with_id)
prop_areas = [prop.area for prop in properties]
cluster_areas = pd.DataFrame(prop_areas, columns=['area'], index=range(1, np.max(clusters_with_id)+1))
cluster_areas = cluster_areas.sort_values('area')
cluster_areas.tail(5)

## Inspect clusters

In [None]:
def show_cluster(this_cluster, sample):
    
    this_cluster_masked = np.ma.masked_where(this_cluster == 0, this_cluster)

    fig = plt.figure(figsize=(10,20))

    plt.subplot(1,2,1)
    plt.imshow(sample["image"].numpy().transpose(1, 2, 0))
    plt.imshow(this_cluster_masked, alpha=0.8, cmap='Oranges', vmin=0, vmax=1)
    plt.title('Input image with cluster')

    plt.subplot(1,2,2)
    plt.imshow(this_cluster)
    plt.title('Cluster')
    plt.show()

In [None]:
biggest_cluster_id = cluster_areas.sort_values('area').index[-1]
biggest_cluster = (clusters_with_id == biggest_cluster_id)*1
show_cluster(biggest_cluster, sample)

In [None]:
second_biggest_cluster_id = cluster_areas.index[-2]
second_biggest_cluster = (clusters_with_id == second_biggest_cluster_id)*1
show_cluster(second_biggest_cluster, sample)

In [None]:
xval, yval = scipy.ndimage.center_of_mass(biggest_cluster)
xval = int(np.round(xval))
yval = int(np.round(yval))
xval, yval

Double check:

In [None]:
biggest_cluster[xval, yval]

Biggest cluster indices

In [None]:
cluster_x_vals, cluster_y_vals = biggest_cluster.nonzero()

In [None]:
len(cluster_x_vals)

## Segment everything model

In [None]:
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator

In [None]:
# from https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))
        
# from :
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

Define model to be used and set it up:

In [None]:
sam_checkpoint = "sam_vit_h_4b8939.pth"
checkpoint_path = ProjPaths.model_path / 'sam' / sam_checkpoint
model_type = "vit_h"

device = "cpu"
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
sam.to(device=device)
predictor = SamPredictor(sam)

Get image embedding

## Get segmentation for full image

In [None]:
def sample_to_sam_format(sample):
    img_vals = sample['image'].cpu().detach().numpy() # to numpy
    img_vals = np.moveaxis(img_vals, 0, -1) # change dimensions to HWC
    img_vals_255 = np.round(img_vals*255, 0) # colors to 0-255 range
    img_vals_255_uint =  img_vals_255.astype(np.uint8) # int, not float
    
    return img_vals_255_uint

In [None]:
sam_image = sample_to_sam_format(sample)

In [None]:
predictor.set_image(sam_image)

In [None]:
mask_generator = SamAutomaticMaskGenerator(sam)

In [None]:
masks = mask_generator.generate(sam_image)

In [None]:
# background_image = np.zeros(sam_image.shape)

# plt.figure(figsize=(15,15))
# plt.imshow(background_image)
# #plt.imshow(sam_image)
# show_anns(masks)
# plt.axis('off')
# plt.show() 

In [None]:
all_mask_classes = np.zeros(clusters_with_id.shape)
for ii in range(0, len(masks)):
    
    this_mask = masks[ii]['segmentation']
    all_mask_classes[this_mask] = ii+1
# Note: 0 values correspond to non-existing SAM cluster

In [None]:
fig = plt.figure(figsize=(10,20))

plt.subplot(1,2,1)
plt.imshow(sam_image)
plt.title('Input image')

all_mask_classes_masked = np.ma.masked_where(all_mask_classes == 0, all_mask_classes)

plt.subplot(1,2,2)
plt.imshow(sam_image)
plt.imshow(all_mask_classes_masked, alpha=0.8)
plt.title('SAM clusters')
plt.show()

In [None]:
fig = plt.figure(figsize=(10,20))

plt.subplot(1,2,1)
plt.imshow(clusters_with_id)
plt.title('Model clusters')

plt.subplot(1,2,2)
plt.imshow(all_mask_classes_masked)
plt.title('SAM clusters')
plt.show()

Pick one cluster and intersect with SAM clusters

In [None]:
intersection = biggest_cluster * all_mask_classes

In [None]:
intersection.shape

In [None]:
cluster_size = np.sum(biggest_cluster)

In [None]:
cluster_size

In [None]:
intersect_cluster_ids = np.unique(intersection)
intersect_cluster_ids = [ii for ii in intersect_cluster_ids if ii > 0]
intersect_cluster_ids

In [None]:
sam_cluster_areas = []
intersection_areas = []

all_intersect_clusters = np.zeros(clusters_with_id.shape)
for ii in range(0, len(intersect_cluster_ids)):
    this_id = intersect_cluster_ids[ii]

    this_mask = all_mask_classes == this_id
    all_intersect_clusters[this_mask] = ii+1

    this_sam_cluster_area = np.sum(this_mask)
    sam_cluster_areas.append(this_sam_cluster_area)

    intersection_mask = intersection == this_id
    this_intersection_area = np.sum(intersection_mask)
    intersection_areas.append(this_intersection_area)

In [None]:
intersection_metrics = pd.DataFrame({'cluster_id': intersect_cluster_ids, 'cluster_area': sam_cluster_areas, 'intersection_area': intersection_areas})
intersection_metrics['target_size'] = cluster_size
intersection_metrics['overlap_ratio'] = intersection_metrics['intersection_area'] / intersection_metrics['cluster_area']
intersection_metrics

In [None]:
intersection_metrics['intersection_area'].sum()

In [None]:
fig = plt.figure(figsize=(20,30))

plt.subplot(1,3,1)
plt.imshow(biggest_cluster * all_mask_classes)
plt.title('Intersection with SAM clusters')

biggest_cluster_masked = np.ma.masked_where(biggest_cluster == 0, biggest_cluster)

plt.subplot(1,3,2)
plt.imshow(all_intersect_clusters)
plt.imshow(biggest_cluster_masked, alpha=0.9, cmap='Oranges', vmin=0, vmax=1)
plt.title('Intersecting SAM clusters')

image_masked = sam_image.copy()
xx_inds = all_intersect_clusters == 0
image_masked[xx_inds, 0] = 0
image_masked[xx_inds, 1] = 0
image_masked[xx_inds, 2] = 0

plt.subplot(1,3,3)
plt.imshow(image_masked)
plt.imshow(biggest_cluster_masked, alpha=0.5, cmap='Oranges', vmin=0, vmax=1)
plt.title('Intersecting SAM clusters')

plt.show()

Idea:
- keep SAM clusters that are covered by more than a certain percentage
- additionally keep model clusters (or skip parts that have tiny overlap with some other large SAM cluster)
- optional: keep SAM clusters that fully cover a model cluster, even if overlap is rather tiny

In [None]:
intersection_metrics

In [None]:
all_mask_classes

In [None]:
overlap_clusters = np.zeros(all_mask_classes.shape)
veto_clusters = np.zeros(all_mask_classes.shape)

overlap_threshold = 0.4
veto_threshold = 0.05

for idx, row in intersection_metrics.iterrows():
    
    this_cluster_id = row['cluster_id']
    xx_inds = all_mask_classes == this_cluster_id
    
    if row['overlap_ratio'] > overlap_threshold:
        
        overlap_clusters[xx_inds] = this_cluster_id
        
    if row['overlap_ratio'] < veto_threshold:
        
        veto_clusters[xx_inds] = 1

In [None]:
np.unique(biggest_cluster)

In [None]:
modified_cluster = biggest_cluster.copy()
modified_cluster[veto_clusters == 1] = 0
modified_cluster[overlap_clusters > 0] = 1

modified_cluster_masked = np.ma.masked_where(modified_cluster == 0, modified_cluster)

In [None]:
fig = plt.figure(figsize=(15,15))

plt.subplot(2,2,1)
plt.imshow(overlap_clusters)
plt.title('Union of meaningfully overlapping clusters')

plt.subplot(2,2,2)
plt.imshow(veto_clusters)
plt.title('Union of insignficantly overlapping clusters')

plt.subplot(2,2,3)
plt.imshow(modified_cluster_masked)
plt.imshow(biggest_cluster_masked, alpha=0.3, cmap='Oranges', vmin=0, vmax=1)
plt.title('Original vs modified model cluster')

image_masked = sam_image.copy()
xx_inds = modified_cluster == 0
image_masked[xx_inds, 0] = 255
image_masked[xx_inds, 1] = 255
image_masked[xx_inds, 2] = 255

plt.subplot(2,2,4)
# plt.imshow(image_masked)
# plt.title('Modified model cluster')

plt.imshow(sam_image)
plt.imshow(modified_cluster_masked, alpha=0.8, cmap='Oranges', vmin=0, vmax=1)
plt.title('Modified model cluster')

plt.show()

In [None]:
show_cluster(biggest_cluster, sample)

## Apply modification to clusters

Put everything into a function and use this to modify / polish all model predictions

Inputs:
- calibration parameters: threshold values
- single model cluster
- sam clusters

In [None]:
def modify_cluster(this_model_cluster, all_mask_classes, overlap_threshold=0.4, veto_threshold=0.05):
    cluster_size = np.sum(this_model_cluster)

    # compute intersections with SAM clusters
    intersection = this_model_cluster * all_mask_classes

    # find relevant SAM clusters
    intersect_cluster_ids = np.unique(intersection)
    intersect_cluster_ids = [ii for ii in intersect_cluster_ids if ii > 0]

    # compute overlap metrics
    sam_cluster_areas = []
    intersection_areas = []

    all_intersect_clusters = np.zeros(clusters_with_id.shape)
    for ii in range(0, len(intersect_cluster_ids)):
        this_id = intersect_cluster_ids[ii]

        this_mask = all_mask_classes == this_id
        all_intersect_clusters[this_mask] = ii+1

        this_sam_cluster_area = np.sum(this_mask)
        sam_cluster_areas.append(this_sam_cluster_area)

        intersection_mask = intersection == this_id
        this_intersection_area = np.sum(intersection_mask)
        intersection_areas.append(this_intersection_area)

    intersection_metrics = pd.DataFrame({'cluster_id': intersect_cluster_ids, 'cluster_area': sam_cluster_areas, 'intersection_area': intersection_areas})
    intersection_metrics['target_size'] = cluster_size
    intersection_metrics['overlap_ratio'] = intersection_metrics['intersection_area'] / intersection_metrics['cluster_area']

    # compute modification areas
    overlap_clusters = np.zeros(all_mask_classes.shape)
    veto_clusters = np.zeros(all_mask_classes.shape)
    full_overlap_clusters = np.zeros(all_mask_classes.shape)

    for idx, row in intersection_metrics.iterrows():

        this_cluster_id = row['cluster_id']
        xx_inds = all_mask_classes == this_cluster_id

        if row['overlap_ratio'] > overlap_threshold:

            overlap_clusters[xx_inds] = this_cluster_id

        if row['overlap_ratio'] < veto_threshold:

            veto_clusters[xx_inds] = 1
            
        if row['target_size'] == row['intersection_area']:
            
            full_overlap_clusters[xx_inds] = 1
            
    modified_cluster = this_model_cluster.copy()
    modified_cluster[veto_clusters == 1] = 0
    modified_cluster[overlap_clusters > 0] = 1
    # modified_cluster[full_overlap_clusters > 0] = 1 # optional; not really tested yet

    return modified_cluster

In [None]:
cluster_areas.tail(10)

In [None]:
this_cluster_id = cluster_areas.index[-2]
this_model_cluster = (clusters_with_id == this_cluster_id)*1
show_cluster(this_model_cluster, sample)

In [None]:
this_modified_cluster = modify_cluster(this_model_cluster, all_mask_classes, overlap_threshold=0.4, veto_threshold=0.05)
show_cluster(this_modified_cluster, sample)

In [None]:
min_pixel_size = 50

In [None]:
all_modified_clusters = np.zeros(clusters_with_id.shape)

In [None]:
plt.imshow(this_model_cluster)

In [None]:
plt.imshow(this_modified_cluster)

In [None]:
for ii in range(0, cluster_areas.shape[0]):
    
    this_cluster_id = cluster_areas.index[ii]
    this_area = cluster_areas.iloc[ii].squeeze()
    
    if this_area >= min_pixel_size:
        this_model_cluster = (clusters_with_id == this_cluster_id)*1
        this_modified_cluster = modify_cluster(this_model_cluster, all_mask_classes, overlap_threshold=0.4, veto_threshold=0.05)
        
        xx_inds = this_modified_cluster > 0
        all_modified_clusters[xx_inds] = 1

In [None]:
fig = plt.figure(figsize=(15,30))

plt.subplot(1,3,1)
plt.imshow(sample["image"].numpy().transpose(1, 2, 0))
plt.title('Input image')

plt.subplot(1,3,2)
plt.imshow(pred)
plt.title('Model predictions')

plt.subplot(1,3,3)
plt.imshow(all_modified_clusters)
plt.title('Modified model predictions')

plt.show()

In [None]:
np.unique(pred)

In [None]:
labels = sample['labels']
labels = labels.cpu().detach().numpy()[0, :, :]

In [None]:
true_pos, true_neg, false_pos, false_neg = classification_cases(labels, all_modified_clusters)
pred_metrics_modified = prediction_metrics(true_pos, true_neg, false_pos, false_neg)
    
# Translate into true / false positives / negatives:
classes = np.zeros(true_pos.shape)
classes[true_neg] = 0
classes[false_pos] = 1
classes[false_neg] = 2
classes[true_pos] = 3
classes_masked_modified = np.ma.masked_where(classes == 0, classes)

In [None]:
np.unique(classes)

In [None]:
show_image_and_true_false_classifications(sample, classes_masked_modified)

In [None]:
show_image_and_true_false_classifications(sample, classes_masked)

In [None]:
pred_metrics

In [None]:
pred_metrics_modified

## Predict for given points

Predict segmentation at given point(s):

In [None]:
input_points = np.array([[yval, xval]])
input_labels = np.array([1])

In [None]:
masks, scores, logits = predictor.predict(
    point_coords=input_points,
    point_labels=input_labels,
    multimask_output=True,
)

In [None]:
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(sam_image)
    show_mask(mask, plt.gca())
    show_points(input_points, input_labels, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  
  

Alternatively, with multiple points:

In [None]:
import random

In [None]:
rand_inds = random.sample(range(0, len(cluster_x_vals)), 10)

In [None]:
x_vals = cluster_x_vals[rand_inds]
y_vals = cluster_y_vals[rand_inds]

In [None]:
input_points = np.array([[y_vals[ii], x_vals[ii]] for ii in range(0, len(x_vals))])
input_labels = np.array([1 for ii in range(0, len(x_vals))])

In [None]:
masks, scores, logits = predictor.predict(
    point_coords=input_points,
    point_labels=input_labels,
    multimask_output=True,
)

In [None]:
masks.shape

In [None]:
for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(img_vals_255_uint)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  
  