In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn.manifold import TSNE
import os
import pickle
import random 

In [None]:
with open('/Users/christoph/Downloads/mcp_sne.pkl', 'rb') as f:
            data = pickle.load(f)
labels = np.loadtxt('/Users/christoph/Downloads/mcp_targets.txt', delimiter='\n')

assert len(data) == len(labels)

In [None]:
cols = [f'emb_{i}' for i in range(len(data[0]))]
df = pd.DataFrame(data, columns=cols)
df['y'] = labels
df['label'] = df['y'].apply(lambda i: str(i))

In [None]:
data_tsne = df[cols].values
tsne = TSNE(random_state=42, n_components=2, verbose=1, perplexity=40, n_iter=300).fit_transform(data_tsne)

In [None]:
df['tsne-2d-one'] = tsne[:,0]
df['tsne-2d-two'] = tsne[:,1]


plt.figure(figsize=(16,10))
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 2),
    data=df,
    legend="full",
    alpha=0.3
)

In [None]:
with open('/Users/christoph/Downloads/mcp_sne_cifar10.pkl', 'rb') as f:
            data = pickle.load(f)
labels = np.loadtxt('/Users/christoph/Downloads/mcp_targets_cifar10.txt', delimiter='\n')

assert len(data) == len(labels)

In [None]:
cols2 = [f'emb_{i}' for i in range(len(data[0]))]
df2 = pd.DataFrame(data, columns=cols2)
df2['y'] = labels
df2['label'] = df2['y'].apply(lambda i: str(i))

In [None]:
data_tsne2 = df2[cols2].values
tsne2 = TSNE(random_state=42, n_components=2, verbose=1, perplexity=40, n_iter=300).fit_transform(data_tsne2)

In [None]:
df2['tsne-2d-one'] = tsne2[:,0]
df2['tsne-2d-two'] = tsne2[:,1]


plt.figure(figsize=(16,10))
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 10),
    data=df2,
    legend="full",
    alpha=0.3
)

In [None]:
plt.figure(figsize=(16,7))
ax1 = plt.subplot(1, 2, 1)
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 2),
    data=df,
    legend="full",
    alpha=0.3,
    ax=ax1
)

ax2 = plt.subplot(1, 2, 2)
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 10),
    data=df2,
    legend="full",
    alpha=0.3,
    ax = ax2
)

In [None]:
# OOD class 
with open('/Users/christoph/Downloads/mcp_sne.pkl', 'rb') as f:
            data = pickle.load(f)
labels = np.loadtxt('/Users/christoph/Downloads/mcp_targets.txt', delimiter='\n')

assert len(data) == len(labels)

with open('/Users/christoph/Downloads/mcp_sne_ood.pkl', 'rb') as f:
            data_ood = pickle.load(f)
labels_ood = [2.0 for _ in range(len(data_ood))]

assert len(data_ood) == len(labels_ood)

In [None]:
cols = [f'emb_{i}' for i in range(len(data[0]))]
df = pd.DataFrame(data, columns=cols)
df['y'] = labels

df_ood = pd.DataFrame(data_ood, columns=cols)
df_ood['y'] = labels_ood

df = df.append(df_ood)
df['label'] = df['y'].apply(lambda i: str(i))

In [None]:
data_tsne = df[cols].values
tsne = TSNE(random_state=42, n_components=2, verbose=1, perplexity=40, n_iter=400).fit_transform(data_tsne)

In [None]:
df['tsne-2d-one'] = tsne[:,0]
df['tsne-2d-two'] = tsne[:,1]


plt.figure(figsize=(16,10))
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 3),
    data=df,
    legend="full",
    alpha=0.3
).set_title('MCP Baseline CheXpert')

In [None]:
# ODIN
with open('/Users/christoph/Downloads/odin_sne.pkl', 'rb') as f:
            data_odin = pickle.load(f)
labels_odin = np.loadtxt('/Users/christoph/Downloads/mcp_targets.txt', delimiter='\n')

assert len(data_odin) == len(labels_odin)

with open('/Users/christoph/Downloads/odin_sne_ood.pkl', 'rb') as f:
            data_ood_odin = pickle.load(f)
labels_ood_odin = [2.0 for _ in range(len(data_ood_odin))]

assert len(data_ood_odin) == len(labels_ood_odin)

cols_odin = [f'emb_{i}' for i in range(len(data_odin[0]))]
df_odin = pd.DataFrame(data_odin, columns=cols_odin)
df_odin['y'] = labels_odin

df_ood_odin = pd.DataFrame(data_ood_odin, columns=cols_odin)
df_ood_odin['y'] = labels_ood_odin

df_odin = df_odin.append(df_ood_odin)
df_odin['label'] = df_odin['y'].apply(lambda i: str(i))

In [None]:
data_tsne_odin = df_odin[cols_odin].values
tsne_odin = TSNE(random_state=42, n_components=2, verbose=1, perplexity=40, n_iter=400).fit_transform(data_tsne_odin)

In [None]:
df_odin['tsne-2d-one'] = tsne_odin[:,0]
df_odin['tsne-2d-two'] = tsne_odin[:,1]


plt.figure(figsize=(16,10))
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 3),
    data=df_odin,
    legend="full",
    alpha=0.3
).set_title('ODIN CheXpert')

In [None]:
plt.figure(figsize=(16,7))
ax1 = plt.subplot(1, 2, 1)
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 3),
    data=df,
    legend="full",
    alpha=0.3,
    ax=ax1
).set_title('MCP CheXpert vs OOD')

ax2 = plt.subplot(1, 2, 2)
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 3),
    data=df_odin,
    legend="full",
    alpha=0.3,
    ax=ax2
).set_title('ODIN CheXpert vs OOD')

In [None]:

# reusable function
def tsne_method(method='mcp', data='cifar10', subset=False):
    with open(f'/Users/christophberger/Downloads/{method}_sne_{data}.pkl', 'rb') as f:
                data_odin = pickle.load(f)
    labels_odin = np.loadtxt(f'/Users/christophberger/Downloads/mcp_targets_{data}.txt', delimiter='\n')

    assert len(data_odin) == len(labels_odin)

    with open(f'/Users/christophberger/Downloads/{method}_sne_ood_{data}.pkl', 'rb') as f:
                data_ood_odin = pickle.load(f)
    if data == 'cifar10':
        x = 10.0
    else:
        x = 2.0
    if subset:
        data_ood_odin = random.sample(data_ood_odin, 1000)
    labels_ood_odin = [x for _ in range(len(data_ood_odin))]

    assert len(data_ood_odin) == len(labels_ood_odin)

    cols_odin = [f'emb_{i}' for i in range(len(data_odin[0]))]
    df_odin = pd.DataFrame(data_odin, columns=cols_odin)
    df_odin['y'] = labels_odin

    df_ood_odin = pd.DataFrame(data_ood_odin, columns=cols_odin)
    df_ood_odin['y'] = labels_ood_odin

    df_odin = df_odin.append(df_ood_odin)
    df_odin['label'] = df_odin['y'].apply(lambda i: str(i))
    data_tsne_odin = df_odin[cols_odin].values
    tsne = TSNE(random_state=42, n_components=2, verbose=1, perplexity=40, n_iter=400).fit_transform(data_tsne_odin)
    
    df_odin['tsne-2d-one'] = tsne[:,0]
    df_odin['tsne-2d-two'] = tsne[:,1]
    return df_odin


In [None]:
df1 = tsne_method(subset=True)
df2 = tsne_method('odin', subset=True)

In [None]:
plt.figure(figsize=(16,7))
ax1 = plt.subplot(1, 2, 1)

cifar_mapping = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck', 'SVHN (OOD)']

sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 11),
    data=df1,
    legend="full",
    alpha=0.3,
    ax=ax1
).set_title('MCP CIFAR-10 vs OOD')
handles, labels = ax1.get_legend_handles_labels()
ax1.legend(handles,cifar_mapping)

ax2 = plt.subplot(1, 2, 2)
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 11),
    data=df2,
    legend="full",
    alpha=0.3,
    ax=ax2
).set_title('ODIN CIFAR-10 vs OOD')
plt.savefig('cifar10_tsne.pdf')
handles, labels = ax2.get_legend_handles_labels()
ax2.legend(handles,cifar_mapping)

plt.savefig('cifar_tsne.pdf')

In [None]:
df1 = tsne_method(data='chexpert')
df2 = tsne_method('odin', data='chexpert')
#df3 = tsne_method(subset=True)
#df4 = tsne_method('odin', subset=True)

In [None]:
marker_size = 20


plt.figure(figsize=(16,7))
# ax1 = plt.subplot(1, 2, 1, sharey=True)
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(16, 8))
plot1 = sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    s=marker_size,
    palette=sns.color_palette("hls", 3),
    data=df1,
    legend="full",
    alpha=0.3,
    ax=ax1
).set_title('Baseline CheXpert vs OOD')
handles, labels = ax1.get_legend_handles_labels()
ax1.legend(handles,['Cardiomegaly', 'Pneumothorax', 'Fracture (OOD)'])
#ax1.legend(['Cardiomegaly', 'Pneumothorax', 'OOD (Fracture)'], loc='upper right')

sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    s=marker_size,
    palette=sns.color_palette("hls", 3),
    data=df2,
    legend="full",
    alpha=0.3,
    ax=ax2
).set_title('ODIN CheXpert vs OOD')

handles, labels = ax2.get_legend_handles_labels()
ax2.legend(handles,['Cardiomegaly', 'Pneumothorax', 'Fracture (OOD)'])

plt.savefig('chexpert_tsne.pdf')

In [None]:
marker_size = 20

# ax1 = plt.subplot(1, 2, 1, sharey=True)
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=1, ncols=4, sharex=True, sharey=True, figsize=(24, 6))
fig.subplots_adjust(wspace=0.025, hspace=0.05)
plot1 = sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    s=marker_size,
    palette=sns.color_palette("hls", 3),
    data=df1,
    legend=False,
    alpha=0.3,
    ax=ax3
).set_title('Baseline CheXpert', fontsize=20)
#handles, labels = ax1.get_legend_handles_labels()
#ax1.legend(handles,['Cardiomegaly', 'Pneumothorax', 'Fracture (OOD)'])
#ax1.legend(['Cardiomegaly', 'Pneumothorax', 'OOD (Fracture)'], loc='upper right')

sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    s=marker_size,
    palette=sns.color_palette("hls", 3),
    data=df2,
    legend=False,
    alpha=0.3,
    ax=ax4
).set_title('ODIN CheXpert', fontsize=20)

#handles, labels = ax2.get_legend_handles_labels()
#ax2.legend(handles,['Cardiomegaly', 'Pneumothorax', 'Fracture (OOD)'])

sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 11),
    data=df3,
    legend=False,
    alpha=0.3,
    ax=ax1
).set_title('Baseline CIFAR-10', fontsize=20)

sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    palette=sns.color_palette("hls", 11),
    data=df4,
    legend=False,
    alpha=0.3,
    ax=ax2
).set_title('ODIN CIFAR-10', fontsize=20)

for ax in [ax1, ax2, ax3, ax4]:
    ax.set_ylabel('')    
    ax.set_xlabel('')
    
plt.rcParams.update({'font.size': 18})

plt.savefig('chexpert_tsne.pdf')