# Can the Swin Transformer Embeddings Capture the Variance in Galaxies?
---
By Nabeel Rehemtulla and Ved Shah
07 May 2025
version 0.1
## Overview
This notebook has been developed for beginners and experts to take the embeddings (i.e., a representation of the input data, galaxy images here, in a high dimensional space) from the Swin Transformer applied to Galaxy Zoo sources to see whether this state of the art model can: (i) classify galaxy morphology, (ii) identify anomalies, or (iii) your other (more interesting?) idea.
This work is intentionally open ended and designed for collaboration - please work with someone you have not collaborated with before - we think it may be possible to have interesting results within ~60-90 min, but most importantly we want you to start something.

### Galaxy Zoo Project
Galaxy Zoo is a citizen science project that enlists volunteers to visually classify galaxies based on their morphology using images from various telescopes ([Lintott et al., 2008](https://doi.org/10.1111/j.1365-2966.2008.13689.x)). With immense contributions from citizen scientists, this effort has yielded large, labeled data sets well suited for large-scale data analysis in astronomy. We will work with Galaxy Zoo 2, which involves classified images of ~240,000 galaxies from the Sloan Digital Sky Survey ([Willett et al., 2013](https://doi.org/10.1093/mnras/stt1458); [Hart et al., 2016](https://academic.oup.com/mnras/article/461/4/3663/2608720)).

### Swin Transformer
We will process the galaxy images with a **Swin Transformer**, a hierarchical vision transformer model that computes self-attention within shifted windows, achieving state-of-the-art performance on computer vision tasks ([Liu et al., 2021](https://arxiv.org/abs/2103.14030)). Its design enables efficient computation and scalability to various image resolutions, making it suitable for diverse applications.

Today, we'll work with a Swin Transformer pretrained on the ImageNet dataset: a large-scale image database containing millions of labeled images across 1,000 categories like cats, dogs, oil filters, and balloons. Although a model trained on such images may not appear useful for galaxy classification, we'll find that this pre-trained Swin Transformer is a generally effective tool for extracting information from images.

### Embeddings
Nabeel and Ved have run all ~240,000 Galaxy Zoo 2 images through this pre-trained Swin Transformer, producing **embeddings**. These embeddings are dense 1D vector representations of the data, ideally capturing all useful information in the image in a low-dimensional space.

### Let's start with loading the data

In [None]:
# This notebook can be run through Google Colab but we strongly recommend running it on your local machine
# Start by cloning this repository https://github.com/nabeelre/SkAI-Pillar1-GalaxyZoo
# You'll then need to install the necessary packages locally
  # You probably already have most but this will install everything
  # pip install pillow umap-learn scikit-learn pandas numpy matplotlib
# This zip contains all the data used here:

# Alternatively, you can run this notebook in Google Colab and grant it access to files in your Google Drive.
# To do so, you'll first need to add a shortcut to this Google Drive folder to your own Drive
  # Open https://drive.google.com/drive/folders/176O-BZ6OcNAAmruBdk32PQLdp4iaahy5?usp=share_link
  # Click on SkAI-Pillar-1-GalaxyZoo
  # Click on Organize -> Add Shortcut
  # Add it to My Drive/
# Then, load the Drive helper and mount
import os
from google.colab import drive
drive.mount('/content/drive')
os.chdir("/content/drive/My Drive/SkAI-Pillar-1-GalaxyZoo")
# This option is mostly functional but Google Drive has trouble with showing the galaxy images


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# For these, run:
# pip install pillow umap-learn scikit-learn
from PIL import Image
import umap.umap_ as umap
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def show_image_by_assetid(asset_id, images_dir="data/images/", ax=None):
    file_path = os.path.join(images_dir, f"{asset_id}.jpg")
    if os.path.exists(file_path):
        img = Image.open(file_path)

        if ax is not None:
            ax.imshow(img)
            ax.axis('off')
        else:
            plt.imshow(img)
            plt.axis('off')
            plt.show()
    else:
        print(f"File not found: {file_path}")

In [None]:
gz_df = pd.read_csv('data/gz2_data.csv', index_col=None)

gz_df

Unnamed: 0,specobjid,dr8objid,dr7objid,ra,dec,rastring,decstring,sample,gz2class,total_classifications,...,embedding_758,embedding_759,embedding_760,embedding_761,embedding_762,embedding_763,embedding_764,embedding_765,embedding_766,embedding_767
0,1.802675e+18,,588017703996096547,160.990400,11.703790,10:43:57.70,+11:42:13.6,original,SBb?t,44,...,-0.208791,-0.655810,-0.256993,-0.297325,-0.535887,-0.155010,0.081996,-0.204063,0.643326,-0.644211
1,1.992984e+18,,587738569780428805,192.410830,15.164207,12:49:38.60,+15:09:51.1,original,Ser,45,...,-0.351440,-0.361393,-0.126954,-0.577279,-0.537606,0.012043,0.054425,-0.085100,0.991432,-0.594417
2,1.489569e+18,,587735695913320507,210.802200,54.348953,14:03:12.53,+54:20:56.2,original,Sc+t,46,...,-0.287179,-1.003007,-0.110764,-0.315056,-0.225965,0.261588,-0.258050,0.034452,0.380732,-0.283825
3,2.924084e+18,1.237668e+18,587742775634624545,185.303420,18.382704,12:21:12.82,+18:22:57.7,original,SBc(r),45,...,-0.303128,-0.499811,-0.352115,-0.370614,-0.393273,0.163165,0.246283,-0.055247,0.847974,-0.272424
4,1.387165e+18,1.237658e+18,587732769983889439,187.366790,8.749928,12:29:28.03,+08:44:59.7,extra,Ser,49,...,-0.336349,-0.375839,-0.409501,-0.145990,-0.613818,-0.088835,0.099277,-0.235265,0.678835,-0.124420
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
243429,2.013145e+18,1.237664e+18,587737807967092857,145.174330,63.154312,09:40:41.84,+63:09:15.5,original,Sc2l,35,...,-0.132506,-0.554635,-0.449896,-0.531359,-0.229972,-0.291191,0.217363,0.023003,0.484809,0.222627
243430,3.333027e+17,1.237649e+18,588848898847605001,199.857380,-1.037356,13:19:25.77,-01:02:14.5,original,Sb,40,...,-0.197414,-0.472145,-0.340123,-0.278594,-0.404097,0.042914,0.023667,-0.064570,0.538044,0.581938
243431,1.959147e+18,1.237661e+18,587735348018872369,139.620040,10.317866,09:18:28.81,+10:19:04.3,original,Ei(o),49,...,-0.080637,-0.410562,-0.452735,-0.010951,-0.503455,0.099016,-0.428303,-0.041844,-0.449666,0.118865
243432,4.673293e+17,1.237660e+18,588015508218577115,52.470364,-0.750249,03:29:52.89,+00:45:00.9,stripe82,Er,43,...,-0.358848,-0.647999,-0.313777,-0.438809,-0.268666,-0.021707,0.211087,0.021398,0.743367,0.252559


In [None]:
embs = gz_df[[f'embedding_{i}' for i in range(768)]]

embs

Unnamed: 0,embedding_0,embedding_1,embedding_2,embedding_3,embedding_4,embedding_5,embedding_6,embedding_7,embedding_8,embedding_9,...,embedding_758,embedding_759,embedding_760,embedding_761,embedding_762,embedding_763,embedding_764,embedding_765,embedding_766,embedding_767
0,0.398954,0.470600,-0.082388,0.026677,-0.436941,0.339321,-0.062581,-0.024188,0.619561,0.812915,...,-0.208791,-0.655810,-0.256993,-0.297325,-0.535887,-0.155010,0.081996,-0.204063,0.643326,-0.644211
1,0.804402,0.077317,-0.054148,0.019534,-0.522124,0.308817,-0.024202,-0.161937,0.638864,0.776197,...,-0.351440,-0.361393,-0.126954,-0.577279,-0.537606,0.012043,0.054425,-0.085100,0.991432,-0.594417
2,0.425019,0.602385,-0.414535,-0.417999,-0.367586,-0.253823,-0.063896,0.241658,0.864711,0.522136,...,-0.287179,-1.003007,-0.110764,-0.315056,-0.225965,0.261588,-0.258050,0.034452,0.380732,-0.283825
3,0.581502,0.068087,0.007727,-0.001978,-0.361340,0.105308,-0.035522,0.039886,0.411517,0.640379,...,-0.303128,-0.499811,-0.352115,-0.370614,-0.393273,0.163165,0.246283,-0.055247,0.847974,-0.272424
4,0.534941,-0.113286,0.027658,-0.374459,-0.428486,0.013077,-0.020676,0.036668,0.537464,0.756157,...,-0.336349,-0.375839,-0.409501,-0.145990,-0.613818,-0.088835,0.099277,-0.235265,0.678835,-0.124420
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
243429,0.468131,0.216164,-0.311012,-0.577160,-0.438113,0.014854,-0.037592,-0.049144,0.830118,0.743296,...,-0.132506,-0.554635,-0.449896,-0.531359,-0.229972,-0.291191,0.217363,0.023003,0.484809,0.222627
243430,0.383623,0.385306,-0.011959,-0.340805,-0.892922,0.107852,-0.033983,-0.189224,0.544229,0.671145,...,-0.197414,-0.472145,-0.340123,-0.278594,-0.404097,0.042914,0.023667,-0.064570,0.538044,0.581938
243431,0.149528,0.732732,-0.088457,-0.697601,-0.557609,-0.112229,-0.062160,-0.342729,1.033721,0.915165,...,-0.080637,-0.410562,-0.452735,-0.010951,-0.503455,0.099016,-0.428303,-0.041844,-0.449666,0.118865
243432,0.603055,0.452526,-0.127767,-0.518049,-0.694753,0.084832,-0.040239,-0.186272,0.471817,0.788210,...,-0.358848,-0.647999,-0.313777,-0.438809,-0.268666,-0.021707,0.211087,0.021398,0.743367,0.252559


This `DataFrame` contains information on the 240,000 galaxies classified in the Galaxy Zoo 2 project. Notable columns including the following

`morphology` is a string describing the shape/type of the galaxy. For astronomers, these are similar to but not identical to the standard Hubble Tuning Fork types. Morphologies beginning with `E` are galaxies with smooth light profiles. Their degree of roundless is encoded by the following character: `r` for completely round, `i` for in-between, and `c` for cigar-shaped. Morphologies beginning with `S` are galaxies with disks. Edge-on disks follow `S` with `er`, `eb`, or `en` for having a round, boxy, or no bulge respectively. Face-on disks have `B` following `S` if they show a bar and `a`, `b`, `c`, `d` next encoding the prominance of the bulge: dominant, obvious, just noticable, and none, respectively.

`REDSHIFT` is a float containing the spectroscopic redshift of the galaxy. Redshift ($z$) is a measure of the galaxy's line-of-sight velocity to/from us. We use this as a proxy for measuring the galaxy's distance from us where larger values indicate greater distances and $z=0$ indicates no distance from us.  

Where relevant, Galaxy Zoo citizen scientists also indicate additional information, like the presence of a ring shape (`ring`), an arc resulting from lensing (`lens/arc`), a galaxy with disturbed morphology (`disturbed`), irregular morphology (`irregular`), a galaxy merger (`merger`), a prominent dust lane (`dust lane`), or something else (`other`). For face-on disk galaxies, the count of spiral arms is available (`arm_count`; `+` indicates more than 4 arms) and how tightly they are wound (`arm_winding`; tightly `t`, moderately `m`, and loosely `l`).

`asset_id` can be used for fetching the image of a galaxy. We've provided a helpful `show_image_by_assetid()` function.

`embedding_0` to `embedding_767` represent the 1D embedding vector we produced with the pre-trained Swin transformer. We've extracted these into a 2D array where each row is an embedding that corresponds to the galaxy in the same row of `gz_df`.

Many other features are available like the brightness of the galaxy in `UGRIZ` bands are also available. We also include the sky coordinates and `SDSS` DR7 object IDs in case you'd like to gather other information on these galaxies.

### Let's start by looking at some examples in each morphology type

In [None]:
morphologies = ['Er', 'Ei', 'Ec']
filtered_dfs = {morph: gz_df[gz_df['morphology'] == morph].sample(5) for morph in morphologies}

fig, axes = plt.subplots(5, 3, figsize=(15, 20))

for col, morph in enumerate(morphologies):
    for row, (_, row_data) in enumerate(filtered_dfs[morph].iterrows()):
        ax = axes[row, col]
        show_image_by_assetid(row_data['asset_id'], ax=ax)
        ax.set_title(f"{morph} - asset_id={row_data['asset_id']}")
        ax.axis('off')

plt.suptitle('Smooth galaxies with round, intermediate, and cigar shapes', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

File not found: images/81863.jpg
File not found: images/69889.jpg


In [None]:
morphologies = ['Ser', 'Seb', 'Sen']
filtered_dfs = {morph: gz_df[gz_df['morphology'] == morph].sample(5) for morph in morphologies}

fig, axes = plt.subplots(5, 3, figsize=(15, 20))

for col, morph in enumerate(morphologies):
    for row, (_, row_data) in enumerate(filtered_dfs[morph].iterrows()):
        ax = axes[row, col]
        show_image_by_assetid(row_data['asset_id'], ax=ax)
        ax.set_title(f"{morph} - asset_id={row_data['asset_id']}")
        ax.axis('off')

plt.suptitle('Disk galaxies with round, boxy, and no bulges', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

In [None]:
morphologies = ['SBa', 'SBd', 'Sa', 'Sd']
filtered_dfs = {morph: gz_df[gz_df['morphology'] == morph].sample(5) for morph in morphologies}

fig, axes = plt.subplots(5, 4, figsize=(15, 20))

for col, morph in enumerate(morphologies):
    for row, (_, row_data) in enumerate(filtered_dfs[morph].iterrows()):
        ax = axes[row, col]
        show_image_by_assetid(row_data['asset_id'], ax=ax)
        ax.set_title(f"{morph} - asset_id={row_data['asset_id']}")
        ax.axis('off')

plt.suptitle('Disk galaxies with/without bars (B / no B) with prominent/weak bulges (a/d)', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

In [None]:
flags = ['ring', 'lens/arc', 'merger', 'disturbed', 'other']
filtered_rows = [gz_df[gz_df[flag] == 1].sample(1).iloc[0] for flag in flags]

fig, axes = plt.subplots(1, 5, figsize=(20, 5))

for ax, flag, row_data in zip(axes, flags, filtered_rows):
    show_image_by_assetid(row_data['asset_id'], ax=ax)
    ax.set_title(f"{flag.capitalize()} - asset_id={row_data['asset_id']}")
    ax.axis('off')

plt.suptitle('Galaxies with some special features', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

### Let's see if we can see any clustering in the embeddings!
The embeddings should contain much of the information in those images. Inspecting how the embeddings cluster is impossible in their native 768 dimensional space, so we'll have to reduce their dimensionality. We'll start this off with [UMAP](https://umap-learn.readthedocs.io/en/latest/) (Uniform Manifold Approximation and Projection), a dimensionality reduction technique that preserves the global structure of data while emphasizing local relationships.

You can also try out [t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html), [PCA](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html), and other similar tools.

In [None]:
# Run UMAP on the embeddings
umap_model = umap.UMAP()
embedding_2d = umap_model.fit_transform(embs)

# store the embeddings into our dataframe
gz_df[['UMAP_1', 'UMAP_2']] = embedding_2d


In [None]:
plt.figure(figsize=(10, 8))
plt.scatter(gz_df['UMAP_1'], gz_df['UMAP_2'], s=5)
plt.xlabel('UMAP_1')
plt.ylabel('UMAP_2')
plt.show()

### Possibile tasks with Galaxy Zoo 2 embeddings
1. Galaxy morphology classification

2. Anomalous galaxy detection

3. (*Challenge*) Fine-tuning a pre-trained model to compete with state-of-the-art performance

&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; [Cao et al 2024](https://www.aanda.org/articles/aa/pdf/2024/03/aa48544-23.pdf) claims ~98% accuracy at 5-way classification on this dataset with a convolutional vision transformer.

&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; Classes they used: In-between smooth, Completely round smooth, Edge-on, Spiral, Cigar-shape smooth

&nbsp; &nbsp; &nbsp; &nbsp;... Or whatever else you can think of!