In [1]:
# Code to generate Figures 4f and 4g

In [None]:
import os
import plotly.express as px
import matplotlib as mpl
import matplotlib.pyplot as plt
from hsi_detect.spectrum import Spectrum
from hsi_detect.image import HyperspectralImage
from hsi_detect.classifier import HierarchicalKMeansUnmixer
from hsi_detect.utils import *
from datetime import date

today = date.today()
date_str = today.strftime('%d%b%Y')
print ('Date prefix:', date_str)

# Plotting parameters
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['lines.linewidth'] = 0.5
mpl.rcParams['axes.linewidth']= 0.5
mpl.rcParams['xtick.major.width'] = 0.5
mpl.rcParams['xtick.minor.width'] = 0.5
mpl.rcParams['ytick.major.width'] = 0.5
mpl.rcParams['ytick.minor.width'] = 0.5
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 7

In [None]:
IMAGE_PATH = '/Users/Itai/Desktop/172/results/REFLECTANCE_172.hdr'
# CONC_MAP_PATH = 'plots_concentrations.csv'
REFERENCE_SPECTRUM_PATH = '../../from_box/Grad/research/bioHSI/04_image_processing/00_data/absorbance_data/YF10_infered_absorbance_from_pellets_09Jul2024.npy'

## DEFAULT SAVE PATH
savedir = '/'.join(IMAGE_PATH.split('/')[:-1])+f'{IMAGE_PATH.split("/")[-1].split(".hdr")[0]}_outputs_from_analysis/'

print (savedir)
if not os.path.isdir(savedir):
    os.mkdir(savedir)
    print ('Made directory:', savedir)

In [None]:
# Load and visualize image
hsi_img = HyperspectralImage(IMAGE_PATH, smoothing_window=11)
hsi_img.show(dpi=300, savepath=savedir+f'{date_str}_reconstructed_RGB.png')

# Load and visualize the spectrum of the HSR
reference_spectrum = Spectrum('../../from_box/Grad/research/bioHSI/04_image_processing/00_data/absorbance_data/YF10_infered_absorbance_from_pellets_09Jul2024.npy')
reference_spectrum.interpolate_spectrum(hsi_img.centers) #Interpolate spectrum to fit the HSI
reference_spectrum.show()

In [None]:
hsi_classifier = HierarchicalKMeansUnmixer(filter_threshold=0.85, normalize=True)
hsi_classifier.fit(hsi_img, reference_spectrum)
scored_img = hsi_classifier.classify(reference_spectrum)

In [None]:
# The clustering results can be inspected
hsi_classifier.visualize_clusters()
hsi_classifier.visualize_endmembers()

In [None]:
# Visualize classified image
plt.figure(dpi=500)
plt.imshow(np.rot90(scored_img, k=-1), vmin=0, vmax=0.1, cmap='inferno')
plt.xticks([])
plt.yticks([])
plt.box(False)
plt.show()

### Define areas where HSR is present for analysis

In [26]:
import json
if os.path.exists(f'{savedir}/manually_defined_rectangle_coordinates.json'):
    with open(f'{savedir}/manually_defined_rectangle_coordinates.json', 'r') as f:
        coords = json.load(f)
        rectangle_tls = coords['TL']
        rectangle_brs = coords['BR']
    print ('Loaded!')


In [None]:
import numpy as np

RUNNING_LOCALLY = True # Set true if running on your own machine

### Interactive code to define the coordinates of circles of interest

# Style variables
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
styles = {
    'pre': {
        'border': 'thin lightgrey solid',
        'borderRadius': '5px',
        'padding': '10px',
        'width': '100%',
        'color': '#333',
        'backgroundColor': '#f8f9fa'
    },
    'container': {
        'margin': '20px',
        'fontFamily': 'Arial, sans-serif'
    },
    'title': {
        'textAlign': 'center',
        'color': '#2c3e50',
        'marginBottom': '20px'
    },
    'panel': {
        'width': '30%', 
        'display': 'inline-block', 
        'verticalAlign': 'top', 
        'overflowY': 'auto', 
        'maxHeight': '500px',
        'padding': '15px',
        'backgroundColor': '#ffffff',
        'boxShadow': '0 4px 6px rgba(0, 0, 0, 0.1)',
        'borderRadius': '8px',
        'margin': '0 10px'
    },
    'graph': {
        'width': '65%', 
        'display': 'inline-block',
        'boxShadow': '0 4px 6px rgba(0, 0, 0, 0.1)',
        'borderRadius': '8px',
        'backgroundColor': '#ffffff'
    },
    'header': {
        'color': '#3498db',
        'borderBottom': '2px solid #3498db',
        'paddingBottom': '5px',
        'marginBottom': '10px'
    },
    'instruction': {
        'backgroundColor': '#e8f4f8',
        'padding': '10px',
        'borderRadius': '5px',
        'marginBottom': '15px',
        'fontSize': '14px'
    }
}
cmap = plt.get_cmap('tab20')

if not os.path.exists(f'{savedir}/manually_defined_circle_coordinates.json'): # Check if coordinates already exist
    from dash import Dash, dcc, html, Input, Output, callback
    #initialize
    circle_centers = []
    circle_radii = []
    count = 0
    selected = []
    im_id = 0
    ls = []
    rgb = hsi_img.rgb.astype(np.uint8)
    window_to_show = [[0, np.max(rgb.shape[0])], [0, np.max(rgb.shape[1])]]  # default: no cropping
    
    fig = px.imshow((rgb[window_to_show[0][0]:window_to_show[0][1], 
                        window_to_show[1][0]:window_to_show[1][1]]).astype(dtype=np.uint8))
    
    # Improve figure layout
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
        plot_bgcolor='white',
        paper_bgcolor='white',
        dragmode='pan',
        hovermode='closest'
    )

    app = Dash(__name__, external_stylesheets=external_stylesheets)
    app.layout = html.Div([
        html.H2("Circle Definition Tool", style=styles['title']),
        
        html.Div([
            html.Div([
                html.Div([
                    html.P("Instructions:", style={'fontWeight': 'bold'}),
                    html.P("1. Click once to set circle center"),
                    html.P("2. Click again to set circle radius"),
                    html.P("3. Repeat for additional circles")
                ], style=styles['instruction'])
            ]),
            dcc.Graph(
                id='img', 
                figure=fig,
                config={'scrollZoom': True, 'displayModeBar': True}
            )
        ], style=styles['graph']),
        
        html.Div([
            html.H4("Circle Centers", style=styles['header']),
            html.Div(id='center', style=styles['pre'])
        ], style=styles['panel']),
        
        html.Div([
            html.H4("Circle Radii", style=styles['header']),
            html.Div(id='radius', style=styles['pre'])
        ], style=styles['panel']),
        
    ], style=styles['container'])

    @callback(
        Output('center', 'children'),
        Output('radius', 'children'),
        Input('img', 'clickData'),
        Input('img', 'figure')
    )
    def onclick_x(clickData, fig):
        global circle_centers, circle_radii, count, im_id, l    
        if clickData is not None:
            x = int(clickData['points'][0]['x'])
            y = int(clickData['points'][0]['y'])
            adj_x = x + window_to_show[1][0] 
            adj_y = y + window_to_show[0][0]
            if count % 2 == 0:
                circle_centers.append([adj_y, adj_x])
                selected.append([adj_y, adj_x])
                count += 1
            elif count % 2 == 1:
                # Calculate radius as distance from center to this point
                center_y, center_x = circle_centers[-1]
                radius = np.sqrt((adj_y - center_y)**2 + (adj_x - center_x)**2)
                circle_radii.append(radius)
                selected.append([adj_y, adj_x])
                count += 1
        else:
            pass

        # Format the outputs as lists of div elements with improved styling
        center_results = [
            html.Div(f"Circle {idx+1}: x:{c[1]}, y:{c[0]}", 
                    style={'padding': '5px', 'borderBottom': '1px solid #eee'}) 
            for idx, c in enumerate(circle_centers)
        ]
        
        radius_results = [
            html.Div(f"Circle {idx+1}: {r:.2f} pixels", 
                    style={'padding': '5px', 'borderBottom': '1px solid #eee'}) 
            for idx, r in enumerate(circle_radii)
        ]

        return center_results, radius_results
        
    if __name__ == '__main__':
        if RUNNING_LOCALLY:
            app.run(port=np.random.choice(range(8300, 8900)))
        else:
            ip = '10.63.0.87' # Set to your local IP
            app.run(ip, port=np.random.choice(range(8300, 8900)), debug=True)
else:
    print ('Not overwriting existing saved coordinates. To reset, delete', f'{savedir}/manually_defined_circle_coordinates.json')

In [71]:
if not os.path.exists(f'{savedir}/manually_defined_circle_coordinates.json'):
    with open(f'{savedir}/manually_defined_circle_coordinates.json', 'w') as f:
        json.dump({'circle_radii':circle_radii,'circle_centers':circle_centers},f)

In [75]:
import pandas as pd

def apply_function_to_circles (img, centers, radii, fxn):
    returned = []
    for c, r in zip(centers, radii):
        mask = mask_ellipse(np.ones_like(np.zeros(img.shape[:2])), c, r, r)
        mask[mask==0] = np.nan
        if len(mask.shape)<len(img.shape):
            selected = img*mask[:,:,np.newaxis]
        else:
            selected = img*mask
        returned.append(fxn(selected))
    return returned

img = hsi_img.image
masks = make_circle_mask(img.shape, [np.flip(c) for c in circle_centers], circle_radii)


spec_fxn = lambda x : np.nanmean(x, axis=(0,1))
mean_spec_ls = apply_function_to_circles(img/np.nanmax(img, axis=2, keepdims=True), circle_centers, circle_radii, spec_fxn)
mean_spec_ls = apply_function_to_circles(img, circle_centers, circle_radii, spec_fxn)


max_abs_idx = np.argmax(reference_spectrum.intensities)
peak_fxn = lambda x : np.mean(x[:, :, max_abs_idx])
normed_peak_of_interest = apply_function_to_circles(img, circle_centers, circle_radii, peak_fxn)
unnormed_peak_of_interest = apply_function_to_circles(img, circle_centers, circle_radii, peak_fxn)


mean_fxn = lambda x : np.nanmean(x)
# mean_fxn = lambda x : np.nanmean( np.max(np.stack([x, np.zeros_like(x)], axis=1)) )
mean_scores = apply_function_to_circles(scored_img, circle_centers, circle_radii, mean_fxn)


sum_fxn = lambda x : np.nansum(x)
sum_scores = apply_function_to_circles(scored_img, circle_centers, circle_radii, sum_fxn)

# conc_map = pd.read_csv(CONC_MAP_PATH, header=None)

In [None]:
plt.imshow(rgb)
plt.imshow(masks, alpha=0.5)

In [None]:
for i in range(len(mean_scores)):
    plt.plot(hsi_img.centers, mean_spec_ls[i])

In [None]:
plt.figure(dpi=600)
plt.imshow(img[:,:,max_abs_idx], cmap='inferno_r')
plt.xticks([])
plt.yticks([])
plt.box()
plt.show()

In [None]:
s = np.array(mean_scores) #/ cfu  * area 

print ("Saved to:", f'{savedir}/{date_str}_col1_insitu_RG_classified_hill_plot.pdf')

concs_transformed = conc_map.loc[:,0].values

plot_fit_curve(concs_transformed,
                   s,
                   hill_eqn, 
                   [0.5e-2,1e3],
                   [-0.,0.4],
                   f'{savedir}/{date_str}_col1_insitu_RG_classified_hill_plot.pdf',
                   ymax_bound=np.inf,#np.max(s)*3,
                   figsize=(1.1,1.1), markersize=3,
                   x_offset=2e-6, 
                   ylabel = 'Classification score',
                   fit_lower_bounds=[-1,-1,0,0],
                  fit_upper_bounds=[np.inf,0.27,np.inf,2.8],
                  ignore_for_fit = [0,1]
                  )

In [32]:
unnormed_img = img
unnormed_peak_of_interest = apply_function_to_rectangles (unnormed_img, zip(rectangle_tls, rectangle_brs), peak_fxn)


In [None]:
s = 1-np.array(unnormed_peak_of_interest)
 
s = s - np.nanmin(s)
print (s)
savename = f'{savedir}/{date_str}_col1_insitu_RG_classified_hill_plot_unnormed_reflectance_at-{hsi_img.centers[max_abs_idx]}.pdf'
print ("Saved to:", f'{savename}')
plot_fit_curve(conc_map.loc[:,0].values,
                   s,
                   hill_eqn, 
                   [0.5e-2,1e3],
#                    [-np.nanmax(s)*0.01, np.nanmax(s)],
                   [-np.nanmax(s)*0.05,np.nanmax(s)*1.1],
                   savename,
                    line_color = 'black',
                    marker_color='black', edgecolor=None,
                 figsize=(1.1,1.1), markersize=3,x_offset=1e-6, 
                  )

In [None]:
pos_mask_idxs = conc_map.index[conc_map[0]>=5]


mask = make_rectangle_mask(img.shape[:2], zip(np.array(rectangle_tls)[pos_mask_idxs], np.array(rectangle_brs)[pos_mask_idxs]))

plt.figure(dpi=400)
plt.imshow(mask)
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
f, ax = plt.subplots(1, dpi=400)

ax.imshow(hsi_img.rgb/255)

for tl, br in zip(np.array(rectangle_tls)[pos_mask_idxs], np.array(rectangle_brs)[pos_mask_idxs]):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1], tl[0]), w, h, fill=False, color='white', linewidth=0.4, alpha=0.5)
    ax.add_patch(square)

    
for tl, br in zip(np.array(rectangle_tls[6:9]), np.array(rectangle_brs[6:9])):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1], tl[0]), w, h, fill=False, color='white', linewidth=0.4, alpha=0.5)
    ax.add_patch(square)
    
plt.xticks([])
plt.yticks([])
plt.box(False)
plt.savefig(f'{savedir}/{date_str}_high-flight_rgb_img_with_rectangles_marked.pdf', transparent=True)

In [None]:
plt.figure(dpi=600)
# unnormed_img =  unnormed_img[unnormed_peak_of_interest] / unnormed_img[unnormed_peak_of_interest+10]

unnormed_img[unnormed_img==0] = np.nan
# plt.imshow(1-unnormed_img[:,:,max_abs_idx], cmap='inferno', vmin=0)
plt.imshow(unnormed_img[:,:,max_abs_idx+10] / unnormed_img[:,:,max_abs_idx])
plt.xticks([])
plt.yticks([])
plt.box(False)
plt.colorbar()
# plt.savefig(f'{savedir}/unnormalized_image_at-{hsi_img.centers[max_abs_idx]}.png', dpi=400)
plt.show()

In [None]:
f, ax = plt.subplots(1, dpi=400)

crop = [[700,920], [250,350]]
ax.imshow(hsi_img.rgb[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1]].astype(int))

for tl, br in zip(np.array(rectangle_tls)[pos_mask_idxs], np.array(rectangle_brs)[pos_mask_idxs]):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    print (tl, w,h)
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1]-crop[1][0], tl[0]-crop[0][0]), w, h, fill=False, color='white', linewidth=0.5)
    ax.add_patch(square)

    
for tl, br in zip(np.array(rectangle_tls[5:9]), np.array(rectangle_brs[5:9])):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    print (tl, w,h)
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1]-crop[1][0], tl[0]-crop[0][0]), w, h, fill=False, color='white', linewidth=0.5)
    ax.add_patch(square)
    
plt.xticks([])
plt.yticks([])
plt.box(False)
plt.savefig(f'{savedir}/{date_str}_MAFAT_high_flight_rgb_img_with_rectangles_marked_cropped.pdf', dpi=300)

In [27]:
#label as "positive" the pixels within the rectangles of 10^1 and above? 

image_y = flatten_array(mask)

In [None]:
# import numpy as np
from sklearn.metrics import roc_curve, auc, precision_recall_curve

y_scores = flatten_array(scored_img.copy())
y_true = image_y.copy()
y_abs = 1-flatten_array(unnormed_img[:,:,np.argmax(reference_spectrum.intensities)])
# y_abs = 1 - flatten_array(unnormed_img[:,:,np.argmax(reference_spectrum.intensities)]/ unnormed_img[:,:,np.argmax(reference_spectrum.intensities)+10])

y_true = y_true[~np.isnan(y_scores)]
y_abs = y_abs[~np.isnan(y_scores)]
y_scores = y_scores[~np.isnan(y_scores)]



# Compute ROC curve and ROC area
fpr, tpr, thresholds = roc_curve(y_true, y_scores)
class_roc_auc = auc(fpr, tpr)


abs_fpr, abs_tpr, abs_thresholds = roc_curve(y_true, y_abs)
abs_roc_auc =  auc(abs_fpr, abs_tpr)

# Plotting the ROC curve
plt.figure(figsize=(2,2))
plt.plot(fpr, tpr, color='navy', lw=1, label='ROC curve (area = %0.2f)' % class_roc_auc)
print ('class', class_roc_auc)
plt.plot(abs_fpr, abs_tpr, color='darkgray', lw=1, label='ROC curve (area = %0.2f)' % abs_roc_auc)
print ('abs', abs_roc_auc)
# plt.plot([0, 1], [0, 1], color='black', lw=1, linestyle='--')
plt.xlim([-0.02, 1.02])
plt.ylim([-0.02, 1.02])
plt.xticks([0,0.5,1])
plt.yticks([0,0.5,1])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
# plt.legend(loc="lower right")
plt.minorticks_on()
plt.savefig(f'{savedir}/{date_str}_classification_ROC_curve.pdf')
plt.show()

In [None]:
# import numpy as np

y_scores = flatten_array(scored_img.copy())
y_true = image_y.copy()
y_abs = 1-flatten_array(unnormed_img[:,:,np.argmax(reference_spectrum.intensities)])
# y_abs = 1 - flatten_array(unnormed_img[:,:,np.argmax(reference_spectrum.intensities)]/ unnormed_img[:,:,np.argmax(reference_spectrum.intensities)+10])

y_true = y_true[~np.isnan(y_scores)]
y_abs = y_abs[~np.isnan(y_scores)]
y_scores = y_scores[~np.isnan(y_scores)]



# Compute ROC curve and ROC area
p, r, thresholds = precision_recall_curve(y_true, y_scores)
class_prc_auc = auc(p, r)


abs_p, abs_r, abs_thresholds = precision_recall_curve(y_true, y_abs)
abs_prc_auc =  auc(abs_p, abs_r)

# Plotting the ROC curve
plt.figure(figsize=(2,2))
plt.plot(p, r, color='navy', lw=1, label='PRC curve (area = %0.2f)' % class_prc_auc)
print ('class', class_prc_auc)
plt.plot(abs_p, abs_r, color='darkgray', lw=1, label='ROC curve (area = %0.2f)' % abs_prc_auc)
print ('abs', abs_prc_auc)
# plt.plot([0, 1], [0, 1], color='black', lw=1, linestyle='--')
plt.xlim([-0.02, 1.02])
plt.ylim([-0.02, 1.02])
plt.xticks([0,0.5,1])
plt.yticks([0,0.5,1])
plt.xlabel('Precision')
plt.ylabel('Recall')
# plt.legend(loc="lower right")
plt.minorticks_on()

In [43]:
if not os.path.isdir(f'{savedir}/roc_supp/'):
    os.mkdir(f'{savedir}/roc_supp/')
    print (f'Made directory: {savedir}/roc_supp/')
    
from matplotlib.colors import ListedColormap
cmap = ListedColormap(["whitesmoke", "crimson"])

figsize = (img.shape[1]/300, img.shape[0]/300)

In [44]:
crop2 = crop

crop2 = [[crop2[0][0],
          crop2[0][1]-100], 
          [crop2[1][0]+20,
           crop2[1][1]-45]
        ]

In [None]:
f, ax = plt.subplots(1, dpi=400)


# crop = [[650,900], [250,350]]
ax.imshow(hsi_img.rgb[crop2[0][0]:crop2[0][1], crop2[1][0]:crop2[1][1]])

for tl, br in zip(np.array(rectangle_tls)[pos_mask_idxs], np.array(rectangle_brs)[pos_mask_idxs]):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    print (tl, w,h)
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1]-crop2[1][0], tl[0]-crop2[0][0]), w, h, fill=False, color='white', linewidth=0.5)
    ax.add_patch(square)

    
for tl, br in zip(np.array(rectangle_tls[5:9]), np.array(rectangle_brs[5:9])):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    print (tl, w,h)
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1]-crop2[1][0], tl[0]-crop2[0][0]), w, h, fill=False, color='white', linewidth=0.5)
    ax.add_patch(square)
    
plt.xticks([])
plt.yticks([])
plt.box(False)
plt.savefig(f'{savedir}/{date_str}_MAFAT_high_flight_rgb_img_with_rectangles_marked_cropped_smaller.pdf', dpi=300)

In [None]:
f, ax = plt.subplots(1, dpi=400)


# crop = [[650,900], [250,350]]
ax.imshow(scored_img[crop2[0][0]:crop2[0][1], crop2[1][0]:crop2[1][1]], vmax=0.4, cmap='inferno')

for tl, br in zip(np.array(rectangle_tls)[pos_mask_idxs], np.array(rectangle_brs)[pos_mask_idxs]):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    print (tl, w,h)
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1]-crop2[1][0], tl[0]-crop2[0][0]), w, h, fill=False, color='white', linewidth=0.5)
    ax.add_patch(square)

    
for tl, br in zip(np.array(rectangle_tls[5:9]), np.array(rectangle_brs[5:9])):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    print (tl, w,h)
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1]-crop2[1][0], tl[0]-crop2[0][0]), w, h, fill=False, color='white', linewidth=0.5)
    ax.add_patch(square)
    
plt.xticks([])
plt.yticks([])
plt.box(False)
plt.savefig(f'{savedir}/{date_str}_MAFAT_high_flight_scored_img_with_rectangles_marked_cropped_smaller.pdf', dpi=300)

In [None]:
f, ax = plt.subplots(1, dpi=400)


# crop = [[650,900], [250,350]]
ax.imshow(1-unnormed_img[crop2[0][0]:crop2[0][1], crop2[1][0]:crop2[1][1], np.argmax(reference_spectrum.intensities)], vmin=0, vmax=1, cmap='inferno')

for tl, br in zip(np.array(rectangle_tls)[pos_mask_idxs], np.array(rectangle_brs)[pos_mask_idxs]):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    print (tl, w,h)
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1]-crop2[1][0], tl[0]-crop2[0][0]), w, h, fill=False, color='white', linewidth=0.5)
    ax.add_patch(square)

    
for tl, br in zip(np.array(rectangle_tls[5:9]), np.array(rectangle_brs[5:9])):
    
    w = br[1] - tl[1]
    h = br[0] - tl[0]
    print (tl, w,h)
    # Draw white square centered at (x, y)
    square = plt.Rectangle((tl[1]-crop2[1][0], tl[0]-crop2[0][0]), w, h, fill=False, color='white', linewidth=0.5)
    ax.add_patch(square)
    
plt.xticks([])
plt.yticks([])
plt.box(False)
plt.savefig(f'{savedir}/{date_str}_high_flight_maxabs_img_with_rectangles_marked_cropped_smaller.pdf', dpi=300)

In [None]:
cmap = ListedColormap(["gray", "crimson"])

def binarize (arr, thresh):
    arr = arr.copy()
    arr[arr<thresh] = 0
    arr[arr>thresh] = 1 
    return arr.astype(bool)


figsize = (img.shape[1]/300, img.shape[0]/300)

target_tpr_ls = [0.25,0.5, 0.75, 0.99] #np.linspace(0,1,21)
text_for_fig = ''

nanmask = img[:,:,max_abs_idx].copy()
nanmask[~np.isnan(nanmask)] = 1


for t in target_tpr_ls:
    thresh_idx = np.argmin(np.abs(abs_tpr-t))
    abs_set_thresh = abs_thresholds[thresh_idx]
    savename = f'{savedir}/roc_supp/'+'target_tpr_abs_thresh_img_fpr-{:.3f}_tpr-{:.3f}_thresh-{:.2f}.pdf'.format(abs_fpr[thresh_idx], abs_tpr[thresh_idx], abs_set_thresh)
    print (f'fpr: {abs_fpr[thresh_idx]}, tpr: {abs_tpr[thresh_idx]}, thresh: {abs_set_thresh}')

    
    plt.figure(figsize=figsize, dpi=300)
    plt.imshow(binarize(1-unnormed_img[:,:,max_abs_idx], abs_set_thresh).astype(float)*nanmask, cmap=cmap, vmin=0, vmax=1, 
               aspect='equal', interpolation='none')
    
    
        
    plt.xticks([])
    plt.yticks([])
    plt.box(False)
    plt.tight_layout()

    w = crop2[1][1]-crop2[1][0]
    h = crop2[0][1]-crop2[0][0]
    square = plt.Rectangle((crop2[1][0], crop2[0][0]), w, h, fill=False, color='white', linewidth=0.6, alpha=1)
    plt.gca().add_patch(square)

    
    print (savename)
    plt.tight_layout()

    plt.savefig(savename, dpi=300)
    plt.show()
    
    text_for_fig += '{:.3f}\n{:.3f}\n{:.3f}\n\n'.format(abs_tpr[thresh_idx], abs_fpr[thresh_idx], abs_set_thresh)
    
for t in target_tpr_ls:
    thresh_idx = np.argmin(np.abs(tpr-t))
    score_set_thresh = thresholds[thresh_idx]
    savename = f'{savedir}/roc_supp/'+'target_tpr_scored_thresh_img_fpr-{:.3f}_tpr-{:.3f}_thresh-{:.2f}.pdf'.format(fpr[thresh_idx], tpr[thresh_idx], score_set_thresh)
    
    print (f'fpr: {fpr[thresh_idx]}, tpr: {tpr[thresh_idx]}, thresh: {score_set_thresh}')
    
    
    plt.figure(figsize=figsize, dpi=300)
    
    plt.imshow(binarize(scored_img, score_set_thresh).astype(float)*nanmask, cmap=cmap, vmin=0, vmax=1, 
               aspect='equal', interpolation='none')
    
    plt.xticks([])
    plt.yticks([])
    plt.box(False)
    plt.tight_layout()
    
    w = crop2[1][1]-crop2[1][0]
    h = crop2[0][1]-crop2[0][0]

    square = plt.Rectangle((crop2[1][0], crop2[0][0]), w, h, fill=False, color='white', linewidth=0.6, alpha=1)
    plt.gca().add_patch(square)
    
    print (savename)
    plt.tight_layout()

    plt.savefig(savename, dpi=300)
    
    plt.show()
    
    text_for_fig += '{:.3f}\n{:.3f}\n{:.3f}\n\n'.format(tpr[thresh_idx], fpr[thresh_idx], score_set_thresh)

In [90]:
html = make_juxtaposed_html(hsi_img.rgb.astype(np.uint8), (scored_img*255).astype(np.uint8), height=100)

with open(f'{savedir}/juxtaposed_rgb-classified_images.html', 'w') as f:
  f.write(html)