## Overview
**Adam?** Two sentences on Pillar I, what the intention of this activity is ...


### Galaxy Zoo Project
Galaxy Zoo is a citizen science project that enlists volunteers to visually classify galaxies based on their morphology using images from various telescopes ([Lintott et al., 2008](https://doi.org/10.1111/j.1365-2966.2008.13689.x)). With immense contributions from citizen scientists, this effort has yielded large, labeled data sets well suited for large-scale data analysis in astronomy. We will work with Galaxy Zoo 2, which involves classified images of ~240,000 galaxies from the Sloan Digital Sky Survey ([Willett et al., 2013](https://doi.org/10.1093/mnras/stt1458); [Hart et al., 2016](https://academic.oup.com/mnras/article/461/4/3663/2608720)).

### Swin Transformer
We will process the galaxy images with a **Swin Transformer**, a hierarchical vision transformer model that computes self-attention within shifted windows, achieving state-of-the-art performance on computer vision tasks ([Liu et al., 2021](https://arxiv.org/abs/2103.14030)). Its design enables efficient computation and scalability to various image resolutions, making it suitable for diverse applications.

Today, we'll work with a Swin Transformer pretrained on the ImageNet dataset: a large-scale image database containing millions of labeled images across 1,000 categories like cats, dogs, oil filters, and balloons. Although a model trained on such images may not appear useful for galaxy classification, we'll find that this pre-trained Swin Transformer is a generally effective tool for extracting information from images.

### Embeddings
Nabeel and Ved have run all ~240,000 Galaxy Zoo 2 images through this pre-trained Swin Transformer, producing **embeddings**. These embeddings are dense 1D vector representations of the data, ideally capturing all useful information in the image in a low-dimensional space.

### Let's start with loading the data

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from utils import show_image_by_assetid

import torch
from torch import nn
import umap.umap_ as umap
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [None]:
gz_df = pd.read_csv('data/gz2_data.csv', index_col=None)

gz_df

In [None]:
embs = gz_df[[f'embedding_{i}' for i in range(768)]]

embs

This `DataFrame` contains information on the 240,000 galaxies classified in the Galaxy Zoo 2 project. Notable columns including the following

`morphology` is a string describing the shape/type of the galaxy. For astronomers, these are similar to but not identical to the standard Hubble Tuning Fork types. Morphologies beginning with `E` are galaxies with smooth light profiles. Their degree of roundless is encoded by the following character: `r` for completely round, `i` for in-between, and `c` for cigar-shaped. Morphologies beginning with `S` are galaxies with disks. Edge-on disks follow `S` with `er`, `eb`, or `en` for having a round, boxy, or no bulge respectively. Face-on disks have `B` following `S` if they show a bar and `a`, `b`, `c`, `d` next encoding the prominance of the bulge: dominant, obvious, just noticable, and none, respectively. 

`REDSHIFT` is a float containing the spectroscopic redshift of the galaxy. In simplest terms, redshift ($z$) is a measure of distance from us where larger values indicate greater distances and a minimum at $z=0$, indicating no distance from us.  

Where relevant, Galaxy Zoo citizen scientists also indicate additional information, like the presence of a ring shape (`ring`), an arc resulting from lensing (`lens/arc`), a galaxy with disturbed morphology (`disturbed`), irregular morphology (`irregular`), a galaxy merger (`merger`), a prominent dust lane (`dust lane`), or something else (`other`). For face-on disk galaxies, the count of spiral arms is available (`arm_count`; `+` indicates more than 4 arms) and how tightly they are wound (`arm_winding`; tightly `t`, moderately `m`, and loosely `l`).

`asset_id` can be used for fetching the image of a galaxy. We've provided a helpful `show_image_by_assetid()` function.

`embedding_0` to `embedding_767` represent the 1D embedding vector we produced with the pre-trained Swin transformer. We've extracted these into a 2D array where each row is an embedding that corresponds to the galaxy in the same row of `gz_df`.

Many other features are available like the brightness of the galaxy in `UGRIZ` bands are also available. We also include the sky coordinates and `SDSS` DR7 object IDs in case you'd like to gather other information on these galaxies.

### Let's start by looking at some examples in each morphology type

In [None]:
morphologies = ['Er', 'Ei', 'Ec']
filtered_dfs = {morph: gz_df[gz_df['morphology'] == morph].sample(5) for morph in morphologies}

fig, axes = plt.subplots(5, 3, figsize=(15, 20))

for col, morph in enumerate(morphologies):
    for row, (_, row_data) in enumerate(filtered_dfs[morph].iterrows()):
        ax = axes[row, col]
        show_image_by_assetid(row_data['asset_id'], ax=ax)
        ax.set_title(f"{morph} - asset_id={row_data['asset_id']}")
        ax.axis('off')

plt.suptitle('Smooth galaxies with round, intermediate, and cigar shapes', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

In [None]:
morphologies = ['Ser', 'Seb', 'Sen']
filtered_dfs = {morph: gz_df[gz_df['morphology'] == morph].sample(5) for morph in morphologies}

fig, axes = plt.subplots(5, 3, figsize=(15, 20))

for col, morph in enumerate(morphologies):
    for row, (_, row_data) in enumerate(filtered_dfs[morph].iterrows()):
        ax = axes[row, col]
        show_image_by_assetid(row_data['asset_id'], ax=ax)
        ax.set_title(f"{morph} - asset_id={row_data['asset_id']}")
        ax.axis('off')

plt.suptitle('Disk galaxies with round, boxy, and no bulges', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

In [None]:
morphologies = ['SBa', 'SBd', 'Sa', 'Sd']
filtered_dfs = {morph: gz_df[gz_df['morphology'] == morph].sample(5) for morph in morphologies}

fig, axes = plt.subplots(5, 4, figsize=(15, 20))

for col, morph in enumerate(morphologies):
    for row, (_, row_data) in enumerate(filtered_dfs[morph].iterrows()):
        ax = axes[row, col]
        show_image_by_assetid(row_data['asset_id'], ax=ax)
        ax.set_title(f"{morph} - asset_id={row_data['asset_id']}")
        ax.axis('off')

plt.suptitle('Disk galaxies with/without bars (B / no B) with prominent/weak bulges (a/d)', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

In [None]:
flags = ['ring', 'lens/arc', 'merger', 'disturbed', 'other']
filtered_rows = [gz_df[gz_df[flag] == 1].sample(1).iloc[0] for flag in flags]

fig, axes = plt.subplots(1, 5, figsize=(20, 5))

for ax, flag, row_data in zip(axes, flags, filtered_rows):
    show_image_by_assetid(row_data['asset_id'], ax=ax)
    ax.set_title(f"{flag.capitalize()} - asset_id={row_data['asset_id']}")
    ax.axis('off')

plt.suptitle('Galaxies with some special features', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

### Let's see if we can see any clustering in the embeddings!
The embeddings should contain much of the information in those images. Inspecting how the embeddings cluster is impossible in their native 768 dimensional space, so we'll have to reduce their dimensionality. We'll start this off with [UMAP](https://umap-learn.readthedocs.io/en/latest/) (Uniform Manifold Approximation and Projection), a dimensionality reduction technique that preserves the global structure of data while emphasizing local relationships. 

You can also try out [t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html), [PCA](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html), and other similar tools.

In [None]:
# Run UMAP on the embeddings
umap_model = umap.UMAP()
embedding_2d = umap_model.fit_transform(embs)

# store the embeddings into our dataframe
gz_df[['UMAP_1', 'UMAP_2']] = embedding_2d


In [None]:
plt.figure(figsize=(10, 8))
plt.scatter(gz_df['UMAP_1'], gz_df['UMAP_2'], s=5)
plt.xlabel('UMAP_1')
plt.ylabel('UMAP_2')
plt.show()

### Possibile tasks with Galaxy Zoo 2 embeddings
1. Galaxy morphology classification

2. Anomalous galaxy detection

3. (*Challenge*) Fine-tuning a pre-trained model to compete with state-of-the-art performance

&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; [Cao et al 2024](https://www.aanda.org/articles/aa/pdf/2024/03/aa48544-23.pdf) claims ~98% accuracy at 5-way classification on this dataset with a convolutional vision transformer.

&nbsp; &nbsp; &nbsp; &nbsp;... Or whatever else you can think of!