In [1]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
import pandas as pd

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os
# Change directory to your folder
os.chdir('/content/drive/MyDrive/MLHCImages/Two-Shot')

In [21]:
# Define global variables
color_map = {
    'Urban': [170,126,63],
    'Forest': [42, 164, 48],
    'Water': [40, 86, 201],
    'Clouds': [255,255,255],
}

# 20 images we look into
img_list = [
    "image_5001_2016-06-05",
    "image_5001_2017-01-29",
    "image_5001_2018-02-11",
    "image_5001_2018-12-02",
    "image_50001_2016-01-10",
    "image_50001_2016-03-06",
    "image_50001_2017-07-23",
    "image_50001_2018-09-02",
    "image_54001_2016-02-07",
    "image_54001_2016-04-17",
    "image_54001_2017-02-19",
    "image_54001_2018-02-18",
    "image_73001_2016-07-03",
    "image_73001_2017-04-02",
    "image_73001_2018-03-04",
    "image_73001_2018-10-21",
    "image_76001_2016-09-04",
    "image_76001_2017-02-05",
    "image_76001_2018-02-25",
    "image_76001_2018-12-16"
]

img_list = ["output_" + img for img in img_list] # comment out if you are using it for one-shot



In [15]:
def dice_coefficient(mask_true, mask_pred):
    """
    Calculate the Dice coefficient, a measure of set similarity.

    Parameters:
    mask_true (np.array): A binary array indicating the true regions of interest.
    mask_pred (np.array): A binary array indicating the predicted regions of interest.

    Returns:
    float: The Dice coefficient, which ranges from 0 (no overlap) to 1 (perfect overlap).
           This coefficient is calculated as (2 * intersection) / (sum of the sizes of both masks).

    """
    intersection = np.sum(mask_true == mask_pred)
    total = mask_true.size + mask_pred.size

    dice = (2.0 * intersection) / (total)

    return dice


In [19]:
def get_metrics(filename, threshold, class_label, showMasks=False):
    """
    Load an image and its labeled counterpart, apply a threshold to generate a prediction mask based on brightness, and
    calculate the F1 score and Dice coefficient comparing the prediction to the ground truth for a specified class.

    Parameters:
        filename (str): The name of the file without the file extension.
        threshold (float): The brightness threshold for generating the predicted mask.
        class_label (str): The label of the class for which the mask is generated. This label also determines the subdirectory for image storage.
        showMasks (bool, optional): A flag to indicate whether the predicted and true masks should be displayed. Default is False.

    Returns:
        tuple: A tuple containing the F1 score and Dice coefficient of the predicted mask compared to the true mask.

    The function first constructs the file paths for the original and labeled images based on the given filename and class label.
    It then processes the images to conform them to the needed size and color channels. The brightness of the original image is
    evaluated to create a binary prediction mask using the specified threshold. Both the original and labeled images are resized
    and the mask for the given class label is extracted from the labeled image. Optionally, both the prediction and true masks can
    be displayed. Finally, the function computes and returns the F1 score and Dice coefficient between the predicted and true masks.

    Note: To use this function for two-shot learning, comment out the filename replacement line.
    """
    # Open the image
    image_path = class_label+'/'+filename+".png"
    print(image_path)
    image = Image.open(image_path)
    image = image.convert("RGB")
    image = image.resize((448, 448))
    labeled_filename = filename.replace('output_', '')  # comment out if using for one-shot
    #labeled_filename = filename    # uncomment if using for one-shot
    labeled_image_path = 'labeled/'+labeled_filename+"_colored.png"
    print(labeled_image_path)
    labeled_image = Image.open(labeled_image_path)
    if labeled_image.getbands() != (3):
        labeled_image = labeled_image.convert('RGB')

    # Convert the image to a numpy array
    image_array = np.array(image)
    # Get brightness array
    brightness_array = np.mean(image_array, axis=2)
    # Threshold the brightness to get predicted mask
    pred = (brightness_array > threshold)

    # Convert ground truth image to numpy array
    labeled_image_array = np.array(labeled_image.resize((448, 448)))
    # Get mask for chosen class
    mask = np.all(labeled_image_array == color_map[class_label], axis=-1)
    # Plot ground truth mask if True
    if showMasks:
        print('Predicted Mask')
        plt.imshow(pred)
        plt.show()
        print('True Mask')
        plt.imshow(mask)
        plt.show()
    # Calculate F1 Score
    f1 = f1_score(mask.flatten(), pred.flatten())
    # Calculate Dice Score
    dice_score = dice_coefficient(mask, pred)
    return f1, dice_score


In [22]:
# Get metrics for all 20 images
f1_scores = []
dice_scores = []
for img in img_list:
  f1score, dice = get_metrics(img, 100 , 'Water')
  f1_scores.append(f1score)
  dice_scores.append(dice)

Water/output_image_5001_2016-06-05.png
labeled/image_5001_2016-06-05_colored.png
Water/output_image_5001_2017-01-29.png
labeled/image_5001_2017-01-29_colored.png
Water/output_image_5001_2018-02-11.png
labeled/image_5001_2018-02-11_colored.png
Water/output_image_5001_2018-12-02.png
labeled/image_5001_2018-12-02_colored.png
Water/output_image_50001_2016-01-10.png
labeled/image_50001_2016-01-10_colored.png
Water/output_image_50001_2016-03-06.png
labeled/image_50001_2016-03-06_colored.png
Water/output_image_50001_2017-07-23.png
labeled/image_50001_2017-07-23_colored.png
Water/output_image_50001_2018-09-02.png
labeled/image_50001_2018-09-02_colored.png
Water/output_image_54001_2016-02-07.png
labeled/image_54001_2016-02-07_colored.png


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Water/output_image_54001_2016-04-17.png
labeled/image_54001_2016-04-17_colored.png
Water/output_image_54001_2017-02-19.png
labeled/image_54001_2017-02-19_colored.png
Water/output_image_54001_2018-02-18.png
labeled/image_54001_2018-02-18_colored.png
Water/output_image_73001_2016-07-03.png
labeled/image_73001_2016-07-03_colored.png
Water/output_image_73001_2017-04-02.png
labeled/image_73001_2017-04-02_colored.png
Water/output_image_73001_2018-03-04.png
labeled/image_73001_2018-03-04_colored.png
Water/output_image_73001_2018-10-21.png
labeled/image_73001_2018-10-21_colored.png
Water/output_image_76001_2016-09-04.png
labeled/image_76001_2016-09-04_colored.png
Water/output_image_76001_2017-02-05.png
labeled/image_76001_2017-02-05_colored.png
Water/output_image_76001_2018-02-25.png
labeled/image_76001_2018-02-25_colored.png
Water/output_image_76001_2018-12-16.png
labeled/image_76001_2018-12-16_colored.png


In [23]:
# example of output dataframe
result = pd.DataFrame({'Image': img_list, 'F1': f1_scores, 'Dice': dice_scores})
result

Unnamed: 0,Image,F1,Dice
0,output_image_5001_2016-06-05,0.0,0.994734
1,output_image_5001_2017-01-29,0.0,0.989308
2,output_image_5001_2018-02-11,0.0,0.948471
3,output_image_5001_2018-12-02,0.0,0.745556
4,output_image_50001_2016-01-10,0.0,0.944351
5,output_image_50001_2016-03-06,5.5e-05,0.820108
6,output_image_50001_2017-07-23,0.0,0.002566
7,output_image_50001_2018-09-02,0.0,0.917416
8,output_image_54001_2016-02-07,0.0,1.0
9,output_image_54001_2016-04-17,0.0,0.86803


In [24]:
# calculate metrics for all four classes
water_dice = []
for img in img_list:
  f1score, dice = get_metrics(img, 100 , 'Water')
  water_dice.append(dice)

urban_dice = []
for img in img_list:
  f1score, dice = get_metrics(img, 100 , 'Urban')
  urban_dice.append(dice)

cloud_dice = []
for img in img_list:
  f1score, dice = get_metrics(img, 100 , 'Clouds')
  f1_scores.append(f1score)
  cloud_dice.append(dice)

forest_dice = []
for img in img_list:
  f1score, dice = get_metrics(img, 100 , 'Forest')
  f1_scores.append(f1score)
  forest_dice.append(dice)

Water/output_image_5001_2016-06-05.png
labeled/image_5001_2016-06-05_colored.png
Water/output_image_5001_2017-01-29.png
labeled/image_5001_2017-01-29_colored.png
Water/output_image_5001_2018-02-11.png
labeled/image_5001_2018-02-11_colored.png
Water/output_image_5001_2018-12-02.png
labeled/image_5001_2018-12-02_colored.png
Water/output_image_50001_2016-01-10.png
labeled/image_50001_2016-01-10_colored.png
Water/output_image_50001_2016-03-06.png
labeled/image_50001_2016-03-06_colored.png
Water/output_image_50001_2017-07-23.png
labeled/image_50001_2017-07-23_colored.png
Water/output_image_50001_2018-09-02.png
labeled/image_50001_2018-09-02_colored.png
Water/output_image_54001_2016-02-07.png
labeled/image_54001_2016-02-07_colored.png


  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))


Water/output_image_54001_2016-04-17.png
labeled/image_54001_2016-04-17_colored.png
Water/output_image_54001_2017-02-19.png
labeled/image_54001_2017-02-19_colored.png
Water/output_image_54001_2018-02-18.png
labeled/image_54001_2018-02-18_colored.png
Water/output_image_73001_2016-07-03.png
labeled/image_73001_2016-07-03_colored.png
Water/output_image_73001_2017-04-02.png
labeled/image_73001_2017-04-02_colored.png
Water/output_image_73001_2018-03-04.png
labeled/image_73001_2018-03-04_colored.png
Water/output_image_73001_2018-10-21.png
labeled/image_73001_2018-10-21_colored.png
Water/output_image_76001_2016-09-04.png
labeled/image_76001_2016-09-04_colored.png
Water/output_image_76001_2017-02-05.png
labeled/image_76001_2017-02-05_colored.png
Water/output_image_76001_2018-02-25.png
labeled/image_76001_2018-02-25_colored.png
Water/output_image_76001_2018-12-16.png
labeled/image_76001_2018-12-16_colored.png
Urban/output_image_5001_2016-06-05.png
labeled/image_5001_2016-06-05_colored.png
Urban/

In [25]:
# Display final Dice Coefficients
final_result = pd.DataFrame({'Image': img_list, 'Urban': urban_dice, 'Forest': forest_dice, 'Water': water_dice, 'Clouds': cloud_dice})
final_result

Unnamed: 0,Image,Urban,Forest,Water,Clouds
0,output_image_5001_2016-06-05,0.179927,0.901063,0.994734,0.977375
1,output_image_5001_2017-01-29,0.243418,0.903236,0.989308,0.983528
2,output_image_5001_2018-02-11,0.276517,0.857811,0.948471,0.910311
3,output_image_5001_2018-12-02,0.454346,0.720703,0.745556,0.74857
4,output_image_50001_2016-01-10,0.446847,0.696065,0.944351,0.999437
5,output_image_50001_2016-03-06,0.48448,0.634855,0.820108,0.923449
6,output_image_50001_2017-07-23,0.002187,0.992865,0.002566,0.977763
7,output_image_50001_2018-09-02,0.573207,0.573058,0.917416,0.936877
8,output_image_54001_2016-02-07,0.052087,0.974555,1.0,0.999043
9,output_image_54001_2016-04-17,0.237878,0.793871,0.86803,0.950748
