In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn import decomposition
from sklearn.manifold import TSNE

In [None]:
df = pd.read_csv('data/embed-non-negative.csv')
df['image_id'] = [img_path.split('/')[-1] for img_path in df.image_path.values]

In [None]:
density = False
pretrained = False
percentage = 1.0

In [None]:
output_base = 'images'
output_type1 = 'density' if density else 'malignancy'
output_type2 = 'pretrained' if pretrained else 'supervised'
output_type3 = str(percentage)
output_dir = os.path.join(output_base, output_type1, output_type2, output_type3)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [None]:
num_features = 512
prd_path = f"../classification/\
{'density' if density else 'malignancy'}_{'pretrained' if pretrained else 'supervised'}/output/resnet18_{percentage}"
df_prd = pd.read_csv(os.path.join(prd_path, 'predictions.csv'))
df_emb = pd.read_csv(os.path.join(prd_path, 'embeddings.csv'))

In [None]:
df = pd.merge(pd.concat([df_emb, df_prd], axis=1), df, how='inner', on=['image_id'])

In [None]:
df['is_screen'] = False
df.loc[df.desc.str.contains("screen", case=False), 'is_screen'] = True

In [None]:
df.head()

In [None]:
embeddings = np.array(df.iloc[:,0:num_features])

In [None]:
pca = decomposition.PCA(n_components=0.95, whiten=False)
embeddings_pca = pca.fit_transform(embeddings)

print(embeddings_pca.shape)

In [None]:
df['PCA 1'] = embeddings_pca[:,0]
df['PCA 2'] = embeddings_pca[:,1]

In [None]:
tsne = TSNE(n_components=2, learning_rate='auto')
embeddings_tsne = tsne.fit_transform(embeddings_pca)

print(embeddings_tsne.shape)

In [None]:
df['t-SNE 1'] = embeddings_tsne[:,0]
df['t-SNE 2'] = embeddings_tsne[:,1]

In [None]:
df = df.sample(frac=1.0)

In [None]:
alpha = 0.6
style = '.'
markersize = 20
color_palette = 'tab10'
kind = 'scatter'

In [None]:
# x = 'PCA 1'
# y = 'PCA 2'
x = 't-SNE 1'
y = 't-SNE 2'

In [None]:
def plot_scatter(data, hue, x, y, palette):
    hue_order = list(data[hue].unique())
    hue_order.sort()
    sns.set_theme(style="white")
    ax = sns.scatterplot(data=data, x=x, y=y, hue=hue, hue_order=hue_order, alpha=alpha, marker=style, s=markersize, palette=palette)
    sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

def plot_joint(data, hue, x, y, palette):
    hue_order = list(data[hue].unique())
    hue_order.sort()
    sns.set_theme(style="white")
    ax = sns.jointplot(data=data, x=x, y=y, hue=hue, hue_order=hue_order, alpha=alpha, marker=style, s=markersize, palette=palette, marginal_kws={'common_norm': False})
    sns.move_legend(ax.ax_joint, "upper left", bbox_to_anchor=(1.2, 1))
    ax.fig.savefig(f"{output_dir}/{hue}")

In [None]:
# attribute = 'is_screen'

# print(df[attribute].value_counts(normalize=False))
# print('')
# print(df[attribute].value_counts(normalize=True))

# plot_joint(df, attribute, x, y, color_palette)

In [None]:
attribute = 'ViewPosition'

print(df[attribute].value_counts(normalize=False))
print('')
print(df[attribute].value_counts(normalize=True))

plot_joint(df, attribute, x, y, color_palette)

In [None]:
attribute = 'asses'

print(df[attribute].value_counts(normalize=False))
print('')
print(df[attribute].value_counts(normalize=True))

plot_joint(df, attribute, x, y, color_palette)

In [None]:
attribute = 'race'

print(df[attribute].value_counts(normalize=False))
print('')
print(df[attribute].value_counts(normalize=True))

plot_joint(df, attribute, x, y, color_palette)

In [None]:
attribute = 'density'

print(df[attribute].value_counts(normalize=False))
print('')
print(df[attribute].value_counts(normalize=True))

plot_joint(df, attribute, x, y, color_palette)

In [None]:
attribute = 'is_positive'

print(df[attribute].value_counts(normalize=False))
print('')
print(df[attribute].value_counts(normalize=True))

plot_joint(df, attribute, x, y, color_palette)

In [None]:
attribute = 'class_1'

plot_scatter(df, attribute, x, y, 'magma')

## Interactive model inspection

In [None]:
import cv2
import matplotlib as mpl
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.express as px
from skimage.io import imread
from skimage.util import img_as_ubyte
from skimage.transform import resize
from matplotlib import cm
from ipywidgets import Output, HBox

data_dir = '/data2/EMBED/1024x768'

In [None]:
def rgb_to_hex(rgb):
    return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])

color = cm.tab10(np.linspace(0, 1, 10))
colorlist = [(np.array(mpl.colors.to_rgb(c))*255).astype(int).tolist() for c in color]*10

colors = [rgb_to_hex(colorlist[c]) for c in df.is_positive.values]

In [None]:
def preprocess(image, horizontal_flip=False):

    # breast mask
    image_norm = image - np.min(image)
    image_norm = image_norm / np.max(image_norm)
    thresh = cv2.threshold(img_as_ubyte(image_norm), 5, 255, cv2.THRESH_BINARY)[1]

    # Connected components with stats.
    nb_components, output, stats, _ = cv2.connectedComponentsWithStats(thresh, connectivity=4)

    # Find the largest non background component.
    # Note: range() starts from 1 since 0 is the background label.
    max_label, _ = max(
        [(i, stats[i, cv2.CC_STAT_AREA]) for i in range(1, nb_components)],
        key=lambda x: x[1],
    )
    mask = output == max_label
    image_masked = image.copy()
    image_masked[mask == 0] = 0

    if horizontal_flip:
        image_masked = image_masked[:, ::-1].copy()
        
    return image_masked

In [None]:
out = Output()
@out.capture(clear_output=True)
def handle_click(trace, points, state):
    sample = df.iloc[points.point_inds[0]]
    img_orig = imread(os.path.join(data_dir, sample.image_path))
    img_proc = preprocess(img_orig)
    
    s = [8] * len(df)
    for i in points.point_inds:
        s[i] = 16
    with fig.batch_update():
        scatter.marker.size = s

    f, (ax1, ax2) = plt.subplots(1,2, figsize=(8,8))
    ax1.imshow(img_orig, cmap='gray')
    ax1.set_title('original')
    ax1.axis('off')
    ax2.imshow(img_proc, cmap='gray')
    ax2.set_title('processed')
    ax2.axis('off')
    plt.show(f)
    
fig = go.FigureWidget(px.scatter(df, x=x, y=y, template='simple_white', hover_data={'ManufacturerModelName': True, x:False, y:False}))
fig.update_layout(width=600, height=600)
scatter = fig.data[0]
scatter.on_click(handle_click)
scatter.marker.size = [8] * len(df)
scatter.marker.color = colors

HBox([fig, out])