In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
!pwd

/home/walml/repos/zoobot/notebooks/multiq


In [3]:
!git pull

Already up-to-date.


In [4]:
import matplotlib

In [5]:
import os
import logging
import argparse
import glob
import json
from collections import Counter

import numpy as np
from matplotlib.ticker import StrMethodFormatter

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn import metrics
import tensorflow as tf
import pandas as pd
from astropy.table import Table  # for NSA
from astropy import units as u
from sklearn.metrics import confusion_matrix, roc_curve
from PIL import Image
from scipy.stats import binom
from IPython.display import display, Markdown

from sklearn.metrics import accuracy_score, mean_squared_error, mean_absolute_error

from shared_astro_utils import astropy_utils, matching_utils
from zoobot.estimators import make_predictions, bayesian_estimator_funcs
from zoobot.tfrecord import read_tfrecord
from zoobot.uncertainty import discrete_coverage
from zoobot.estimators import input_utils, losses
from zoobot.tfrecord import catalog_to_tfrecord
from zoobot.active_learning import metrics, simulated_metrics, acquisition_utils, check_uncertainty, simulation_timeline, run_estimator_config
from zoobot import label_metadata

In [6]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)


In [7]:
gpus

[]

In [8]:
os.chdir('/home/walml/repos/zoobot')

In [9]:
# start = 50
# end = 60
# start = 0
# end=5

### Load the (latest) model under `model_name` folder in `results_dir`

In [10]:
# catalog_loc = 'data/latest_labelled_catalog.csv
catalog_loc = 'data/decals/decals_master_catalog.csv'
# catalog_loc = 'data/gz2/gz2_master_catalog.csv'

catalog = pd.read_csv(catalog_loc, dtype={'subject_id': str})  # original catalog

catalog['file_loc'] = catalog['local_png_loc'].apply(lambda x: '/media/walml/beta/decals/png_native' + x[32:])
catalog['file_loc'] = catalog['local_png_loc'].apply(lambda x: '/media/walml/beta/decals/png_native' + x[32:])


# catalog_loc = 'data/decals/temp_calibration_catalog.csv'
# catalog = pd.read_csv(catalog_loc, dtype={'subject_id': str})  # original catalog


In [11]:


# Figures will be saved to here

analysis_dir = 'analysis/multiquestion'
# save_dir = f'{analysis_dir}/{model_name}'
# if not os.path.exists(save_dir):
#     os.mkdir(save_dir)

# results_dir = '/home/walml/repos/zoobot/data/experiments/live/no_cutouts/iteration_0'
# results_dir = '/home/walml/repos/zoobot/results/smooth_or_featured_offline'
# results_dir = '/home/walml/repos/zoobot/results/debug'
# results_dir = '/home/walml/repos/zoobot/results/latest/effnetB0_decals_mf_244px_256init_10k'
# results_dir = '/home/walml/repos/zoobot/results/latest/latest_offline_full'

# two identical models, trained at different times
# results_dir = '/home/walml/repos/zoobot/results/latest/latest_offline_10k_x2val_b256'
# results_dir = '/home/walml/repos/zoobot/results/latest/latest_offline_10k_x2val'

# results_dir = '/home/walml/repos/zoobot/results/latest/latest_dirichlet_c100_active_n3_layers'
# results_dir = '/home/walml/repos/zoobot/results/all_featp5_facep5_sim_256_arc_final_0'



# decals cols
questions = label_metadata.decals_questions
version = 'decals'
label_cols = label_metadata.decals_label_cols

# gz2 cols
# questions = label_metadata.gz2_questions
# version = 'gz2'
# label_cols = label_metadata.gz2_label_cols



schema = losses.Schema(label_cols, questions, version=version)

batch_size = 8
initial_size = 300
# initial_size = 128
crop_size = int(initial_size * 0.75)
# crop_size = 128
final_size = 224
channels = 3

n_samples = 5

# if loading single test tfrecord
# tfrecord_locs = [f'data/decals/shards/multilabel_{img_size}/eval/s{initial_size}_shard_0.tfrecord']
# tfrecord_locs = ['data/decals/shards/multilabel_master_256/train/s256_shard_0.tfrecord']
# tfrecord_locs = glob.glob(f'/media/walml/beta/decals/multilabel_master_{initial_size}/train/*.tfrecord')[start:end]
# tfrecord_locs = glob.glob(f'/home/walml/repos/zoobot/data/decals/shards/decals_multiq_{initial_size}_sim_init_2500_featp4/train_shards/*.tfrecord')[start:end]

# for all labelled decals galaxies
# train_locs = glob.glob(f'/home/walml/repos/zoobot/data/decals/shards/decals_multiq_{initial_size}/train_shards/*.tfrecord')
# eval_locs = glob.glob(f'/home/walml/repos/zoobot/data/decals/shards/decals_multiq_{initial_size}/eval_shards/*.tfrecord')
# tfrecord_locs = train_locs + eval_locs

# for all decals galaxies after filter
# train_locs = glob.glob(f'/home/walml/repos/zoobot/data/decals/shards/multilabel_master_filtered_{initial_size}/train/*.tfrecord')
# eval_locs = glob.glob(f'/home/walml/repos/zoobot/data/decals/shards/multilabel_master_filtered_{initial_size}/eval/*.tfrecord')
# tfrecord_locs = eval_locs

# tfrecord_locs = train_locs[:1]

# for calibration dr5 galaxies
# eval_locs = glob.glob(f'/home/walml/repos/zoobot/data/decals/shards/temp_calibration_shards_feat/train/*.tfrecord')
# tfrecord_locs = eval_locs

# for 10k labelled/filtered GZ2 galaxies
# train_locs = glob.glob(f'/home/walml/repos/zoobot/data/gz2/shards/all_featp5_facep5_sim_2p5_{initial_size}/train_shards/*.tfrecord')
# eval_locs = glob.glob(f'/home/walml/repos/zoobot/data/gz2/shards/all_featp5_facep5_sim_2p5_{initial_size}_/eval_shards/*.tfrecord')
# # tfrecord_locs = train_locs + eval_locs
# tfrecord_locs = eval_locs

# # for 10k UNFILTERED GZ2 galaxies
# eval_locs = glob.glob(f'/home/walml/repos/zoobot/data/gz2/shards/all_sim_2p5_unfiltered_{initial_size}/eval_shards/*.tfrecord')
# # tfrecord_locs = train_locs + eval_locs
# tfrecord_locs = eval_locs


# latest arc eval set
# eval_locs = glob.glob(f'/home/walml/repos/zoobot/data/gz2/all_actual_sim_2p5_unfiltered_{initial_size}_arc_eval_shards/eval_shards/*.tfrecord')
# # # tfrecord_locs = train_locs + eval_locs
# tfrecord_locs = eval_locs

eval_locs = glob.glob(f'/home/walml/repos/zoobot/results/temp/decals_n2_allq_m0_eval_shards/*.tfrecord')
# # # tfrecord_locs = train_locs + eval_locs
tfrecord_locs = eval_locs


print(tfrecord_locs)
eval_config = run_estimator_config.get_eval_config(tfrecord_locs, label_cols, batch_size, initial_size, final_size, channels)
# print(eval_config.greyscale)
# print(eval_config.permute_channels)
eval_config.drop_remainder = False
dataset = input_utils.get_input(config=eval_config)

feature_spec = input_utils.get_feature_spec({'id_str': 'string'})
id_str_dataset = input_utils.get_dataset(tfrecord_locs, feature_spec, batch_size=1, shuffle=False, repeat=False, drop_remainder=False)
id_strs = [str(d['id_str'].numpy().squeeze())[2:-1] for d in id_str_dataset]
id_strs[:5]

# n = 0
# for batch in id_str_dataset:
#     for id_str in batch:
#         n+=1
# print(n)

# counter = Counter()
# n = 0
# for g_batch, y_batch in dataset:
# #     for g in g_batch:
# #         counter[g.numpy().sum()] += 1
#         n+=tf.shape(g_batch)[0]
# print(n)



{smooth-or-featured, indices 0 to 2, asked after None: (0, 2), disk-edge-on, indices 3 to 4, asked after smooth-or-featured_featured-or-disk, index 1: (3, 4), has-spiral-arms, indices 5 to 6, asked after disk-edge-on_no, index 4: (5, 6), bar, indices 7 to 9, asked after disk-edge-on_no, index 4: (7, 9), bulge-size, indices 10 to 14, asked after disk-edge-on_no, index 4: (10, 14), how-rounded, indices 15 to 17, asked after smooth-or-featured_smooth, index 0: (15, 17), edge-on-bulge, indices 18 to 20, asked after disk-edge-on_yes, index 3: (18, 20), spiral-winding, indices 21 to 23, asked after has-spiral-arms_yes, index 5: (21, 23), spiral-arm-count, indices 24 to 29, asked after has-spiral-arms_yes, index 5: (24, 29), merging, indices 30 to 33, asked after None: (30, 33)}
['/home/walml/repos/zoobot/results/temp/decals_n2_allq_m0_eval_shards/s300_shard_2.tfrecord', '/home/walml/repos/zoobot/results/temp/decals_n2_allq_m0_eval_shards/s300_shard_1.tfrecord', '/home/walml/repos/zoobot/resu



['J101523.33+031030.9',
 'J112413.94+245257.5',
 'J160105.65-004226.9',
 'J112748.26+264900.1',
 'J155214.79+164831.8']

In [12]:
# fig, axes = plt.subplots(ncols=3)
# n=0
# for images, labels in dataset.take(3):
#     image = images[0]
#     axes[n].imshow(image.numpy().squeeze())
#     n+= 1

In [13]:
len(id_strs)

10000

In [106]:
# # if loading png

# print(catalog['file_loc'])
# assert all([os.path.isfile(x) for x in catalog['file_loc']])
# filenames = tf.constant(list(catalog['file_loc']), dtype=tf.string)
# dataset = tf.data.Dataset.from_tensor_slices(filenames)

# def parse_image(im):
#     im = tf.image.decode_png(im, channels=channels)
#     im = tf.image.convert_image_dtype(im, tf.float32)
#     im = tf.image.resize(im, [initial_size, initial_size])
#     return im


# # for im in dataset.take(1):
# #     print(im)

# # assert False

# dataset = dataset.map(tf.io.read_file)
# dataset = dataset.map(parse_image)

# config = default_estimator_params.get_eval_config(['do not use'], label_cols, batch_size, initial_size, final_size, channels)

# dataset = dataset.batch(batch_size)

# # for batch in dataset.take(1):
# #     print(tf.shape(batch))
# # #     plt.imshow(batch[0])


# # dataset = dataset.map(lambda x: check_shape(x))

# dataset = dataset.map(lambda x: input_utils.preprocess_images(x, config))


# # for batch in dataset.take(1):
# #     print(tf.shape(batch))
# #     plt.imshow(batch[0].numpy().squeeze())


# id_strs = catalog['iauname']



In [107]:
for images, _ in dataset.take(1):
    print(images.shape)

(8, 300, 300, 1)


In [108]:
model = run_estimator_config.get_model(schema, initial_size, crop_size, final_size)
# checkpoint_dir = 'results/debug_300/models/final'
# checkpoint_dir = 'results/latest/test/models/final'
# checkpoint_dir = 'results/temp/latest_offline_sim_unfiltered/in_progress'
# checkpoint_dir = 'results/latest/latest_dirichlet_baseline_a/final'

# checkpoint_dir = 'results/temp/gz2_all_q_warm_active/iteration_0/estimators/models/model_0/final'
# checkpoint_dir = 'results/temp/gz2_all_q_warm_baseline/iteration_2/estimators/models/model_2/final'
# checkpoint_dir = 'results/temp/gz2_all_q_opt4_warm_active/iteration_2/estimators/models/model_2/final'

# n2, not necc. the best model
# checkpoint_dir = 'results/temp/decals_n2_allq_m0/in_progress'

# n2, not necc. the best model
# checkpoint_dir = 'results/temp/decals_retired_allq_m0/models/final'

checkpoint_dir = 'results/temp/decals_retired_allq_m0/models/final'

# iteration 0!
# checkpoint_dir = f'{results_dir}/iteration_0/estimators/models'
# checkpoint_dir = f'{results_dir}/models'
# checkpoint_dir = f'{results_dir}/final'

print(checkpoint_dir)
load_status = model.load_weights(checkpoint_dir)




results/temp/decals_retired_allq_m0/models/final


In [109]:
load_status.assert_nontrivial_match()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f3d1581efd0>

In [110]:
load_status.assert_existing_objects_matched()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f3d1581efd0>

In [111]:
model.predict(dataset.take(5))[:10, :2]

array([[20.358316, 10.980032],
       [51.719475, 33.65394 ],
       [72.833336,  5.651007],
       [ 8.864109, 22.148716],
       [80.322975, 12.311265],
       [ 9.764767, 13.030967],
       [ 8.280451, 11.811539],
       [21.803785,  4.809663],
       [16.164717, 34.592922],
       [20.656013,  9.512716]], dtype=float32)

In [112]:
model.predict(dataset.take(5))[:10, 2:4]

array([[ 3.312257 ,  2.8522325],
       [ 8.755309 , 43.133865 ],
       [ 6.7384024,  6.765002 ],
       [ 2.7138705,  2.0403795],
       [ 8.486013 , 14.931778 ],
       [ 2.3010445,  2.8900008],
       [ 3.949268 ,  1.7112068],
       [ 4.211546 ,  2.9462953],
       [ 5.5215425, 62.367363 ],
       [ 2.8931277,  2.719283 ]], dtype=float32)

In [113]:
# model.evaluate(dataset)  # unfiltered = 2.04 73 secs for four q, ...

In [114]:
# assert False

In [115]:
# %timeit predictions = model.predict(dataset)

In [116]:
dataset

<MapDataset shapes: ((None, 300, 300, 1), (None, 34)), types: (tf.float32, tf.float32)>

In [117]:
# # copied from trust_the_model.ipynb
# def show_galaxies(galaxies, scale=3, nrows=6, ncols=6):
#     fig = plt.gcf()

#     plt.figure(figsize=(scale * nrows, scale * ncols * 1.025))
#     gs1 = gridspec.GridSpec(nrows, ncols)
#     gs1.update(wspace=0.0, hspace=0.0)
#     galaxy_n = 0
#     for row_n in range(nrows):
#         for col_n in range(ncols):
#             galaxy = np.squeeze(galaxies[galaxy_n])
#             ax = plt.subplot(gs1[row_n, col_n])
#             ax.imshow(galaxy)
# #             ax.text(10, 20, 'Smooth = {:.2f}'.format(galaxy['smooth-or-featured_smooth_fraction']), fontsize=12, color='r')
# #             ax.text(10, 50, r'$\rho = {:.2f}$, Var ${:.3f}$'.format(galaxy['median_prediction'], 3*galaxy['predictions_var']), fontsize=12, color='r')
# #             ax.text(10, 80, '$L = {:.2f}$'.format(galaxy['bcnn_likelihood']), fontsize=12, color='r')
#             ax.axis('off')
#             galaxy_n += 1
# #     print('Mean L: {:.2f}'.format(df[:nrows * ncols]['bcnn_likelihood'].mean()))
#     fig = plt.gcf()
# #     fig.tight_layout()
#     return fig

# for images, labels in dataset.take(1):
#     print(images.shape)
#     print(labels.shape)
    
# # plt.hist(images.numpy()[0].flatten())

# _ = show_galaxies(images.numpy(), nrows=2, ncols=2)

In [118]:
# pre_model = tf.keras.models.Sequential()
# pre_model.add(tf.keras.layers.Input(shape=(initial_size, initial_size, 1)))
# run_estimator_config.add_preprocessing_layers(pre_model, crop_size, final_size)

In [119]:
# output = pre_model.predict(dataset.take(1))
# print(output.shape)

In [120]:
# _ = show_galaxies(output, nrows=2, ncols=2)

In [121]:
# checkpoint_dir

In [122]:
# model = run_estimator_config.get_model(schema, initial_size, crop_size, final_size, weights_loc=checkpoint_dir)

In [123]:
# model.predict(dataset.take(1))

In [124]:
# from zoobot.estimators import efficientnet

# input_shape = (final_size, final_size, 1)
# effnet = efficientnet.EfficientNet_custom_top(
#     schema=schema,
#     input_shape=input_shape,
#     get_effnet=efficientnet.EfficientNetB0
#     # further kwargs will be passed to get_effnet
#     # dropout_rate=dropout_rate,
#     # drop_connect_rate=drop_connect_rate
# )
# # effnet.load_weights(checkpoint_dir)

# model = tf.keras.models.Sequential()
# model.add(pre_model)
# model.add(effnet)

# # model.load_weights(checkpoint_dir)

In [125]:
# run_config = run_estimator_config.get_run_config(initial_size, final_size, crop_size, False, '', train_locs, eval_locs, 12, schema, batch_size)


In [126]:
# import sqlite3

In [127]:
# db_loc = 'data/experiments/live/latest/iteration_0/iteration.db'
# db = sqlite3.connect(db_loc)

In [128]:
# df = database.get_all_subjects_df(db)

In [129]:
# df

In [130]:
# unlabelled_locs = glob.glob(f'/home/walml/repos/zoobot/data/gz2/shards/all_featp5_facep5_sim_{initial_size}/*.tfrecord')

In [131]:
# id_str_dataset = input_utils.get_dataset(unlabelled_locs, feature_spec, batch_size=1, shuffle=False, repeat=False, drop_remainder=False)
# id_strs = [str(d['id_str'].numpy().squeeze())[2:-1] for d in id_str_dataset]
# len(id_strs)

In [132]:
# from zoobot.active_learning import database

In [133]:
# predictions = database.make_predictions_on_tfrecord([unlabelled_locs[0]], model, run_config, db, n_samples, initial_size)

In [134]:
# unlabelled_subjects, samples = predictions

In [135]:
# len(unlabelled_subjects)

In [136]:
# unlabelled_subjects[0]

In [137]:
# unlabelled_subjects[0]['predictions'].shape

In [138]:
# unlabelled_config = input_utils.InputConfig(
#         name='eval',
#         tfrecord_loc=[unlabelled_locs[0]],
#         label_cols=[],
#         stratify=False,
#         shuffle=False,  # see above
#         repeat=False,
#         drop_remainder=False,
#         stratify_probs=None,
#         geometric_augmentation=False,
#         photographic_augmentation=False,
#         contrast_range=(0.98, 1.02),
#         batch_size=batch_size,
#         initial_size=initial_size,
#         final_size=final_size,
#         channels=channels,
#         greyscale=True,
#         zoom_central=False  # SMOOTH MODE
#         # zoom_central=True  # BAR MODE
#     )
# unlabelled_dataset = input_utils.get_input(config=unlabelled_config)

In [139]:
# predictions_direct = np.stack([model.predict(unlabelled_dataset) for n in range(n_samples)], axis=-1)

In [140]:
# predictions_direct.shape

In [141]:
# samples[0, 0]

In [142]:
# unlabelled_subjects[5]['predictions'][0:2]

In [143]:
# predictions_direct[5, 0:2]

In [144]:
# df

In [145]:
# schema.answers[2].text, schema.answers[3].text

In [146]:
# schema.answers[4].text, schema.answers[5].text

In [147]:
# fig, axes = plt.subplots(nrows=5, sharex=True)
# for subject_n in range(5):
#     ax = axes[subject_n]
#     ax.hist(predictions_direct[subject_n, 4] / predictions_direct[subject_n, 5], alpha=.5)
#     ax.hist(unlabelled_subjects[subject_n]['predictions'][4] / unlabelled_subjects[subject_n]['predictions'][5], alpha=.5)


In [148]:
# fig, axes = plt.subplots(nrows=5, sharex=True, figsize=(4, 12))
# for subject_n in range(5):
#     ax = axes[subject_n]
#     ax.imshow(np.array(Image.open(df['file_loc'][subject_n])))

In [149]:
# id_strs[0], unlabelled_subjects[0]['id_str']

In [150]:
predictions = np.stack([model.predict(dataset) for n in range(n_samples)], axis=-1)

In [151]:
predictions.shape

(10000, 34, 5)

In [152]:
predictions[0, 0]

array([11.27927 , 14.975777, 17.126745, 16.46605 , 14.626597],
      dtype=float32)

In [153]:
predictions[0, :2, :]

array([[11.27927  , 14.975777 , 17.126745 , 16.46605  , 14.626597 ],
       [18.48789  , 15.061994 , 12.538382 , 11.535164 , 13.7275915]],
      dtype=float32)

In [154]:
predictions[12, 0]

array([49.97901 , 59.155167, 58.436424, 69.242615, 68.42684 ],
      dtype=float32)

In [155]:
predictions[12, 4]

array([44.34936 , 47.56811 , 37.922855, 62.31121 , 37.526535],
      dtype=float32)

In [156]:
# plt.hist(np.std(predictions[:, 0, :], axis=-1), bins=30)
# plt.xlabel('Mean smooth/featured std')

In [157]:
# for batch_x, batch_y in dataset:
#     print(batch_y.numpy())
#     break

In [158]:
# labels = np.concatenate([batch_y for (_, batch_y) in test_dataset], axis=0)
# labels.shape

In [159]:
# fig, ax = plt.subplots()
# ax.hist(predictions[:, 0], alpha=0.5, label='Predictions', density=True)
# ax.hist(labels[:, 0] / labels[:, :2].sum(axis=1), alpha=0.5, label='Labels', density=True)
# plt.legend()

In [160]:
# print(predictions[:, 0].min(), labels[:, 0].min())
# print(predictions[:, 0].max(), labels[:, 0].max())

In [161]:
# acquisitions = acquisition_utils.mutual_info_acquisition_func_multiq(predictions, schema, retirement=40)
# acquisitions.shape
# acquisitions

# single_q_acquisitions = np.array(acquisition_utils.mutual_info_acquisition_func(predictions[:, 0], expected_votes=40))
# single_q_acquisitions[:5], acquisitions[0, :5], acquisitions[1, :5]  # smooth mutual acq should be identical, for both answers by symmmetry
# acquisitions[2, :5], acquisitions[3, :5]  # has-spiral-arms also by symmetry
# acquisitions[4, :5], acquisitions[5, :5], acquisitions[6, :5]  # but for spiral winding there are 3 answers so *not* identical
# predictions.shape, acquisitions.shape, len(id_strs)

In [162]:
predictions.shape, len(id_strs)

((10000, 34, 5), 10000)

In [163]:
predictions[0][0]

array([11.27927 , 14.975777, 17.126745, 16.46605 , 14.626597],
      dtype=float32)

In [164]:
# def prediction_to_row(prediction, id_str):
#     row = {
#         'id_str': id_str
#     }
#     for n, col in enumerate(label_cols):
#         answer = label_cols[n]

# #         row[answer + '_acquisition'] = acquisition[n]
#         row[answer + '_prediction_mean'] = float(prediction[n].mean())
# #         row[answer + '_acquisition'] = acquisition[n]
# #         row['total_acquisition'] = acquisition.sum()
#     return row

def prediction_to_row(prediction, id_str):
    row = {
        'id_str': id_str
    }
    for n, col in enumerate(label_cols):
        answer = label_cols[n]
        row[answer + '_concentration'] = json.dumps(list(prediction[n].astype(float)))
#         row[answer + '_acquisition'] = acquisition[n]
        row[answer + '_concentration_mean'] = float(prediction[n].mean())
#         row[answer + '_acquisition'] = acquisition[n]
#         row['total_acquisition'] = acquisition.sum()
    return row

In [165]:
# def all_to_row(prediction, acquisition, id_str):
#     row = {
#         'id_str': id_str
#     }
#     for n, col in enumerate(label_cols):
#         answer = label_cols[n]
#         row[answer + '_prediction'] = prediction[n]
# #         row[answer + '_acquisition'] = acquisition[n]
#         row[answer + '_prediction_mean'] = float(prediction[n].mean())
# #         row[answer + '_acquisition'] = acquisition[n]
# #         row['total_acquisition'] = acquisition.sum()
#     return row

In [166]:
data = [prediction_to_row(predictions[n], id_strs[n]) for n in range(len(predictions))]
# data = [all_to_row(predictions[n], acquisitions[n], id_strs[n]) for n in range(len(predictions))]
predictions_df = pd.DataFrame(data)

In [167]:
len(predictions_df)

10000

In [168]:
predictions_df.head()

Unnamed: 0,id_str,smooth-or-featured_smooth_concentration,smooth-or-featured_smooth_concentration_mean,smooth-or-featured_featured-or-disk_concentration,smooth-or-featured_featured-or-disk_concentration_mean,smooth-or-featured_artifact_concentration,smooth-or-featured_artifact_concentration_mean,disk-edge-on_yes_concentration,disk-edge-on_yes_concentration_mean,disk-edge-on_no_concentration,...,spiral-arm-count_cant-tell_concentration,spiral-arm-count_cant-tell_concentration_mean,merging_none_concentration,merging_none_concentration_mean,merging_minor-disturbance_concentration,merging_minor-disturbance_concentration_mean,merging_major-disturbance_concentration,merging_major-disturbance_concentration_mean,merging_merger_concentration,merging_merger_concentration_mean
0,J101523.33+031030.9,"[11.27927017211914, 14.975776672363281, 17.126...",14.894887,"[18.487890243530273, 15.061993598937988, 12.53...",14.270205,"[3.4109861850738525, 3.6765503883361816, 3.380...",3.37145,"[2.5431180000305176, 2.891010284423828, 2.7779...",2.771492,"[73.60625457763672, 58.1546630859375, 51.68944...",...,"[6.270047664642334, 8.276326179504395, 12.5538...",9.155304,"[16.36289405822754, 17.06951141357422, 14.9404...",13.893324,"[7.7284440994262695, 8.662688255310059, 7.3509...",7.001131,"[2.550448417663574, 2.3602328300476074, 1.9933...",2.03975,"[3.5474438667297363, 4.097710609436035, 4.0339...",4.585111
1,J112413.94+245257.5,"[55.1696891784668, 52.06551742553711, 56.90679...",54.522839,"[30.61265754699707, 31.842456817626953, 28.707...",31.032757,"[8.016695022583008, 8.519285202026367, 8.50559...",8.310714,"[37.72706985473633, 42.812862396240234, 31.658...",38.771183,"[6.244181156158447, 6.101380825042725, 6.82457...",...,"[16.076433181762695, 14.546131134033203, 18.80...",16.023624,"[92.08403778076172, 94.09244537353516, 88.3727...",92.109886,"[6.697113513946533, 6.890059947967529, 6.08404...",6.518374,"[1.4896457195281982, 1.5561150312423706, 1.476...",1.500419,"[1.2816667556762695, 1.238523006439209, 1.4674...",1.310111
2,J160105.65-004226.9,"[72.8885269165039, 62.529762268066406, 68.5110...",67.38269,"[6.07484245300293, 6.279238224029541, 5.604297...",5.960749,"[7.311710357666016, 6.899960994720459, 6.60113...",6.780965,"[8.313241958618164, 6.744441509246826, 6.20512...",6.890455,"[31.76457977294922, 30.494199752807617, 26.252...",...,"[18.716087341308594, 17.575172424316406, 20.37...",18.915979,"[84.1072006225586, 77.57942199707031, 71.59207...",76.379982,"[8.765542984008789, 8.74241828918457, 10.08154...",9.079222,"[1.8897054195404053, 1.8869670629501343, 1.846...",1.868123,"[1.0784064531326294, 1.0940051078796387, 1.112...",1.105595
3,J112748.26+264900.1,"[10.218161582946777, 9.084182739257812, 8.2000...",8.763184,"[19.81637191772461, 20.72105598449707, 20.2951...",20.260975,"[3.1757473945617676, 2.962010145187378, 2.6536...",2.941709,"[1.7829853296279907, 2.0597589015960693, 1.996...",1.903531,"[68.4912109375, 75.38482666015625, 81.52665710...",...,"[13.298824310302734, 11.231612205505371, 10.60...",11.610242,"[18.421968460083008, 20.00270652770996, 21.238...",18.143351,"[11.918533325195312, 12.801078796386719, 12.58...",11.604419,"[7.076321125030518, 7.527082443237305, 6.91198...",7.214477,"[1.628361701965332, 1.37748384475708, 1.325240...",1.466835
4,J155214.79+164831.8,"[70.9332275390625, 78.71728515625, 64.24332427...",76.407028,"[11.820240020751953, 10.94378662109375, 10.623...",11.463373,"[8.661657333374023, 9.23896598815918, 7.588526...",9.027744,"[16.46335220336914, 16.7069091796875, 15.18986...",18.117764,"[14.67911148071289, 14.895379066467285, 12.161...",...,"[17.69595718383789, 16.06884002685547, 17.0653...",16.611118,"[84.24547576904297, 90.71809387207031, 78.3669...",89.167542,"[8.992124557495117, 9.326862335205078, 8.07801...",8.414492,"[1.755096673965454, 1.8085262775421143, 1.6271...",1.710882,"[1.2377533912658691, 1.1164424419403076, 1.286...",1.155625


In [169]:
# catalog['dr7objid']

In [170]:
catalog['iauname'] = catalog['iauname'].astype(str)
# predictions_df['iauname'] = predictions_df['id_str'].astype(str)
# catalog['id_str'] = catalog['dr7objid'].apply(lambda x: 'dr7objid_' + str(x))


In [171]:
# predictions_df['id_str'].sort_values().values

In [172]:
# catalog['iauname'].sort_values().values
print(len(catalog))

309398


In [173]:
predictions_df['id_str']

0       J101523.33+031030.9
1       J112413.94+245257.5
2       J160105.65-004226.9
3       J112748.26+264900.1
4       J155214.79+164831.8
               ...         
9995    J002315.35-011046.1
9996    J100818.28-002136.7
9997    J120818.30+240608.8
9998    J143428.37+231812.9
9999    J165006.69+232633.6
Name: id_str, Length: 10000, dtype: object

In [174]:
catalog['id_str']

0         J094651.40-010228.5
1         J094630.85-004554.5
2         J094631.59-005917.7
3         J094744.18-004013.4
4         J094751.74-003242.0
                 ...         
309393    J143642.06-012542.2
309394    J143539.18-012924.0
309395    J230924.60-001458.1
309396    J235101.08-100042.7
309397    J235320.91-103238.7
Name: id_str, Length: 309398, dtype: object

In [175]:
df = pd.merge(catalog, predictions_df, how='inner', on='id_str')
print(len(df), len(predictions_df))
assert len(df) == len(predictions_df)

10000 10000


In [176]:
# for q in schema.questions:
#     a = q.answers[0]
#     print(a.text, mean_squared_error(df[a.text + '_fraction'], df[a.text + '_prediction_mean']))

In [177]:
# for q in schema.questions:
#     a = q.answers[0]
#     print(a.text, mean_absolute_error(df[a.text + '_fraction'], df[a.text + '_prediction_mean']))

In [178]:
# fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 12))
# axes = [ax for row in axes for ax in row]
# for n, q in enumerate(schema.questions):
#     a = q.answers[0]
#     sns.scatterplot(data=df, x=a.text + '_fraction', y=a.text + '_prediction_mean', ax=axes[n], alpha=0.3)
# fig.tight_layout()

In [179]:
# for q in schema.questions:
#     a = q.answers[0]
#     print(a.text, mean_squared_error(df[a.text + '_fraction'], df[a.text + '_prediction_mean']))

In [180]:
# assert False

In [181]:
# df.to_csv('temp/calibration_predictions.csv', index=False)
# df.to_csv('temp/10k_model_a_predictions.csv', index=False)
# df.to_csv('temp/10k_model_b_predictions.csv', index=False)
# df.to_csv('temp/dirichlet_concentrations_arc_al.csv', index=False)
# df.to_csv('temp/dirichlet_concentrations_arc_bl.csv', index=False)
# df.to_csv('temp/dirichlet_concentrations_arc_b.csv', index=False)
# df.to_csv('temp/gz2_all_2p5_eval.csv', index=False)
# df.to_csv('temp/gz2_filtered_2p5_eval.csv', index=False)
# df.to_csv('temp/gz2_allq_it0_m0.csv', index=False)
# df.to_csv('temp/gz2_allq_it0_m1.csv', index=False)

# df.to_csv('temp/gz2_allq_opt4_it2_m2.csv', index=False)
# df.to_csv('temp/gz2_allq_baseline_it2_m2.csv', index=False)

# df.to_csv('temp/decals_n2_allq_m0.csv', index=False)
df.to_csv('temp/decals_retired_allq_m0.csv', index=False)


In [182]:
assert False

AssertionError: 

In [14]:
# copied from trust_the_model.ipynb
def show_galaxies(df, scale=3, nrows=6, ncols=6):
    fig = plt.gcf()

    plt.figure(figsize=(scale * nrows, scale * ncols * 1.025))
    gs1 = gridspec.GridSpec(nrows, ncols)
    gs1.update(wspace=0.0, hspace=0.0)
    galaxy_n = 0
    for row_n in range(nrows):
        for col_n in range(ncols):
            galaxy = df.iloc[galaxy_n]
            image = Image.open(galaxy['file_loc'])
            ax = plt.subplot(gs1[row_n, col_n])
            ax.imshow(image)
#             ax.text(10, 20, 'Smooth = {:.2f}'.format(galaxy['smooth-or-featured_smooth_fraction']), fontsize=12, color='r')
#             ax.text(10, 50, r'$\rho = {:.2f}$, Var ${:.3f}$'.format(galaxy['median_prediction'], 3*galaxy['predictions_var']), fontsize=12, color='r')
#             ax.text(10, 80, '$L = {:.2f}$'.format(galaxy['bcnn_likelihood']), fontsize=12, color='r')
            ax.axis('off')
            galaxy_n += 1
#     print('Mean L: {:.2f}'.format(df[:nrows * ncols]['bcnn_likelihood'].mean()))
    fig = plt.gcf()
#     fig.tight_layout()
    return fig


In [15]:
df = pd.read_csv('temp/decals_retired_allq_m0.csv')

In [16]:
save_dir = 'results/temp/decals_allq_top_n'
save_top_n(df, schema, save_dir)

NameError: name 'save_top_n' is not defined

In [None]:
assert False  # moove later analysis elsewhere?

In [None]:
# fig, ax = plt.subplots()
# ax.scatter(predictions[:, 0], labels[:, 0] / labels[:, :2].sum(axis=1))
# plt.close()

In [None]:
matplotlib.get_backend()

In [None]:
# fig, ax = plt.subplots()
# ax.scatter(predictions[:, 4], labels[:, 4] / labels[:, 4:7].sum(axis=1))

In [None]:
# fig, ax = plt.subplots()
# ax.scatter(predictions[:, 5], labels[:, 5] / labels[:, 4:7].sum(axis=1))

In [None]:
# fig, ax = plt.subplots()
# ax.scatter(predictions[:, 6], labels[:, 6] / labels[:, 4:7].sum(axis=1))

In [None]:
# assert False

In [None]:
# to check that the right models have been loaded - should be around 40 for smooth, 0-40 for bars
# plt.hist(sim_model.total_votes), sim_model.total_votes.mean()

In [None]:

def galaxy_posterior_grid(df, schema):
    
    sns.set_context('paper', font_scale=1.5)
    
    scale = 1.5
    
    im_width = 2
    posterior_width = 3
    height = im_width
    
    n_galaxies = len(df)
    n_posteriors = len(schema.answers)
    
    fig = plt.figure(figsize=(scale * (im_width + posterior_width*n_posteriors), (scale * n_galaxies * height)))  # width, height format
    gs = gridspec.GridSpec(len(df) * height, im_width + posterior_width * len(schema.answers))  # y, x format
    image_axes = []
    posterior_axes = []  # (galaxy i.e. row, answer) shape
    
    # create the grid
    for galaxy_n in range(len(df)):
        y_slice = slice(galaxy_n*height, (galaxy_n+1)*height)
        image_axes.append(plt.subplot(gs[y_slice, :im_width]))
        
        temp_galaxy_axes = []
        for answer_n, answer in enumerate(schema.answers):
            x_slice = slice(im_width+answer_n*posterior_width, im_width+(answer_n+1)*posterior_width)
            temp_galaxy_axes.append(plt.subplot(gs[y_slice, x_slice]))
        posterior_axes.append(temp_galaxy_axes)
        
    
    # fill the images
    for ax_n, ax in enumerate(image_axes):
        plot_galaxy(df['file_loc'][ax_n], ax)
    
    # fill the posteriors
    for answer_n, answer in enumerate(schema.answers):
        samples, labels, total_votes = get_single_answer_data(df, answer)
        galaxy_axes = [axes[answer_n] for axes in posterior_axes]
        make_predictions.plot_samples(samples, labels, total_votes, fig, galaxy_axes, alpha=0.06)
    
    # fix x limits for comparison
    for row_n, axes in enumerate(posterior_axes):
        for answer_n, ax in enumerate(axes):
            ax.set_xlim([0, 50])
            if row_n == 0:
                ax.set_title(schema.answers[answer_n].text)
        
#     for n in range(len(labels)):
#         multiple_axes[n].set_ylabel(r'$p(k|N, D)$', visible=True)
#         multiple_axes[n].yaxis.set_visible(True)
#         single_axes[n].set_ylabel(r'$p(k|N, w)$', visible=True)
#         single_axes[n].yaxis.set_visible(True)
#         single_axes[n].yaxis.set_major_locator(plt.NullLocator())
#         multiple_axes[n].yaxis.set_major_locator(plt.NullLocator())
#         if n < len(labels) - 1:
#             single_axes[n].xaxis.set_major_locator(plt.NullLocator())
#             multiple_axes[n].xaxis.set_major_locator(plt.NullLocator())
    
# #     if QUESTION == 'bars':
# #         question = 'Bar'
# #     else:
# #         question = 'Smooth'
# #     single_axes[-1].set_xlabel(r"$k$ '{}' votes, of $N$ total".format(question))
# #     multiple_axes[-1].set_xlabel(r"$k$ '{}' votes, of $N$ total".format(question))
#     fig.tight_layout()

    
#     multiple_axes[0].legend(
#         loc='lower center', 
#         bbox_to_anchor=(0.5, 1.1),
#         ncol=1, 
#         fancybox=True, 
#         shadow=False
#     )

    fig.tight_layout()
    return fig

In [None]:
fig = galaxy_posterior_grid(df[:5], schema)
fig.savefig(save_dir + '/grid.pdf')
fig.savefig(save_dir + '/grid.png')

In [None]:

def custom_samples_with_galaxies(samples, labels, total_votes, png_locs):
    
    sns.set_context('paper', font_scale=1.5)
    
    im_width = 2
    single_width = 3
    multiple_width = 3
    height = im_width
    
    fig = plt.figure(figsize=(0.8 * len(labels) * height * 2., 0.8 * (im_width + single_width + multiple_width) * 1.75))
    gs = gridspec.GridSpec(len(labels) * height, im_width + single_width + multiple_width)  # y, x format
    image_axes = []
    single_axes = []
    multiple_axes = []
    for galaxy_n in range(len(labels)):
        x_slice = slice(galaxy_n*height, (galaxy_n+1)*height)
        image_axes.append(plt.subplot(gs[x_slice, :im_width]))
        single_axes.append(plt.subplot(gs[x_slice, im_width:im_width+single_width]))
        multiple_axes.append(plt.subplot(gs[x_slice, im_width+single_width:]))
    

#     fig, axes = plt.subplots(nrows=len(labels), figsize=(3, len(labels)*1.5), sharex=True)
    make_predictions.plot_samples(samples[:, :1], labels, total_votes, fig, single_axes, alpha=0.06)
    for ax in single_axes:
        ax.set_xlim([0, 50])

    make_predictions.plot_samples(samples, labels, total_votes, fig, multiple_axes, alpha=0.06)
    for ax in multiple_axes:
        ax.set_xlim([0, 50])
        
        
    for ax_n, ax in enumerate(image_axes):
        plot_galaxy(png_locs[ax_n], ax)
        
    
    for n in range(len(labels)):
        multiple_axes[n].set_ylabel(r'$p(k|N, D)$', visible=True)
        multiple_axes[n].yaxis.set_visible(True)
        single_axes[n].set_ylabel(r'$p(k|N, w)$', visible=True)
        single_axes[n].yaxis.set_visible(True)
        single_axes[n].yaxis.set_major_locator(plt.NullLocator())
        multiple_axes[n].yaxis.set_major_locator(plt.NullLocator())
        if n < len(labels) - 1:
            single_axes[n].xaxis.set_major_locator(plt.NullLocator())
            multiple_axes[n].xaxis.set_major_locator(plt.NullLocator())
    
#     if QUESTION == 'bars':
#         question = 'Bar'
#     else:
#         question = 'Smooth'
#     single_axes[-1].set_xlabel(r"$k$ '{}' votes, of $N$ total".format(question))
#     multiple_axes[-1].set_xlabel(r"$k$ '{}' votes, of $N$ total".format(question))
    fig.tight_layout()

    single_axes[0].legend(
        loc='lower center', 
        bbox_to_anchor=(0.5, 1.1),
        ncol=1, 
        fancybox=True, 
        shadow=False
    )
    
    multiple_axes[0].legend(
        loc='lower center', 
        bbox_to_anchor=(0.5, 1.1),
        ncol=1, 
        fancybox=True, 
        shadow=False
    )

    
    fig.tight_layout()
    return fig

In [None]:
def plot_galaxy(image_loc, ax, n_examples=10, crop=0):
    im_size = 424
    im = Image.open(image_loc)
#     if QUESTION == 'bars':
#         crop = 120
#     else:
    crop = 35
    cropped_im = im.crop((crop, crop, 424 - crop, 424 - crop))
    ax.imshow(cropped_im)
    ax.grid(False)
    ax.get_yaxis().set_visible(False)
    ax.get_xaxis().set_visible(False)


In [None]:
question = 'has-spiral-arms'
answer = 'yes'
n = 5
samples, labels, total_votes = get_single_answer_data(question, answer)
png_locs = df['file_loc'][:n]

# catalog = sim_model.catalog[selected_slice]

_ = custom_samples_with_galaxies(samples, labels, total_votes, png_locs)

In [None]:
# fig, axes = plt.subplots(1, 10, figsize=(20, 12))
# for ax_n, ax in enumerate(axes):
#     plot_galaxy(sim_model.catalog.iloc[ax_n]['png_loc'], ax)

In [None]:
# 1 2

In [None]:
selected = slice(80, 73, -1)  # smooth

# selected = slice(0, 7)

if QUESTION == 'bars':
    selected = slice(0, 7)

In [None]:
# np.array(sim_model.model.samples)[selected, :]
# np.array(sim_model.labels)[selected]
# sim_model.catalog['smooth-or-featured_total-votes'][selected]

In [None]:
fig = custom_samples_with_galaxies(sim_model, selected)
# fig.savefig(os.path.join(save_dir, 'mc_model_{}.png'.format(len(np.array(sim_model.labels)[selected]))))
# fig.savefig(os.path.join(save_dir, 'mc_model_{}.pdf'.format(len(np.array(sim_model.labels)[selected]))))

In [None]:
# be sure to switch label in custom_samples before running this
# fig = custom_samples(np.array(single_sim_model.model.samples)[selected, :1], np.array(single_sim_model.labels)[selected], total_votes=single_sim_model.total_votes)
# fig.savefig(os.path.join(save_dir, 'single_model_{}.png'.format(len(np.array(sim_model.labels)[selected]))))
# fig.savefig(os.path.join(save_dir, 'single_model_{}.eps'.format(len(np.array(sim_model.labels)[selected]))))

In [None]:
sns.set(font_scale=1.2)
sns.set_style('white')

fig, ax = plt.subplots()
ungrouped_coverage_df = discrete_coverage.evaluate_discrete_coverage(
    sim_model.labels, 
    sim_model.bin_probs)
coverage_df = ungrouped_coverage_df.groupby('max_state_error').agg({'prediction': 'sum', 'observed': 'sum'}).reset_index()

ungrouped_single_coverage_df = discrete_coverage.evaluate_discrete_coverage(
    single_sim_model.labels, 
    single_sim_model.bin_probs)
single_coverage_df = ungrouped_single_coverage_df.groupby('max_state_error').agg({'prediction': 'sum', 'observed': 'sum'}).reset_index()


plt.plot(coverage_df['max_state_error'], coverage_df['prediction'], label='MC Model Expects')
plt.plot(single_coverage_df['max_state_error'], single_coverage_df['prediction'], label='Single Model Expects')
plt.plot(single_coverage_df['max_state_error'], coverage_df['observed'], 'k--', label='Actual')

ax.set_xlabel('Max Allowed Vote Error')
ax.set_ylabel('Galaxies Within Max Error')
ax.legend()
ax.xaxis.set_major_formatter(StrMethodFormatter('{x:.0f}'))  # must expect 'x' kw arg

ax.set_xlim([0, 15])
fig.tight_layout()
fig.savefig(os.path.join(save_dir, 'coverage_comparison_200_samples.png'))
fig.savefig(os.path.join(save_dir, 'coverage_comparison_200_samples.pdf'))

In [None]:
ungrouped_coverage_df.to_csv(os.path.join(save_dir, QUESTION + '_ungrouped_coverage_df.csv'), index=False)
ungrouped_single_coverage_df.to_csv(os.path.join(save_dir, QUESTION + '_ungrouped_coverage_df.csv'), index=False)

In [None]:
coverage_df['error'] = coverage_df['prediction'] - coverage_df['observed']
coverage_df['relative_error'] = coverage_df['error'] / coverage_df['observed']
coverage_df.to_csv(os.path.join(save_dir, QUESTION + '_coverage_df.csv'), index=False)
coverage_df.head(20)

In [None]:
single_coverage_df['error'] = single_coverage_df['prediction'] - single_coverage_df['observed']
single_coverage_df['relative_error'] = single_coverage_df['error'] / single_coverage_df['observed']
single_coverage_df.to_csv(os.path.join(save_dir, QUESTION + '_single_coverage_df.csv'), index=False)
single_coverage_df.head(20)

TODO - I might consider adding an MSE model as a comparison, to hopefully beat. I think this might be quite similar though. Ideally I can compare this with previous work somehow.

In [None]:
sns.set(font_scale=1.2)
sns.set_style('white')
fig, ax = plt.subplots()
ax.hist(sim_model.abs_rho_error, bins=25)
# ax.axvline(sim_model.mean_abs_rho_error, color='r') 
ax.set_xlim([0, 1.])
ax.set_ylabel('Galaxies')
ax.set_xlabel(r'| Expected $\hat{\rho}$ - observed vote fraction $\frac{k}{N}$ |')
fig.tight_layout()
fig.savefig(os.path.join(save_dir, 'difference_in_rho.png'))
fig.savefig(os.path.join(save_dir, 'difference_in_rho.pdf'))

In [None]:
sim_model.abs_rho_error.mean(), single_sim_model.abs_rho_error.mean()

In [None]:
np.sqrt(sim_model.mean_abs_rho_error), np.sqrt(single_sim_model.mean_abs_rho_error)  

In [None]:
np.sqrt(sim_model.mean_square_rho_error), np.sqrt(single_sim_model.mean_square_rho_error) # this is the rmse

In [None]:
# alpha = 0.3
# n_bins = 25

# # dummy for bins
# fig, ax = plt.subplots()
# _, bins, _  = ax.hist(sim_model.labels / sim_model.total_votes, bins=n_bins, alpha=alpha, label=r'Observed $\rho$')
# ax.hist(sim_model.mean_rho_prediction, bins=n_bins, alpha=alpha, label=r'Mean Rho Prediction $\hat{\rho}}$')
# # ax.hist(single_sim_model.mean_rho_prediction, bins=bins, alpha=alpha, label=r'Single Rho Prediction $\hat{\rho}}$')

# fig, ax = plt.subplots()
# sns.set(font_scale=1.)
# sns.set_style('white')

# ax.hist(sim_model.mean_rho_prediction, bins=bins, alpha=alpha, label=r'Mean Rho Prediction $\hat{\rho}}$')
# # ax.hist(single_sim_model.mean_rho_prediction, bins=bins, alpha=alpha, label=r'Single Rho Prediction $\hat{\rho}}$')
# ax.hist(sim_model.labels / sim_model.total_votes, bins=bins, alpha=alpha, label=r'Observed $\rho$')
# ax.legend()
# ax.set_xlim([0., 1.])
# ax.set_ylabel('Galaxies')
# ax.set_xlabel(r'Typical vote fraction $\rho$')
# fig.tight_layout()
# fig.savefig(os.path.join(save_dir, 'typical_vote_fraction_distribution.png'))

# This is a repeat of the above histograms

In [None]:
np.sum(sim_model.mean_rho_prediction > 0.5), np.sum(single_sim_model.mean_rho_prediction > 0.5), np.sum((sim_model.labels / sim_model.total_votes) > 0.5)

In [None]:
(sim_model.labels / sim_model.total_votes).min(), (sim_model.labels / sim_model.total_votes).max()

In [None]:
sim_model.mean_rho_prediction.min(), sim_model.mean_rho_prediction.max()

In [None]:
single_sim_model.mean_rho_prediction.min(), single_sim_model.mean_rho_prediction.max()

In [None]:
sim_model.total_votes

## Save DataFrame of predictions + catalog (GZ2) for use elsewhere

In [None]:
import json

In [None]:
response_df = pd.DataFrame(data={
    'total_votes': sim_model.total_votes, 
    'k': sim_model.labels, 
    'vote_fraction': (sim_model.labels / sim_model.total_votes), 
    'rho_prediction': sim_model.mean_rho_prediction
#     'png_loc': sim_model.catalog.png_loc
})
safe_catalog_cols = list(set(sim_model.catalog.columns.values) - set(['total_votes', 'ra_subject', 'dec_subject']))
df = pd.concat([response_df, sim_model.catalog[safe_catalog_cols]], axis=1)
df['smooth'] = df['vote_fraction'] > 0.5
df['confidence_proxy'] = np.abs(0.5 - df['rho_prediction'])
df['rho_predictions'] = 0
for n in range(len(df)):
    df['rho_predictions'][n] = json.dumps(list(sim_model.model.samples[n, :]))
    df = df.sort_values('confidence_proxy', ascending=False)

In [None]:
df['rho_predictions']

In [None]:
df.head()

In [None]:
df.to_parquet('/data/repos/zoobot/notebooks/{}_test_predictions_and_gz2_catalog.parquet'.format(QUESTION))

### Replicate (ish) Sanchez 2017 ROC Curves

In [None]:
confusion_matrix((sim_model.labels / sim_model.total_votes) > 0.5, sim_model.mean_rho_prediction > 0.5)

In [None]:
 1 - ((66 + 99) / (490 + 1845 + 66 + 99))

In [None]:
 1 - ((189 + 81) / (1858 + 189 + 81 + 372))

In [None]:
fig, ax = plt.subplots()
sns.set(font_scale=1.2)
sns.set_style('white')

fpr, tpr, _ = roc_curve(df['smooth'], df['rho_prediction'])
ax.plot(fpr, tpr, label='All')
df_low_entropy = df[df['confidence_proxy'] > 0.3]
fpr, tpr, _ = roc_curve(df_low_entropy['smooth'], df_low_entropy['rho_prediction'])
ax.plot(fpr, tpr, label=r'"High Confidence" i.e. $\hat{\rho} < 0.2$ or $\hat{\rho} > 0.8$')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.legend()

fig.tight_layout()
fig.savefig(os.path.join(save_dir, 'roc_curve.png'))
fig.savefig(os.path.join(save_dir, 'roc_curve.pdf'))

In [None]:
len(df), len(df_low_entropy)

### Replicate(ish) Khan 2018 Confusion Matrices

> After selecting the OBJIDs from Table 2 based on the probability thresholds of 0.985 and 0.926 for spirals and ellipticals respectively,

In [None]:
df.sample(1)

In [None]:
cdf_array = binom.cdf((df['total_votes'] / 2.).astype(int), df['total_votes'], df['rho_prediction'])

In [None]:
(df['total_votes'] / 2.).astype(int).sample(10)

In [None]:
df['total_votes'].sample(10)

In [None]:
df['rho_prediction'].sample(10)

In [None]:
plt.hist(cdf_array, bins=30)

In [None]:
binom.cdf(20, 40, 0.88)

In [None]:
sum(1 - cdf_array > 0.985)

In [None]:
sum(cdf_array > 0.926)

In [None]:
high_prob_df = df[(cdf_array < (1 - 0.985)) | (cdf_array > 0.926)]

In [None]:
len(high_prob_df)

In [None]:
if QUESTION == 'smooth':
    spiral_pc_to_keep = 516 / 6677
    n_spirals = int(len(df) * spiral_pc_to_keep)
    elliptical_pc_to_keep = 550 / 5904
    n_ellipticals = int(len(df) * elliptical_pc_to_keep)
    print(spiral_pc_to_keep, n_spirals, elliptical_pc_to_keep, n_ellipticals)
    high_prob_df = pd.concat([
        df.sort_values('rho_prediction')[:n_spirals],
        df.sort_values('rho_prediction', ascending=False)[:n_ellipticals]
    ])
if QUESTION == 'bars':
    n_to_keep = int(len(df) * 0.08)
    high_prob_df = pd.concat([
        df.sort_values('rho_prediction')[:int(n_to_keep/2)],
        df.sort_values('rho_prediction', ascending=False)[:int(n_to_keep/2)]
    ])



In [None]:
high_prob_df.sample(20)

In [None]:
confusion_matrix(high_prob_df['vote_fraction'] >= 0.5, high_prob_df['rho_prediction'] >= 0.5)

In [None]:
error = high_prob_df[~(high_prob_df['vote_fraction'] > 0.5) & (high_prob_df['rho_prediction'] > 0.5)]

In [None]:
error

In [None]:
error['vote_fraction'] > 0.5, error['rho_prediction'] > 0.5

In [None]:
img = Image.open(error.iloc[0]['png_loc'])
plt.imshow(img)
fontdict = {'size': 16, 'color': 'white'}
plt.text(30, 360, r'Expected vote frac $\hat{\rho}$: 0.80', fontdict=fontdict)
plt.text(30, 400, r'Observed vote frac $\frac{k}{N}$: 0.50', fontdict=fontdict)
plt.axis('off')
plt.savefig(os.path.join(save_dir, 'high_prob_error_0.png'))
plt.savefig(os.path.join(save_dir, 'high_prob_error_0.eps'))

In [None]:
# img = Image.open(error.iloc[1]['png_loc'])
# plt.imshow(img)
# fontdict = {'size': 16, 'color': 'white'}
# plt.text(30, 360, r'Expected vote frac $\hat{\rho}$: 0.13', fontdict=fontdict)
# plt.text(30, 400, r'Observed vote frac $\frac{k}{N}$: 0.54', fontdict=fontdict)
# plt.axis('off')
# plt.savefig(os.path.join(save_dir, 'high_prob_error_1.png'))
# plt.savefig(os.path.join(save_dir, 'high_prob_error_1.eps'))

In [None]:
confusion_matrix(df['vote_fraction'][:int(len(df) / 2)] > 0.5, df['rho_prediction'][:int(len(df) / 2)] > 0.5)

In [None]:
if QUESTION == 'smooth':
    labels = ['Smooth', 'Featured']
    
#     cm = np.array([[ 232,    2], [   0, 191]])
#     name = 'confusion_matrix_high_confidence'
    
    cm = np.array([[ 490,   66],
       [  99, 1845]])
    name = 'confusion_matrix'
    
if QUESTION == 'bars':
    labels = ['No Bar', 'Bar']
    cm = np.array([[100,    0], [   0,   100]])
    name = 'confusion_matrix_high_confidence'
    
#     cm = np.array([[1858,    81], [   189,   372]])
#     name = 'confusion_matrix'

sns.set(font_scale=3.)
sns.set_style('white')

fig, ax = plt.subplots(figsize=(8, 8))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=labels, yticklabels=labels, cbar=False, square=True, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('Observed')
fig.tight_layout()
fig.savefig(os.path.join(save_dir, '{}.png'.format(name)))
fig.savefig(os.path.join(save_dir, '{}.pdf'.format(name)))

In [None]:
1 - (8 / (1159 + 83 + 8))

In [None]:
# sim_model.export_performance_metrics(save_dir)

In [None]:
# Draw a galaxy, infer a range of p, redraw, and measure accuracy - work in progress

### Plot other standard acquisition visualisations

In [None]:
new_acquisition_viz = False
if new_acquisition_viz:
    image_locs = sim_model.catalog['png_loc']
    images = np.stack([np.array(Image.open(loc)) for loc in image_locs])
    assert images.shape == (2500, 424, 424, 3)
    acquisition_utils.save_acquisition_examples(images, sim_model.mutual_info, 'mutual_info', save_dir)

In [None]:
# fig, row = plt.subplots(ncols=3, figsize=(12, 4))

In [None]:
# row = sim_model.acquisition_vs_volunteer_votes(row)

### Visualise Selection of Catalog Features w.r.t. Acquisition Function

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(20, 12))
gs = gridspec.GridSpec(6, 5, figure=fig)

#### Smooth Votes

In [None]:
ax0 = plt.subplot(gs[:4, :])
sns.scatterplot(
    np.array(sim_model.catalog['smooth-or-featured_smooth_fraction'] * 40).astype(int),
    sim_model.model.acquisitions, hue=np.array(sim_model.model.acquisitions) > np.array(sim_model.model.acquisitions[103]),
    ax=ax0)
ax0.set_ylabel('Mutual Information')
ax0.set_xlabel('Smooth Votes')
ax0.legend([r'Top 10% $\mathcal{I}$', r'Bottom 90% $\mathcal{I}$'])

In [None]:
ax1 = plt.subplot(gs[4:, :])
ax1.hist(np.array(sim_model.labels * 40).astype(int), density=True, alpha=0.4)
ax1.hist(np.array(sim_model.labels * 40).astype(int)[:200], density=True, alpha=0.4)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('Smooth Votes')

In [None]:
plt.savefig(os.path.join(save_dir, 'temp.png'))

#### Redshift

In [None]:
fig, ax = plt.subplots()
sns.jointplot(np.array(sim_model.labels * 40).astype(int), sim_model.catalog['redshift'], kind='kde')
ax0.set_ylabel('Redshift')
ax0.set_xlabel('Volunteer Votes')

In [None]:
ax0 = plt.subplot(gs[:2, :])
sns.jointplot(np.array(sim_model.labels * 40).astype(int), sim_model.catalog['redshift'], kind='kde', ax=ax0)
ax0.set_ylabel('Redshift')
ax0.set_xlabel('Volunteer Votes')

In [None]:
ax0 = plt.subplot(gs[:2, :])
sns.jointplot(sim_model.catalog['redshift'], sim_model.model.acquisitions, ax=ax0)
ax0.set_ylabel('Mutual Information')
ax0.set_xlabel('Redshift')

In [None]:
ax1 = plt.subplot(gs[4:, :])

ax1.hist(sim_model.catalog['redshift'], density=True, alpha=0.4)
# ax1.hist(sim_model.catalog['redshift'], density=True, alpha=0.4)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('Smooth Votes')
# TODO sort by mutual information

### Below here is only relevant for DECALS, with extra questions. TODO update with GZ2 merger options?

In [None]:
assert False

In [None]:
merger_strs

In [None]:
merger_label = 'merging_major-disturbance'

In [None]:
ax0 = plt.subplot(gs[:2, :])
sns.scatterplot(sim_model.catalog[merger_label], sim_model.model.mutual_info, ax=ax0)
ax0.set_ylabel('Mutual Information')
ax0.set_xlabel(merger_label)

In [None]:
_ = plt.hist(sim_model.catalog[merger_label], bins=40)

In [None]:
featured_no_merger = sim_model.model.mutual_info[(sim_model.catalog[merger_label] == 0) & (sim_model.catalog['smooth-or-featured_smooth_fraction'] < 0.4)]
featured_merger = sim_model.model.mutual_info[(sim_model.catalog[merger_label] > 0) & (sim_model.catalog['smooth-or-featured_smooth_fraction'] < 0.4)]

In [None]:
featured_no_merger.mean()

In [None]:
featured_merger.mean()

In [None]:
fig, ax = plt.subplots()

In [None]:
ax.hist(featured_no_merger, alpha=0.3, density=True)

In [None]:
ax.hist(featured_merger, alpha=0.3, density=True)

In [None]:
fig

In [None]:
ax = sns.scatterplot(
    sim_model.catalog['smooth-or-featured_artifact'], 
    sim_model.model.mutual_info, 
    hue=sim_model.model.mutual_info > sim_model.model.mutual_info[103])
ax.set_xlim([0, 14.5])

In [None]:
sns.regplot(
    sim_model.catalog['smooth-or-featured_artifact'], 
    sim_model.model.mutual_info)

In [None]:
fig, ax = plt.subplots()

In [None]:
ax.hist(
    sim_model.catalog['smooth-or-featured_artifact'][sim_model.model.mutual_info > sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)
ax.hist(
    sim_model.catalog['smooth-or-featured_artifact'][sim_model.model.mutual_info < sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)

In [None]:
fig

In [None]:
has-spiral-arms_yes
spiral-winding_prediction-encoded

In [None]:
sns.scatterplot(
    sim_model.catalog['has-spiral-arms_yes'][sim_model.catalog['smooth-or-featured_smooth_fraction'] < 0.5], 
    sim_model.model.mutual_info[sim_model.catalog['smooth-or-featured_smooth_fraction'] < 0.5], 
    hue=sim_model.model.mutual_info > sim_model.model.mutual_info[100])

In [None]:
fig, ax = plt.subplots()
ax.hist(
    sim_model.catalog['has-spiral-arms_yes'][sim_model.model.mutual_info > sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)
ax.hist(
    sim_model.catalog['has-spiral-arms_yes'][sim_model.model.mutual_info < sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)

In [None]:
fig, ax = plt.subplots()
ax.hist(
    sim_model.catalog['redshift'][sim_model.model.mutual_info > sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)
ax.hist(
    sim_model.catalog['redshift'][sim_model.model.mutual_info < sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)

In [None]:
fig, ax = plt.subplots()
ax.hist(
    sim_model.catalog['merging_major-disturbance'][sim_model.model.mutual_info > sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)
ax.hist(
    sim_model.catalog['merging_major-disturbance'][sim_model.model.mutual_info < sim_model.model.mutual_info[100]],
    density=True,
    alpha=0.4,
)

In [None]:
for merger_label in merger_strs:
    print('\n' + merger_label)
    print(sim_model.model.mutual_info[sim_model.catalog[merger_label] > 1].mean())
    print(sim_model.model.mutual_info[sim_model.catalog[merger_label] == 1].mean())

In [None]:
data = [
    {'Volunteer Response': 'Merging', 'Mean Mutual Information': sim_model.model.mutual_info[sim_model.catalog['merging_both-v1'] > 1].mean()},
    {'Volunteer Response': 'Major Disturbance', 'Mean Mutual Information': sim_model.model.mutual_info[sim_model.catalog['merging_major-disturbance'] > 1].mean()},
    {'Volunteer Response': 'Minor Disturbance', 'Mean Mutual Information': sim_model.model.mutual_info[sim_model.catalog['merging_minor-disturbance'] > 1].mean()},
    {'Volunteer Response': 'No Disturbance', 'Mean Mutual Information': sim_model.model.mutual_info[sim_model.catalog['merging_none'] > 20].mean()}
    ]

In [None]:
df = pd.DataFrame(data)

In [None]:
df.head()

In [None]:
fig, ax = plt.subplots()
ax = sns.barplot(data=df, y='Volunteer Response', x='Mean Mutual Information', ax=ax)
ax.set_xlim([0.2, 0.36])
fig.tight_layout()

In [None]:
print(sim_model.model.mutual_info[(sim_model.catalog['merging_minor-disturbance'] + sim_model.catalog['merging_major-disturbance']) > 0].mean())
print(sim_model.model.mutual_info[(sim_model.catalog['merging_minor-disturbance'] + sim_model.catalog['merging_major-disturbance']) == 0].mean())

In [None]:
plt.hist(sim_model.catalog['merging_tidal-debris-v1'], bins=40)