In [None]:
import numpy as np
import pandas as pd
import os
import glob
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from astropy.io import fits
from astropy.wcs import WCS
from astropy.nddata import Cutout2D
from astropy.coordinates import SkyCoord
from astropy.time import Time
from astropy.wcs.wcs import FITSFixedWarning 
import astropy.units as u
from astroquery.casda import Casda
from astroquery.utils.tap.core import TapPlus
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from datetime import datetime

import warnings


# Suppress FITS fixed warnings for cleaner output
warnings.filterwarnings('ignore', category=FITSFixedWarning)


# ==============================================================================
# 0. HELPER FUNCTIONS
# ==============================================================================

def create_directory(directory_path):
    """Create directory if it doesn't exist."""
    if not os.path.exists(directory_path):
        try:
            os.makedirs(directory_path)
            print(f"Directory '{directory_path}' created successfully.")
        except OSError as e:
            print(f"Error creating directory '{directory_path}': {e}")


def _download_single_file(url, savedir, casda_obj, file_num, total_files):
    """Helper function to download a single file (used for parallel downloads)."""
    try:
        filelist = casda_obj.download_files([url], savedir=savedir)
        filename = filelist[0] if filelist else 'unknown'
        print(f'  [{file_num}/{total_files}] Downloaded: {os.path.basename(filename)}')
        return (True, filename, None)
    except Exception as e:
        print(f'  [{file_num}/{total_files}] ERROR downloading {url}: {e}')
        return (False, url, str(e))


def _parallel_download_files(url_list, savedir, casda_obj, max_workers=4):
    """Download multiple files in parallel using ThreadPoolExecutor."""
    downloaded_files = []
    failed_downloads = []
    total_files = len(url_list)
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_url = {
            executor.submit(_download_single_file, url, savedir, casda_obj, i+1, total_files): url 
            for i, url in enumerate(url_list)
        }
        
        for future in as_completed(future_to_url):
            url = future_to_url[future]
            try:
                success, filename, error = future.result()
                if success:
                    downloaded_files.append(filename)
                else:
                    failed_downloads.append((url, error))
            except Exception as e:
                print(f'  Unexpected error processing {url}: {e}')
                failed_downloads.append((url, str(e)))
    
    print(f'\nDownload summary: {len(downloaded_files)}/{total_files} successful')
    if failed_downloads:
        print(f'Failed downloads ({len(failed_downloads)}):')
        for url, error in failed_downloads:
            print(f'  - {url}: {error}')
    
    return downloaded_files


def simplify_fits_header(hdu):
    """
    Simplify a FITS header from 4D or 3D to 2D for cutout operations.
    
    Aggressively removes all WCS keywords related to axes > 2 
    to prevent the 'operands could not be broadcast' error during WCS initialization.
    """
    data_shape = hdu.data.shape
    if len(data_shape) == 4:
        # Standard ASKAP continuum: (Stokes, Freq, Dec, RA)
        data = hdu.data[0, 0, :, :]
    elif len(data_shape) == 3:
        # 3D data: (Freq/Stokes, Dec, RA) - take the first plane
        data = hdu.data[0, :, :]
    elif len(data_shape) == 2:
        # Already 2D
        data = hdu.data
    else:
        raise ValueError(f"Unexpected data shape: {data_shape}")
    
    # Modify header: set NAXIS to 2 and remove all higher axis keywords
    header = hdu.header.copy()
    
    # Base removal keys (CRPIX, CTYPE, etc.)
    keys_to_remove = ['NAXIS3', 'NAXIS4', 'CRPIX3', 'CRPIX4', 
                      'CTYPE3', 'CTYPE4', 'CRVAL3', 'CRVAL4',
                      'CDELT3', 'CDELT4', 'CUNIT3', 'CUNIT4', 'PC3_3', 'PC4_4']

    #Remove all PC/CD matrix components referencing axes 3 and 4
    for i in range(1, 5):
        for j in range(1, 5):
            if i > 2 or j > 2:
                # Remove PC/CD terms that involve 3 or 4
                keys_to_remove.append(f'PC{i}_{j}')
                keys_to_remove.append(f'CD{i}_{j}')

    # Apply removal
    for key in keys_to_remove:
        if key in header:
            del header[key]

    # Set primary axis count
    header['NAXIS'] = 2
    if 'WCSAXES' in header:
        header['WCSAXES'] = 2
    
    return data, header


def reconstruct_cutout_metadata(source_name, cutout_dir):
    """
    Scans a directory for existing cutouts and rebuilds the metadata 
    dataframe required for plotting. Used for plot_only mode.
    """
    clean_name = source_name.replace(' ', '').replace('(', '').replace(')', '')
    search_pattern = f"{cutout_dir}{clean_name}_*.fits"
    files = glob.glob(search_pattern)
    
    records = []
    print(f"Found {len(files)} existing cutouts for {source_name}")
    
    for filepath in files:
        try:
            with fits.open(filepath) as hdul:
                header = hdul[0].header
                
                # Extract info needed for sorting and plotting
                obs_date = header.get('OBSDATE', header.get('DATE-OBS', ''))
                mjd = header.get('MJD-OBS', 0)
                
                # If MJD is missing but date is present, calculate it
                if mjd == 0 and obs_date:
                    try:
                        mjd = Time(obs_date).mjd
                    except:
                        mjd = 0
                        
                int_time = header.get('INT_TIME', 0) 

                records.append({
                    'source_name': source_name,
                    'cutout_filename': filepath,
                    'image_filename': os.path.basename(filepath),
                    'obs_date': obs_date,
                    'obs_mjd': mjd,
                    'integration_time': int_time
                })
        except Exception as e:
            print(f"Skipping corrupt file {filepath}: {e}")
            
    return pd.DataFrame(records)


def pre_filter_valid_cutouts(cutout_info):
    """
    Checks each cutout file and removes records where data is all NaN or the file is corrupted.
    This ensures only cutouts with actual data are passed to the plotter.
    """
    valid_records = []
    initial_count = len(cutout_info)
    
    for _, row in cutout_info.iterrows():
        filepath = row['cutout_filename']
        try:
            with fits.open(filepath) as hdul:
                data = hdul[0].data
                # Check if there is AT LEAST ONE non-NaN pixel
                if np.any(~np.isnan(data)):
                    valid_records.append(row)
                else:
                    print(f"  [Skipped] Cutout file is empty or all NaN: {os.path.basename(filepath)}")
        except Exception as e:
            print(f"  [Skipped] Corrupt file during plot preparation: {os.path.basename(filepath)} ({e})")
            
    final_count = len(valid_records)
    if initial_count != final_count:
        print(f"Filtered {initial_count - final_count} invalid cutouts. {final_count} valid cutouts remain.")

    return pd.DataFrame(valid_records)


# ==============================================================================
# 1. CASDA QUERY & DOWNLOAD FUNCTIONS
# ==============================================================================

def query_casda_images(ra, dec, search_radius=10*u.deg, 
                       image_type='cont.restored.t0',
                       stokes='I'):
    """
    Query CASDA for images covering the given coordinates, filtering by suffix.
    """
    tap = TapPlus(url="https://casda.csiro.au/casda_vo_tools/tap")
    stokes_query = f"pol_states = '/{stokes}/'"
    
    print(f"Querying CASDA for Stokes {stokes} images...")
    job = tap.launch_job_async(
        f"SELECT TOP 50000 * FROM ivoa.obscore WHERE "
        f"({stokes_query} AND dataproduct_subtype = '{image_type}')"
    )
    r = job.get_results()
    
    data = r[(r['quality_level'] == 'GOOD') | (r['quality_level'] == 'UNCERTAIN')]
    public_data = Casda.filter_out_unreleased(data).to_pandas()
    public_data = public_data[public_data['obs_id'].str.contains('ASKAP')]
    
    print(f'Total ASKAP images in archive: {len(public_data)}')
    
    # Spatial filter
    source_coord = SkyCoord(ra * u.deg, dec * u.deg)
    image_coords = SkyCoord(
        np.array(public_data['s_ra']) * u.deg,
        np.array(public_data['s_dec']) * u.deg
    )
    
    seps = source_coord.separation(image_coords)
    matches = np.where(seps < search_radius)[0]
    matching_data = public_data.iloc[matches].copy()
    print(f'Images covering source: {len(matching_data)}')
    
    # SUFFIX FILTER
    suffix = ".restored.conv.fits"
    filtered_matches = matching_data[matching_data['filename'].str.endswith(suffix)].copy()
    print(f'Filtered for "{suffix}": {len(filtered_matches)} images')
    
    return filtered_matches


def download_casda_images(image_data, savedir, casda_obj, 
                          parallel=True, max_workers=4, chunk_size=50):
    """
    Download CASDA images with optimized batch staging.
    """
    create_directory(savedir)
    if savedir[-1] != '/': savedir += '/'
    
    files_to_download = []
    for filename in image_data['filename']:
        if not os.path.isfile(f'{savedir}{filename}'):
            files_to_download.append(filename)
    
    if len(files_to_download) == 0:
        print('All files already downloaded.')
        return []
    
    print(f'Files to download: {len(files_to_download)}')
    
    # Stage files in chunks
    print('Staging files (chunked batch mode)...')
    url_list = []
    
    for i in range(0, len(files_to_download), chunk_size):
        chunk = files_to_download[i:i + chunk_size]
        print(f'Staging chunk {i//chunk_size + 1}: {len(chunk)} files')
        
        chunk_data = image_data[image_data['filename'].isin(chunk)]
        # Must pass a numpy record array or equivalent table, not a DataFrame
        chunk_table = chunk_data.to_records(index=False) 
        
        try:
            urls = casda_obj.stage_data(chunk_table)
            for url in urls:
                if 'checksum' not in url and url not in url_list:
                    url_list.append(url)
        except Exception as e:
            print(f'Error staging chunk: {e}')
            # Fallback to individual staging if batch fails
            for filename in chunk:
                try:
                    file_data = image_data[image_data['filename'] == filename]
                    file_table = file_data.to_records(index=False)
                    urls = casda_obj.stage_data(file_table) # Note: returns a list
                    for url in urls:
                        if url not in url_list:
                             url_list.append(url)
                except Exception as e2:
                    print(f'Error staging {filename}: {e2}')
    
    print(f'Staged {len(url_list)} files')
    
    # Download files
    if parallel and len(url_list) > 1:
        print(f'Downloading {len(url_list)} files in parallel (max {max_workers} workers)...')
        downloaded = _parallel_download_files(url_list, savedir, casda_obj, max_workers)
    else:
        print('Downloading files (sequential)...')
        # astroquery.casda.download_files expects a list of staged URLs
        downloaded = casda_obj.download_files(url_list, savedir=savedir)
    
    return downloaded


# ==============================================================================
# 2. CUTOUT GENERATION (WITH WCS & BEAM)
# ==============================================================================

def make_cutouts(sources_df, image_data, 
                 image_dir, cutout_dir,
                 cutout_size=4.0*u.arcmin,
                 ra_col='ra', dec_col='dec',
                 pmra_col='pmra', pmdec_col='pmdec',
                 epoch_col='epoch', name_col='name',
                 apply_proper_motion=True):
    """
    Create cutouts for sources from CASDA images, incorporating all necessary fixes.
    Cutouts consisting of only NaN values are NOT saved or recorded.
    """
    create_directory(cutout_dir)
    
    if image_dir[-1] != '/': image_dir += '/'
    if cutout_dir[-1] != '/': cutout_dir += '/'
    
    cutout_records = []
    
    for idx, img_row in image_data.iterrows():
        filename = img_row['filename']
        filepath = f"{image_dir}{filename}"
        
        if not os.path.isfile(filepath):
            continue
        
        print(f"\nProcessing: {filename}")
        
        t_min_val = img_row['t_min']
        if not np.isfinite(t_min_val):
            print(f"  WARNING: Skipping image {filename} due to invalid t_min (MJD).")
            continue
            
        img_epoch = Time(t_min_val, format='mjd')
        
        # Find sources in field
        img_coord = SkyCoord(img_row['s_ra'] * u.deg, img_row['s_dec'] * u.deg)
        sources_in_field = [src for _, src in sources_df.iterrows() 
                            if SkyCoord(src[ra_col] * u.deg, src[dec_col] * u.deg).separation(img_coord) < 6 * u.deg]
        
        if len(sources_in_field) == 0:
            continue
        
        hdul = None
        try:
            hdul = fits.open(filepath)
            hdu = hdul[0]
            original_header = hdu.header # Save original header reference
            
            data, header = simplify_fits_header(hdu)
            wcs = WCS(header) 
            
            pixel_scale = np.abs(header['CDELT1']) * u.deg
            # Calculate pixel size dynamically
            cutout_pixels = int((cutout_size / pixel_scale).to(u.dimensionless_unscaled).value)
            
            for src in sources_in_field:
                src_name = src[name_col]
                clean_name = src_name.replace(' ', '').replace('(', '').replace(')', '')
                cutout_filename = f"{cutout_dir}{clean_name}_{filename}"
                
                if os.path.isfile(cutout_filename):
                    print(f"    Skipping existing: {clean_name}")
                    continue
                
                # Apply Proper Motion
                ra_val = src[ra_col] # Use raw series access
                dec_val = src[dec_col]
                src_coord = SkyCoord(ra_val * u.deg, dec_val * u.deg) # Default (J2000)
                
                if apply_proper_motion and pmra_col in src and pmdec_col in src:
                    try:
                        # Ensure PM inputs are single scalar floats
                        pmra_val = float(np.nan_to_num(src[pmra_col], nan=0.0))
                        pmdec_val = float(np.nan_to_num(src[pmdec_col], nan=0.0))

                        src_coord = SkyCoord(
                            ra_val * u.deg, dec_val * u.deg,
                            pm_ra_cosdec=pmra_val * u.mas / u.yr,
                            pm_dec=pmdec_val * u.mas / u.yr,
                            frame='icrs', obstime=Time(src[epoch_col]), distance=1.0 * u.pc
                        ).apply_space_motion(img_epoch)
                    except Exception as pm_e:
                        print(f"    WARNING: Proper motion calculation failed ({pm_e}). Using J2000 coordinates.")
                        # src_coord remains the default J2000 assignment from above
                        pass
                    
                try:
                    # Cutout is performed on the 2D data using the 2D WCS object
                    cutout = Cutout2D(data, src_coord, size=cutout_pixels, wcs=wcs)
                    
                    if not np.any(~np.isnan(cutout.data)):
                        print(f"    Skipping empty cutout for {src_name} (All NaNs)")
                        continue

                    # Prepare new header
                    cutout_hdu = fits.PrimaryHDU(data=cutout.data)
                    cutout_hdu.header.update(cutout.wcs.to_header())
                    
                    keys_to_copy = ['BMAJ', 'BMIN', 'BPA', 'BUNIT', 'TELESCOP', 'INSTRUME']
                    for key in keys_to_copy:
                        if key in original_header:
                            cutout_hdu.header[key] = original_header[key]
                    
                    # Add Custom Keys
                    cutout_hdu.header['OBJECT'] = src_name
                    cutout_hdu.header['OBSDATE'] = img_epoch.isot
                    cutout_hdu.header['MJD-OBS'] = img_row['t_min'] # Ensure MJD is in cutout
                    cutout_hdu.header['INT_TIME'] = img_row['t_exptime']
                    
                    cutout_hdu.writeto(cutout_filename, overwrite=True)
                    
                    cutout_records.append({
                        'source_name': src_name,
                        'image_filename': filename,
                        'cutout_filename': cutout_filename,
                        'obs_date': img_epoch.isot,
                        'obs_mjd': img_row['t_min'],
                        'integration_time': img_row['t_exptime']
                    })
                    print(f"    Created cutout: {clean_name}")
                    
                except Exception as e:
                    print(f"    Failed cutout for {src_name}: {e}")
            
        except Exception as e:
            print(f"  Error processing image: {e}")
        finally:
            if hdul: hdul.close() # Ensure file is closed
    
    return pd.DataFrame(cutout_records)


# ==============================================================================
# 3. PLOTTING & WORKFLOW
# ==============================================================================

def plot_cutout_grid(source_name, cutout_info, cutout_dir,
                     save_path=None, close_image=False)
                     target_ra=None, target_dec=None):
    """
    Plot a grid of cutouts. Assumes cutout_info contains only valid records
    (empty/corrupt files should be filtered out before calling this function).
    
    Parameters:
    -----------
    target_ra, target_dec : float, optional
    If provided, adds a crosshair marker at this position
    """
    if cutout_dir[-1] != '/': cutout_dir += '/'
    if save_path is None: save_path = cutout_dir
    elif save_path[-1] != '/': save_path += '/'
    
    # Sort and prepare the final set of valid cutouts
    src_cutouts = cutout_info[cutout_info['source_name'] == source_name].sort_values(by='obs_mjd')
    n_cutouts = len(src_cutouts)
    
    if n_cutouts == 0:
        print(f"No valid cutouts found for {source_name} after filtering.")
        return
    
    # Grid setup
    n_cols = 3
    n_rows = (n_cutouts + n_cols - 1) // n_cols
    figsize = (12, 3.8 * n_rows) # Dynamic height
        
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    
    # Ensure axes is iterable and correctly shaped
    if n_cutouts == 1: axes = np.array([axes])
    if n_rows == 1 and n_cols > 1: axes = axes.reshape(1, -1)
    elif n_cols == 1 and n_rows > 1: axes = axes.reshape(-1, 1)
    
    # Plot each cutout
    for idx, (_, row) in enumerate(src_cutouts.iterrows()):
        ax = axes.flat[idx]
        
        try:
            with fits.open(row['cutout_filename']) as hdul:
                header = hdul[0].header
                data = hdul[0].data
                
                # Get Beam Parameters
                bmaj = header.get('BMAJ', 0.0) # Major axis (deg)
                bmin = header.get('BMIN', 0.0) # Minor axis (deg)
                bpa = header.get('BPA', 0.0)   # Position angle (deg)
                
                # Get Pixel Scale
                pix_scale = np.abs(header.get('CDELT2', header.get('CD2_2', 0.0)))
            
            # Since we pre-filtered, we can proceed directly to plotting
            valid_data = data[~np.isnan(data)]
            
            # Set Vmin/Vmax using 1st and 99th percentile for robust scaling
            vmin, vmax = np.nanpercentile(valid_data, [1, 99])
            
            im = ax.imshow(data, cmap='magma', origin='lower', vmin=vmin, vmax=vmax)
            cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            cbar.set_label(header.get('BUNIT', 'Jy/beam'), fontsize=8)
            cbar.ax.tick_params(labelsize=7)
            
            # Configure RA/Dec axes
            ax.coords[0].set_axislabel('RA (J2000)', fontsize=8)
            ax.coords[1].set_axislabel('Dec (J2000)', fontsize=8)
            ax.coords[0].set_ticklabel(size=7)
            ax.coords[1].set_ticklabel(size=7)
            ax.coords[0].set_major_formatter('hh:mm:ss.s')
            ax.coords[1].set_major_formatter('dd:mm:ss')
            
            # Add crosshair marker at target position if provided
           if target_ra is not None and target_dec is not None:
                target_coord = SkyCoord(target_ra * u.deg, target_dec * u.deg)
                ax.plot_coord(target_coord, 'w+', markersize=12, markeredgewidth=2, markeredgecolor='gray', 
                            label='Target', zorder=15)
            
            if bmaj > 0 and pix_scale > 0:
                major_pix = bmaj / pix_scale
                minor_pix = bmin / pix_scale
                
                # Position: 10% from bottom-left corner
                beam_x = data.shape[1] * 0.1
                beam_y = data.shape[0] * 0.1
                
                # Ellipse (angle is 90 + BPA for typical radio display where North is up)
                beam_patch = patches.Ellipse(
                    (beam_x, beam_y), 
                    width=minor_pix, 
                    height=major_pix, 
                    angle=90 + bpa, 
                    edgecolor='cyan',    # High contrast color
                    facecolor='none',
                    hatch='///',         
                    linewidth=1.5,
                    zorder=10             
                )
                ax.add_patch(beam_patch)
                
                # Set axis limits to match the data extent (removes blank space)
                ax.set_xlim(-0.5, data.shape[1] - 0.5)
                ax.set_ylim(-0.5, data.shape[0] - 0.5)
                
                # Add marker at center
                center_x = data.shape[1] // 2
                center_y = data.shape[0] // 2
                rect = patches.Rectangle((center_x-major_pix/2, center_y-major_pix/2), 1*major_pix, 1*major_pix,
                                         linewidth=1, edgecolor='white', facecolor='none')
                ax.add_patch(rect)

            # Title (Date | SBID)
            obs_date = str(row['obs_date']).split('T')[0]
            int_time = round(row['integration_time'] / 60, 1)                
            sbid_match = re.search(r'(SB\d+)', row['image_filename'])
            sbid_str = sbid_match.group(1) if sbid_match else ''
            ax.set_title(f"{obs_date} | {sbid_str}\n{int_time} min", fontsize=10)
            
            ax.set_xticks([])
            ax.set_yticks([])
            
        except Exception as e:
            # This catch should only happen if a file is genuinely corrupt
            ax.text(0.5, 0.5, 'Error', ha='center', va='center', transform=ax.transAxes)
            print(f"Plot Error for {row['image_filename']}: {e}")
            
    # Clear empty plots at the end of the grid
    for idx in range(n_cutouts, n_rows * n_cols):
        fig.delaxes(axes.flat[idx])
    
    fig.suptitle(source_name, fontsize=14, y=0.99)
    fig.tight_layout()
    
    clean_name = source_name.replace(' ', '').replace('(', '').replace(')', '')
    save_file = f"{save_path}{clean_name}_cutouts.png"
    fig.savefig(save_file, dpi=150, bbox_inches='tight')
    print(f"Saved figure: {save_file}")
    
    if close_image:
        plt.close(fig)

def process_single_source(name, ra, dec, 
                          image_dir, cutout_dir,
                          casda_obj=None,
                          pmra=0.0, pmdec=0.0, epoch='J2000',
                          stokes='I',
                          search_radius=10*u.deg,
                          cutout_size=4.0*u.arcmin,
                          parallel=True, max_workers=4,
                          plot_only=False):
    """
    Complete workflow for a single source.
    Set plot_only=True to skip CASDA query/download and just plot existing files.
    """
    print(f"\n{'='*60}")
    print(f"Processing: {name}")
    if plot_only:
        print("(PLOT ONLY MODE - Skipping CASDA Query & Download)")
    else:
        print(f"Coordinates: RA={ra:.4f}°, Dec={dec:.4f}°")
    print(f"{'='*60}\n")
    
    if plot_only:
        print("Step 1/3: Reconstructing metadata from existing files...")
        cutout_info = reconstruct_cutout_metadata(name, cutout_dir)
            
    else:
        sources_df = pd.DataFrame({
            'name': [name], 'ra': [ra], 'dec': [dec], 
            'pmra': [pmra], 'pmdec': [pmdec], 'epoch': [epoch]
        })
        
        # Query CASDA
        print("Step 1/4: Querying CASDA for images...")
        if casda_obj is None:
            raise ValueError("casda_obj is required when plot_only=False")

        image_data = query_casda_images(ra, dec, 
                                        search_radius=search_radius,
                                        stokes=stokes)
        
        if len(image_data) == 0:
            print(f"No images found for {name}")
            return None
        
        # Download images
        print("\nStep 2/4: Downloading images...")
        download_casda_images(image_data, image_dir, casda_obj,
                            parallel=parallel, max_workers=max_workers)
        
        # Create cutouts
        print("\nStep 3/4: Creating cutouts...")
        cutout_info = make_cutouts(sources_df, image_data, 
                                image_dir, cutout_dir,
                                cutout_size=cutout_size,
                                apply_proper_motion=(pmra != 0.0 or pmdec != 0.0))
    
    if cutout_info is None or len(cutout_info) == 0:
        print(f"No cutouts available for {name}")
        return None

    # Filter out any files that were created in a previous run but contain no data
    print("\nStep 3/4 (or 2/3): Filtering out empty cutouts...")
    cutout_info = pre_filter_valid_cutouts(cutout_info)
    
    if len(cutout_info) == 0:
        print(f"No valid cutouts remaining for {name} after filtering.")
        return None
    
    # Plot cutouts
    print("\nStep 4/4 (or 3/3): Creating visualization...")
    plot_cutout_grid(name, cutout_info, cutout_dir)
    
    print(f"\n{'='*60}")
    print(f"✓ Complete! Displaying {len(cutout_info)} valid cutouts for {name}")
    print(f"{'='*60}\n")
    
    return cutout_info

In [None]:
# from casda_cutout_workflow import process_single_source, Casda, u

casda = Casda()
casda.login(username='<INSERT YOUR OPAL EMAIL ADDRESS HERE>')

# --- Full Run (Query, Download, Cutout, Plot) ---
# This will overwrite existing cutouts and ensure headers are correct.

start = datetime.now()
run_dt_stamp = start.strftime("%Y%m%d_%H%M")
name = "SN 2012dy"
ra, dec = 319.71125, -57.64514
cutout_info = process_single_source(
    name=name,
    ra=ra,
    dec=dec,
    image_dir=f'./CASDA_Cutouts/{name}/images/',
    cutout_dir=f'./CASDA_Cutouts/{name}/cutouts/',
    casda_obj=casda,
    plot_only=False,  # <--- MUST BE FALSE TO GENERATE NEW CUTOUTS
    cutout_size=1.0*u.arcmin,
    search_radius=3.0*u.deg,
    max_workers=8
)
end = datetime.now()
elapsed = end - start
print(f"{name} CASDA executed in seconds: {elapsed.total_seconds()}")

In [None]:

# --- Plot Only Run (If files already exist) ---
# If you just want to regenerate the figure quickly:
name = "SN 2012dy"
ra, dec = 319.71125, -57.64514
cutout_info = process_single_source(
    name=name,
    ra=ra,
    dec=dec,
    image_dir=f'./CASDA_Cutouts/{name}/images/',
    cutout_dir=f'./CASDA_Cutouts/{name}/cutouts/',
    plot_only=True
)