In [None]:
# Import necessary libraries
import os
import pandas as pd
from astropy.coordinates import SkyCoord
from astroquery.mast import Catalogs
from astropy import units as u
from astropy.io import fits
from astropy.io.fits import CompImageHDU, getheader
from tqdm import tqdm
import glob
import requests
import urllib
import asyncio
import nest_asyncio
import aiohttp
import aiofiles
from io import BytesIO

In [None]:
# Step 1.1: Extract SDSS DR18 LRG Samples using CasJobs
# Data listed in dr18_lrg_sample.csv (Sample Size: 1,174,900)
# Criteria : 

# SELECT
  # s.specobjid,
  # s.z AS redshift,                   
  # s.veldisp,
  # p.ra, p.dec, p.u, p.g, p.r, p.i, p.modelMag_r,
  # s.programname, s.plate, s.fiberid, s.mjd into mydb.LRG_full_catalog from SpecObjAll AS s
# JOIN PhotoObjAll AS p ON s.bestobjid = p.objid
# WHERE
  # s.class = 'GALAXY'
  # AND s.z BETWEEN 0.1 AND 0.7
  # AND s.veldisp > 0 AND s.veldisp < 500
  # AND s.programname IN ('boss', 'eboss')
  # AND (p.r - p.i) > 0.5
  # AND (p.g - p.r) > 0.7
  # AND p.modelMag_r BETWEEN 16 AND 21

In [None]:
# Cross-match SDSS data with the PS1 catalog

BATCH_SIZE = 10000
INPUT_CSV = "sdss_dr18_lrg_sample.csv"
OUTPUT_CSV_TEMPLATE = "sdss_lrg_ps1_matched_batch_{batch_num}.csv"
OUTPUT_FOLDER = "sdss_lrg_queried_objects"

df = pd.read_csv(INPUT_CSV)
n_objects = len(df)

def query_ps1_batch(batch_df):
    g_mag, r_mag, i_mag = [], [], []
    g_depth, r_depth, i_depth = [], [], []
    g_fwhm, r_fwhm, i_fwhm = [], [], []

    for _, row in tqdm(batch_df.iterrows(), total=len(batch_df)):
        try:
            coord = SkyCoord(ra=row['ra'], dec=row['dec'], unit='deg', frame='icrs')
            result = Catalogs.query_region(coord, radius=2.5 * u.arcsec, catalog='PanSTARRS', data_release='dr2')
            if len(result) > 0:
                best = result[0]
                g_mag.append(best.get('gMeanPSFMag'))
                r_mag.append(best.get('rMeanPSFMag'))
                i_mag.append(best.get('iMeanPSFMag'))

                g_depth.append(best.get('gMeanDepth'))
                r_depth.append(best.get('rMeanDepth'))
                i_depth.append(best.get('iMeanDepth'))

                g_fwhm.append(best.get('gFWHM'))
                r_fwhm.append(best.get('rFWHM'))
                i_fwhm.append(best.get('iFWHM'))
            else:
                g_mag.append(None)
                r_mag.append(None)
                i_mag.append(None)
                g_depth.append(None)
                r_depth.append(None)
                i_depth.append(None)
                g_fwhm.append(None)
                r_fwhm.append(None)
                i_fwhm.append(None)
        except Exception:
            g_mag.append(None)
            r_mag.append(None)
            i_mag.append(None)
            g_depth.append(None)
            r_depth.append(None)
            i_depth.append(None)
            g_fwhm.append(None)
            r_fwhm.append(None)
            i_fwhm.append(None)

    batch_df['ps1_g_mag'] = g_mag
    batch_df['ps1_r_mag'] = r_mag
    batch_df['ps1_i_mag'] = i_mag
    batch_df['ps1_g_depth'] = g_depth
    batch_df['ps1_r_depth'] = r_depth
    batch_df['ps1_i_depth'] = i_depth
    batch_df['ps1_g_fwhm'] = g_fwhm
    batch_df['ps1_r_fwhm'] = r_fwhm
    batch_df['ps1_i_fwhm'] = i_fwhm

    # Filter: keep only rows with at least 2 non-null PS1 magnitudes
    mask = (
        batch_df[['ps1_g_mag', 'ps1_r_mag', 'ps1_i_mag']]
        .notnull()
        .sum(axis=1) >= 2
    )
    filtered_df = batch_df[mask].reset_index(drop=True)
    print(f"Filtered from {len(batch_df)} → {len(filtered_df)} rows with ≥ 2 PS1 bands")
    return filtered_df

for i in range(0, n_objects, BATCH_SIZE):
    batch_num = i // BATCH_SIZE + 1
    output_file = os.path.join(OUTPUT_FOLDER, OUTPUT_CSV_TEMPLATE.format(batch_num=batch_num))
    
    if os.path.exists(output_file):
        print(f"Batch {batch_num} already done (found {output_file}), skipping...")
        continue

    batch_df = df.iloc[i:i+BATCH_SIZE].copy()
    print(f"\nProcessing batch {batch_num} ({i} to {i + len(batch_df) - 1})...")
    filtered_df = query_ps1_batch(batch_df)
    os.makedirs(OUTPUT_FOLDER, exist_ok=True)
    filtered_df.to_csv(output_file, index=False)
    print(f"Saved filtered batch {batch_num} to {output_file}")

print("\nAll batches processed!")

In [None]:
# Calculating color of each object
# Prioritize g-r, then r-i, then g-i, then paste them in each csv file 

INPUT_FOLDER = "sdss_lrg_queried_objects"
CHUNK_SIZE = 10000

def determine_color(row):
    g, r, i = row['ps1_g_mag'], row['ps1_r_mag'], row['ps1_i_mag']
    if pd.notnull(g) and pd.notnull(r):
        row['color_name'] = 'g-r'
        row['color_value'] = g - r
    elif pd.notnull(r) and pd.notnull(i):
        row['color_name'] = 'r-i'
        row['color_value'] = r - i
    elif pd.notnull(g) and pd.notnull(i):
        row['color_name'] = 'g-i'
        row['color_value'] = g - i
    else:
        row['color_name'] = None
        row['color_value'] = None
    return row

for filename in sorted(os.listdir(INPUT_FOLDER)):
    if not filename.endswith(".csv"):
        continue

    input_path = os.path.join(INPUT_FOLDER, filename)
    temp_output_path = os.path.join(INPUT_FOLDER, f"temp_{filename}")
    first_batch = True

    print(f"\nProcessing: {filename}")

    for chunk in pd.read_csv(input_path, chunksize=CHUNK_SIZE):
        # Replace '--' with NaN and convert to float
        for col in ['ps1_g_mag', 'ps1_r_mag', 'ps1_i_mag']:
            if col in chunk.columns:
                chunk[col] = pd.to_numeric(chunk[col], errors='coerce')

        # Apply color computation
        chunk = chunk.apply(determine_color, axis=1)

        # Save to temp file
        chunk.to_csv(
            temp_output_path,
            mode='w' if first_batch else 'a',
            index=False,
            header=first_batch
        )
        first_batch = False
        print(f"  Processed chunk with {len(chunk)} rows")

    # Overwrite original file
    os.replace(temp_output_path, input_path)
    print(f"Updated file saved: {filename}")

print("All CSVs processed with color info added.")

In [None]:
# Downloading PS1 cutouts for each object that has a valid PS1 magnitude 

nest_asyncio.apply()  # For Jupyter

INPUT_FOLDER = "sdss_lrg_queried_objects"
OUTPUT_FOLDER = "cutouts"
FILTERS = ['g', 'r', 'i']
CUTOUT_SIZE = 80
MIN_FITS_BYTES = 10_000
MAX_CONCURRENT_REQUESTS = 50  # Tune this depending on your bandwidth

# Ensure output folders exist
for band in FILTERS:
    os.makedirs(os.path.join(OUTPUT_FOLDER, band), exist_ok=True)

semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)

def is_valid_mag(val):
    return not (pd.isna(val) or (isinstance(val, str) and val.strip() == "--"))

async def fetch_with_retries(session, url, timeout=15, retries=3, backoff=1.5):
    for attempt in range(retries):
        try:
            async with session.get(url, timeout=timeout) as resp:
                if resp.status == 200:
                    return await resp.read()
                else:
                    raise Exception(f"HTTP {resp.status}")
        except Exception as e:
            if attempt < retries - 1:
                await asyncio.sleep(backoff ** attempt)
            else:
                raise e

async def download_band_fits(session, specobjid, ra, dec, band):
    async with semaphore:
        save_path = os.path.join(OUTPUT_FOLDER, band, f"{specobjid}_{band}.fits")
        if os.path.exists(save_path):
            return f"Skipped {specobjid} {band} (already exists)"

        # Step 1: Get full FITS file path from ps1filenames.py
        script_url = f"http://ps1images.stsci.edu/cgi-bin/ps1filenames.py?ra={ra}&dec={dec}&filters={band}"

        try:
            text = (await fetch_with_retries(session, script_url)).decode("utf-8")
            lines = text.strip().splitlines()
            if len(lines) < 2:
                return f"No FITS entry for {specobjid} {band}"

            fits_path = lines[1].split()[7]  # full FITS file path on server

            # Step 2: Build FITS cutout URL with parameters
            base_cutout_url = "https://ps1images.stsci.edu/cgi-bin/fitscut.cgi"
            cutout_url = (
                f"{base_cutout_url}"
                f"?red={fits_path}"
                f"&format=fits"
                f"&x={ra}"
                f"&y={dec}"
                f"&size={CUTOUT_SIZE}"
                f"&wcs=1"
                f"&imagename=cutout_{os.path.basename(fits_path)}"
            )

            content = await fetch_with_retries(session, cutout_url, timeout=30)

            if len(content) < MIN_FITS_BYTES:
                return f"Invalid FITS cutout too small for {specobjid} {band}"

            with fits.open(BytesIO(content)) as hdu:
                fits.writeto(save_path, hdu[0].data, hdu[0].header, overwrite=True)

            return f"Downloaded cutout {specobjid} {band}"

        except Exception as e:
            return f"Error downloading cutout {specobjid} {band}: {e}"

async def process_row(session, specobjid, ra, dec, row):
    tasks = []
    for band in FILTERS:
        mag_col = f"ps1_{band}_mag"
        if mag_col in row and is_valid_mag(row[mag_col]):
            tasks.append(download_band_fits(session, specobjid, ra, dec, band))
    results = await asyncio.gather(*tasks)
    for r in results:
        print(r)

async def process_file(session, filepath, index, total):
    print(f"[{index}/{total}] Processing {filepath}...")
    df = pd.read_csv(filepath)
    for _, row in df.iterrows():
        specobjid = str(int(row['specobjid']))
        ra = row['ra']
        dec = row['dec']
        await process_row(session, specobjid, ra, dec, row)
    print(f"Finished {filepath}")

async def main():
    csv_files = sorted(f for f in os.listdir(INPUT_FOLDER) if f.endswith(".csv"))
    print(f"Found {len(csv_files)} CSV files")

    async with aiohttp.ClientSession() as session:
        for idx, filename in enumerate(csv_files, 1):
            path = os.path.join(INPUT_FOLDER, filename)
            await process_file(session, path, idx, len(csv_files))

# Run in Jupyter:
await main()