# Install Packages

In [None]:
# Import Colab specific packages and mount to Google Drive folder
from google.colab.patches import cv2_imshow
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)

# !apt-get update
# !apt install msttcorefonts -qq

# Import all necessary packages for code to run

import os
import pathlib
import numpy as np
import math
import pandas as pd
import glob
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.patches import Circle
import seaborn as sns
import copy
from scipy import ndimage
from scipy.spatial import procrustes
import warnings
import pickle
import sklearn
import tensorflow as tf
from tensorflow.python.client import device_lib
from random import randint
from tqdm import tqdm
import cv2
#from IPython.display import Image
import PIL
# !pip3 install fpdf
# from fpdf import FPDF
# #import PyPDF2
# !pip3 install PyPDF2
# #from PyPDF2 import PdfFileMerger, PdfFileReader
# from PyPDF2 import PdfMerger, PdfReader
import time # measure how long training takes
import skimage
from skimage.measure import label, find_contours
from skimage.metrics import hausdorff_distance #, hausdorff_pair
from skimage.morphology import skeletonize, convex_hull_image
from skimage import draw
import random

# # Packages and programs required to run Octave code in this notebook
# #print('ATTEMPTING TO INSTALL OCT2PY')
# !pip3 install oct2py #--no-deps
# #print('ATTEMPTING TO INSTALL OCTAVE')
# !apt install octave # makes it possible to run matlab scripts
# #print('ATTEMPTING TO INSTALL OCTAVE DEV TOOLS')
# !apt install liboctave-dev
# #!pip3 install --no-deps -e '/content/gdrive/My Drive/Colab Notebooks/Sector Project/oct2py-5.5.1'
# #print('BELOW IS AN ERROR IN SETUP!')
# import oct2py
# from oct2py import octave
# %load_ext oct2py.ipython

tf.config.list_physical_devices('GPU')
device_lib.list_local_devices()

# Get info on running device
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)

# Check if you are using a high-ram runtime
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
    print('Not using a high-RAM runtime')
else:
    print('You are using a high-RAM runtime!')

# Functions to Load

## Image Processing

In [None]:
H = 1024
W = 1024
num_classes = 3

# function to crop image
def crop(img):
    i, j = img.nonzero()[:2]
    x_min = i.min()
    x_max = i.max() + 1
    y_min = j.min()
    y_max = j.max() + 1
    return img[x_min:x_max, y_min:y_max], [x_min, x_max, y_min, y_max]

def trim_border(x):
    if x.shape[0] > x.shape[1]:
        pixel_diff = x.shape[0] - x.shape[1]
        if pixel_diff % 2 == 0:
            x_cropped = x[(pixel_diff // 2):(x.shape[0]-(pixel_diff // 2)),:,:]
        else:
            x_cropped = x[math.ceil(pixel_diff / 2):x.shape[0]-(math.floor(pixel_diff / 2)),:,:]

    elif x.shape[0] < x.shape[1]:
        pixel_diff = x.shape[1] - x.shape[0]
        if pixel_diff % 2 == 0:
            x_cropped = x[:,(pixel_diff // 2):x.shape[1]-(pixel_diff // 2),:]
        else:
            x_cropped = x[:,math.ceil(pixel_diff / 2):x.shape[1]-(math.floor(pixel_diff / 2)),:]

    else:
        x_cropped = x
    #print('After Resizing: ' + 'Width: ' + str((x_cropped.shape)[0]) + ', Height: ' + str((x_cropped.shape)[1]))
    #raise NameError('Image is saved.')
    return x_cropped

def read_image(x):
    x = cv2.imread(x,cv2.IMREAD_COLOR)
    x = trim_border(x) # this should be called 
    x = cv2.resize(x,(W,H))
    x = x / 255.0
    x = x.astype(np.float32)
    return x

def read_mask(x):
    x = cv2.imread(x,cv2.IMREAD_GRAYSCALE)
    #print(x.shape)
    x = cv2.resize(x,(W,H))
    if num_classes == 2:
        x = x / 255.0
        # This is necessary because for some reason binary images automatically have 0 and 255 encoded.
        #x = cv2.resize(x,(16,16))
        #print(x.shape)
    x = x.astype(np.int32)
    return x

def tf_dataset(x,y, batch=1):
    dataset = tf.data.Dataset.from_tensor_slices((x,y))
    dataset = dataset.shuffle(buffer_size=500)
    dataset = dataset.map(preprocess)
    dataset = dataset.batch(batch)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(2)
    return dataset

def preprocess(x,y):
    def f(x,y):
        x = x.decode()
        y = y.decode()

        image = read_image(x)
        mask = read_mask(y)

        return image, mask

    image, mask = tf.numpy_function(f, [x, y], [tf.float32, tf.int32])
    mask = tf.one_hot(mask, num_classes, dtype=tf.int32)
    image.set_shape([H,W,3]) # does not change
    mask.set_shape([H,W,num_classes]) # change last argument dependeing on how many classes you want to do for segmentation

    return image, mask

## Quick Processes

In [None]:
def find_mode(my_array):
    vals, counts = np.unique(my_array, return_counts = True)
    index = np.argmax(counts)
    return vals[index], counts[index]

def get_count_breakdown(my_array):
    if len(my_array) > 0:
        vals, counts = np.unique(my_array, return_counts = True)
        max_val = np.max(vals)
        vals_list = list(vals)
        sorted_counts = []
        for i in range(0, max_val+1):
            this_val = vals_list.index(i)
            sorted_counts.append(counts[this_val])
        return np.array(sorted_counts)
    else:
        return np.array([])

## Plotting

In [None]:
def addlabels_centered(x,y,fs):
    for i in range(len(x)):
        ax.text(x[i], y[i]+5, y[i], ha = 'center', fontfamily="serif", fontsize=fs)

def addlabels_initial(x,y,fs):
    for i in range(len(x)):
        ax.text(i-(0.25), y[i]+5, y[i], ha = 'center', fontfamily="serif", fontsize=fs)

def addlabels_prediction(x,y,fs):
    for i in range(len(x)):
        ax.text(i, y[i]+5, y[i], ha = 'center', fontfamily="serif", fontsize=fs)

def addlabels_truemarks(x,y,fs):
    for i in range(len(x)):
        ax.text(i+(0.25), y[i]+5, y[i], ha = 'center', fontfamily="serif", fontsize=fs)

def addlabels_initial_ax(x,y,fs, this_axis):
    for i in range(len(x)):
        ax[this_axis].text(i-(0.25), y[i]+5, y[i], ha = 'center', fontfamily="serif", fontsize=fs)

def addlabels_prediction_ax(x,y,fs, this_axis):
    for i in range(len(x)):
        ax[this_axis].text(i, y[i]+5, y[i], ha = 'center', fontfamily="serif", fontsize=fs)

def addlabels_truemarks_ax(x,y,fs, this_axis):
    for i in range(len(x)):
        ax[this_axis].text(i+(0.25), y[i]+5, y[i], ha = 'center', fontfamily="serif", fontsize=fs)

# Load Data Table of Colony Info

In [None]:
sector_project_folder = '/content/gdrive/My Drive/Colab Notebooks/psi-sectored-classification/Pipeline'
merged_table = pd.read_csv(sector_project_folder + '/Pub Data/2021_07_01_merged_table.csv')
merged_table['Correct # Sectors'] = np.abs(merged_table['Pred # Sectors'] - merged_table['True # Sectors']) == 0
correct_entries = merged_table[merged_table['Correct # Sectors?'] == True]
correct_entries_1 = correct_entries[correct_entries['Set'] == 1]
correct_entries_2 = correct_entries[correct_entries['Set'] == 2]
colony_groups = correct_entries[correct_entries['Pred # Sectors'] == 1]
print(colony_groups[(colony_groups['Set'] == 2) & (colony_groups['Plate Name'] == 'Plate_4.jpg')])
#incorrect_entries = merged_table[merged_table['Correct # Sectors?'] == False]
#colony_groups = incorrect_entries[incorrect_entries['True # Sectors'] == 2]
#print(colony_groups[(colony_groups['Set'] == 2)])
print(merged_table.columns)

In [None]:
merged_table

# Place labels on every detection

In [None]:
# [PSI+]: 0 red regions
# [psi-]: 0 white regions
# Sx (Sectored x): At least 1 white region and exactly x red regions
# NA: Not quantifiable

full_colony_states_before = np.array(['UNFILLED' for i in range(0, len(merged_table))])
full_colony_states_after = np.array(['UNFILLED' for i in range(0, len(merged_table))])
full_colony_states_true = np.array(['UNFILLED' for i in range(0, len(merged_table))])
#colony_states_set = set(colony_states)
#print(colony_states_set)

max_sector_count_before = max(merged_table['Initial # Regions'])
max_sector_count_after = max(merged_table['Pred # Sectors'])
max_sector_count_true = max(merged_table['True # Sectors'])

max_sector_count_all = max([max_sector_count_before, max_sector_count_after])

# NA: Get all non-quantifiable colonies

#colony_states_before[merged_table['Quantifiable'] == False] = 'NA'
#colony_states_after[merged_table['Quantifiable'] == False] = 'NA'
#colony_states_true[merged_table['Quantifiable'] == False] = 'NA'

# [PSI+]: Get all quantifiable colonies with no red regions

full_colony_states_before[(merged_table['(BC) Stable'] == True)] = '[PSI+]'
full_colony_states_after[(merged_table['(AC) Stable'] == True)] = '[PSI+]'
full_colony_states_true[(merged_table['Quantifiable'] == True) & (merged_table['Quantifiable Stable'] == True)] = '[PSI+]'

# [psi-]: Get all quantifiable colonies with no white regions

full_colony_states_before[(merged_table['(BC) Cured'] == True)] = '[psi-]'
full_colony_states_after[(merged_table['(AC) Cured'] == True)] = '[psi-]'
full_colony_states_true[(merged_table['Quantifiable'] == True) & (merged_table['Quantifiable Cured'] == True)] = '[psi-]'

# Sx: Get all quantifiable colonies with at least 1 white region and exactly x red regions 

for num_regions in range(1, max_sector_count_all+1):
    full_colony_states_before[(merged_table['(BC) Cured'] == False) & (merged_table['(BC) Stable'] == False) & (merged_table['Initial # Regions'] == num_regions)] = str('S' + str(num_regions))
    full_colony_states_after[(merged_table['(AC) Cured'] == False) & (merged_table['(AC) Stable'] == False) & (merged_table['Pred # Sectors'] == num_regions)] = str('S' + str(num_regions))
    full_colony_states_true[(merged_table['Quantifiable'] == True) & (merged_table['Quantifiable Sectored'] == True) & (merged_table['True # Sectors'] == num_regions)] = str('S' + str(num_regions))

print(np.unique(full_colony_states_before))
print(np.unique(full_colony_states_after))
print(np.unique(full_colony_states_true))

#unmarked_locations = np.where(colony_states_true == 'UNFILLED')

# Make corrections to the table for unfilled locations


# Display any colony locations what are marked as UNFILLED

# colony_row = merged_table.iloc[unmarked_locations]
# print(colony_row)
# print(colony_row.index)

# If every location has been filled, then add these to the merged table
merged_table['Label Before'] = full_colony_states_before
merged_table['Label After'] = full_colony_states_after
merged_table['Label True'] = full_colony_states_true
# counter = 0

# for ind in colony_row.index:
#     colony_number = colony_row['Colony Number'].iloc[counter]
#     plate_name = colony_row['Plate Name'].iloc[counter]
#     set_number = colony_row['Set'].iloc[counter]

#     #if counter == 0:
#     #    merged_table['Quantifiable Stable'] = 

#     # Get image
#     if set_number == 2:
#         image_to_display = read_image(sector_project_folder + '/Real Images/Wes Plates/Set 2 Prepro/' + plate_name)*255
#     image_to_display = cv2.rectangle(image_to_display, (colony_row['Side Left'].iloc[counter], colony_row['Side Top'].iloc[counter]), (colony_row['Side Right'].iloc[counter], colony_row['Side Bottom'].iloc[counter]), (255, 0, 0), 2)
#     #cv2_imshow(image_to_display)


#     counter += 1

# Get Annotated Colony Data Only

## Colony States (without sector counts)

In [None]:
quantifiable_colony_data = merged_table[merged_table['Quantifiable'] == True]
print('Number of quantifiable colonies:', str(len(quantifiable_colony_data)))

max_true = np.max(quantifiable_colony_data['True # Sectors'])

true_correct_counts = []

plot_labels = ['[PSI+]', '[psi-]', 'Sectored']

# get the [PSI+] and [psi-] colonies
white_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[PSI+]']
red_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[psi-]']

true_correct_counts.append(len(white_colony_data_true))
true_correct_counts.append(len(red_colony_data_true))

# Get the sectored colonies
sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'].str.startswith('S')]
true_correct_counts.append(len(sector_colony_data_true))


# for j in range(1, max_true+2):
#     sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == 'S' + str(j)]
#     true_correct_counts.append(len(sector_colony_data_true))

#     plot_labels.append('S'+str(j))

x = np.arange(3)
width = 0.25

fig, ax = plt.subplots(figsize=(15,5), sharey=True)

rects1 = ax.bar(x, true_correct_counts, width, color='green', label='Manual Counts')
#rects2 = ax.bar(x+width/2, post_counts_after, width, color='red', label='With Purity Correction')
#rects2 = ax.bar(x + width/2, all_counts, width, label='All Colonies', color='red')

ax.set_ylim(bottom=0, top=(1.1*max(true_correct_counts)))

ax.set_title('Annotated States of Quantifiable Colonies (N=' + str(np.sum(true_correct_counts)) + ')',fontfamily="serif", fontsize=18)
ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)

ax.set_ylabel('Frequency', fontsize=14)
#ax[1].set_ylabel('Frequency', fontsize=14)
#ax[2].set_ylabel('Frequency', fontsize=14)

ax.axvline(x = 0.5, color = 'k', linestyle = '--')
ax.axvline(x = 1.5, color = 'k', linestyle = '--')

ax.set_xticks(x)
ax.set_xticklabels(plot_labels)
ax.tick_params(axis="both", labelsize=12)

addlabels_centered(x, true_correct_counts, 10)

plt.show()

## Colony states (with sector counts)

In [None]:
quantifiable_colony_data = merged_table[merged_table['Quantifiable'] == True]
print('Number of quantifiable colonies:', str(len(quantifiable_colony_data)))

max_true = np.max(quantifiable_colony_data['True # Sectors'])

true_correct_counts = []

plot_labels = ['[PSI+]', '[psi-]']

# get the [PSI+] and [psi-] colonies
white_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[PSI+]']
red_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[psi-]']

true_correct_counts.append(len(white_colony_data_true))
true_correct_counts.append(len(red_colony_data_true))

# Get the sectored colonies
# sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'].str.startswith('S')]
# true_correct_counts.append(len(sector_colony_data_true))


for j in range(1, max_true+2):
    sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == 'S' + str(j)]
    true_correct_counts.append(len(sector_colony_data_true))

    plot_labels.append('S'+str(j))

x = np.arange(max_true+3)
width = 0.25

fig, ax = plt.subplots(figsize=(15,5), sharey=True)

rects1 = ax.bar(x, true_correct_counts, width, color='green', label='Manual Counts')
#rects2 = ax.bar(x+width/2, post_counts_after, width, color='red', label='With Purity Correction')
#rects2 = ax.bar(x + width/2, all_counts, width, label='All Colonies', color='red')

ax.set_ylim(bottom=0, top=(1.1*max(true_correct_counts)))

ax.set_title('Annotated States of Quantifiable Colonies (N=' + str(np.sum(true_correct_counts)) + ')',fontfamily="serif", fontsize=18)
ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)

ax.set_ylabel('Frequency', fontsize=14)
#ax[1].set_ylabel('Frequency', fontsize=14)
#ax[2].set_ylabel('Frequency', fontsize=14)

ax.axvline(x = 0.5, color = 'k', linestyle = '--')
ax.axvline(x = 1.5, color = 'k', linestyle = '--')

ax.set_xticks(x)
ax.set_xticklabels(plot_labels)
ax.tick_params(axis="both", labelsize=12)

addlabels_centered(x, true_correct_counts, 10)

plt.show()

# Plot Predictions in General

## Colony States (without sector counts)

In [None]:
post_counts_before = []
post_counts_after = []

plot_labels = ['[PSI+]', '[psi-]', 'Sectored']

print(np.unique(merged_table['Label Before']))

# get the [PSI+] colonies
white_colony_data_before = merged_table[merged_table['Label Before'] == '[PSI+]']
white_colony_data_after = merged_table[merged_table['Label After'] == '[PSI+]']

post_counts_before.append(len(white_colony_data_before))
post_counts_after.append(len(white_colony_data_after))

# get the [psi+] colonies
red_colony_data_before = merged_table[merged_table['Label Before'] == '[psi-]']
red_colony_data_after = merged_table[merged_table['Label After'] == '[psi-]']

post_counts_before.append(len(red_colony_data_before))
post_counts_after.append(len(red_colony_data_after))

# get the sectored colonies
max_sector_counts = np.nanmax( [ np.nanmax(merged_table['Initial # Regions']), np.nanmax(merged_table['Pred # Sectors']), np.nanmax(merged_table['True # Sectors'])]).astype(int)
sector_colony_data_before = merged_table[merged_table['Label Before'].str.startswith('S')]
sector_colony_data_after = merged_table[merged_table['Label After'].str.startswith('S')]
post_counts_before.append(len(sector_colony_data_before))
post_counts_after.append(len(sector_colony_data_after))

# for j in range(1, max_sector_counts+1):
#     sector_colony_data_before = merged_table[merged_table['Label Before'] == 'S' + str(j)]
#     sector_colony_data_after = merged_table[merged_table['Label After'] == 'S' + str(j)]

#     post_counts_before.append(len(sector_colony_data_before))
#     post_counts_after.append(len(sector_colony_data_after))

#     plot_labels.append('S'+str(j))

x = np.arange(3)
#x = np.arange(np.max([3, max_sector_counts]) + 2)
width = 0.25

fig, ax = plt.subplots(figsize=(15,5), sharey=True)

rects1 = ax.bar(x-width/2, post_counts_before, width, color='blue', label='Original Prediction')
rects2 = ax.bar(x+width/2, post_counts_after, width, color='red', label='With Purity Correction')
#rects2 = ax.bar(x + width/2, all_counts, width, label='All Colonies', color='red')

ax.set_ylim(bottom=0, top=(1.1*max(post_counts_before + post_counts_after)))

ax.set_title('Predicted States of Detected Colonies (N=' + str(np.sum(post_counts_before)) + ')',fontfamily="serif", fontsize=18)
ax.set_xlabel('Colony States', fontsize=14)
ax.set_xlabel('Colony States', fontsize=14)
ax.set_xlabel('Colony States', fontsize=14)

ax.set_ylabel('Frequency', fontsize=14)
#ax[1].set_ylabel('Frequency', fontsize=14)
#ax[2].set_ylabel('Frequency', fontsize=14)

ax.axvline(x = 0.5, color = 'k', linestyle = '--')
ax.axvline(x = 1.5, color = 'k', linestyle = '--')

ax.set_xticks(x)
ax.set_xticklabels(plot_labels)
ax.tick_params(axis="both", labelsize=12)
ax.legend()

# Get the counts directly
#print(x-width/2)
#print(x+width/2)

addlabels_centered(x-width/2, post_counts_before, 10)
addlabels_centered(x+width/2, post_counts_after, 10)

# plt.xlabel('Red Sectors per Colony')
# plt.ylabel('Frequency')
# plt.xticks(range(max_sector_counts + 1), fontsize=12)
# plt.yticks(fontsize=12)
# ax.title.set_fontsize(16)
# ax.xaxis.label.set_fontsize(14)
# ax.yaxis.label.set_fontsize(14)
plt.show()

## Colony States (with sector counts)

In [None]:
post_counts_before = []
post_counts_after = []

plot_labels = ['[PSI+]', '[psi-]']

print(np.unique(merged_table['Label Before']))

# get the [PSI+] colonies
white_colony_data_before = merged_table[merged_table['Label Before'] == '[PSI+]']
white_colony_data_after = merged_table[merged_table['Label After'] == '[PSI+]']

post_counts_before.append(len(white_colony_data_before))
post_counts_after.append(len(white_colony_data_after))

# get the [psi+] colonies
red_colony_data_before = merged_table[merged_table['Label Before'] == '[psi-]']
red_colony_data_after = merged_table[merged_table['Label After'] == '[psi-]']

post_counts_before.append(len(red_colony_data_before))
post_counts_after.append(len(red_colony_data_after))

# get the sectored colonies
max_sector_counts = np.nanmax( [ np.nanmax(merged_table['Initial # Regions']), np.nanmax(merged_table['Pred # Sectors']), np.nanmax(merged_table['True # Sectors'])]).astype(int)
for j in range(1, max_sector_counts+1):
    sector_colony_data_before = merged_table[merged_table['Label Before'] == 'S' + str(j)]
    sector_colony_data_after = merged_table[merged_table['Label After'] == 'S' + str(j)]

    post_counts_before.append(len(sector_colony_data_before))
    post_counts_after.append(len(sector_colony_data_after))

    plot_labels.append('S'+str(j))

x = np.arange(np.max([3, max_sector_counts]) + 2)
width = 0.25

fig, ax = plt.subplots(figsize=(15,5), sharey=True)

rects1 = ax.bar(x-width/2, post_counts_before, width, color='blue', label='Original Prediction')
rects2 = ax.bar(x+width/2, post_counts_after, width, color='red', label='With Purity Correction')
#rects2 = ax.bar(x + width/2, all_counts, width, label='All Colonies', color='red')

ax.set_ylim(bottom=0, top=(1.1*max(post_counts_before + post_counts_after)))

ax.set_title('Predicted States of Detected Colonies (N=' + str(np.sum(post_counts_before)) + ')',fontfamily="serif", fontsize=18)
ax.set_xlabel('Colony States', fontsize=14)
ax.set_xlabel('Colony States', fontsize=14)
ax.set_xlabel('Colony States', fontsize=14)

ax.set_ylabel('Frequency', fontsize=14)
#ax[1].set_ylabel('Frequency', fontsize=14)
#ax[2].set_ylabel('Frequency', fontsize=14)

ax.axvline(x = 0.5, color = 'k', linestyle = '--')
ax.axvline(x = 1.5, color = 'k', linestyle = '--')

ax.set_xticks(x)
ax.set_xticklabels(plot_labels)
ax.tick_params(axis="both", labelsize=12)
ax.legend()

# Get the counts directly
#print(x-width/2)
#print(x+width/2)

addlabels_centered(x-width/2, post_counts_before, 10)
addlabels_centered(x+width/2, post_counts_after, 10)

# plt.xlabel('Red Sectors per Colony')
# plt.ylabel('Frequency')
# plt.xticks(range(max_sector_counts + 1), fontsize=12)
# plt.yticks(fontsize=12)
# ax.title.set_fontsize(16)
# ax.xaxis.label.set_fontsize(14)
# ax.yaxis.label.set_fontsize(14)
plt.show()

## Sector Sizes for Colonies with One Sector

In [None]:
single_sector_colony_data_after = merged_table[merged_table['Label After'] == 'S1']
single_sector_data = single_sector_colony_data_after['Red Area (Corr)'] / single_sector_colony_data_after['Colony Area (Corr)']

prop_size_mean = np.mean(single_sector_data)
prop_size_med = np.median(single_sector_data)

fig,ax = plt.subplots()

ax.hist(single_sector_data, label='Proportions')
ax.set_title('Single Sector Colonies: Proportion of Sector\nComprising Colony (N=' + str(len(single_sector_data)) + ')')
ax.set_xlabel('Proportion of Colony with Sector')
ax.set_ylabel('Frequency')
ax.set_xlim([0, 1])

ax.vlines(prop_size_mean, ax.get_ylim()[0], ax.get_ylim()[1], color='black', label='Mean')
ax.vlines(prop_size_med, ax.get_ylim()[0], ax.get_ylim()[1], color='red', label='Median')
ax.tick_params(axis="both", labelsize=12)
ax.legend()
plt.show()

# Plot Quantifiable Colony Data with True Counts

## Colony States (without sector counts)

In [None]:
quantifiable_colony_data = merged_table[merged_table['Quantifiable'] == True].reset_index()
print('Number of quantifiable colonies:', str(len(quantifiable_colony_data)))

max_true = np.max(quantifiable_colony_data['True # Sectors'])

post_counts_before = []
post_counts_after = []
true_correct_counts = []

plot_labels = ['[PSI+]', '[psi-]', 'Sectored']

# get the [PSI+] and [psi-] colonies
white_colony_data_before = quantifiable_colony_data[quantifiable_colony_data['Label Before'] == '[PSI+]']
white_colony_data_after = quantifiable_colony_data[quantifiable_colony_data['Label After'] == '[PSI+]']
white_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[PSI+]']

red_colony_data_before = quantifiable_colony_data[quantifiable_colony_data['Label Before'] == '[psi-]']
red_colony_data_after = quantifiable_colony_data[quantifiable_colony_data['Label After'] == '[psi-]']
red_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[psi-]']

post_counts_before.append(len(white_colony_data_before))
post_counts_after.append(len(white_colony_data_after))

post_counts_before.append(len(red_colony_data_before))
post_counts_after.append(len(red_colony_data_after))

true_correct_counts.append(len(white_colony_data_true))
true_correct_counts.append(len(red_colony_data_true))

# Get the sectored colonies
sector_colony_data_before = quantifiable_colony_data[quantifiable_colony_data['Label Before'].str.startswith('S')]
sector_colony_data_after = quantifiable_colony_data[quantifiable_colony_data['Label After'].str.startswith('S')]
sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'].str.startswith('S')]

post_counts_before.append(len(sector_colony_data_before))
post_counts_after.append(len(sector_colony_data_after))
true_correct_counts.append(len(sector_colony_data_true))


# for j in range(1, max_true+2):
#     sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == 'S' + str(j)]
#     true_correct_counts.append(len(sector_colony_data_true))

#     plot_labels.append('S'+str(j))

x = np.arange(3)
width = 0.25

print(post_counts_before)
print(post_counts_after)
print(true_correct_counts)

fig, ax = plt.subplots(figsize=(15,5), sharey=True)

rects1 = ax.bar(x-width/2, post_counts_before, width/2, color='blue', label='Original Prediction')
rects2 = ax.bar(x, post_counts_after, width/2, color='red', label='With Purity Correction')
rects3 = ax.bar(x+width/2, true_correct_counts, width/2, label='Manaul Counts', color='green')

ax.set_ylim(bottom=0, top=(1.1*max(post_counts_before + post_counts_after + true_correct_counts)))

ax.set_title('Annotated States of Quantifiable Colonies (N=' + str(np.sum(true_correct_counts)) + ')',fontfamily="serif", fontsize=18)
ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)

ax.set_ylabel('Frequency', fontsize=14)
#ax[1].set_ylabel('Frequency', fontsize=14)
#ax[2].set_ylabel('Frequency', fontsize=14)

ax.axvline(x = 0.5, color = 'k', linestyle = '--')
ax.axvline(x = 1.5, color = 'k', linestyle = '--')

ax.set_xticks(x)
ax.set_xticklabels(plot_labels)
ax.tick_params(axis="both", labelsize=12)

addlabels_centered(x-width/2, post_counts_before, 10)
addlabels_centered(x, post_counts_after, 10)
addlabels_centered(x+width/2, true_correct_counts, 10)

plt.show()

## Colony States (with sector counts)

In [None]:
quantifiable_colony_data = merged_table[merged_table['Quantifiable'] == True].reset_index()
print('Number of quantifiable colonies:', str(len(quantifiable_colony_data)))

max_true = np.max(quantifiable_colony_data['True # Sectors'])

post_counts_before = []
post_counts_after = []
true_correct_counts = []

plot_labels = ['[PSI+]', '[psi-]']

# get the [PSI+] and [psi-] colonies
white_colony_data_before = quantifiable_colony_data[quantifiable_colony_data['Label Before'] == '[PSI+]']
white_colony_data_after = quantifiable_colony_data[quantifiable_colony_data['Label After'] == '[PSI+]']
white_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[PSI+]']

red_colony_data_before = quantifiable_colony_data[quantifiable_colony_data['Label Before'] == '[psi-]']
red_colony_data_after = quantifiable_colony_data[quantifiable_colony_data['Label After'] == '[psi-]']
red_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[psi-]']

post_counts_before.append(len(white_colony_data_before))
post_counts_after.append(len(white_colony_data_after))

post_counts_before.append(len(red_colony_data_before))
post_counts_after.append(len(red_colony_data_after))

true_correct_counts.append(len(white_colony_data_true))
true_correct_counts.append(len(red_colony_data_true))

# Get the sectored colonies
# sector_colony_data_before = quantifiable_colony_data[quantifiable_colony_data['Label Before'].str.startswith('S')]
# sector_colony_data_after = quantifiable_colony_data[quantifiable_colony_data['Label After'].str.startswith('S')]
# sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'].str.startswith('S')]

#post_counts_before.append(len(sector_colony_data_before))
#post_counts_after.append(len(sector_colony_data_after))
#true_correct_counts.append(len(sector_colony_data_true))

max_sector_counts = np.max([np.nanmax(quantifiable_colony_data['Initial # Regions']), np.nanmax(quantifiable_colony_data['Pred # Sectors']), np.nanmax(quantifiable_colony_data['True # Sectors'])])

for j in range(1, max_sector_counts+1):
    sector_colony_data_before = quantifiable_colony_data[quantifiable_colony_data['Label Before'] == 'S' + str(j)]
    sector_colony_data_after = quantifiable_colony_data[quantifiable_colony_data['Label After'] == 'S' + str(j)]
    sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == 'S' + str(j)]
    
    post_counts_before.append(len(sector_colony_data_before))
    post_counts_after.append(len(sector_colony_data_after))
    true_correct_counts.append(len(sector_colony_data_true))

    plot_labels.append('S'+str(j))

x = np.arange(np.max([3, max_sector_counts+2]))
width = 0.25

print(post_counts_before)
print(post_counts_after)
print(true_correct_counts)

fig, ax = plt.subplots(figsize=(15,5), sharey=True)

rects1 = ax.bar(x-width, post_counts_before, width, color='blue', label='Original Prediction')
rects2 = ax.bar(x, post_counts_after, width, color='red', label='With Purity Correction')
rects3 = ax.bar(x+width, true_correct_counts, width, label='Manaul Counts', color='green')

ax.set_ylim(bottom=0, top=(1.1*max(post_counts_before + post_counts_after + true_correct_counts)))

ax.set_title('Annotated States of Quantifiable Colonies (N=' + str(np.sum(true_correct_counts)) + ')',fontfamily="serif", fontsize=18)
ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)

ax.set_ylabel('Frequency', fontsize=14)
#ax[1].set_ylabel('Frequency', fontsize=14)
#ax[2].set_ylabel('Frequency', fontsize=14)

ax.axvline(x = 0.5, color = 'k', linestyle = '--')
ax.axvline(x = 1.5, color = 'k', linestyle = '--')

ax.set_xticks(x)
ax.set_xticklabels(plot_labels)
ax.tick_params(axis="both", labelsize=12)

addlabels_centered(x-width, post_counts_before, 10)
addlabels_centered(x, post_counts_after, 10)
addlabels_centered(x+width, true_correct_counts, 10)

plt.show()

# Plot Correct Quantifiable Colony Data

## Colony States (without sector counts)

In [None]:
quantifiable_colony_data = merged_table[merged_table['Quantifiable'] == True].reset_index()
print('Number of quantifiable colonies:', str(len(quantifiable_colony_data)))

max_true = np.max(quantifiable_colony_data['True # Sectors'])

correct_counts_before = []
correct_counts_after = []
true_correct_counts = []

plot_labels = ['[PSI+]', '[psi-]', 'Sectored']

# get the [PSI+] and [psi-] colonies
correct_white_colony_data_before = quantifiable_colony_data[(quantifiable_colony_data['Label Before'] == '[PSI+]') & (quantifiable_colony_data['Label True'] == '[PSI+]')]
correct_white_colony_data_after = quantifiable_colony_data[(quantifiable_colony_data['Label After'] == '[PSI+]') & (quantifiable_colony_data['Label True'] == '[PSI+]')]
white_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[PSI+]']

correct_red_colony_data_before = quantifiable_colony_data[(quantifiable_colony_data['Label Before'] == '[psi-]') & (quantifiable_colony_data['Label True'] == '[psi-]')]
correct_red_colony_data_after = quantifiable_colony_data[(quantifiable_colony_data['Label After'] == '[psi-]') & (quantifiable_colony_data['Label True'] == '[psi-]')]
red_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[psi-]']

correct_counts_before.append(len(correct_white_colony_data_before))
correct_counts_after.append(len(correct_white_colony_data_after))

correct_counts_before.append(len(correct_red_colony_data_before))
correct_counts_after.append(len(correct_red_colony_data_after))

true_correct_counts.append(len(white_colony_data_true))
true_correct_counts.append(len(red_colony_data_true))

# Get the sectored colonies
correct_sector_colony_data_before = quantifiable_colony_data[(quantifiable_colony_data['Label Before'].str.startswith('S')) & (quantifiable_colony_data['Label True'].str.startswith('S'))]
correct_sector_colony_data_after = quantifiable_colony_data[(quantifiable_colony_data['Label After'].str.startswith('S')) & (quantifiable_colony_data['Label True'].str.startswith('S'))]
sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'].str.startswith('S')]

correct_counts_before.append(len(correct_sector_colony_data_before))
correct_counts_after.append(len(correct_sector_colony_data_after))
true_correct_counts.append(len(sector_colony_data_true))

max_sector_counts = np.max(quantifiable_colony_data[(quantifiable_colony_data['Label True'].str.startswith('S')) & (quantifiable_colony_data['True # Sectors'] > 0)])

# for j in range(1, max_sector_counts+1):
#     sector_colony_data_before = quantifiable_colony_data[quantifiable_colony_data['Label Before'] == 'S' + str(j)]
#     sector_colony_data_after = quantifiable_colony_data[quantifiable_colony_data['Label After'] == 'S' + str(j)]
#     sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == 'S' + str(j)]
    
#     post_counts_before.append(len(sector_colony_data_before))
#     post_counts_after.append(len(sector_colony_data_after))
#     true_correct_counts.append(len(sector_colony_data_true))

#     plot_labels.append('S'+str(j))

#x = np.arange(np.max([3, max_sector_counts+2]))
x = np.arange(3)
width = 0.25

print(correct_counts_before)
print(correct_counts_after)
print(true_correct_counts)

fig, ax = plt.subplots(figsize=(15,5), sharey=True)

rects1 = ax.bar(x-width/2, correct_counts_before, width/2, color='blue', label='Original Prediction')
rects2 = ax.bar(x, correct_counts_after, width/2, color='red', label='With Purity Correction')
rects3 = ax.bar(x+width/2, true_correct_counts, width/2, label='Manaul Counts', color='green')

ax.set_ylim(bottom=0, top=(1.1*max(correct_counts_before + correct_counts_after + true_correct_counts)))

ax.set_title('Correctly Classified of Quantifiable Colonies (N=' + str(np.sum(true_correct_counts)) + ')',fontfamily="serif", fontsize=18)
ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)

ax.set_ylabel('Frequency', fontsize=14)
#ax[1].set_ylabel('Frequency', fontsize=14)
#ax[2].set_ylabel('Frequency', fontsize=14)

ax.axvline(x = 0.5, color = 'k', linestyle = '--')
ax.axvline(x = 1.5, color = 'k', linestyle = '--')

ax.set_xticks(x)
ax.set_xticklabels(plot_labels)
ax.tick_params(axis="both", labelsize=12)

addlabels_centered(x-width/2, correct_counts_before, 10)
addlabels_centered(x, correct_counts_after, 10)
addlabels_centered(x+width/2, true_correct_counts, 10)

plt.show()

## Colony States (with sector counts)

In [None]:
quantifiable_colony_data = merged_table[merged_table['Quantifiable'] == True].reset_index()
print('Number of quantifiable colonies:', str(len(quantifiable_colony_data)))

max_true = np.max(quantifiable_colony_data['True # Sectors'])

correct_counts_before = []
correct_counts_after = []
true_correct_counts = []

plot_labels = ['[PSI+]', '[psi-]']

# get the [PSI+] and [psi-] colonies
correct_white_colony_data_before = quantifiable_colony_data[(quantifiable_colony_data['Label Before'] == '[PSI+]') & (quantifiable_colony_data['Label True'] == '[PSI+]')]
correct_white_colony_data_after = quantifiable_colony_data[(quantifiable_colony_data['Label After'] == '[PSI+]') & (quantifiable_colony_data['Label True'] == '[PSI+]')]
white_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[PSI+]']

correct_red_colony_data_before = quantifiable_colony_data[(quantifiable_colony_data['Label Before'] == '[psi-]') & (quantifiable_colony_data['Label True'] == '[psi-]')]
correct_red_colony_data_after = quantifiable_colony_data[(quantifiable_colony_data['Label After'] == '[psi-]') & (quantifiable_colony_data['Label True'] == '[psi-]')]
red_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == '[psi-]']

correct_counts_before.append(len(correct_white_colony_data_before))
correct_counts_after.append(len(correct_white_colony_data_after))

correct_counts_before.append(len(correct_red_colony_data_before))
correct_counts_after.append(len(correct_red_colony_data_after))

true_correct_counts.append(len(white_colony_data_true))
true_correct_counts.append(len(red_colony_data_true))

# Get the sectored colonies
# correct_sector_colony_data_before = quantifiable_colony_data[(quantifiable_colony_data['Label Before'].str.startswith('S')) & (quantifiable_colony_data['Label True'].str.startswith('S'))]
# correct_sector_colony_data_after = quantifiable_colony_data[(quantifiable_colony_data['Label After'].str.startswith('S')) & (quantifiable_colony_data['Label True'].str.startswith('S'))]
# sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'].str.startswith('S')]

# correct_counts_before.append(len(correct_sector_colony_data_before))
# correct_counts_after.append(len(correct_sector_colony_data_after))
# true_correct_counts.append(len(sector_colony_data_true))

max_sector_counts = np.max(quantifiable_colony_data[(quantifiable_colony_data['Label True'].str.startswith('S')) & (quantifiable_colony_data['True # Sectors'] > 0)]['True # Sectors'])
print(max_sector_counts)
for j in range(1, max_sector_counts+1):
    correct_sector_colony_data_before = quantifiable_colony_data[(quantifiable_colony_data['Label Before'] == 'S' + str(j)) & (quantifiable_colony_data['Label True'] == 'S' + str(j))]
    correct_sector_colony_data_after = quantifiable_colony_data[(quantifiable_colony_data['Label After'] == 'S' + str(j)) & (quantifiable_colony_data['Label True'] == 'S' + str(j))]
    sector_colony_data_true = quantifiable_colony_data[quantifiable_colony_data['Label True'] == 'S' + str(j)]
    
    correct_counts_before.append(len(correct_sector_colony_data_before))
    correct_counts_after.append(len(correct_sector_colony_data_after))
    true_correct_counts.append(len(sector_colony_data_true))

    plot_labels.append('S'+str(j))

x = np.arange(np.max([3, max_sector_counts+2]))
#x = np.arange(3)
width = 0.25

print(correct_counts_before)
print(correct_counts_after)
print(true_correct_counts)

formatted_plot_labels = copy.deepcopy(plot_labels)
formatted_plot_labels[0] = r'$[PSI^+]}$'
formatted_plot_labels[1] = r'$[psi^-]$'
#print(formatted_plot_labels[0])

fig, ax = plt.subplots(figsize=(15,5), sharey=True)

rects1 = ax.bar(x-width, correct_counts_before, width, color='blue', label='Original Prediction')
rects2 = ax.bar(x, correct_counts_after, width, color='red', label='With Purity Correction')
rects3 = ax.bar(x+width, true_correct_counts, width, label='Manual Counts', color='green')

ax.set_ylim(bottom=0, top=(1.1*max(correct_counts_before + correct_counts_after + true_correct_counts)))

ax.set_title('Correctly Classified Quantifiable Colonies (N=' + str(np.sum(true_correct_counts)) + ')',fontfamily="sans-serif", fontsize=18)
ax.set_xlabel('Colony States', fontsize=16)
#ax.set_xlabel('Colony States', fontsize=14)
#ax.set_xlabel('Colony States', fontsize=14)

ax.set_ylabel('Frequency', fontsize=16)
#ax[1].set_ylabel('Frequency', fontsize=14)
#ax[2].set_ylabel('Frequency', fontsize=14)

ax.legend(loc='best', fontsize=12)

ax.axvline(x = 0.5, color = 'k', linestyle = '--')
ax.axvline(x = 1.5, color = 'k', linestyle = '--')

ax.set_xticks(x)
ax.set_xticklabels(formatted_plot_labels)
ax.tick_params(axis="both", labelsize=14)

addlabels_centered(x-width, correct_counts_before, 11)
addlabels_centered(x, correct_counts_after, 11)
addlabels_centered(x+width, true_correct_counts, 11)



plt.show()

## Confusion Matrices

In [None]:
# Get labels
these_labels = ['[PSI+]', '[psi-]', 'S1' ,'S2', 'S3', 'S4', 'S5', 'S6' ,'S7', 'S8']

# Generate confusion matrices with counts
conf_mat_before = sklearn.metrics.confusion_matrix(quantifiable_colony_data['Label True'], quantifiable_colony_data['Label Before'], labels=these_labels)
conf_mat_after = sklearn.metrics.confusion_matrix(quantifiable_colony_data['Label True'], quantifiable_colony_data['Label After'], labels=these_labels)

conf_mat_before_max = np.max(conf_mat_before)
conf_mat_after_max = np.max(conf_mat_after)

conf_mat_max = np.max([conf_mat_before_max, conf_mat_after_max])

# Get confusion matrices normalized by row
norm_mat_before = np.zeros_like(conf_mat_before).astype(float)
num_rows = norm_mat_before.shape[0]
for i in range(0,num_rows):
    row_sum = np.sum(conf_mat_before[i,:]).astype(float)
    if row_sum != 0:
        norm_mat_before[i,:] = conf_mat_before[i,:].astype(float) / row_sum
    else:
        norm_mat_before[i,:] = 0

norm_mat_after = np.zeros_like(conf_mat_after).astype(float)
num_rows = norm_mat_after.shape[0]
for i in range(0,num_rows):
    row_sum = np.sum(conf_mat_after[i,:])
    if row_sum != 0:
        norm_mat_after[i,:] = conf_mat_after[i,:].astype(float) / row_sum
    else:
        norm_mat_after[i,:] = 0


# Crate matrix used for cell annotation
left_paren_mat_before = np.reshape(np.repeat(np.repeat('(', conf_mat_before.shape[1]), conf_mat_before.shape[0], axis=0), conf_mat_before.shape).astype(str)
right_paren_mat_before = np.reshape(np.repeat(np.repeat(')', conf_mat_before.shape[1]), conf_mat_before.shape[0], axis=0), conf_mat_before.shape).astype(str)
left_paren_mat_after = np.reshape(np.repeat(np.repeat('(', conf_mat_after.shape[1]), conf_mat_after.shape[0], axis=0), conf_mat_after.shape).astype(str)
right_paren_mat_after = np.reshape(np.repeat(np.repeat(')', conf_mat_after.shape[1]), conf_mat_after.shape[0], axis=0), conf_mat_after.shape).astype(str)
percent_mat_before = np.reshape(np.repeat(np.repeat('%', conf_mat_before.shape[1]), conf_mat_before.shape[0], axis=0), conf_mat_before.shape).astype(str)
percent_mat_after = np.reshape(np.repeat(np.repeat('%', conf_mat_after.shape[1]), conf_mat_after.shape[0], axis=0), conf_mat_after.shape).astype(str)

break_line_mat_before = np.reshape(np.repeat(np.repeat('\n', conf_mat_before.shape[1]), conf_mat_before.shape[0], axis=0), conf_mat_before.shape).astype(str)
break_line_mat_after = np.reshape(np.repeat(np.repeat('\n', conf_mat_after.shape[1]), conf_mat_after.shape[0], axis=0), conf_mat_after.shape).astype(str)

norm_mat_before_rounded = np.around(100*norm_mat_before, decimals=1)
norm_mat_after_rounded = np.around(100*norm_mat_after, decimals=1)
#print(percent_mat_after)

annotation_mat_before = np.char.add(conf_mat_before.astype(str), break_line_mat_before)
annotation_mat_before = np.char.add(annotation_mat_before, left_paren_mat_before)
annotation_mat_before = np.char.add(annotation_mat_before, norm_mat_before_rounded.astype(str))
annotation_mat_before = np.char.add(annotation_mat_before, percent_mat_before)
annotation_mat_before = np.char.add(annotation_mat_before, right_paren_mat_before)

annotation_mat_after = np.char.add(conf_mat_after.astype(str), break_line_mat_after)
annotation_mat_after = np.char.add(annotation_mat_after, left_paren_mat_after)
annotation_mat_after = np.char.add(annotation_mat_after, norm_mat_after_rounded.astype(str))
annotation_mat_after = np.char.add(annotation_mat_after, percent_mat_after)
annotation_mat_after = np.char.add(annotation_mat_after, right_paren_mat_after)

these_formatted_labels = copy.deepcopy(these_labels)
these_formatted_labels[0] = r'$[PSI^+]}$'
these_formatted_labels[1] = r'$[psi^-]$'


fig, ax = plt.subplots(1,2, figsize=(15,6))
ax[0] = sns.heatmap(conf_mat_before[:5, :5], ax=ax[0], annot=True, vmin=0, vmax=conf_mat_max, fmt='.0f', annot_kws={'fontsize':14})
ax[0].set_xlabel('Predicted States', fontsize=14)
ax[0].set_ylabel('True States', fontsize=14)
ax[0].set_title('Original Predictions', fontsize=16)
ax[0].set_xticklabels(these_formatted_labels[:5], fontsize=14)
ax[0].set_yticklabels(these_formatted_labels[:5], rotation=0, fontsize=14)
#plt.show()

ax[1] = sns.heatmap(conf_mat_after[:5, :5], ax=ax[1], annot=True, vmin=0, vmax=conf_mat_max, fmt='.0f', annot_kws={'fontsize':14})
ax[1].set_xlabel('Predicted States', fontsize=14)
ax[1].set_ylabel('True States', fontsize=14)
ax[1].set_title('With Purity Correction', fontsize=16)
ax[1].set_xticklabels(these_formatted_labels[:5], fontsize=14)
ax[1].set_yticklabels(these_formatted_labels[:5], rotation=0, fontsize=14)
plt.show()


# Normalized matrices plots

fig, ax = plt.subplots(1,2, figsize=(15,6))
# [:7, :]
# [:5, :5]
ax[0] = sns.heatmap(norm_mat_before[:5, :5], ax=ax[0], annot=True, fmt='.0%', vmin=0, vmax=1, annot_kws={'fontsize':14})
ax[0].set_xlabel('Predicted States', fontsize=14)
ax[0].set_ylabel('True States', fontsize=14)
ax[0].set_title('Original Predictions', fontsize=16)
ax[0].set_xticklabels(these_formatted_labels[:5], fontsize=14)
ax[0].set_yticklabels(these_formatted_labels[:5], rotation=0, fontsize=14)
cbar = ax[0].collections[0].colorbar
cbar.set_ticks([0, .2, .4, .6, .8, 1])
cbar.set_ticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
#plt.show()

# [:7, :5]
ax[1] = sns.heatmap(norm_mat_after[:5, :5], ax=ax[1], annot=True, fmt='.0%', vmin=0, vmax=1, annot_kws={'fontsize':14})
ax[1].set_xlabel('Predicted States', fontsize=14)
ax[1].set_ylabel('True States', fontsize=14)
ax[1].set_title('With Purity Correction', fontsize=16)
ax[1].set_xticklabels(these_formatted_labels[:5], fontsize=14)
ax[1].set_yticklabels(these_formatted_labels[:5], rotation=0, fontsize=14)
cbar = ax[0].collections[0].colorbar
cbar.set_ticks([0, .2, .4, .6, .8, 1])
cbar.set_ticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
plt.show()

# plot confusion matrices annotated with both numbers and percentages in each cell

fig, ax = plt.subplots(1,2, figsize=(15,6))
ax[0] = sns.heatmap(norm_mat_before[:5, :5], ax=ax[0], annot=annotation_mat_before[:5, :5], vmin=0, vmax=1, fmt='s', annot_kws={'fontsize':16})
ax[0].set_xlabel('Predicted States', fontsize=14)
ax[0].set_ylabel('True States', fontsize=14)
ax[0].set_title('Original Predictions', fontsize=16)
ax[0].set_xticklabels(these_formatted_labels[:5], fontsize=14)
ax[0].set_yticklabels(these_formatted_labels[:5], rotation=0, fontsize=14)
#plt.show()

ax[1] = sns.heatmap(norm_mat_after[:5, :5], ax=ax[1], annot=annotation_mat_after[:5, :5], vmin=0, vmax=1, fmt='s', annot_kws={'fontsize':16})
ax[1].set_xlabel('Predicted States', fontsize=14)
ax[1].set_ylabel('True States', fontsize=14)
ax[1].set_title('With Purity Correction', fontsize=16)
ax[1].set_xticklabels(these_formatted_labels[:5], fontsize=14)
ax[1].set_yticklabels(these_formatted_labels[:5], rotation=0, fontsize=14)
plt.show()

# Normalized matrices plots, color coded by proportion in each row, but annotated with total

fig, ax = plt.subplots(1,2, figsize=(16,6))
# [:7, :]
# [:5, :5]
ax[0] = sns.heatmap(norm_mat_before[:5, :5], ax=ax[0], annot=conf_mat_before[:5, :5], fmt='.0f', vmin=0, vmax=1, annot_kws={'fontsize':16})
ax[0].set_xlabel('Predicted States', fontsize=16)
ax[0].set_ylabel('True States', fontsize=16)
ax[0].set_title('Original Predictions', fontsize=18)
ax[0].set_xticklabels(these_formatted_labels[:5], fontsize=16)
ax[0].set_yticklabels(these_formatted_labels[:5], rotation=0, fontsize=16)
cbar = ax[0].collections[0].colorbar
cbar.set_ticks([0, .2, .4, .6, .8, 1])
cbar.set_ticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
#cbar.ax.set_ylabel('Proporition of Colonies by Label',  size = 14)
cbar.ax.tick_params(labelsize=14)
#plt.show()

# [:7, :5]
ax[1] = sns.heatmap(norm_mat_after[:5, :5], ax=ax[1], annot=conf_mat_after[:5, :5], fmt='.0f', vmin=0, vmax=1, annot_kws={'fontsize':16})
ax[1].set_xlabel('Predicted States', fontsize=16)
#ax[1].set_ylabel('True States', fontsize=16)
ax[1].set_title('With Purity Correction', fontsize=18)
ax[1].set_xticklabels(these_formatted_labels[:5], fontsize=16)
ax[1].set_yticklabels(these_formatted_labels[:5], rotation=0, fontsize=16)
cbar = ax[1].collections[0].colorbar
cbar.set_ticks([0, .2, .4, .6, .8, 1])
cbar.set_ticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
cbar.ax.set_ylabel('Proportion of Colonies by Manual Label',  size = 14)
cbar.ax.tick_params(labelsize=14)
plt.show()

# Purity Score Comparison

In [None]:
quantifiable_merged_table = merged_table[merged_table['Quantifiable'] == True].reset_index()

# Set matrix used for plotting the bars

agg_count_mat = np.zeros((5,4), dtype=int)
table_length = len(quantifiable_merged_table)

for this_row in range(0, table_length):

    this_label_before = quantifiable_merged_table.iloc[this_row]['Label Before']
    this_label_after = quantifiable_merged_table.iloc[this_row]['Label After']
    this_label_true = quantifiable_merged_table.iloc[this_row]['Label True']

    this_sector_count_before = quantifiable_merged_table.iloc[this_row]['Initial # Regions']
    this_sector_count_after = quantifiable_merged_table.iloc[this_row]['Pred # Sectors']
    this_sector_count_true = quantifiable_merged_table.iloc[this_row]['True # Sectors']

    if this_label_true == '[PSI+]':
        correct_label_before = this_label_before == this_label_true
        correct_label_after = this_label_after == this_label_true

        if correct_label_before & correct_label_after:
            agg_count_mat[0,0] += 1
        elif (not correct_label_before) & (not correct_label_after):
            agg_count_mat[0,1] += 1
        elif (not correct_label_before) & (correct_label_after):
            agg_count_mat[0,2] += 1
        elif (correct_label_before) & (not correct_label_after):
            agg_count_mat[0,3] += 1

    if this_label_true == '[psi-]':
        correct_label_before = this_label_before == this_label_true
        correct_label_after = this_label_after == this_label_true

        if correct_label_before & correct_label_after:
            agg_count_mat[1,0] += 1
        elif (not correct_label_before) & (not correct_label_after):
            agg_count_mat[1,1] += 1
        elif (not correct_label_before) & (correct_label_after):
            agg_count_mat[1,2] += 1
        elif (correct_label_before) & (not correct_label_after):
            agg_count_mat[1,3] += 1

    if this_label_true.startswith('S'):

        num_sectors_true = int(this_label_true[1:])
        correct_label_before = this_label_before == this_label_true
        correct_label_after = this_label_after == this_label_true

        if num_sectors_true == 1:

            if correct_label_before & correct_label_after:
                agg_count_mat[2,0] += 1
            elif (not correct_label_before) & (not correct_label_after):
                agg_count_mat[2,1] += 1
            elif (not correct_label_before) & (correct_label_after):
                agg_count_mat[2,2] += 1
            elif (correct_label_before) & (not correct_label_after):
                agg_count_mat[2,3] += 1

        elif num_sectors_true == 2:

            if correct_label_before & correct_label_after:
                agg_count_mat[3,0] += 1
            elif (not correct_label_before) & (not correct_label_after):
                agg_count_mat[3,1] += 1
            elif (not correct_label_before) & (correct_label_after):
                agg_count_mat[3,2] += 1
            elif (correct_label_before) & (not correct_label_after):
                agg_count_mat[3,3] += 1

        elif num_sectors_true > 2:

            if correct_label_before & correct_label_after:
                agg_count_mat[4,0] += 1
            elif (not correct_label_before) & (not correct_label_after):
                agg_count_mat[4,1] += 1
            elif (not correct_label_before) & (correct_label_after):
                agg_count_mat[4,2] += 1
            elif (correct_label_before) & (not correct_label_after):
                agg_count_mat[4,3] += 1


print(agg_count_mat)



# # Check accuracy of [PSI+] labels
# quantifiable_white

# initial_yes = quantifiable_merged_table['Label Before'] == quantifiable_merged_table['Label True']
# pred_yes = quantifiable_merged_table['Label After'] == quantifiable_merged_table['Label True']
# initial_no = quantifiable_merged_table['Label Before'] != quantifiable_merged_table['Label True']
# pred_no = quantifiable_merged_table['Label After'] != quantifiable_merged_table['Label True']

# # Check accuracy of state labels
# initial_yes = quantifiable_merged_table['Label Before'] == quantifiable_merged_table['Label True']
# pred_yes = quantifiable_merged_table['Label After'] == quantifiable_merged_table['Label True']
# initial_no = quantifiable_merged_table['Label Before'] != quantifiable_merged_table['Label True']
# pred_no = quantifiable_merged_table['Label After'] != quantifiable_merged_table['Label True']

# initial_yes_pred_yes = initial_yes & pred_yes
# initial_yes_pred_no = initial_yes & pred_no
# initial_no_pred_yes = initial_no & pred_yes
# initial_no_pred_no = initial_no & pred_no

# #print(np.sum(initial_yes_pred_yes))
# #print(np.sum(initial_yes_pred_no))
# #print(np.sum(initial_no_pred_yes))
# #print(np.sum(initial_no_pred_no))

# agg_counts = np.array([np.sum(initial_yes_pred_yes), np.sum(initial_no_pred_no), np.sum(initial_no_pred_yes), np.sum(initial_yes_pred_no)])

# # counts_white = np.array([np.sum(quantifiable_merged_table.loc[initial_yes_pred_yes]['Label True'] == '[PSI+]'), np.sum(quantifiable_merged_table.loc[initial_no_pred_no]['Label True'] == '[PSI+]'), np.sum(quantifiable_merged_table.loc[initial_no_pred_yes]['Label True'] == '[PSI+]'), np.sum(quantifiable_merged_table.loc[initial_yes_pred_no]['Label True'] == '[PSI+]')])
# # counts_red = np.array([np.sum(quantifiable_merged_table.loc[initial_yes_pred_yes]['Label True'] == '[psi-]'), np.sum(quantifiable_merged_table.loc[initial_no_pred_no]['Label True'] == '[psi-]'), np.sum(quantifiable_merged_table.loc[initial_no_pred_yes]['Label True'] == '[psi-]'), np.sum(quantifiable_merged_table.loc[initial_yes_pred_no]['Label True'] == '[psi-]'])
# # counts_sectored_1 = np.array([np.sum(quantifiable_merged_table.loc[initial_yes_pred_yes]['True # Sectors'] == 2), np.sum(quantifiable_merged_table.loc[initial_no_pred_no]['True # Sectors'] == 2), np.sum(quantifiable_merged_table.loc[initial_no_pred_yes]['True # Sectors'] == 2), np.sum(quantifiable_merged_table.loc[initial_yes_pred_no]['True # Sectors'] == 2)])
# # counts_sectored_2 = np.array([np.sum(quantifiable_merged_table.loc[initial_yes_pred_yes]['True # Sectors'] >= 3), np.sum(quantifiable_merged_table.loc[initial_no_pred_no]['True # Sectors'] == 3), np.sum(quantifiable_merged_table.loc[initial_no_pred_yes]['True # Sectors'] >= 3), np.sum(quantifiable_merged_table.loc[initial_yes_pred_no]['True # Sectors'] >= 3)])

# print(np.sum(quantifiable_merged_table.loc[initial_yes_pred_yes]['True # Sectors'] == 0))

category_list = []
category_list.append('Remained\nCorrect')
category_list.append('Remained\nIncorrect')
category_list.append('Became\nCorrect')
category_list.append('Became\nIncorrect')

# true_single_frequency = list([len(true_cured_colonies), len(true_sectored_colonies)])

# pred_single_frequency_before = list([np.sum(single_cured_colonies_before), np.sum(single_sectored_colonies_before)])
# pred_single_frequency_after = list([np.sum(single_cured_colonies_after), np.sum(single_sectored_colonies_after)])

x = np.arange(len(category_list))  # the label locations
width = 0.25  # the width of the bars

#fig, ax = plt.subplots(figsize=(5,5))
#x - width/2
#ax[1].set_ylim(bottom=0, top=max(true_single_frequency + pred_single_frequency_before + pred_single_frequency_after)+50)

fig,ax = plt.subplots(1,2, figsize=(13,5))
#rects1 = ax.bar(x, agg_counts, width, color='blue')
ax[0].bar(x, agg_count_mat[0,:], width, label=r'$[PSI^+]}$')
ax[0].bar(x, agg_count_mat[1,:], width, bottom=np.sum(agg_count_mat[0:1,:], axis=0), label=r'$[psi^-]}$')
ax[0].bar(x, agg_count_mat[2,:], width, bottom=np.sum(agg_count_mat[0:2,:], axis=0), label='1 Sector')
ax[0].bar(x, agg_count_mat[3,:], width, bottom=np.sum(agg_count_mat[0:3,:], axis=0), label='2 Sectors')
ax[0].bar(x, agg_count_mat[4,:], width, bottom=np.sum(agg_count_mat[0:4,:], axis=0), label='3+ Sectors')

ax[0].set_xlabel('Prediction After Purity Correction')
ax[0].set_ylabel('Frequency')
ax[0].set_title('Influence of Purity Correction on\nPredicted Labels (N='+ str(table_length) + ')', fontsize=16)
ax[0].xaxis.label.set_fontsize(14)
ax[0].yaxis.label.set_fontsize(14)
ax[0].set_xticks(np.arange(0, 4, step=1))
ax[0].set_xticklabels(category_list, fontfamily="serif", fontsize=12)
#ax[1].set_xlim([-0.5, 2])
#ax[1].set_ylim([0,500])
ax[0].tick_params(axis='both', labelsize=12)

addlabels_prediction_ax(x, np.sum(agg_count_mat, axis=0), 9, 0)
ax[0].legend()
#rects2 = ax[1].bar(x, pred_single_frequency_after, width, label='Correct w Purity', color='red')
#rects3 = ax[1].bar(x + width, true_single_frequency, width, label='True Counts', color='green')

# Do the same for the weighted purity scores
quantitfiable_white = quantifiable_merged_table[quantifiable_merged_table['Label True'] == '[PSI+]']
quantitfiable_red = quantifiable_merged_table[quantifiable_merged_table['Label True'] == '[psi-]']
quantitfiable_sectored_1 = quantifiable_merged_table[quantifiable_merged_table['Label True'] == 'S1']
quantitfiable_sectored_2 = quantifiable_merged_table[quantifiable_merged_table['Label True'] == 'S2']
quantitfiable_sectored_3 = quantifiable_merged_table[(quantifiable_merged_table['Label True'].str.startswith('S')) & (quantifiable_merged_table['True # Sectors'] > 2)]
# quantifiable_merged_table_0 = quantifiable_merged_table[quantifiable_merged_table['True # Sectors'] == 0]
# quantifiable_merged_table_1 = quantifiable_merged_table[quantifiable_merged_table['True # Sectors'] == 1]
# quantifiable_merged_table_2 = quantifiable_merged_table[quantifiable_merged_table['True # Sectors'] == 2]
# quantifiable_merged_table_3 = quantifiable_merged_table[quantifiable_merged_table['True # Sectors'] >= 3]

#fig,ax = plt.subplots()
ax[1].plot([0, 1], [0,1], color='black')
ax[1].plot([0.5, 0.5], [0,0.5], color='black')
ax[1].set_axisbelow(True)
ax[1].scatter(quantitfiable_white['(BC) Weighted Full Average Score'], quantitfiable_white['(AC) Weighted Full Average Score'], label=r'$[PSI^+]}$', alpha=0.5)
ax[1].scatter(quantitfiable_red['(BC) Weighted Full Average Score'], quantitfiable_red['(AC) Weighted Full Average Score'], label=r'$[psi^-]}$', alpha=0.5)
ax[1].scatter(quantitfiable_sectored_1['(BC) Weighted Full Average Score'], quantitfiable_sectored_1['(AC) Weighted Full Average Score'], label='1 Sector', alpha=0.5)
ax[1].scatter(quantitfiable_sectored_2['(BC) Weighted Full Average Score'], quantitfiable_sectored_2['(AC) Weighted Full Average Score'], label='2 Sectors', alpha=0.5)
ax[1].scatter(quantitfiable_sectored_3['(BC) Weighted Full Average Score'], quantitfiable_sectored_3['(AC) Weighted Full Average Score'], label='3+ Sectors', alpha=0.5)

ax[1].set_title('Weighted Colony Purity\nBefore and After Correction', fontsize=16)
ax[1].set_xlim(left=0, right=1)
ax[1].set_ylim(bottom=0, top=1)
ax[1].set_xlabel('Purity Before')
ax[1].set_ylabel('Purity After')
ax[1].xaxis.label.set_fontsize(14)
ax[1].yaxis.label.set_fontsize(14)
ax[1].tick_params(axis='both', labelsize=12)
ax[1].legend(loc='best')
plt.show()

# Precision and Recall of Classified Colonies

In [None]:
# 1Precision is TP/(TP+FP), while recall is TP/(TP+FN).
# True positives (TP) are colonies which were predicted to have a given label and were manually annotated with that label.
# False positives (FP) are colonies which were predicted to have a given label but were manually annotated with a different label.
# False negatives (FN) are colonies which were predicted to have any label besides a given label, but were manually annotated with that given label.
# True negatives (TN) are coloines which were predicted to have any label besides a given label, and has a manual annotation different from that same given label.

# Precision:
#   TP / (TP + FP)
#   Quantifiable colonies / (Quantifiable colonies + Detected Non-quanitfiable colonies)
#   What proportion of colonies detected were quantifiable?

# Recall:
#   TP / (TP + FN)
#   Quantifiable colonies / (Quantifiable colonies + Undetected quantifiable colonies)
#   What proportion of quantifiable colonies were detected?

# Get true/false postives/negatives for [PSI+]
true_positive_white = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] =='[PSI+]') & (quantifiable_merged_table['Label True'] == '[PSI+]')])
true_negative_white = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] != '[PSI+]') & (quantifiable_merged_table['Label True'] != '[PSI+]')])
false_positive_white = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] =='[PSI+]') & (quantifiable_merged_table['Label True'] != '[PSI+]')])
false_negative_white = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] != '[PSI+]') & (quantifiable_merged_table['Label True'] == '[PSI+]')])

precision_white = true_positive_white / (true_positive_white + false_positive_white)
recall_white = true_positive_white / (true_positive_white + false_negative_white)
f1_white = 2*(precision_white*recall_white) / (precision_white + recall_white)

print(precision_white)
print(recall_white)
print(f1_white)

# Get the true positives labeled [psi-]
true_positive_red = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] =='[psi-]') & (quantifiable_merged_table['Label True'] == '[psi-]')])
true_negative_red = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] != '[psi-]') & (quantifiable_merged_table['Label True'] != '[psi-]')])
false_positive_red = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] =='[psi-]') & (quantifiable_merged_table['Label True'] != '[psi-]')])
false_negative_red = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] != '[psi-]') & (quantifiable_merged_table['Label True'] == '[psi-]')])

precision_red = true_positive_red / (true_positive_red + false_positive_red)
recall_red = true_positive_red / (true_positive_red + false_negative_red)
specificity_red = true_negative_red / (true_negative_red + false_positive_red)
f1_red = 2*(precision_red*recall_red) / (precision_red + recall_red)

print(precision_red)
print(recall_red)
print(f1_red)

# Get the true positives labeled sectored, with any number of sectors
true_positive_sectored = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'].str.startswith('S')) & (quantifiable_merged_table['Label True'].str.startswith('S'))])
true_negative_sectored = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'].str.startswith('[')) & (quantifiable_merged_table['Label True'].str.startswith('['))])
false_positive_sectored = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'].str.startswith('S')) & (quantifiable_merged_table['Label True'].str.startswith('['))])
false_negative_sectored = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'].str.startswith('[')) & (quantifiable_merged_table['Label True'].str.startswith('S'))])

precision_sectored = true_positive_sectored / (true_positive_sectored + false_positive_sectored)
recall_sectored = true_positive_sectored / (true_positive_sectored + false_negative_sectored)
f1_sectored = 2*(precision_sectored*recall_sectored) / (precision_sectored + recall_sectored)

print(precision_sectored)
print(recall_sectored)
print(f1_sectored)

# Get the true postiives labeled sectored with one sector only
true_positive_sectored_1 = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] == 'S1') & (quantifiable_merged_table['Label True'] == 'S1')])
true_negative_sectored_1 = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] != 'S1') & (quantifiable_merged_table['Label True'] != 'S1')])
false_positive_sectored_1 = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] == 'S1') & (quantifiable_merged_table['Label True'] != 'S1')])
false_negative_sectored_1 = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] != 'S1') & (quantifiable_merged_table['Label True'] == 'S1')])

precision_sectored_1 = true_positive_sectored_1 / (true_positive_sectored_1 + false_positive_sectored_1)
recall_sectored_1 = true_positive_sectored_1 / (true_positive_sectored_1 + false_negative_sectored_1)
f1_sectored_1 = 2*(precision_sectored_1*recall_sectored_1) / (precision_sectored_1 + recall_sectored_1)

print(precision_sectored_1)
print(recall_sectored_1)
print(f1_sectored_1)

# Get the true postiives labeled sectored with two sectors only
true_positive_sectored_2 = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] == 'S2') & (quantifiable_merged_table['Label True'] == 'S2')])
true_negative_sectored_2 = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] != 'S2') & (quantifiable_merged_table['Label True'] != 'S2')])
false_positive_sectored_2 = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] == 'S2') & (quantifiable_merged_table['Label True'] != 'S2')])
false_negative_sectored_2 = len(quantifiable_merged_table[(quantifiable_merged_table['Label After'] != 'S2') & (quantifiable_merged_table['Label True'] == 'S2')])

precision_sectored_2 = true_positive_sectored_2 / (true_positive_sectored_2 + false_positive_sectored_2)
recall_sectored_2 = true_positive_sectored_2 / (true_positive_sectored_2 + false_negative_sectored_2)
f1_sectored_2 = 2*(precision_sectored_2*recall_sectored_2) / (precision_sectored_2 + recall_sectored_2)

print(precision_sectored_2)
print(recall_sectored_2)
print(f1_sectored_2)

#((quantifiable_merged_table['Label After'].str.startswith('[')) | (quantifiable_merged_table['Label After'].str.startswith('S') & quantifiable_merged_table['Pred # Sectors'] <= 2))

# Get the true postiives labeled sectored with three or more sectors
true_positive_sectored_3 = len(quantifiable_merged_table[((quantifiable_merged_table['Label After'].str.startswith('S')) & (quantifiable_merged_table['Pred # Sectors'] > 2)) & ((quantifiable_merged_table['Label True'].str.startswith('S')) & (quantifiable_merged_table['Pred # Sectors'] > 2))])
true_negative_sectored_3 = len(quantifiable_merged_table[((quantifiable_merged_table['Label After'].str.startswith('[')) | (quantifiable_merged_table['Label After'].str.startswith('S') & quantifiable_merged_table['Pred # Sectors'] <= 2)) & ((quantifiable_merged_table['Label True'].str.startswith('[')) | (quantifiable_merged_table['Label True'].str.startswith('S') & quantifiable_merged_table['True # Sectors'] <= 2))])
false_positive_sectored_3 = len(quantifiable_merged_table[((quantifiable_merged_table['Label After'].str.startswith('S')) & (quantifiable_merged_table['Pred # Sectors'] > 2)) & ((quantifiable_merged_table['Label True'].str.startswith('[')) | (quantifiable_merged_table['Label True'].str.startswith('S') & quantifiable_merged_table['True # Sectors'] <= 2))])
false_negative_sectored_3 = len(quantifiable_merged_table[((quantifiable_merged_table['Label After'].str.startswith('[')) | (quantifiable_merged_table['Label After'].str.startswith('S') & quantifiable_merged_table['Pred # Sectors'] <= 2)) & ((quantifiable_merged_table['Label True'].str.startswith('S')) & (quantifiable_merged_table['Pred # Sectors'] > 2))])

precision_sectored_3 = true_positive_sectored_3 / (true_positive_sectored_3 + false_positive_sectored_3)
recall_sectored_3 = true_positive_sectored_3 / (true_positive_sectored_3 + false_negative_sectored_3)
f1_sectored_3 = 2*(precision_sectored_3*recall_sectored_3) / (precision_sectored_3 + recall_sectored_3)

print(precision_sectored_3)
print(recall_sectored_3)
print(f1_sectored_3)
