In [1]:
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import matplotlib.ticker as ticker
import seaborn as sns
sns.set_context('notebook')
from sklearn.neighbors import NearestNeighbors
from PIL import Image

In [28]:
def show_galaxies(df, n_galaxies=36, label=None):
    if n_galaxies == 8:
        fig, all_axes = plt.subplots(1, 8, figsize=(20, 3.5))
    elif n_galaxies == 11:
        fig, all_axes = plt.subplots(1, 11, figsize=(20, 4.5))
    elif n_galaxies == 12:
        fig, axes = plt.subplots(2, 6, figsize=(20, 7))
        all_axes = [ax for row in axes for ax in row]
    else:
        fig, axes = plt.subplots(6, 6, figsize=(20, 20))
        all_axes = [ax for row in axes for ax in row]
    for ax_n, ax in enumerate(all_axes):
        img_loc = os.path.join('/Volumes/beta/decals/png_native/dr5', df.iloc[ax_n]['png_loc'].replace('/media/walml/beta1/decals/png_native/dr5/', ''))
        im = Image.open(img_loc)
        
        crop_pixels = 120
        initial_size = 424 # assumed, careful
        (left, upper, right, lower) = (crop_pixels, crop_pixels, initial_size-crop_pixels, initial_size-crop_pixels)
        im = im.crop((left, upper, right, lower))

        ax.imshow(np.array(im))
        
        if ax_n == 0:
            ax.patch.set_edgecolor('green')  
            ax.patch.set_linewidth('14')  
            # can't just disable axis as also disables border, do manually instead
            ax.xaxis.set_major_locator(ticker.NullLocator())
            ax.xaxis.set_minor_locator(ticker.NullLocator())
            ax.yaxis.set_major_locator(ticker.NullLocator())
            ax.yaxis.set_minor_locator(ticker.NullLocator())
            if label:
                ax.set_ylabel(label, labelpad=7, fontsize=14)
        else:
            ax.axis('off')
            
    fig.tight_layout(pad=1.)
    
    return fig

## Load the color embedding

In [4]:
# n_components = 10

# pca_df = pd.read_parquet('/Users/walml/repos/zoobot/data/results/dr5_color_pca{}_and_ids.parquet'.format(n_components)).reset_index()
# embed_cols = [col for col in pca_df if 'feat_' in col]
# print(len(pca_df))

In [5]:
# catalog_df = pd.read_parquet('/Volumes/beta/galaxy_zoo/decals/catalogs/dr5_nsa_v1_0_0_to_upload.parquet', columns=['iauname', 'png_loc'])
# catalog_df = catalog_df.rename(columns={'iauname': 'galaxy_id'})
# pca_df = pd.merge(pca_df, catalog_df, on='galaxy_id', how='inner').reset_index()
# print(len(pca_df))

## Or, load the greyscale version (needs manual embed)

In [6]:
df = pd.read_parquet('/Volumes/beta/cnn_features/decals/dr5_b0_full_features_and_safe_catalog.parquet')
wrong_size = pd.read_parquet('/Users/walml/repos/zoobot_private/gz_decals_volunteers_auto_posteriors_wrongsize.parquet', columns=['iauname', 'wrong_size_statistic', 'wrong_size_warning'])
print(len(df))
df = pd.merge(df, wrong_size, on='iauname', how='inner')
print(len(df))
df = df[~df['wrong_size_warning']]
df = df.reset_index()

305657
273722


In [219]:
from sklearn.decomposition import IncrementalPCA
import pickle

def get_embed(features, n_components, save=''):
    embedder = IncrementalPCA(n_components=n_components)
    embed = embedder.fit_transform(features) 
     # no train/test needed as unsupervised
#     if len(save) > 0:
#         plt.plot(embedder.explained_variance_)  # 5 would probably do?
#         plt.savefig(save)
#         plt.close()
    print(embedder.explained_variance_ratio_)
    print(embedder.explained_variance_ratio_.sum())
    return embed

feature_cols = [col for col in df if 'feat_' in col]
features = df[feature_cols].values


In [220]:
n_components = 10

In [222]:
X = get_embed(features, n_components=n_components)
# with open('pc{}_embed_for_similarity_nb.pickle'.format(n_components), 'wb') as f:
#     pickle.dump(X, f)

[0.33243682 0.14899685 0.09159593 0.06734716 0.06648994 0.04236355
 0.0303928  0.02884448 0.01898969 0.01681282]
0.8442700447879479


In [10]:
with open('pc{}_embed_for_similarity_nb.pickle'.format(n_components), 'rb') as f:
    embed = pickle.load(f)

In [11]:
pca_df = pd.DataFrame(data=embed, columns=['feat_{}_pca'.format(n) for n in range(n_components)])
pca_df['galaxy_id'] = df['iauname']
pca_df['png_loc'] = df['png_loc']

In [12]:
tags_df = pd.read_csv('/Users/walml/repos/recommender_hack/tags_for_shoaib.csv')
print(len(tags_df))
tags_df = tags_df[tags_df['iauname'].isin(pca_df['galaxy_id'])]
print(len(tags_df))

110764
93900


In [195]:
join = {
    'star-forming': 'starforming',
    'starformation': 'starforming',
    'star_forming': 'starforming',
    'lenticular-galaxy': 'lenticular', 
    'ringed': 'ring',
    'interacting': 'interaction',
    'disturbance': 'disturbed',
    'bright-core': 'core',
    'dusty': 'dust-lane',
    'dust': 'dust-lane',
    'dustlane': 'dust-lane',
    'foreground-star': 'star',
    'central-core': 'core',
    'wrong-size': 'wrong_size',
    'ringed': 'ring',
    'interaction': 'interacting',
    'overlap': 'overlapping',
    'tidal-debris': 'tidal',
    'merger': 'merging'
}

# pairs_to_replace = [
#     ('star-forming', 'starforming'),
#     ('starburst', 'starforming'),
#     ('starformation', 'starforming'),
#     ('dust-lane', 'dustlane'),
#     ('dust', 'dustlane'),
#     ('dusty', 'dustlane'),
#     ('edge-on', 'edgeon'),
#     ('seyfert-1-galaxy', 'seyfert-1'),
#     ('interaction', 'interacting'),
#     ('overlapping-object', 'overlap'),
#     ('overlapping', 'overlap'),
#     ('central-core', 'core'),
#     ('ringed', 'ring'),
#     ('wrong_size', 'wrong-size'),
#     ('tidal-debris', 'tidal'),
#     ('objects_that_need_more_research', 'need_more_research')
# ]

tags_df['tag_clean'] = tags_df['tag'].apply(lambda x: join.get(x, x))  # pretty smug about this - update only if it's in the dict

In [196]:
(tags_df['tag_clean'] != tags_df['tag']).sum()

8953

In [197]:
tag_changed = tags_df['tag_clean'] != tags_df['tag']
tags_df[tag_changed][['tag', 'tag_clean']]

Unnamed: 0,tag,tag_clean
2,overlap,overlapping
10,merger,merging
17,merger,merging
134,overlap,overlapping
135,dustlane,dust-lane
...,...,...
110709,merger,merging
110710,overlap,overlapping
110736,ringed,ring
110755,merger,merging


In [198]:
# (tags_df['tag'] == 'star-forming').sum()

In [199]:
# ml_df = pd.read_parquet('TODO')

tags_df['tag_clean'].value_counts()[:40]### Clean up tags

In [200]:
# tags_df['tag'].value_counts()[:40]

In [201]:
tags_df['tag_clean'].value_counts()[:40]

starforming                        10586
spiral                              6819
agn                                 6212
starburst                           4318
disturbed                           3638
ring                                3022
merging                             2728
overlapping                         2250
edge-on                             2177
bar                                 2087
dust-lane                           1897
barred-spiral                       1842
decals                              1662
sdss                                1586
irregular                           1450
tidal                                862
asteroid                             824
elliptical                           773
lenticular                           575
broadline                            542
core                                 530
spiral2                              517
hot                                  499
h-alpha-peak                         481
star            

In [162]:
# tags_df['tag_clean'].value_counts()[40:80]

In [204]:
skip_q = ['spiral', 'edge-on', 'bar', 'barred-spiral', 'elliptical', 'spiral2', 'galaxy', 'edgeon', 'strong-bar', 'barred', 'smooth', 'spiral-2', 'merging']
skip_meta = ['decals', 'sdss', 'broadline', 'agn', 'staburst', 'h-alpha-peak', 'infra-red-source', 'liner-type-agn', 'emission-line-galaxy', 'radio_galaxy', 'seyfert-1-galaxy', 'seyfert-1', 'radio-source', 'seyfert-2', 'disk', 'seyfert-2-galaxy', 'qso', 'radio-galaxy']
skip_color = ['red-galaxy', 'red', 'blue', 'green']

# wavelength_tags = ['agn', 'seyfert-1', 'infra-red-source', 'liner-type-agn', 'radio_galaxy', 'seyfert-2', 'seyfert-2-galaxy', 'red-galaxy', 'radio-source', 'blue', 'green', 'qso', 'emission-line-galaxy', 'radio-galaxy', 'h-alpha-peak', 'broadline']
# question_tags = ['spiral', 'edgeon', 'elliptical', 'merger', 'bar', 'barred-spiral', 'merging', 'spiral2', 'spiral-2', 'strong-bar', 'barred', 'smooth', 'disk', 'disturbance']
# metadata_tags = ['decals', 'sdss', 'galaxy']
# duplicate_tags = ['fuzzy', 'overlapping-star', 'main-belt']

# tags_to_skip = wavelength_tags + question_tags + metadata_tags + duplicate_tags

tags_to_skip = set(skip_q).union(set(skip_meta)).union(set(skip_color))
top_tags = [tag for tag in list(tags_df['tag_clean'].value_counts().index) if tag not in tags_to_skip][:50]

In [205]:
top_tags[:17]

['starforming',
 'starburst',
 'disturbed',
 'ring',
 'overlapping',
 'dust-lane',
 'irregular',
 'tidal',
 'asteroid',
 'lenticular',
 'core',
 'hot',
 'star',
 'wrong_size',
 'artifact',
 'diffuse',
 'objects_that_need_more_research']

In [206]:
# top_tags.sort()
# top_tags

### Count tags

In [207]:
tag_counts_by_iauname = {}

for tag in top_tags:
    tag_counts_this_tag = tags_df.query(f'tag == "{tag}"').groupby('iauname').agg({'tag': 'count'}).reset_index()
    tag_counts_by_iauname[tag] = tag_counts_this_tag

In [208]:
top_galaxy_by_tag = {}
for tag, iauname_counts in tag_counts_by_iauname.items():
    top_galaxy_by_tag[tag] = list(iauname_counts.sort_values('tag')['iauname'])[-1]

### Search for those galaxies

In [210]:
def get_neighbors(X, query_index, n_neighbors, metric):
    nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree', metric=metric).fit(X)
    distances, indices = nbrs.kneighbors(X[query_index].reshape(1, -1))
#     print(something)
    return np.squeeze(indices)  # ordered by similarity, will include itself

In [211]:
!pwd

/Users/walml/repos/morphology-tools/notebooks


In [213]:
max_galaxies = 8
embed_cols = [col for col in pca_df if 'feat_' in col]
embedding = pca_df[embed_cols].values

for tag_n, tag in enumerate(top_tags[:18]):
#     print(tag)
#     print(top_galaxy_by_tag[tag])
    galaxy_index = np.argmax(pca_df['galaxy_id'] == top_galaxy_by_tag[tag])
    assert galaxy_index # != 0
    indices = get_neighbors(embedding, galaxy_index, n_neighbors=max_galaxies, metric='euclidean')
#     print(indices)
#     break
    tag_label = f'"{tag.capitalize()}"'.replace('-', ' ').replace('_', ' ')
    fig = show_galaxies(pca_df.iloc[np.squeeze(indices)], n_galaxies=max_galaxies, label=tag_label)  # first is itself
    
#     to save:
    # fig.savefig(f'/Users/walml/repos/morphology-tools/notebooks/similar/pca10/similar_{tag_n}_tag_{tag}_n{max_galaxies}_pca10_talk.png')
    plt.close()

In [None]:
# referee suggests investigating how similarity in labels maps to similarity in representation
# e.g. are the most similar labelled galaxies also the most similar in representation
# specifically, for a query galaxy, compare those

"are the most similar images returned all belonging to similar human-labelled classes, and the model is simply picking
out things that end up in the same classification bucket. Or is learning that images that were given different human labels nevertheless end up near each other in representation space?"

I suspect that similar labels must end up in similar representation spaces (that's how the classifier works, after all), and that different human labels should not end up near each other (unless that label is significantly wrong - could check this)

The "general" claim is that for galaxies of equal votes, they are more similar in representation if they are visually similar beyond the labels

In [None]:
assert False

In [None]:
index_0 = 167864
index_1 = 110058

In [None]:
embedding[index_0]

In [None]:
embedding[index_1]

In [None]:
distances = np.zeros(len(embedding))
for n, index in enumerate(range(len(embedding))):
    distances[n] = np.sum(np.sqrt(np.abs(embedding[index_1] ** 2 - embedding[index] ** 2)))

In [None]:
np.argsort(distances)[:5]

In [None]:
np.sum(np.sqrt(np.abs(embedding[index_1] ** 2 - embedding[168590] ** 2)))

In [None]:
# show_galaxies(pca_df.iloc[np.argsort(distances)[:5]])

In [None]:

# for tag in top_tags:
#     iauname = top_galaxy_by_tag[tag]
#     query_index = np.argmax(pca_df['galaxy_id'] == iauname)
#     neighbor_indices = get_neighbors(pca_df[feature_cols].values, query_index, n_neighbors=max_galaxies, metric='manhattan')
#     fig = show_galaxies(pca_df.iloc[neighbor_indices])
    
#     tag_label = f'"{tag.capitalize()}"'.replace('-', ' ').replace('_', ' ')
#     fig = show_galaxies(pca_df.iloc[np.squeeze(indices)], n_galaxies=n_galaxies, label=tag_label)  # first is itself
#     fig.savefig(f'similar/pca10/similar_{tag_n}_tag_{tag}_n{n_galaxies}_pca10_talk.png'.format(n_galaxies))
# #     fig.savefig(f'similar/features/similar_tag_{tag}_n{n_galaxies}_features.png'.format(n_galaxies))
# #     plt.close()