In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import logging
import argparse
import glob
import json
from multiprocessing import Pool

from IPython.display import display, Markdown

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import StrMethodFormatter
import matplotlib.gridspec as gridspec
import seaborn as sns
from PIL import Image

import pandas as pd
from astropy.table import Table  # for NSA
from astropy import units as u

from scipy import stats, integrate
from scipy.stats import binom
import statsmodels.api as sm

from sklearn import metrics
from sklearn.metrics import confusion_matrix, plot_confusion_matrix, roc_curve, mean_squared_error, mean_absolute_error
import tensorflow as tf

from shared_astro_utils import astropy_utils, matching_utils
from zoobot.estimators import make_predictions, bayesian_estimator_funcs
from zoobot.tfrecord import read_tfrecord
from zoobot.uncertainty import discrete_coverage
from zoobot.estimators import input_utils, losses, dirichlet_stats
from zoobot.tfrecord import catalog_to_tfrecord
from zoobot.active_learning import metrics, simulated_metrics, acquisition_utils, check_uncertainty, simulation_timeline, run_estimator_config
from zoobot.active_learning import acquisition_utils
from zoobot import label_metadata


In [3]:
os.chdir('/home/walml/repos/zoobot')

In [4]:
questions = label_metadata.decals_questions
label_cols = label_metadata.decals_label_cols
version = 'decals'

schema = losses.Schema(label_cols, questions, version=version)
schema.questions

{smooth-or-featured, indices 0 to 2, asked after None: (0, 2), disk-edge-on, indices 3 to 4, asked after smooth-or-featured_featured-or-disk, index 1: (3, 4), has-spiral-arms, indices 5 to 6, asked after disk-edge-on_no, index 4: (5, 6), bar, indices 7 to 9, asked after disk-edge-on_no, index 4: (7, 9), bulge-size, indices 10 to 14, asked after disk-edge-on_no, index 4: (10, 14), how-rounded, indices 15 to 17, asked after smooth-or-featured_smooth, index 0: (15, 17), edge-on-bulge, indices 18 to 20, asked after disk-edge-on_yes, index 3: (18, 20), spiral-winding, indices 21 to 23, asked after has-spiral-arms_yes, index 5: (21, 23), spiral-arm-count, indices 24 to 29, asked after has-spiral-arms_yes, index 5: (24, 29), merging, indices 30 to 33, asked after None: (30, 33)}


[smooth-or-featured, indices 0 to 2, asked after None,
 disk-edge-on, indices 3 to 4, asked after smooth-or-featured_featured-or-disk, index 1,
 has-spiral-arms, indices 5 to 6, asked after disk-edge-on_no, index 4,
 bar, indices 7 to 9, asked after disk-edge-on_no, index 4,
 bulge-size, indices 10 to 14, asked after disk-edge-on_no, index 4,
 how-rounded, indices 15 to 17, asked after smooth-or-featured_smooth, index 0,
 edge-on-bulge, indices 18 to 20, asked after disk-edge-on_yes, index 3,
 spiral-winding, indices 21 to 23, asked after has-spiral-arms_yes, index 5,
 spiral-arm-count, indices 24 to 29, asked after has-spiral-arms_yes, index 5,
 merging, indices 30 to 33, asked after None]

## Combine predictions from all models into one catalog

In [5]:

# DR1
# predictions_locs = [f'results/folder_dr1_model_decals_dr_train_labelled_m{n}_predictions.csv' for n in range(5)]
# save_loc = 'latest_ml_catalog_dr1_only.parquet'

# DR2
# predictions_locs = [f'results/folder_dr2_model_decals_dr_train_labelled_m{n}_predictions.csv' for n in range(5)]
# save_loc = 'latest_ml_catalog_dr2_only.parquet'

# DR5, directly from tfrecords (includes 3k duplicates)
# called eval_predictions, but is actually predictions on all tfrecords (labelled and unlabelled)
# predictions_locs = [f'results/decals_dr_train_labelled_m{n}_eval_predictions.csv' for n in range(5)]
predictions_locs = [f'results/folder_dr5_model_decals_dr_train_labelled_m{n}_predictions.csv' for n in range(5)]
# TODO chnage to 5
save_loc = 'latest_ml_catalog_dr5_only.parquet'


In [6]:
concentration_cols = [a.text + '_concentration' for a in schema.answers]

# copied from performance_metrics.ipynb
samples = []
for predictions_loc in predictions_locs:
    predictions = pd.read_csv(predictions_loc)
    model_samples = dirichlet_stats.load_all_concentrations(predictions, concentration_cols=concentration_cols)
    samples.append(model_samples)
samples = np.concatenate(samples, axis=2)

print(samples.shape)

(343128, 34, 25)


In [7]:
len(predictions)

343128

In [8]:
predictions.columns.values

array(['png_loc', 'smooth-or-featured_smooth_concentration',
       'smooth-or-featured_featured-or-disk_concentration',
       'smooth-or-featured_artifact_concentration',
       'disk-edge-on_yes_concentration', 'disk-edge-on_no_concentration',
       'has-spiral-arms_yes_concentration',
       'has-spiral-arms_no_concentration', 'bar_strong_concentration',
       'bar_weak_concentration', 'bar_no_concentration',
       'bulge-size_dominant_concentration',
       'bulge-size_large_concentration',
       'bulge-size_moderate_concentration',
       'bulge-size_small_concentration', 'bulge-size_none_concentration',
       'how-rounded_round_concentration',
       'how-rounded_in-between_concentration',
       'how-rounded_cigar-shaped_concentration',
       'edge-on-bulge_boxy_concentration',
       'edge-on-bulge_none_concentration',
       'edge-on-bulge_rounded_concentration',
       'spiral-winding_tight_concentration',
       'spiral-winding_medium_concentration',
       'spiral-wi

In [9]:
predictions['iauname'] = predictions['png_loc'].apply(lambda x: x.split('/')[-1].replace('.jpeg', '').replace('.png', '').replace('_standard', ''))

In [10]:
predictions['iauname']

0         J223253.27-005423.9
1         J223445.65-010456.2
2         J223515.21-003519.5
3         J223402.99+001117.3
4         J223710.17-005700.4
                 ...         
343123    J022211.96-003834.5
343124    J022221.11-074059.5
343125    J022903.73-082413.6
343126    J022727.34-043858.2
343127    J022513.64-015438.9
Name: iauname, Length: 343128, dtype: object

In [11]:
iauname_df = predictions[['iauname', 'png_loc']]

In [12]:
data = []
for sample in samples:
    row = {}
    for n, col in enumerate(label_cols):
        answer = label_cols[n]
        # slow, but works
        row[answer + '_concentration'] = json.dumps(list(np.around(sample[n], 4).astype(float)))
    data.append(row)

concentration_df = pd.DataFrame(data=data)

In [13]:
concentration_df.head()

Unnamed: 0,smooth-or-featured_smooth_concentration,smooth-or-featured_featured-or-disk_concentration,smooth-or-featured_artifact_concentration,disk-edge-on_yes_concentration,disk-edge-on_no_concentration,has-spiral-arms_yes_concentration,has-spiral-arms_no_concentration,bar_strong_concentration,bar_weak_concentration,bar_no_concentration,...,spiral-arm-count_1_concentration,spiral-arm-count_2_concentration,spiral-arm-count_3_concentration,spiral-arm-count_4_concentration,spiral-arm-count_more-than-4_concentration,spiral-arm-count_cant-tell_concentration,merging_none_concentration,merging_minor-disturbance_concentration,merging_major-disturbance_concentration,merging_merger_concentration
0,"[18.8731, 22.7141, 20.7146, 18.2101, 20.091, 1...","[39.0552, 51.8402, 43.3618, 49.3892, 43.879, 4...","[3.5365, 3.58, 3.7669, 3.2913, 3.9583, 3.2235,...","[3.501, 5.4355, 4.2553, 4.0008, 4.7009, 5.2416...","[42.9519, 36.9419, 42.9178, 57.9632, 34.6256, ...","[5.0956, 4.2247, 4.4058, 8.8174, 4.3745, 3.958...","[4.0325, 6.3041, 5.5175, 3.8162, 5.7515, 5.471...","[10.8883, 22.3576, 15.8604, 19.8254, 20.8928, ...","[6.6697, 6.7961, 7.1283, 9.6206, 6.4508, 5.426...","[7.0041, 7.3049, 7.5909, 6.4403, 5.4805, 7.024...",...,"[10.3803, 10.5544, 11.7161, 7.2947, 6.7669, 5....","[74.1174, 84.4166, 82.518, 91.4303, 79.7492, 9...","[2.9143, 1.7124, 1.9782, 1.6086, 2.2643, 1.353...","[1.0625, 1.0588, 1.0609, 1.0303, 1.1333, 1.029...","[1.0027, 1.0097, 1.0072, 1.0024, 1.0108, 1.006...","[22.7806, 18.1786, 20.8253, 10.4263, 19.3903, ...","[51.5797, 56.9599, 50.3964, 37.8404, 51.686, 3...","[14.1546, 12.8948, 11.3041, 8.4152, 12.9024, 9...","[4.6987, 4.8144, 3.9932, 2.6639, 4.2171, 3.049...","[1.4204, 1.4245, 1.6515, 1.7707, 1.6373, 2.010..."
1,"[100.2998, 100.9084, 98.8625, 96.6109, 95.5195...","[11.3494, 11.9764, 11.8131, 13.3513, 10.6898, ...","[21.192, 21.4996, 19.7249, 21.2485, 21.661, 22...","[7.3492, 11.0223, 7.2616, 6.4169, 5.3163, 6.09...","[100.9746, 100.9995, 100.9881, 93.2174, 99.654...","[1.066, 1.0411, 1.1538, 1.1333, 1.1911, 1.0931...","[48.0764, 67.1999, 42.7396, 49.795, 34.2735, 4...","[1.0816, 1.021, 1.0374, 1.2873, 1.2634, 1.0353...","[2.2426, 1.9689, 2.1154, 3.1913, 2.7703, 1.957...","[76.0942, 85.964, 65.8677, 50.1279, 54.3441, 5...",...,"[3.1159, 1.5191, 1.6137, 4.1671, 3.3328, 2.584...","[4.1307, 1.8367, 3.4813, 9.5983, 7.1829, 3.680...","[2.0401, 1.7039, 1.6891, 4.3883, 3.3034, 4.800...","[1.5661, 1.4354, 1.1941, 2.3287, 1.9341, 1.882...","[2.2376, 2.0628, 1.575, 2.4465, 2.0777, 2.3039...","[5.6598, 3.2196, 3.8623, 11.6139, 7.3158, 7.41...","[64.3555, 75.4268, 67.9457, 58.0742, 56.7317, ...","[2.9335, 2.95, 3.8461, 3.9544, 3.3731, 3.6801,...","[1.1221, 1.0583, 1.162, 1.2146, 1.1471, 1.2354...","[1.1054, 1.1007, 1.181, 1.3842, 1.2649, 1.261,..."
2,"[100.6286, 99.8829, 100.1223, 99.6535, 99.7411...","[9.5877, 12.1126, 10.7685, 10.212, 9.5006, 11....","[13.8786, 12.557, 12.7363, 13.4321, 13.8922, 1...","[7.2084, 7.5475, 9.4287, 7.4908, 8.8195, 6.983...","[99.0, 100.9875, 100.8836, 99.2058, 100.2806, ...","[1.1055, 1.1696, 1.1697, 1.2031, 1.174, 1.2169...","[43.6163, 41.2119, 40.3802, 36.7099, 41.3417, ...","[1.3803, 1.0141, 1.029, 1.5132, 1.1614, 1.1021...","[2.8964, 1.9269, 2.0382, 3.2041, 2.4041, 2.269...","[75.3437, 55.8791, 59.1298, 66.628, 64.1499, 5...",...,"[1.6679, 1.0134, 1.0149, 1.8143, 1.1796, 1.661...","[7.5613, 2.0045, 2.3126, 7.9255, 3.1963, 4.049...","[4.0654, 1.8939, 1.7705, 4.3699, 3.2378, 4.417...","[5.6245, 1.8998, 2.1777, 4.407, 3.3383, 2.4469...","[5.9456, 2.1951, 3.1281, 5.4452, 4.665, 2.6137...","[15.8112, 3.3296, 3.5174, 13.1405, 7.1652, 9.3...","[81.8863, 77.5161, 81.7811, 73.9158, 80.9015, ...","[4.9307, 4.2565, 5.0386, 4.4947, 5.0667, 4.605...","[1.3159, 1.0919, 1.1438, 1.2673, 1.1925, 1.196...","[1.1482, 1.1133, 1.1168, 1.2639, 1.1763, 1.140..."
3,"[71.2866, 82.2588, 75.9155, 66.3959, 77.7713, ...","[6.4053, 6.2552, 5.9803, 5.4375, 6.469, 5.3874...","[8.545, 8.136, 10.5003, 9.1182, 7.9304, 9.0336...","[7.1974, 12.0605, 6.5911, 6.3696, 11.9035, 4.4...","[28.5018, 36.8346, 24.632, 21.1756, 52.8328, 2...","[3.0817, 2.7644, 2.375, 2.474, 4.4556, 2.6432,...","[14.0236, 14.9046, 19.2895, 11.8857, 14.4581, ...","[2.9977, 2.5995, 2.9434, 2.555, 2.4383, 2.8204...","[6.9935, 7.3001, 6.3014, 6.4472, 9.795, 5.5185...","[40.8346, 44.7967, 38.0588, 39.491, 47.8499, 3...",...,"[6.0013, 1.8884, 3.7468, 3.1089, 1.407, 5.6581...","[34.1222, 28.1465, 20.2378, 24.8247, 44.7048, ...","[2.1075, 3.3162, 2.8961, 2.0058, 2.8996, 2.685...","[1.3432, 1.881, 1.7137, 1.3124, 1.8995, 1.5268...","[2.0902, 2.2166, 2.5132, 1.8517, 1.876, 2.1202...","[11.2085, 10.3061, 11.2904, 9.2863, 9.681, 15....","[66.5555, 76.7336, 64.9932, 67.0382, 83.479, 4...","[23.5164, 17.218, 21.946, 18.2494, 20.7623, 18...","[4.5913, 2.4923, 4.1921, 3.7639, 2.8685, 4.339...","[1.1565, 1.1684, 1.2812, 1.2579, 1.1201, 1.211..."
4,"[43.4332, 41.4422, 35.893, 50.8412, 40.0915, 3...","[11.328, 14.0638, 10.8818, 15.8887, 11.1832, 1...","[8.1987, 6.5369, 7.4544, 8.9742, 8.0949, 5.511...","[4.7779, 4.7764, 3.994, 4.4643, 5.0072, 3.6229...","[44.1802, 40.3991, 36.5893, 37.2032, 35.2251, ...","[2.5664, 3.0394, 2.5884, 2.7524, 2.4275, 3.717...","[10.6808, 9.1973, 8.9464, 14.7528, 8.8792, 9.3...","[1.4609, 1.675, 1.5845, 1.8152, 1.6327, 1.4944...","[6.4947, 6.9151, 6.4102, 7.1318, 5.835, 8.1248...","[63.03, 52.9313, 47.167, 53.8787, 46.8882, 49....",...,"[6.4754, 10.2901, 12.0365, 8.7904, 12.3105, 4....","[4.7843, 9.2146, 8.6585, 7.5258, 10.2603, 4.20...","[2.4778, 2.5925, 2.4461, 2.8785, 2.5333, 1.911...","[1.5907, 1.4149, 1.3169, 2.0221, 1.3608, 1.302...","[1.2753, 1.163, 1.1522, 1.518, 1.2171, 1.0554,...","[32.8962, 37.2702, 36.8641, 37.6862, 40.8479, ...","[22.7559, 21.1437, 24.3475, 24.4952, 22.2225, ...","[14.9245, 12.6263, 15.7483, 17.0295, 13.2497, ...","[6.7779, 4.2665, 6.5508, 7.0939, 5.3469, 7.484...","[1.6118, 1.9507, 1.5931, 1.9319, 1.6597, 1.727..."


In [14]:
# # overwrite the columns of the final predictions df with the samples from all models, making the ensemble catalog
# for col_n, col in enumerate(concentration_cols):
#     predictions[col] = predictions[col].apply(lambda x: samples[:, col_n])   # remember to convert back to json at end
# #     predictions.loc[:, col] = samples[:, col_n] # can't slice like this sadly

### Remove the volunteer predictions, we want just the ML in this catalog. Can merge back in if we like later.

In [15]:
# for q in schema.questions:
#     for a in q.answers:
#         try:
#             del predictions[a.text]
#             del predictions[a.text + '_fraction']
#         except KeyError:
#             pass

## Calculate vote fractions

In [16]:
# copied from advanced_ml
def apply_over_questions(concentrations, question_index_groups, func):
    results = []
    for q_n in range(len(question_index_groups)):
        q_indices = question_index_groups[q_n]
        q_start = q_indices[0]
        q_end = q_indices[1]
        q_result = func(concentrations[:, q_start:q_end+1])
        results.append(q_result)
    
    results = np.concatenate(results, axis=1)
    return results  # leave the reduce_sum to the estimator

def mean_for_answers(concentrations_for_q):
     return dirichlet_stats.DirichletEqualMixture(concentrations_for_q).mean()
    

In [17]:
# using samples because concentrations_df is json-serialised
mean_answers = apply_over_questions(samples, schema.question_index_groups, mean_for_answers)  # calculate the mean prediction for every question
# mean_answers = apply_over_questions(concentrations[:, :, :1], schema.question_index_groups, mean_for_answers)  # calculate the prediction for just the first model
mean_answers.shape  # shape of (galaxies, mean prediction for answer)

(343128, 34)

In [18]:
fraction_cols = [a + '_fraction' for a in schema.label_cols]
fractions_df = pd.DataFrame(data=mean_answers, columns=fraction_cols)

## Join together

In [19]:
df = pd.concat([iauname_df, concentration_df, fractions_df], axis=1)

## Calculate proportion of volunteers who would be asked

In [20]:
retirement = 1
for question in schema.questions:
    expected_votes = acquisition_utils.get_expected_votes_human(df, question, retirement, schema, round_votes=False)
    if isinstance(expected_votes, tf.Tensor):
        expected_votes = expected_votes.numpy()
    df[question.text + '_proportion_volunteers_asked'] = expected_votes

## Correct file paths (my use only)

In [21]:
predictions['png_loc'][0] 

'/data/phys-zooniverse/chri5177/galaxy_zoo/decals/dr5/png/J223/J223253.27-005423.9.png'

In [22]:
predictions['png_loc'] = predictions['png_loc'].str.replace('/data/phys-zooniverse/chri5177/galaxy_zoo/decals/', '')

In [23]:
if 'dr1' in save_loc:
    predictions['local_png_loc'] = predictions['png_loc'].apply(lambda x: os.path.join('/media/walml/beta/galaxy_zoo/decals', x))
elif 'dr2' in save_loc:
    # same
    predictions['local_png_loc'] = predictions['png_loc'].apply(lambda x: os.path.join('/media/walml/beta/galaxy_zoo/decals', x))
elif 'dr5' in save_loc:
    predictions['local_png_loc'] = predictions['png_loc'].apply(lambda x: os.path.join('/media/walml/beta/decals/png_native', x.replace('/png/', '/')))
else:
    assert False

In [24]:
predictions['local_png_loc'][0]

'/media/walml/beta/decals/png_native/dr5/J223/J223253.27-005423.9.png'

In [25]:
assert os.path.isfile(predictions['local_png_loc'][0])

In [26]:
# drop duplicates
df['iauname'].duplicated().value_counts()  # some galaxies in the shards have duplicates, somehow?! luckily, very very few.
df = df.drop_duplicates(subset=['iauname'], keep='first')

In [27]:
# easy to save as concentrations are already json-serialised
df.to_parquet(save_loc, index=False)

In [28]:
df.columns.values

array(['iauname', 'png_loc', 'smooth-or-featured_smooth_concentration',
       'smooth-or-featured_featured-or-disk_concentration',
       'smooth-or-featured_artifact_concentration',
       'disk-edge-on_yes_concentration', 'disk-edge-on_no_concentration',
       'has-spiral-arms_yes_concentration',
       'has-spiral-arms_no_concentration', 'bar_strong_concentration',
       'bar_weak_concentration', 'bar_no_concentration',
       'bulge-size_dominant_concentration',
       'bulge-size_large_concentration',
       'bulge-size_moderate_concentration',
       'bulge-size_small_concentration', 'bulge-size_none_concentration',
       'how-rounded_round_concentration',
       'how-rounded_in-between_concentration',
       'how-rounded_cigar-shaped_concentration',
       'edge-on-bulge_boxy_concentration',
       'edge-on-bulge_none_concentration',
       'edge-on-bulge_rounded_concentration',
       'spiral-winding_tight_concentration',
       'spiral-winding_medium_concentration',
      