# In-class notebook: 2024-01-03

In this notebook, we will get familiar with the basics function related to downloading SDSS data using astroML, some basic plotting techniques, as well as some implementations of using trees to speed up calculations.

This notebook is intended to support Chapter 1-2 of the textbook, and material is taken from the following scripts (from astroML):
* https://github.com/astroML/astroML_figures/blob/main/book_figures/chapter1/fig_SDSS_imaging.py
* https://github.com/astroML/astroML_figures/blob/main/book_figures/chapter1/fig_sdss_S82standards.py
* https://github.com/astroML/astroML_figures/blob/main/book_figures/chapter1/fig_S82_scatter_contour.py
* https://github.com/astroML/astroML_figures/blob/main/book_figures/chapter1/fig_S82_hess.py
* https://github.com/astroML/astroML_figures/blob/main/book_figures/chapter1/fig_sdss_spectrum.py
* https://github.com/astroML/astroML_figures/blob/main/book_figures/chapter2/fig_kdtree_example.py
* https://github.com/astroML/astroML_figures/blob/main/book_figures/chapter2/fig_balltree_example.py

In [None]:
# import basic stuff
import numpy as np
from matplotlib import pyplot as plt

## Download the SDSS imaging data and make some plots of the photometry for stars and galaxies

In [None]:
from astroML.datasets import fetch_imaging_sample

Nstars = 5000
Ngals = 5000

data = fetch_imaging_sample()
# you can see what this is doing here https://github.com/astroML/astroML/blob/main/astroML/datasets/imaging_sample.py
# check where the files were downloaded? 

print('total object counts', len(data))

# object type 6 is stars and 3 is galaxies
objtype = data['type']
stars = data[objtype == 6][:Nstars]
galaxies = data[objtype == 3][:Ngals]

In [None]:
# print the column names for both stars and galaxies
print(stars.dtype.names)

### Any idea what these columns are? 

You can read the schema to find out what is what - https://skyserver.sdss.org/dr7/en/help/browser/browser.asp

### Now plot

In [None]:
plot_kwargs = dict(color='k', linestyle='none', marker=',')

fig = plt.figure(figsize=(5, 3.75))

ax1 = fig.add_subplot(221)
ax1.plot(galaxies['gRaw'] - galaxies['rRaw'],
         galaxies['rRaw'],
         **plot_kwargs)

ax2 = fig.add_subplot(223, sharex=ax1)
ax2.plot(galaxies['gRaw'] - galaxies['rRaw'],
         galaxies['rRaw'] - galaxies['iRaw'],
         **plot_kwargs)

ax3 = fig.add_subplot(222, sharey=ax1)
ax3.plot(stars['gRaw'] - stars['rRaw'],
         stars['rRaw'],
         **plot_kwargs)

ax4 = fig.add_subplot(224, sharex=ax3, sharey=ax2)
ax4.plot(stars['gRaw'] - stars['rRaw'],
         stars['rRaw'] - stars['iRaw'],
         **plot_kwargs)

# set labels and titles
ax1.set_ylabel(r'${\rm r}$')
ax2.set_ylabel(r'${\rm r - i}$')
ax2.set_xlabel(r'${\rm g - r}$')
ax4.set_xlabel(r'${\rm g - r}$')
ax1.set_title('Galaxies')
ax3.set_title('Stars')

# set axis limits
ax2.set_xlim(-1, 3)
ax3.set_ylim(22.5, 14)
ax4.set_xlim(-1, 3)
ax4.set_ylim(-1, 2)

# adjust tick spacings on all axes
for ax in (ax1, ax2, ax3, ax4):
    ax.xaxis.set_major_locator(plt.MultipleLocator(1))
    ax.yaxis.set_major_locator(plt.MultipleLocator(1))

### What do we see here?

Discuss with your neighors. Make different plots if you need to justify your theory.

* selection in r band
* issue with scatter plots

## Download a set of SDSS standard stars and plot its color-color diagram

In [None]:
from astroML.datasets import fetch_sdss_S82standards

# Fetch the stripe 82 data (see https://github.com/astroML/astroML/blob/main/astroML/datasets/sdss_S82standards.py)
data = fetch_sdss_S82standards()

# select the first 10000 points
data = data[:10000]

# select the mean magnitudes for g, r, i
g = data['mmu_g']
r = data['mmu_r']
i = data['mmu_i']

### Plot

In [None]:
# Plot the g-r vs r-i colors
fig, ax = plt.subplots(figsize=(5, 3.75))
ax.plot(g - r, r - i, marker='.', markersize=2,
        color='black', linestyle='none')

ax.set_xlim(-0.6, 2.0)
ax.set_ylim(-0.6, 2.5)

ax.set_xlabel(r'${\rm g - r}$')
ax.set_ylabel(r'${\rm r - i}$')

### What do we see here?

Discuss with your neighbors.

* More precise photometry
* The region with the highest point density is dominated by main sequence stars. The thin extension toward the lower-left
corner is dominated by the so-called blue horizontal branch stars and white dwarf stars.

### We can plot it differently, with more points: contours and Hess

In [None]:
# Fetch the stripe 82 data (see https://github.com/astroML/astroML/blob/main/astroML/datasets/sdss_S82standards.py)
data = fetch_sdss_S82standards()

# select the mean magnitudes for g, r, i
g = data['mmu_g']
r = data['mmu_r']
i = data['mmu_i']

In [None]:
from astroML.plotting import scatter_contour

fig, ax = plt.subplots(figsize=(5, 3.75))
scatter_contour(g - r, r - i, threshold=200, log_counts=True, ax=ax,
                histogram2d_args=dict(bins=40),
                plot_args=dict(marker=',', linestyle='none', color='black'),
                contour_args=dict(cmap=plt.cm.bone))

ax.set_xlabel(r'${\rm g - r}$')
ax.set_ylabel(r'${\rm r - i}$')

ax.set_xlim(-0.6, 2.5)
ax.set_ylim(-0.6, 2.5)


In [None]:
import copy

# Compute and plot the 2D histogram
H, xbins, ybins = np.histogram2d(g - r, r - i,
                                 bins=(np.linspace(-0.5, 2.5, 50),
                                       np.linspace(-0.5, 2.5, 50)))

# Create a black and white color map where bad data (NaNs) are white
cmap = copy.copy(plt.cm.get_cmap('binary'))
cmap.set_bad('w', 1.)

# Use the image display function imshow() to plot the result
fig, ax = plt.subplots(figsize=(5, 3.75))
H[H == 0] = 1  # prevent warnings in log10
ax.imshow(np.log10(H).T, origin='lower',
          extent=[xbins[0], xbins[-1], ybins[0], ybins[-1]],
          cmap=cmap, interpolation='nearest',
          aspect='auto')

ax.set_xlabel(r'${\rm g - r}$')
ax.set_ylabel(r'${\rm r - i}$')

ax.set_xlim(-0.6, 2.5)
ax.set_ylim(-0.6, 2.5)

## We can also download SDSS spectra and plot them

In [None]:
from astroML.datasets import fetch_sdss_spectrum

# Fetch single spectrum, you can find these unique identifiers using e.g. fetch_sdss_specgals https://github.com/astroML/astroML/blob/main/astroML/datasets/sdss_specgals.py
plate = 1615
mjd = 53166
fiber = 513

spec = fetch_sdss_spectrum(plate, mjd, fiber)

#------------------------------------------------------------
# Plot the resulting spectrum
fig, ax = plt.subplots(figsize=(5, 3.75))
ax.plot(spec.wavelength(), spec.spectrum, '-k', lw=1)

ax.set_xlim(3000, 10000)
ax.set_ylim(25, 300)

ax.set_xlabel(r'$\lambda {(\rm \AA)}$')
ax.set_ylabel('Flux')
ax.set_title('Plate = %(plate)i, MJD = %(mjd)i, Fiber = %(fiber)i' % locals())

### Plot a couple more other spectra, maybe try plotting for red and blue galaxies separately?

* what are we looking at?
* what are some qualitative features?

## We are going to switch gear now that you know how to get data -- let's talk about trees

This is in the context of one way we commonly use to speed up calculations with large datasets. We look at KD trees and Ball trees.

In [None]:
# Create a KDTree class which will recursively subdivide the
# space into rectangular regions.  Note that this is just an example
# and shouldn't be used for real computation; instead use the optimized
# code in scipy.spatial.cKDTree or sklearn.neighbors.BallTree

class KDTree:
    """Simple KD tree class"""

    # class initialization function
    def __init__(self, data, mins, maxs):
        self.data = np.asarray(data)

        # data should be two-dimensional
        assert self.data.shape[1] == 2

        if mins is None:
            mins = data.min(0)
        if maxs is None:
            maxs = data.max(0)

        self.mins = np.asarray(mins)
        self.maxs = np.asarray(maxs)
        self.sizes = self.maxs - self.mins

        self.child1 = None
        self.child2 = None

        if len(data) > 1:
            # sort on the dimension with the largest spread (this alternates in this example)
            largest_dim = np.argmax(self.sizes)
            i_sort = np.argsort(self.data[:, largest_dim])
            self.data[:] = self.data[i_sort, :]

            # find split point
            N = self.data.shape[0]
            half_N = int(N / 2)
            split_point = 0.5 * (self.data[half_N, largest_dim]
                                 + self.data[half_N - 1, largest_dim])

            # create subnodes (form a line in the plane)
            mins1 = self.mins.copy()
            mins1[largest_dim] = split_point
            maxs2 = self.maxs.copy()
            maxs2[largest_dim] = split_point

            # Recursively build a KD-tree on each sub-node
            self.child1 = KDTree(self.data[half_N:], mins1, self.maxs)
            self.child2 = KDTree(self.data[:half_N], self.mins, maxs2)

    def draw_rectangle(self, ax, depth=None):
        """Recursively plot a visualization of the KD tree region"""
        if depth == 0:
            rect = plt.Rectangle(self.mins, *self.sizes, ec='k', fc='none')
            ax.add_patch(rect)

        if self.child1 is not None:
            if depth is None:
                self.child1.draw_rectangle(ax)
                self.child2.draw_rectangle(ax)
            elif depth > 0:
                self.child1.draw_rectangle(ax, depth - 1)
                self.child2.draw_rectangle(ax, depth - 1)


In [None]:
# Create a set of structured random points in two dimensions
np.random.seed(0)

X = np.random.random((30, 2)) * 2 - 1
X[:, 1] *= 0.1
X[:, 1] += X[:, 0] ** 2

plt.figure(figsize=(3,3))
plt.scatter(X[:,0], X[:,1])
plt.xlim(-1.1,1.1)
plt.ylim(-0.1,1.1)

In [None]:

#------------------------------------------------------------
# Use our KD Tree class to recursively divide the space
KDT = KDTree(X, [-1.1, -0.1], [1.1, 1.1])

#------------------------------------------------------------
# Plot four different levels of the KD tree
fig = plt.figure(figsize=(5, 5))
fig.subplots_adjust(wspace=0.1, hspace=0.15,
                    left=0.1, right=0.9,
                    bottom=0.05, top=0.9)

for level in range(1, 5):
    ax = fig.add_subplot(2, 2, level, xticks=[], yticks=[])
    ax.scatter(X[:, 0], X[:, 1], s=9)
    KDT.draw_rectangle(ax, depth=level - 1)

    ax.set_xlim(-1.2, 1.2)
    ax.set_ylim(-0.15, 1.15)
    ax.set_title('level %i' % level)

# suptitle() adds a title to the entire figure
fig.suptitle('$k$d-tree Example')

In [None]:
from matplotlib.patches import Circle

class BallTree:
    """Simple Ball tree class"""

    # class initialization function
    def __init__(self, data):
        self.data = np.asarray(data)

        # data should be two-dimensional
        assert self.data.shape[1] == 2

        self.loc = data.mean(0)
        self.radius = np.sqrt(np.max(np.sum((self.data - self.loc) ** 2, 1)))

        self.child1 = None
        self.child2 = None

        if len(self.data) > 1:
            # sort on the dimension with the largest spread
            largest_dim = np.argmax(self.data.max(0) - self.data.min(0))
            i_sort = np.argsort(self.data[:, largest_dim])
            self.data[:] = self.data[i_sort, :]

            # find split point
            N = self.data.shape[0]
            half_N = int(N / 2)
            split_point = 0.5 * (self.data[half_N, largest_dim]
                                 + self.data[half_N - 1, largest_dim])

            # recursively create subnodes
            self.child1 = BallTree(self.data[half_N:])
            self.child2 = BallTree(self.data[:half_N])

    def draw_circle(self, ax, depth=None):
        """Recursively plot a visualization of the Ball tree region"""
        if depth is None or depth == 0:
            circ = Circle(self.loc, self.radius, ec='k', fc='none')
            ax.add_patch(circ)

        if self.child1 is not None:
            if depth is None:
                self.child1.draw_circle(ax)
                self.child2.draw_circle(ax)
            elif depth > 0:
                self.child1.draw_circle(ax, depth - 1)
                self.child2.draw_circle(ax, depth - 1)

In [None]:
# Use our Ball Tree class to recursively divide the space
BT = BallTree(X)

#------------------------------------------------------------
# Plot four different levels of the Ball tree
fig = plt.figure(figsize=(5, 5))
fig.subplots_adjust(wspace=0.1, hspace=0.15,
                    left=0.1, right=0.9,
                    bottom=0.05, top=0.9)

for level in range(1, 5):
    ax = fig.add_subplot(2, 2, level, xticks=[], yticks=[])
    ax.scatter(X[:, 0], X[:, 1], s=9)
    BT.draw_circle(ax, depth=level - 1)

    ax.set_xlim(-1.35, 1.35)
    ax.set_ylim(-1.0, 1.7)
    ax.set_title('level %i' % level)

# suptitle() adds a title to the entire figure
fig.suptitle('Ball-tree Example')