In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import matplotlib

In [3]:
import os
import logging
import argparse
import glob
import json
from collections import Counter

import numpy as np
from matplotlib.ticker import StrMethodFormatter

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn import metrics
import tensorflow as tf
import pandas as pd
from astropy.table import Table  # for NSA
from astropy import units as u
from sklearn.metrics import confusion_matrix, roc_curve
from PIL import Image
from scipy.stats import binom
from IPython.display import display, Markdown

from sklearn.metrics import accuracy_score, mean_squared_error, mean_absolute_error

from shared_astro_utils import astropy_utils, matching_utils
from zoobot.estimators import make_predictions, bayesian_estimator_funcs
from zoobot.tfrecord import read_tfrecord
from zoobot.uncertainty import discrete_coverage
from zoobot.estimators import input_utils, losses
from zoobot.tfrecord import catalog_to_tfrecord
from zoobot.active_learning import metrics, simulated_metrics, acquisition_utils, check_uncertainty, simulation_timeline, run_estimator_config


In [4]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)


In [5]:
os.chdir('/home/walml/repos/zoobot')

In [6]:
# catalog_loc = 'data/latest_labelled_catalog.csv
# catalog_loc = 'data/decals/decals_master_catalog.csv'
catalog_loc = 'data/gz2/gz2_master_catalog.csv'
catalog = pd.read_csv(catalog_loc, dtype={'subject_id': str})  # original catalog
catalog['file_loc'] = catalog['local_png_loc'].apply(lambda x: '/media/walml/beta/decals/png_native' + x[32:])


# catalog_loc = 'data/decals/temp_calibration_catalog.csv'
# catalog = pd.read_csv(catalog_loc, dtype={'subject_id': str})  # original catalog


In [7]:
label_cols = [
    'has-spiral-arms_yes',
    'has-spiral-arms_no'
]


In [9]:

batch_size = 8
initial_size = 300
# initial_size = 128
crop_size = int(initial_size * 0.75)
# crop_size = 128
final_size = 224
channels = 3

In [10]:
train_locs = glob.glob(f'/home/walml/repos/zoobot/data/gz2/shards/all_featp5_facep5_sim_{initial_size}/train_shards/*.tfrecord')
eval_locs = glob.glob(f'/home/walml/repos/zoobot/data/gz2/shards/all_featp5_facep5_sim_{initial_size}/eval_shards/*.tfrecord')
# tfrecord_locs = train_locs + eval_locs
tfrecord_locs = eval_locs

print(tfrecord_locs)
eval_config = run_estimator_config.get_eval_config(tfrecord_locs, label_cols, batch_size, initial_size, final_size, channels)
# print(eval_config.greyscale)
# print(eval_config.permute_channels)
eval_config.drop_remainder = False
dataset = input_utils.get_input(config=eval_config)




['/home/walml/repos/zoobot/data/gz2/shards/all_featp5_facep5_sim_300/eval_shards/s300_shard_0.tfrecord']


In [297]:
def get_data_from_loc(loc):
    feature_spec = input_utils.get_feature_spec({'id_str': 'string'})
    id_str_dataset = input_utils.get_dataset(loc, feature_spec, batch_size=1, shuffle=False, repeat=False, drop_remainder=False)
    id_strs = [str(d['id_str'].numpy().squeeze())[2:-1] for d in id_str_dataset]
    eval_config = run_estimator_config.get_eval_config(loc, label_cols, batch_size, initial_size, final_size, channels)
    dataset = input_utils.get_input(config=eval_config)
#     all_fingerprints = [fingerprint_batch(b) for b, _ in dataset]
#     flat_fingerprints = [tuple(f.numpy().astype(float)) for b in all_fingerprints for f in b]
#     fingerprint_db = pd.DataFrame(data={'id_str': id_strs, 'fingerprint': flat_fingerprints})
    fingerprint_db = pd.DataFrame(data={'id_str': id_strs})
    return fingerprint_db

In [298]:
fingerprint_data = []
for loc in train_locs:
    fingerprint_db = get_data_from_loc(loc)
    name = os.path.basename(loc)
    fingerprint_db['name'] = name
    fingerprint_data.append(fingerprint_db)
fingerprint_df = pd.concat(fingerprint_data, axis=0)



In [299]:
fingerprint_df.head()

Unnamed: 0,id_str,name
0,dr7objid_587741602027012125,s300_shard_2.tfrecord
1,dr7objid_587733609088418036,s300_shard_2.tfrecord
2,dr7objid_587729409151926291,s300_shard_2.tfrecord
3,dr7objid_587739406264238209,s300_shard_2.tfrecord
4,dr7objid_587741725511778505,s300_shard_2.tfrecord


In [301]:
# fingerprint_df['fingerprint'].value_counts()

In [310]:
fingerprint_df['name'].value_counts()

s300_shard_1.tfrecord    4096
s300_shard_0.tfrecord    4096
s300_shard_2.tfrecord     323
Name: name, dtype: int64

In [303]:
multirecord = get_data_from_loc(train_locs)



In [306]:
len(multirecord), len(fingerprint_df)

(8515, 8515)

In [309]:
len(set(multirecord['id_str']).intersection(set(fingerprint_df['id_str'])))

8515

In [304]:
df = pd.merge(multirecord, fingerprint_df[['id_str', 'name']], on='id_str', how='left')

In [305]:
df['id_str'].value_counts()

dr7objid_587733441048740041    1
dr7objid_587742060554616867    1
dr7objid_587742014368907518    1
dr7objid_587732482218524798    1
dr7objid_587742014908334278    1
                              ..
dr7objid_587739721385247023    1
dr7objid_587728879260860524    1
dr7objid_587736543630852175    1
dr7objid_587741533312516203    1
dr7objid_587736542017487081    1
Name: id_str, Length: 8515, dtype: int64

In [311]:
df['name'].value_counts()

s300_shard_1.tfrecord    4096
s300_shard_0.tfrecord    4096
s300_shard_2.tfrecord     323
Name: name, dtype: int64

In [312]:
df['name']

0       s300_shard_2.tfrecord
1       s300_shard_1.tfrecord
2       s300_shard_0.tfrecord
3       s300_shard_2.tfrecord
4       s300_shard_1.tfrecord
                ...          
8510    s300_shard_0.tfrecord
8511    s300_shard_1.tfrecord
8512    s300_shard_0.tfrecord
8513    s300_shard_1.tfrecord
8514    s300_shard_0.tfrecord
Name: name, Length: 8515, dtype: object

So the loading pattern is cycle of length=3 until smallest record exhausted, then cycle of length=2, etc

In [245]:
def fingerprint_batch(image_batch, loc=None):
#     return np.array([fingerprint(image) for image in image_batch.numpy()])
    result = tf.reduce_sum(image_batch, axis=[1, 2])
    
    if loc is not None:
        with open(loc, 'a+') as f:
            for image_fingerprint in result.numpy().astype(float):
                f.write(json.dumps(tuple(image_fingerprint.flatten())) + '\n')
    return result

In [246]:
def fingerprint(image):
    return image.sum(axis=1)

In [247]:
fingerprint(np.random.rand(final_size, final_size, channels)).shape

(224, 3)

In [248]:
dummy_batch = tf.constant(np.random.rand(batch_size, final_size, final_size, channels), dtype=tf.float32)
result = fingerprint_batch(dummy_batch)
print(len(result))
print(result.shape)

8
(8, 3)


In [249]:
temp_loc = 'notebooks/multiq/loaded.text'
if os.path.isfile(temp_loc):
    os.remove(temp_loc)

In [250]:
model = tf.keras.Sequential(
    [
        tf.keras.layers.Lambda(lambda x: fingerprint_batch(x, loc=temp_loc), output_shape=[batch_size, final_size, channels]),
        tf.keras.layers.Lambda(lambda x: tf.random.uniform(shape=[batch_size]), output_shape=[batch_size])
    ]
)
model.run_eagerly = True

In [251]:
_ = model.predict(dummy_batch)

In [252]:
with open(temp_loc, 'r') as f:
    contents = f.readlines()

In [253]:
batch_fingerprints = [json.loads(line) for line in contents]

In [254]:
len(batch_fingerprints)

8

In [255]:
len(batch_fingerprints[0])

3

In [256]:
batch_fingerprints[0][0]

25051.34765625