In [39]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
!pwd

/home/walml/repos/zoobot


In [41]:
!git pull

Already up-to-date.


In [42]:
import matplotlib

In [43]:
import os
import logging
import argparse
import glob

import numpy as np
from matplotlib.ticker import StrMethodFormatter

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn import metrics
import tensorflow as tf
import pandas as pd
from astropy.table import Table  # for NSA
from astropy import units as u
from sklearn.metrics import confusion_matrix, roc_curve
from PIL import Image
from scipy.stats import binom
from IPython.display import display, Markdown

from shared_astro_utils import astropy_utils, matching_utils
from zoobot.estimators import make_predictions, bayesian_estimator_funcs
from zoobot.tfrecord import read_tfrecord
from zoobot.uncertainty import discrete_coverage
from zoobot.estimators import input_utils, losses
from zoobot.tfrecord import catalog_to_tfrecord
from zoobot.active_learning import metrics, simulated_metrics, acquisition_utils, check_uncertainty, simulation_timeline, default_estimator_params


In [44]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)


In [45]:
os.chdir('/home/walml/repos/zoobot')

In [77]:
start = 50
end = 60

### Load the (latest) model under `model_name` folder in `results_dir`

In [78]:
from collections import Counter

In [79]:
# catalog_loc = 'data/latest_labelled_catalog.csv
catalog_loc = 'data/decals/decals_master_catalog.csv'
catalog = pd.read_csv(catalog_loc, dtype={'subject_id': str})  # original catalog
# catalog = catalog[:5000]  
catalog['file_loc'] = catalog['local_png_loc'].apply(lambda x: '/media/walml/beta/decals/png_native' + x[32:])

model_name = 'latest_offline'

In [80]:


# Figures will be saved to here

analysis_dir = 'analysis/multiquestion'
save_dir = f'{analysis_dir}/{model_name}'
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

results_dir = 'results'

label_cols = [
    'smooth-or-featured_smooth',
    'smooth-or-featured_featured-or-disk',
    'has-spiral-arms_yes',
    'has-spiral-arms_no',
    'spiral-winding_tight',
    'spiral-winding_medium',
    'spiral-winding_loose',
    'bar_strong',
    'bar_weak',
    'bar_no',
    'bulge-size_dominant',
    'bulge-size_large',
    'bulge-size_moderate',
    'bulge-size_small',
    'bulge-size_none'
]

questions = [
    'smooth-or-featured',
    'has-spiral-arms',
    'spiral-winding',
    'bar',
    'bulge-size'
]


batch_size = 32
initial_size = 256
final_size = 64
channels = 3

n_samples = 5

# if loading single test tfrecord
# tfrecord_locs = [f'data/decals/shards/multilabel_{img_size}/eval/s{initial_size}_shard_0.tfrecord']
# tfrecord_locs = ['data/decals/shards/multilabel_master_256/train/s256_shard_0.tfrecord']
tfrecord_locs = glob.glob(f'/media/walml/beta/decals/multilabel_master_{initial_size}/train/*.tfrecord')[start:end]

# tfrecord_locs = ['data/decals/shards/multilabel_all_temp/train/s128_shard_0.tfrecord']
print(tfrecord_locs)
eval_config = default_estimator_params.get_eval_config(tfrecord_locs, label_cols, batch_size, initial_size, final_size, channels)
dataset = input_utils.get_input(config=eval_config)

feature_spec = input_utils.get_feature_spec({'id_str': 'string'})
id_str_dataset = input_utils.get_dataset(tfrecord_locs, feature_spec, batch_size=1, shuffle=False, repeat=False)
id_strs = [str(d['id_str'].numpy().squeeze())[2:-1] for d in id_str_dataset]
id_strs[:5]

# n = 0
# for batch in id_str_dataset:
#     for id_str in batch:
#         n+=1
# print(n)

# counter = Counter()
# n = 0
# for g_batch, y_batch in dataset:
# #     for g in g_batch:
# #         counter[g.numpy().sum()] += 1
#         n+=tf.shape(g_batch)[0]
# print(n)



['/media/walml/beta/decals/multilabel_master_256/train/s256_shard_50.tfrecord']
loading filenames: <TensorSliceDataset shapes: (), types: tf.string>
loading filenames: <TensorSliceDataset shapes: (), types: tf.string>


['J120546.54+055750.8',
 'J134127.12+105353.5',
 'J124108.42+135738.7',
 'J103232.48+143246.7',
 'J212033.56+105801.3']

In [81]:
len(id_strs)

2415

In [82]:
# # if loading png

# print(catalog['file_loc'])
# assert all([os.path.isfile(x) for x in catalog['file_loc']])
# filenames = tf.constant(list(catalog['file_loc']), dtype=tf.string)
# dataset = tf.data.Dataset.from_tensor_slices(filenames)

# def parse_image(im):
#     im = tf.image.decode_png(im, channels=channels)
#     im = tf.image.convert_image_dtype(im, tf.float32)
#     im = tf.image.resize(im, [initial_size, initial_size])
#     return im


# # for im in dataset.take(1):
# #     print(im)

# # assert False

# dataset = dataset.map(tf.io.read_file)
# dataset = dataset.map(parse_image)

# config = default_estimator_params.get_eval_config(['do not use'], label_cols, batch_size, initial_size, final_size, channels)

# dataset = dataset.batch(batch_size)

# # for batch in dataset.take(1):
# #     print(tf.shape(batch))
# # #     plt.imshow(batch[0])


# # dataset = dataset.map(lambda x: check_shape(x))

# dataset = dataset.map(lambda x: input_utils.preprocess_images(x, config))


# # for batch in dataset.take(1):
# #     print(tf.shape(batch))
# #     plt.imshow(batch[0].numpy().squeeze())


# id_strs = catalog['iauname']



In [83]:
model = default_estimator_params.get_model(label_cols, questions, final_size)

checkpoint_dir = f'{results_dir}/{model_name}/results/models'
model.load_weights(checkpoint_dir)



{smooth-or-featured, indices 0 to 1, asked after None: (0, 1), has-spiral-arms, indices 2 to 3, asked after <zoobot.estimators.losses.Answer object at 0x7f4fb01f5e10>: (2, 3), spiral-winding, indices 4 to 6, asked after <zoobot.estimators.losses.Answer object at 0x7f4fa5b8efd0>: (4, 6), bar, indices 7 to 9, asked after <zoobot.estimators.losses.Answer object at 0x7f4fb01f5e10>: (7, 9), bulge-size, indices 10 to 14, asked after <zoobot.estimators.losses.Answer object at 0x7f4fb01f5e10>: (10, 14)}
Name: smooth-or-featured, start 0, end 1
Name: has-spiral-arms, start 2, end 3
Name: spiral-winding, start 4, end 6
Name: bar, start 7, end 9
Name: bulge-size, start 10, end 14


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f4fb0e7b9d0>

In [84]:
# %timeit predictions = model.predict(dataset)

In [85]:
dataset

<MapDataset shapes: ((None, 64, 64, 1), (None, 15)), types: (tf.float32, tf.float32)>

In [86]:
# model.predict(dataset).shape

In [87]:
predictions = np.stack([model.predict(dataset) for n in range(n_samples)], axis=-1)

[(0, 1), (2, 3), (4, 6), (7, 9), (10, 14)]
0 1
2 3
4 6
7 9
10 14


In [88]:
predictions.shape

(2415, 15, 5)

In [89]:
# for batch_x, batch_y in dataset:
#     print(batch_y.numpy())
#     break

In [90]:
# labels = np.concatenate([batch_y for (_, batch_y) in test_dataset], axis=0)
# labels.shape

In [91]:
labels = catalog[label_cols].values

In [92]:
# fig, ax = plt.subplots()
# ax.hist(predictions[:, 0], alpha=0.5, label='Predictions', density=True)
# ax.hist(labels[:, 0] / labels[:, :2].sum(axis=1), alpha=0.5, label='Labels', density=True)
# plt.legend()

In [93]:
# print(predictions[:, 0].min(), labels[:, 0].min())
# print(predictions[:, 0].max(), labels[:, 0].max())

In [94]:
schema = losses.Schema(label_cols, questions)
schema.questions

{smooth-or-featured, indices 0 to 1, asked after None: (0, 1), has-spiral-arms, indices 2 to 3, asked after <zoobot.estimators.losses.Answer object at 0x7f4fc49d7ed0>: (2, 3), spiral-winding, indices 4 to 6, asked after <zoobot.estimators.losses.Answer object at 0x7f4fc49d7e50>: (4, 6), bar, indices 7 to 9, asked after <zoobot.estimators.losses.Answer object at 0x7f4fc49d7ed0>: (7, 9), bulge-size, indices 10 to 14, asked after <zoobot.estimators.losses.Answer object at 0x7f4fc49d7ed0>: (10, 14)}


[smooth-or-featured, indices 0 to 1, asked after None,
 has-spiral-arms, indices 2 to 3, asked after <zoobot.estimators.losses.Answer object at 0x7f4fc49d7ed0>,
 spiral-winding, indices 4 to 6, asked after <zoobot.estimators.losses.Answer object at 0x7f4fc49d7e50>,
 bar, indices 7 to 9, asked after <zoobot.estimators.losses.Answer object at 0x7f4fc49d7ed0>,
 bulge-size, indices 10 to 14, asked after <zoobot.estimators.losses.Answer object at 0x7f4fc49d7ed0>]

In [95]:
# acquisitions = acquisition_utils.mutual_info_acquisition_func_multiq(predictions, schema, retirement=40)
# acquisitions.shape
# acquisitions

# single_q_acquisitions = np.array(acquisition_utils.mutual_info_acquisition_func(predictions[:, 0], expected_votes=40))
# single_q_acquisitions[:5], acquisitions[0, :5], acquisitions[1, :5]  # smooth mutual acq should be identical, for both answers by symmmetry
# acquisitions[2, :5], acquisitions[3, :5]  # has-spiral-arms also by symmetry
# acquisitions[4, :5], acquisitions[5, :5], acquisitions[6, :5]  # but for spiral winding there are 3 answers so *not* identical
# predictions.shape, acquisitions.shape, len(id_strs)

In [96]:
predictions.shape, len(id_strs)

((2415, 15, 5), 2415)

In [97]:
def prediction_to_row(prediction, id_str):
    row = {
        'id_str': id_str
    }
    for n, col in enumerate(label_cols):
        answer = label_cols[n]
        row[answer + '_prediction'] = prediction[n]
#         row[answer + '_acquisition'] = acquisition[n]
        row[answer + '_prediction_mean'] = float(prediction[n].mean())
#         row[answer + '_acquisition'] = acquisition[n]
#         row['total_acquisition'] = acquisition.sum()
    return row

In [98]:
def all_to_row(prediction, acquisition, id_str):
    row = {
        'id_str': id_str
    }
    for n, col in enumerate(label_cols):
        answer = label_cols[n]
        row[answer + '_prediction'] = prediction[n]
#         row[answer + '_acquisition'] = acquisition[n]
        row[answer + '_prediction_mean'] = float(prediction[n].mean())
#         row[answer + '_acquisition'] = acquisition[n]
#         row['total_acquisition'] = acquisition.sum()
    return row

In [99]:
data = [prediction_to_row(predictions[n], id_strs[n]) for n in range(len(predictions))]
# data = [all_to_row(predictions[n], acquisitions[n], id_strs[n]) for n in range(len(predictions))]
predictions_df = pd.DataFrame(data)

In [100]:
len(predictions_df)

2415

In [101]:
predictions_df.head()

Unnamed: 0,id_str,smooth-or-featured_smooth_prediction,smooth-or-featured_smooth_prediction_mean,smooth-or-featured_featured-or-disk_prediction,smooth-or-featured_featured-or-disk_prediction_mean,has-spiral-arms_yes_prediction,has-spiral-arms_yes_prediction_mean,has-spiral-arms_no_prediction,has-spiral-arms_no_prediction_mean,spiral-winding_tight_prediction,...,bulge-size_dominant_prediction,bulge-size_dominant_prediction_mean,bulge-size_large_prediction,bulge-size_large_prediction_mean,bulge-size_moderate_prediction,bulge-size_moderate_prediction_mean,bulge-size_small_prediction,bulge-size_small_prediction_mean,bulge-size_none_prediction,bulge-size_none_prediction_mean
0,J120546.54+055750.8,"[0.30420297, 0.29689792, 0.1815976, 0.27818456...",0.252087,"[0.6957971, 0.7031021, 0.81840235, 0.7218154, ...",0.747913,"[0.77654797, 0.68017673, 0.9010634, 0.79541725...",0.805829,"[0.22345208, 0.31982327, 0.09893663, 0.2045827...",0.194171,"[0.36976084, 0.5811233, 0.54553777, 0.5828835,...",...,"[0.0246401, 0.02125156, 0.01801422, 0.02128763...",0.020114,"[0.06994944, 0.04650045, 0.053316157, 0.038661...",0.049567,"[0.5141608, 0.43578383, 0.4751809, 0.31585756,...",0.435906,"[0.36725083, 0.46015358, 0.4309142, 0.552231, ...",0.457002,"[0.023998888, 0.036310542, 0.022574492, 0.0719...",0.037411
1,J134127.12+105353.5,"[0.66698474, 0.6663146, 0.662525, 0.7146262, 0...",0.690463,"[0.33301523, 0.33368537, 0.33747497, 0.2853738...",0.309537,"[0.23337178, 0.30273125, 0.49857944, 0.2058643...",0.282494,"[0.76662827, 0.6972688, 0.50142056, 0.79413563...",0.717506,"[0.44853082, 0.250708, 0.31300673, 0.3163705, ...",...,"[0.07231821, 0.0875522, 0.0803811, 0.049855016...",0.072807,"[0.22716962, 0.28933305, 0.2419108, 0.15618472...",0.219773,"[0.5412547, 0.47177008, 0.51008976, 0.55140704...",0.52838,"[0.14002505, 0.121737115, 0.14817934, 0.219411...",0.156028,"[0.019232363, 0.029607613, 0.019438962, 0.0231...",0.023013
2,J124108.42+135738.7,"[0.5939548, 0.5553277, 0.63378006, 0.59075475,...",0.595311,"[0.40604523, 0.4446723, 0.36621988, 0.40924525...",0.404689,"[0.31524974, 0.50869465, 0.34540093, 0.4710049...",0.405067,"[0.6847502, 0.4913053, 0.6545991, 0.528995, 0....",0.594933,"[0.3221565, 0.32892954, 0.34041446, 0.26781067...",...,"[0.060038064, 0.044161692, 0.06280769, 0.06437...",0.057201,"[0.20401071, 0.16899005, 0.19661307, 0.2424580...",0.209927,"[0.5576686, 0.5837672, 0.56842136, 0.5288443, ...",0.544825,"[0.15632583, 0.18357164, 0.1519133, 0.14060122...",0.162504,"[0.021956755, 0.019509498, 0.020244569, 0.0237...",0.025542
3,J103232.48+143246.7,"[0.24313717, 0.296936, 0.2039256, 0.21983513, ...",0.21745,"[0.7568629, 0.70306396, 0.79607445, 0.78016484...",0.782551,"[0.91874164, 0.8905016, 0.91629034, 0.94856566...",0.928969,"[0.0812584, 0.10949836, 0.08370964, 0.05143435...",0.071031,"[0.48085463, 0.47088748, 0.47635913, 0.4744492...",...,"[0.025323467, 0.02414925, 0.029758457, 0.01648...",0.02247,"[0.056805905, 0.09140686, 0.063217394, 0.04846...",0.057665,"[0.37605965, 0.3769583, 0.34806174, 0.39062214...",0.357599,"[0.5147198, 0.43075505, 0.5246157, 0.5111978, ...",0.523047,"[0.027091227, 0.07673053, 0.03434673, 0.033224...",0.039219
4,J212033.56+105801.3,"[0.6517667, 0.73542684, 0.7160145, 0.7482792, ...",0.708516,"[0.3482333, 0.26457322, 0.2839855, 0.25172082,...",0.291484,"[0.08934232, 0.20825186, 0.117199324, 0.059358...",0.16087,"[0.9106577, 0.79174817, 0.8828007, 0.94064116,...",0.83913,"[0.24652085, 0.16471621, 0.12415182, 0.1637303...",...,"[0.118599854, 0.08799145, 0.14941905, 0.127707...",0.117834,"[0.4488062, 0.35998735, 0.30993697, 0.36788717...",0.372014,"[0.38391578, 0.49395758, 0.49997556, 0.4455365...",0.458517,"[0.036783103, 0.04825744, 0.03501271, 0.049017...",0.04263,"[0.011895055, 0.009806191, 0.0056557204, 0.009...",0.009005


In [102]:
predictions_df['iauname'] = predictions_df['id_str']
df = pd.merge(catalog, predictions_df, how='inner', on='iauname')
len(df)
assert len(df) == len(predictions_df)

In [103]:
df.columns.values

array(['Unnamed: 0', 'iauname', 'nsa_id', 'ra_subject', 'dec_subject',
       'petrotheta', 'petroth50', 'petroth90', 'redshift',
       'local_fits_loc', 'local_png_loc', 'fits_ready', 'fits_filled',
       'png_ready', 'best_match', 'sky_separation', 'nsa_version', 'mag',
       'ra', 'dec', 'file_loc', 'subject_id', 'bar_no', 'bar_strong',
       'bar_weak', 'bulge-size_dominant', 'bulge-size_large',
       'bulge-size_moderate', 'bulge-size_none', 'bulge-size_small',
       'disk-edge-on_no', 'disk-edge-on_yes', 'edge-on-bulge_boxy',
       'edge-on-bulge_none', 'edge-on-bulge_rounded',
       'has-spiral-arms_no', 'has-spiral-arms_yes',
       'how-rounded_cigar-shaped', 'how-rounded_in-between',
       'how-rounded_round', 'merging_both-v1',
       'merging_major-disturbance', 'merging_merger',
       'merging_minor-disturbance', 'merging_neither-v1', 'merging_none',
       'merging_tidal-debris-v1', 'smooth-or-featured_artifact',
       'smooth-or-featured_featured-or-disk', 'sm

In [104]:
df['subject_url']

0       https://panoptes-uploads.zooniverse.org/produc...
1       https://panoptes-uploads.zooniverse.org/produc...
2       https://panoptes-uploads.zooniverse.org/produc...
3       https://panoptes-uploads.zooniverse.org/produc...
4       https://panoptes-uploads.zooniverse.org/produc...
                              ...                        
2410                                                  NaN
2411                                                  NaN
2412                                                  NaN
2413                                                  NaN
2414                                                  NaN
Name: subject_url, Length: 2415, dtype: object

In [105]:
df['file_loc']

0       /media/walml/beta/decals/png_native/dr5/J114/J...
1       /media/walml/beta/decals/png_native/dr5/J114/J...
2       /media/walml/beta/decals/png_native/dr5/J115/J...
3       /media/walml/beta/decals/png_native/dr5/J120/J...
4       /media/walml/beta/decals/png_native/dr5/J115/J...
                              ...                        
2410    /media/walml/beta/decals/png_native/dr5/J104/J...
2411    /media/walml/beta/decals/png_native/dr5/J105/J...
2412    /media/walml/beta/decals/png_native/dr5/J140/J...
2413    /media/walml/beta/decals/png_native/dr5/J142/J...
2414    /media/walml/beta/decals/png_native/dr5/J143/J...
Name: file_loc, Length: 2415, dtype: object

In [106]:
df.to_csv('temp/master_256_predictions_{}_{}.csv'.format(start, end), index=False)

In [107]:
assert False

AssertionError: 

In [None]:
# copied from trust_the_model.ipynb
def show_galaxies(df, scale=3, nrows=3, ncols=3):
    fig = plt.gcf()

    plt.figure(figsize=(scale * nrows, scale * ncols * 1.025))
    gs1 = gridspec.GridSpec(nrows, ncols)
    gs1.update(wspace=0.0, hspace=0.0)
    galaxy_n = 0
    for row_n in range(nrows):
        for col_n in range(ncols):
            galaxy = df.iloc[galaxy_n]
            image = Image.open(galaxy['file_loc'])
            ax = plt.subplot(gs1[row_n, col_n])
            ax.imshow(image)
#             ax.text(10, 20, 'Smooth = {:.2f}'.format(galaxy['smooth-or-featured_smooth_fraction']), fontsize=12, color='r')
#             ax.text(10, 50, r'$\rho = {:.2f}$, Var ${:.3f}$'.format(galaxy['median_prediction'], 3*galaxy['predictions_var']), fontsize=12, color='r')
#             ax.text(10, 80, '$L = {:.2f}$'.format(galaxy['bcnn_likelihood']), fontsize=12, color='r')
            ax.axis('off')
            galaxy_n += 1
#     print('Mean L: {:.2f}'.format(df[:nrows * ncols]['bcnn_likelihood'].mean()))
    fig = plt.gcf()
#     fig.tight_layout()
    return fig


In [None]:


def save_top_n(schema, save_dir):
    for question in schema.questions:
        for answer in question.answers:
            fig = show_galaxies(df.sort_values(answer.text + '_prediction_mean', ascending=False)[:9])
            fig.savefig(save_dir + '_' + answer.text + '.png')
            plt.close()

In [None]:
save_dir = 'results/temp'
save_top_n(schema, save_dir)

In [None]:
assert False  # moove later analysis elsewhere?

In [None]:
# fig, ax = plt.subplots()
# ax.scatter(predictions[:, 0], labels[:, 0] / labels[:, :2].sum(axis=1))
# plt.close()

In [None]:
matplotlib.get_backend()

In [None]:
# fig, ax = plt.subplots()
# ax.scatter(predictions[:, 4], labels[:, 4] / labels[:, 4:7].sum(axis=1))

In [None]:
# fig, ax = plt.subplots()
# ax.scatter(predictions[:, 5], labels[:, 5] / labels[:, 4:7].sum(axis=1))

In [None]:
# fig, ax = plt.subplots()
# ax.scatter(predictions[:, 6], labels[:, 6] / labels[:, 4:7].sum(axis=1))

In [None]:
# assert False

In [None]:
# to check that the right models have been loaded - should be around 40 for smooth, 0-40 for bars
# plt.hist(sim_model.total_votes), sim_model.total_votes.mean()

### Visualise Posteriors

Check for systematic offset - in general, model seems slightly skewed towards low k?

In [None]:

def get_single_answer_data(df, answer, n=10):
    samples = np.stack(df[answer.text + '_prediction'][:n])
    labels = df[answer.text][:n].values.astype(int)
    total_votes = df[answer.question.text + '_total-votes'][:n].values.astype(int)
    return samples, labels, total_votes
    

In [None]:
def show_population_stats(samples, labels, total_votes, title):
    # sns.set_context('paper')
    sns.set(font_scale=1.)
    sns.set_style('white')
    alpha = 0.5
    # matplotlib.rcParams.update({'font.size': 50}
    
    mean_samples = samples.mean(axis=-1)
    expected_k = mean_samples * total_votes

    # dummy for bins
    fig, (ax0, ax1) = plt.subplots(nrows=2)
    _, bins_rho, _ = ax0.hist(labels/ total_votes, bins=25, alpha=alpha, label='Actual')
    _, bins_k, _ = ax1.hist(labels, bins=25, alpha=alpha, label='Actual')
    plt.close()

    # now for real
    fig, (ax0, ax1) = plt.subplots(nrows=2)

    _, bins, _ = ax0.hist(mean_samples, bins=bins_rho, alpha=alpha, label='Model')
    ax0.hist(labels/ total_votes, bins=bins_rho, alpha=alpha, label='Actual')
    ax0.set_xlabel(r'Vote Fraction $\rho$')
    ax0.set_ylabel('Galaxies')
    ax0.legend()
    ax0.set_xlim([0., 1.])

    _, bins, _ = ax1.hist(expected_k, bins=bins_k, alpha=alpha, label='Model')
    ax1.hist(labels, bins=bins_k, alpha=alpha, label='Actual')
    ax1.set_xlabel(r'Positive Responses $k$')
    ax1.set_ylabel('Galaxies')
    ax1.legend()
    ax1.set_xlim([0, 40])

    ax0.set_title(title)
    fig.tight_layout()
    return fig
    # fig.savefig(os.path.join(save_dir, 'posterior_over_full_sample.png'))
    # fig.savefig(os.path.join(save_dir, 'posterior_over_full_sample.pdf'))

In [None]:

answer = schema.get_answer('smooth-or-featured_smooth')
samples, labels, total_votes = get_single_answer_data(df, answer, n=len(df))
_ = show_population_stats(samples, labels, total_votes, answer.text)

In [None]:
for answer in schema.answers:
    samples, labels, total_votes = get_single_answer_data(df, answer, n=len(df))
    fig = show_population_stats(samples, labels, total_votes, answer.text)
    fig.savefig(save_dir + '/population_distribution_' + answer.text + '.png')

In [None]:
def custom_samples(samples, labels, total_votes):
    sns.set_context('paper', font_scale=1.5)
    fig, axes = plt.subplots(nrows=len(labels), figsize=(3, len(labels)*1.5), sharex=True)
    make_predictions.plot_samples(samples, labels, total_votes, fig, axes, alpha=0.06)
    for ax in axes:
        ax.set_xlim([0, 50])
    
    for n in range(len(labels)):
#         axes[n].set_ylabel(r'$p(v|D)$', visible=True)
        axes[n].set_ylabel(r'$p(v|w)$', visible=True)
        axes[n].yaxis.set_visible(True)
    
    axes[-1].set_xlabel('Volunteer Votes')
    fig.tight_layout()

    axes[0].legend(
        loc='lower center', 
        bbox_to_anchor=(0.5, 1.1),
        ncol=1, 
        fancybox=True, 
        shadow=False
    )
    fig.tight_layout()
    return fig

In [None]:

question = 'smooth-or-featured'
answer = 'smooth'
n = 5
samples, labels, total_votes = get_single_answer_data(df, question, answer)
_ = custom_samples(samples, labels, total_votes)

In [None]:

question = 'has-spiral-arms'
answer = 'yes'
n = 5
samples, labels, total_votes = get_single_answer_data(df, question, answer)
_ = custom_samples(samples, labels, total_votes)

In [None]:

def galaxy_posterior_grid(df, schema):
    
    sns.set_context('paper', font_scale=1.5)
    
    scale = 1.5
    
    im_width = 2
    posterior_width = 3
    height = im_width
    
    n_galaxies = len(df)
    n_posteriors = len(schema.answers)
    
    fig = plt.figure(figsize=(scale * (im_width + posterior_width*n_posteriors), (scale * n_galaxies * height)))  # width, height format
    gs = gridspec.GridSpec(len(df) * height, im_width + posterior_width * len(schema.answers))  # y, x format
    image_axes = []
    posterior_axes = []  # (galaxy i.e. row, answer) shape
    
    # create the grid
    for galaxy_n in range(len(df)):
        y_slice = slice(galaxy_n*height, (galaxy_n+1)*height)
        image_axes.append(plt.subplot(gs[y_slice, :im_width]))
        
        temp_galaxy_axes = []
        for answer_n, answer in enumerate(schema.answers):
            x_slice = slice(im_width+answer_n*posterior_width, im_width+(answer_n+1)*posterior_width)
            temp_galaxy_axes.append(plt.subplot(gs[y_slice, x_slice]))
        posterior_axes.append(temp_galaxy_axes)
        
    
    # fill the images
    for ax_n, ax in enumerate(image_axes):
        plot_galaxy(df['file_loc'][ax_n], ax)
    
    # fill the posteriors
    for answer_n, answer in enumerate(schema.answers):
        samples, labels, total_votes = get_single_answer_data(df, answer)
        galaxy_axes = [axes[answer_n] for axes in posterior_axes]
        make_predictions.plot_samples(samples, labels, total_votes, fig, galaxy_axes, alpha=0.06)
    
    # fix x limits for comparison
    for row_n, axes in enumerate(posterior_axes):
        for answer_n, ax in enumerate(axes):
            ax.set_xlim([0, 50])
            if row_n == 0:
                ax.set_title(schema.answers[answer_n].text)
        
#     for n in range(len(labels)):
#         multiple_axes[n].set_ylabel(r'$p(k|N, D)$', visible=True)
#         multiple_axes[n].yaxis.set_visible(True)
#         single_axes[n].set_ylabel(r'$p(k|N, w)$', visible=True)
#         single_axes[n].yaxis.set_visible(True)
#         single_axes[n].yaxis.set_major_locator(plt.NullLocator())
#         multiple_axes[n].yaxis.set_major_locator(plt.NullLocator())
#         if n < len(labels) - 1:
#             single_axes[n].xaxis.set_major_locator(plt.NullLocator())
#             multiple_axes[n].xaxis.set_major_locator(plt.NullLocator())
    
# #     if QUESTION == 'bars':
# #         question = 'Bar'
# #     else:
# #         question = 'Smooth'
# #     single_axes[-1].set_xlabel(r"$k$ '{}' votes, of $N$ total".format(question))
# #     multiple_axes[-1].set_xlabel(r"$k$ '{}' votes, of $N$ total".format(question))
#     fig.tight_layout()

    
#     multiple_axes[0].legend(
#         loc='lower center', 
#         bbox_to_anchor=(0.5, 1.1),
#         ncol=1, 
#         fancybox=True, 
#         shadow=False
#     )

    fig.tight_layout()
    return fig

In [None]:
fig = galaxy_posterior_grid(df[:5], schema)
fig.savefig(save_dir + '/grid.pdf')
fig.savefig(save_dir + '/grid.png')

In [None]:

def custom_samples_with_galaxies(samples, labels, total_votes, png_locs):
    
    sns.set_context('paper', font_scale=1.5)
    
    im_width = 2
    single_width = 3
    multiple_width = 3
    height = im_width
    
    fig = plt.figure(figsize=(0.8 * len(labels) * height * 2., 0.8 * (im_width + single_width + multiple_width) * 1.75))
    gs = gridspec.GridSpec(len(labels) * height, im_width + single_width + multiple_width)  # y, x format
    image_axes = []
    single_axes = []
    multiple_axes = []
    for galaxy_n in range(len(labels)):
        x_slice = slice(galaxy_n*height, (galaxy_n+1)*height)
        image_axes.append(plt.subplot(gs[x_slice, :im_width]))
        single_axes.append(plt.subplot(gs[x_slice, im_width:im_width+single_width]))
        multiple_axes.append(plt.subplot(gs[x_slice, im_width+single_width:]))
    

#     fig, axes = plt.subplots(nrows=len(labels), figsize=(3, len(labels)*1.5), sharex=True)
    make_predictions.plot_samples(samples[:, :1], labels, total_votes, fig, single_axes, alpha=0.06)
    for ax in single_axes:
        ax.set_xlim([0, 50])

    make_predictions.plot_samples(samples, labels, total_votes, fig, multiple_axes, alpha=0.06)
    for ax in multiple_axes:
        ax.set_xlim([0, 50])
        
        
    for ax_n, ax in enumerate(image_axes):
        plot_galaxy(png_locs[ax_n], ax)
        
    
    for n in range(len(labels)):
        multiple_axes[n].set_ylabel(r'$p(k|N, D)$', visible=True)
        multiple_axes[n].yaxis.set_visible(True)
        single_axes[n].set_ylabel(r'$p(k|N, w)$', visible=True)
        single_axes[n].yaxis.set_visible(True)
        single_axes[n].yaxis.set_major_locator(plt.NullLocator())
        multiple_axes[n].yaxis.set_major_locator(plt.NullLocator())
        if n < len(labels) - 1:
            single_axes[n].xaxis.set_major_locator(plt.NullLocator())
            multiple_axes[n].xaxis.set_major_locator(plt.NullLocator())
    
#     if QUESTION == 'bars':
#         question = 'Bar'
#     else:
#         question = 'Smooth'
#     single_axes[-1].set_xlabel(r"$k$ '{}' votes, of $N$ total".format(question))
#     multiple_axes[-1].set_xlabel(r"$k$ '{}' votes, of $N$ total".format(question))
    fig.tight_layout()

    single_axes[0].legend(
        loc='lower center', 
        bbox_to_anchor=(0.5, 1.1),
        ncol=1, 
        fancybox=True, 
        shadow=False
    )
    
    multiple_axes[0].legend(
        loc='lower center', 
        bbox_to_anchor=(0.5, 1.1),
        ncol=1, 
        fancybox=True, 
        shadow=False
    )

    
    fig.tight_layout()
    return fig

In [None]:
def plot_galaxy(image_loc, ax, n_examples=10, crop=0):
    im_size = 424
    im = Image.open(image_loc)
#     if QUESTION == 'bars':
#         crop = 120
#     else:
    crop = 35
    cropped_im = im.crop((crop, crop, 424 - crop, 424 - crop))
    ax.imshow(cropped_im)
    ax.grid(False)
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)


In [None]:
question = 'has-spiral-arms'
answer = 'yes'
n = 5
samples, labels, total_votes = get_single_answer_data(question, answer)
png_locs = df['file_loc'][:n]

# catalog = sim_model.catalog[selected_slice]

_ = custom_samples_with_galaxies(samples, labels, total_votes, png_locs)

In [None]:
# fig, axes = plt.subplots(1, 10, figsize=(20, 12))
# for ax_n, ax in enumerate(axes):
#     plot_galaxy(sim_model.catalog.iloc[ax_n]['png_loc'], ax)

In [None]:
# 1 2

In [None]:
selected = slice(80, 73, -1)  # smooth

# selected = slice(0, 7)

if QUESTION == 'bars':
    selected = slice(0, 7)

In [None]:
# np.array(sim_model.model.samples)[selected, :]
# np.array(sim_model.labels)[selected]
# sim_model.catalog['smooth-or-featured_total-votes'][selected]

In [None]:
fig = custom_samples_with_galaxies(sim_model, selected)
# fig.savefig(os.path.join(save_dir, 'mc_model_{}.png'.format(len(np.array(sim_model.labels)[selected]))))
# fig.savefig(os.path.join(save_dir, 'mc_model_{}.pdf'.format(len(np.array(sim_model.labels)[selected]))))

In [None]:
# be sure to switch label in custom_samples before running this
# fig = custom_samples(np.array(single_sim_model.model.samples)[selected, :1], np.array(single_sim_model.labels)[selected], total_votes=single_sim_model.total_votes)
# fig.savefig(os.path.join(save_dir, 'single_model_{}.png'.format(len(np.array(sim_model.labels)[selected]))))
# fig.savefig(os.path.join(save_dir, 'single_model_{}.eps'.format(len(np.array(sim_model.labels)[selected]))))

In [None]:
sns.set(font_scale=1.2)
sns.set_style('white')

fig, ax = plt.subplots()
ungrouped_coverage_df = discrete_coverage.evaluate_discrete_coverage(
    sim_model.labels, 
    sim_model.bin_probs)
coverage_df = ungrouped_coverage_df.groupby('max_state_error').agg({'prediction': 'sum', 'observed': 'sum'}).reset_index()

ungrouped_single_coverage_df = discrete_coverage.evaluate_discrete_coverage(
    single_sim_model.labels, 
    single_sim_model.bin_probs)
single_coverage_df = ungrouped_single_coverage_df.groupby('max_state_error').agg({'prediction': 'sum', 'observed': 'sum'}).reset_index()


plt.plot(coverage_df['max_state_error'], coverage_df['prediction'], label='MC Model Expects')
plt.plot(single_coverage_df['max_state_error'], single_coverage_df['prediction'], label='Single Model Expects')
plt.plot(single_coverage_df['max_state_error'], coverage_df['observed'], 'k--', label='Actual')

ax.set_xlabel('Max Allowed Vote Error')
ax.set_ylabel('Galaxies Within Max Error')
ax.legend()
ax.xaxis.set_major_formatter(StrMethodFormatter('{x:.0f}'))  # must expect 'x' kw arg

ax.set_xlim([0, 15])
fig.tight_layout()
fig.savefig(os.path.join(save_dir, 'coverage_comparison_200_samples.png'))
fig.savefig(os.path.join(save_dir, 'coverage_comparison_200_samples.pdf'))

In [None]:
ungrouped_coverage_df.to_csv(os.path.join(save_dir, QUESTION + '_ungrouped_coverage_df.csv'), index=False)
ungrouped_single_coverage_df.to_csv(os.path.join(save_dir, QUESTION + '_ungrouped_coverage_df.csv'), index=False)

In [None]:
coverage_df['error'] = coverage_df['prediction'] - coverage_df['observed']
coverage_df['relative_error'] = coverage_df['error'] / coverage_df['observed']
coverage_df.to_csv(os.path.join(save_dir, QUESTION + '_coverage_df.csv'), index=False)
coverage_df.head(20)

In [None]:
single_coverage_df['error'] = single_coverage_df['prediction'] - single_coverage_df['observed']
single_coverage_df['relative_error'] = single_coverage_df['error'] / single_coverage_df['observed']
single_coverage_df.to_csv(os.path.join(save_dir, QUESTION + '_single_coverage_df.csv'), index=False)
single_coverage_df.head(20)

TODO - I might consider adding an MSE model as a comparison, to hopefully beat. I think this might be quite similar though. Ideally I can compare this with previous work somehow.

In [None]:
sns.set(font_scale=1.2)
sns.set_style('white')
fig, ax = plt.subplots()
ax.hist(sim_model.abs_rho_error, bins=25)
# ax.axvline(sim_model.mean_abs_rho_error, color='r') 
ax.set_xlim([0, 1.])
ax.set_ylabel('Galaxies')
ax.set_xlabel(r'| Expected $\hat{\rho}$ - observed vote fraction $\frac{k}{N}$ |')
fig.tight_layout()
fig.savefig(os.path.join(save_dir, 'difference_in_rho.png'))
fig.savefig(os.path.join(save_dir, 'difference_in_rho.pdf'))

In [None]:
sim_model.abs_rho_error.mean(), single_sim_model.abs_rho_error.mean()

In [None]:
np.sqrt(sim_model.mean_abs_rho_error), np.sqrt(single_sim_model.mean_abs_rho_error)  

In [None]:
np.sqrt(sim_model.mean_square_rho_error), np.sqrt(single_sim_model.mean_square_rho_error) # this is the rmse

In [None]:
# alpha = 0.3
# n_bins = 25

# # dummy for bins
# fig, ax = plt.subplots()
# _, bins, _  = ax.hist(sim_model.labels / sim_model.total_votes, bins=n_bins, alpha=alpha, label=r'Observed $\rho$')
# ax.hist(sim_model.mean_rho_prediction, bins=n_bins, alpha=alpha, label=r'Mean Rho Prediction $\hat{\rho}}$')
# # ax.hist(single_sim_model.mean_rho_prediction, bins=bins, alpha=alpha, label=r'Single Rho Prediction $\hat{\rho}}$')

# fig, ax = plt.subplots()
# sns.set(font_scale=1.)
# sns.set_style('white')

# ax.hist(sim_model.mean_rho_prediction, bins=bins, alpha=alpha, label=r'Mean Rho Prediction $\hat{\rho}}$')
# # ax.hist(single_sim_model.mean_rho_prediction, bins=bins, alpha=alpha, label=r'Single Rho Prediction $\hat{\rho}}$')
# ax.hist(sim_model.labels / sim_model.total_votes, bins=bins, alpha=alpha, label=r'Observed $\rho$')
# ax.legend()
# ax.set_xlim([0., 1.])
# ax.set_ylabel('Galaxies')
# ax.set_xlabel(r'Typical vote fraction $\rho$')
# fig.tight_layout()
# fig.savefig(os.path.join(save_dir, 'typical_vote_fraction_distribution.png'))

# This is a repeat of the above histograms

In [None]:
np.sum(sim_model.mean_rho_prediction > 0.5), np.sum(single_sim_model.mean_rho_prediction > 0.5), np.sum((sim_model.labels / sim_model.total_votes) > 0.5)

In [None]:
(sim_model.labels / sim_model.total_votes).min(), (sim_model.labels / sim_model.total_votes).max()

In [None]:
sim_model.mean_rho_prediction.min(), sim_model.mean_rho_prediction.max()

In [None]:
single_sim_model.mean_rho_prediction.min(), single_sim_model.mean_rho_prediction.max()

In [None]:
sim_model.total_votes

## Save DataFrame of predictions + catalog (GZ2) for use elsewhere

In [None]:
import json

In [None]:
response_df = pd.DataFrame(data={
    'total_votes': sim_model.total_votes, 
    'k': sim_model.labels, 
    'vote_fraction': (sim_model.labels / sim_model.total_votes), 
    'rho_prediction': sim_model.mean_rho_prediction
#     'png_loc': sim_model.catalog.png_loc
})
safe_catalog_cols = list(set(sim_model.catalog.columns.values) - set(['total_votes', 'ra_subject', 'dec_subject']))
df = pd.concat([response_df, sim_model.catalog[safe_catalog_cols]], axis=1)
df['smooth'] = df['vote_fraction'] > 0.5
df['confidence_proxy'] = np.abs(0.5 - df['rho_prediction'])
df['rho_predictions'] = 0
for n in range(len(df)):
    df['rho_predictions'][n] = json.dumps(list(sim_model.model.samples[n, :]))
    df = df.sort_values('confidence_proxy', ascending=False)

In [None]:
df['rho_predictions']

In [None]:
df.head()

In [None]:
df.to_parquet('/data/repos/zoobot/notebooks/{}_test_predictions_and_gz2_catalog.parquet'.format(QUESTION))

### Replicate (ish) Sanchez 2017 ROC Curves

In [None]:
confusion_matrix((sim_model.labels / sim_model.total_votes) > 0.5, sim_model.mean_rho_prediction > 0.5)

In [None]:
 1 - ((66 + 99) / (490 + 1845 + 66 + 99))

In [None]:
 1 - ((189 + 81) / (1858 + 189 + 81 + 372))

In [None]:
fig, ax = plt.subplots()
sns.set(font_scale=1.2)
sns.set_style('white')

fpr, tpr, _ = roc_curve(df['smooth'], df['rho_prediction'])
ax.plot(fpr, tpr, label='All')
df_low_entropy = df[df['confidence_proxy'] > 0.3]
fpr, tpr, _ = roc_curve(df_low_entropy['smooth'], df_low_entropy['rho_prediction'])
ax.plot(fpr, tpr, label=r'"High Confidence" i.e. $\hat{\rho} < 0.2$ or $\hat{\rho} > 0.8$')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.legend()

fig.tight_layout()
fig.savefig(os.path.join(save_dir, 'roc_curve.png'))
fig.savefig(os.path.join(save_dir, 'roc_curve.pdf'))

In [None]:
len(df), len(df_low_entropy)

### Replicate(ish) Khan 2018 Confusion Matrices

> After selecting the OBJIDs from Table 2 based on the probability thresholds of 0.985 and 0.926 for spirals and ellipticals respectively,

In [None]:
df.sample(1)

In [None]:
cdf_array = binom.cdf((df['total_votes'] / 2.).astype(int), df['total_votes'], df['rho_prediction'])

In [None]:
(df['total_votes'] / 2.).astype(int).sample(10)

In [None]:
df['total_votes'].sample(10)

In [None]:
df['rho_prediction'].sample(10)

In [None]:
plt.hist(cdf_array, bins=30)

In [None]:
binom.cdf(20, 40, 0.88)

In [None]:
sum(1 - cdf_array > 0.985)

In [None]:
sum(cdf_array > 0.926)

In [None]:
high_prob_df = df[(cdf_array < (1 - 0.985)) | (cdf_array > 0.926)]

In [None]:
len(high_prob_df)

In [None]:
if QUESTION == 'smooth':
    spiral_pc_to_keep = 516 / 6677
    n_spirals = int(len(df) * spiral_pc_to_keep)
    elliptical_pc_to_keep = 550 / 5904
    n_ellipticals = int(len(df) * elliptical_pc_to_keep)
    print(spiral_pc_to_keep, n_spirals, elliptical_pc_to_keep, n_ellipticals)
    high_prob_df = pd.concat([
        df.sort_values('rho_prediction')[:n_spirals],
        df.sort_values('rho_prediction', ascending=False)[:n_ellipticals]
    ])
if QUESTION == 'bars':
    n_to_keep = int(len(df) * 0.08)
    high_prob_df = pd.concat([
        df.sort_values('rho_prediction')[:int(n_to_keep/2)],
        df.sort_values('rho_prediction', ascending=False)[:int(n_to_keep/2)]
    ])



In [None]:
high_prob_df.sample(20)

In [None]:
confusion_matrix(high_prob_df['vote_fraction'] >= 0.5, high_prob_df['rho_prediction'] >= 0.5)

In [None]:
error = high_prob_df[~(high_prob_df['vote_fraction'] > 0.5) & (high_prob_df['rho_prediction'] > 0.5)]

In [None]:
error

In [None]:
error['vote_fraction'] > 0.5, error['rho_prediction'] > 0.5

In [None]:
img = Image.open(error.iloc[0]['png_loc'])
plt.imshow(img)
fontdict = {'size': 16, 'color': 'white'}
plt.text(30, 360, r'Expected vote frac $\hat{\rho}$: 0.80', fontdict=fontdict)
plt.text(30, 400, r'Observed vote frac $\frac{k}{N}$: 0.50', fontdict=fontdict)
plt.axis('off')
plt.savefig(os.path.join(save_dir, 'high_prob_error_0.png'))
plt.savefig(os.path.join(save_dir, 'high_prob_error_0.eps'))

In [None]:
# img = Image.open(error.iloc[1]['png_loc'])
# plt.imshow(img)
# fontdict = {'size': 16, 'color': 'white'}
# plt.text(30, 360, r'Expected vote frac $\hat{\rho}$: 0.13', fontdict=fontdict)
# plt.text(30, 400, r'Observed vote frac $\frac{k}{N}$: 0.54', fontdict=fontdict)
# plt.axis('off')
# plt.savefig(os.path.join(save_dir, 'high_prob_error_1.png'))
# plt.savefig(os.path.join(save_dir, 'high_prob_error_1.eps'))

In [None]:
confusion_matrix(df['vote_fraction'][:int(len(df) / 2)] > 0.5, df['rho_prediction'][:int(len(df) / 2)] > 0.5)

In [None]:
if QUESTION == 'smooth':
    labels = ['Smooth', 'Featured']
    
#     cm = np.array([[ 232,    2], [   0, 191]])
#     name = 'confusion_matrix_high_confidence'
    
    cm = np.array([[ 490,   66],
       [  99, 1845]])
    name = 'confusion_matrix'
    
if QUESTION == 'bars':
    labels = ['No Bar', 'Bar']
    cm = np.array([[100,    0], [   0,   100]])
    name = 'confusion_matrix_high_confidence'
    
#     cm = np.array([[1858,    81], [   189,   372]])
#     name = 'confusion_matrix'

sns.set(font_scale=3.)
sns.set_style('white')

fig, ax = plt.subplots(figsize=(8, 8))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=labels, yticklabels=labels, cbar=False, square=True, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('Observed')
fig.tight_layout()
fig.savefig(os.path.join(save_dir, '{}.png'.format(name)))
fig.savefig(os.path.join(save_dir, '{}.pdf'.format(name)))

In [None]:
1 - (8 / (1159 + 83 + 8))

In [None]:
# sim_model.export_performance_metrics(save_dir)

In [None]:
# Draw a galaxy, infer a range of p, redraw, and measure accuracy - work in progress

### Plot other standard acquisition visualisations

In [None]:
new_acquisition_viz = False
if new_acquisition_viz:
    image_locs = sim_model.catalog['png_loc']
    images = np.stack([np.array(Image.open(loc)) for loc in image_locs])
    assert images.shape == (2500, 424, 424, 3)
    acquisition_utils.save_acquisition_examples(images, sim_model.mutual_info, 'mutual_info', save_dir)

In [None]:
# fig, row = plt.subplots(ncols=3, figsize=(12, 4))

In [None]:
# row = sim_model.acquisition_vs_volunteer_votes(row)

### Visualise Selection of Catalog Features w.r.t. Acquisition Function

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(20, 12))
gs = gridspec.GridSpec(6, 5, figure=fig)

#### Smooth Votes

In [None]:
ax0 = plt.subplot(gs[:4, :])
sns.scatterplot(
    np.array(sim_model.catalog['smooth-or-featured_smooth_fraction'] * 40).astype(int),
    sim_model.model.acquisitions, hue=np.array(sim_model.model.acquisitions) > np.array(sim_model.model.acquisitions[103]),
    ax=ax0)
ax0.set_ylabel('Mutual Information')
ax0.set_xlabel('Smooth Votes')
ax0.legend([r'Top 10% $\mathcal{I}$', r'Bottom 90% $\mathcal{I}$'])

In [None]:
ax1 = plt.subplot(gs[4:, :])
ax1.hist(np.array(sim_model.labels * 40).astype(int), density=True, alpha=0.4)
ax1.hist(np.array(sim_model.labels * 40).astype(int)[:200], density=True, alpha=0.4)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('Smooth Votes')

In [None]:
plt.savefig(os.path.join(save_dir, 'temp.png'))

#### Redshift

In [None]:
fig, ax = plt.subplots()
sns.jointplot(np.array(sim_model.labels * 40).astype(int), sim_model.catalog['redshift'], kind='kde')
ax0.set_ylabel('Redshift')
ax0.set_xlabel('Volunteer Votes')

In [None]:
ax0 = plt.subplot(gs[:2, :])
sns.jointplot(np.array(sim_model.labels * 40).astype(int), sim_model.catalog['redshift'], kind='kde', ax=ax0)
ax0.set_ylabel('Redshift')
ax0.set_xlabel('Volunteer Votes')

In [None]:
ax0 = plt.subplot(gs[:2, :])
sns.jointplot(sim_model.catalog['redshift'], sim_model.model.acquisitions, ax=ax0)
ax0.set_ylabel('Mutual Information')
ax0.set_xlabel('Redshift')

In [None]:
ax1 = plt.subplot(gs[4:, :])

ax1.hist(sim_model.catalog['redshift'], density=True, alpha=0.4)
# ax1.hist(sim_model.catalog['redshift'], density=True, alpha=0.4)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('Smooth Votes')
# TODO sort by mutual information

### Below here is only relevant for DECALS, with extra questions. TODO update with GZ2 merger options?

In [None]:
assert False

In [None]:
merger_strs

In [None]:
merger_label = 'merging_major-disturbance'

In [None]:
ax0 = plt.subplot(gs[:2, :])
sns.scatterplot(sim_model.catalog[merger_label], sim_model.model.mutual_info, ax=ax0)
ax0.set_ylabel('Mutual Information')
ax0.set_xlabel(merger_label)

In [None]:
_ = plt.hist(sim_model.catalog[merger_label], bins=40)

In [None]:
featured_no_merger = sim_model.model.mutual_info[(sim_model.catalog[merger_label] == 0) & (sim_model.catalog['smooth-or-featured_smooth_fraction'] < 0.4)]
featured_merger = sim_model.model.mutual_info[(sim_model.catalog[merger_label] > 0) & (sim_model.catalog['smooth-or-featured_smooth_fraction'] < 0.4)]

In [None]:
featured_no_merger.mean()

In [None]:
featured_merger.mean()

In [None]:
fig, ax = plt.subplots()

In [None]:
ax.hist(featured_no_merger, alpha=0.3, density=True)

In [None]:
ax.hist(featured_merger, alpha=0.3, density=True)

In [None]:
fig

In [None]:
ax = sns.scatterplot(
    sim_model.catalog['smooth-or-featured_artifact'], 
    sim_model.model.mutual_info, 
    hue=sim_model.model.mutual_info > sim_model.model.mutual_info[103])
ax.set_xlim([0, 14.5])

In [None]:
sns.regplot(
    sim_model.catalog['smooth-or-featured_artifact'], 
    sim_model.model.mutual_info)

In [None]:
fig, ax = plt.subplots()

In [None]:
ax.hist(
    sim_model.catalog['smooth-or-featured_artifact'][sim_model.model.mutual_info > sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)
ax.hist(
    sim_model.catalog['smooth-or-featured_artifact'][sim_model.model.mutual_info < sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)

In [None]:
fig

In [None]:
has-spiral-arms_yes
spiral-winding_prediction-encoded

In [None]:
sns.scatterplot(
    sim_model.catalog['has-spiral-arms_yes'][sim_model.catalog['smooth-or-featured_smooth_fraction'] < 0.5], 
    sim_model.model.mutual_info[sim_model.catalog['smooth-or-featured_smooth_fraction'] < 0.5], 
    hue=sim_model.model.mutual_info > sim_model.model.mutual_info[100])

In [None]:
fig, ax = plt.subplots()
ax.hist(
    sim_model.catalog['has-spiral-arms_yes'][sim_model.model.mutual_info > sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)
ax.hist(
    sim_model.catalog['has-spiral-arms_yes'][sim_model.model.mutual_info < sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)

In [None]:
fig, ax = plt.subplots()
ax.hist(
    sim_model.catalog['redshift'][sim_model.model.mutual_info > sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)
ax.hist(
    sim_model.catalog['redshift'][sim_model.model.mutual_info < sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)

In [None]:
fig, ax = plt.subplots()
ax.hist(
    sim_model.catalog['merging_major-disturbance'][sim_model.model.mutual_info > sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)
ax.hist(
    sim_model.catalog['merging_major-disturbance'][sim_model.model.mutual_info < sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)

In [None]:
for merger_label in merger_strs:
    print('\n' + merger_label)
    print(sim_model.model.mutual_info[sim_model.catalog[merger_label] > 1].mean())
    print(sim_model.model.mutual_info[sim_model.catalog[merger_label] == 1].mean())

In [None]:
data = [
    {'Volunteer Response': 'Merging', 'Mean Mutual Information': sim_model.model.mutual_info[sim_model.catalog['merging_both-v1'] > 1].mean()},
    {'Volunteer Response': 'Major Disturbance', 'Mean Mutual Information': sim_model.model.mutual_info[sim_model.catalog['merging_major-disturbance'] > 1].mean()},
    {'Volunteer Response': 'Minor Disturbance', 'Mean Mutual Information': sim_model.model.mutual_info[sim_model.catalog['merging_minor-disturbance'] > 1].mean()},
    {'Volunteer Response': 'No Disturbance', 'Mean Mutual Information': sim_model.model.mutual_info[sim_model.catalog['merging_none'] > 20].mean()}
    ]

In [None]:
df = pd.DataFrame(data)

In [None]:
df.head()

In [None]:
fig, ax = plt.subplots()
ax = sns.barplot(data=df, y='Volunteer Response', x='Mean Mutual Information', ax=ax)
ax.set_xlim([0.2, 0.36])
fig.tight_layout()

In [None]:
print(sim_model.model.mutual_info[(sim_model.catalog['merging_minor-disturbance'] + sim_model.catalog['merging_major-disturbance']) > 0].mean())
print(sim_model.model.mutual_info[(sim_model.catalog['merging_minor-disturbance'] + sim_model.catalog['merging_major-disturbance']) == 0].mean())

In [None]:
plt.hist(sim_model.catalog['merging_tidal-debris-v1'], bins=40)