# Initialization

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
from rich import print
import pickle
import plotly.graph_objects as go
from IPython.display import clear_output
import astropy.units as u
from astroquery.ipac.irsa import Irsa
import pandas as pd
from sklearn.cluster import DBSCAN
import os
import sys

ROOT = os.path.join("./")
sys.path.append(ROOT + "lib")

from lightsource import LightSource
from sourceset import SourceSet, default_aug
from helpers import get_coordinates


torch.set_default_dtype(torch.float64)

neowise = 'neowiser_p1bs_psd'
min_qual_frame = 4
min_obj_qual = 3

test_cutoff = 0.9

QUERY = False # Toggle querying - Only do this if youve changed the data.csvs
RECLUSTER = False # Toggle re-clustering

# Spreadsheets to Data Tables


In [2]:
if QUERY:
  raw_classes = {
      "null": {

      },
      "nova": {

      },
      "pulsating_var": {

      },
      "transit": {

      }
  }

  for roots, dirs, files in os.walk(ROOT + "object_spreadsheets/"):
    for spreadsheet in files:
      kind = spreadsheet[:-4] # remove .csv
      bucket = raw_classes[kind]

      df = pd.read_csv(ROOT + "object_spreadsheets/" + spreadsheet)
      df.dropna(axis=0, how="any", inplace=True)

      i = 0
      for objname, ra, dec, rad, qual in zip(df["Name"], df["RAJ2000"], df["DecJ2000"], df["query_rad"], df["qual"]):
        i += 1
        radius = float(rad)
        qual = int(qual)
        if qual < min_obj_qual:
          continue

        print("Querying {}...".format(objname))

        coordstring = "{} {}".format(ra, dec)
        c = get_coordinates(coordstring) # Safer

        tbl = Irsa.query_region(c, catalog=neowise, spatial="Cone",
                            radius=radius * u.arcsec)

        tbl = tbl.to_pandas()
        tbl = tbl.loc[tbl["qual_frame"]>= min_qual_frame]

        bucket[objname] = LightSource(tbl)


        clear_output(wait=True)
        print("{}%".format(i*100 // len(df)), end="\r")
        print(kind + ": ", len(df.loc[df["qual"] >= min_obj_qual]))

else:
  raw_classes = pickle.load(open(ROOT + "cached_data/raw_classes.pkl", "rb"))

  #Print sources and their number of detections
  # for kind in raw_classes:
  #   for objname in raw_classes[kind]:
  #       obj = raw_classes[kind][objname]
  #       l = len(obj.get_numpy())
  #       print("{} source: {} - {}dets".format(kind, objname, l))
  


In [3]:
if RECLUSTER:
  eps = {
      "null": 0.75/3600,
      "nova": 0.75/3600,
      "transit": 1/3600,
      "pulsating_var": 1/3600
  }

  min_pts = {
      "null": 15,
      "nova": 8,
      "transit": 12,
      "pulsating_var": 12
  }

  for kind in raw_classes:
    for obj in raw_classes[kind]:
      np_tbl = raw_classes[kind][obj].get_numpy()
      df = raw_classes[kind][obj].get_pandas()

      cluster_tbl = np_tbl[:, :2]
      clstr = DBSCAN(eps=eps[kind], min_samples=min_pts[kind]).fit(cluster_tbl) # high minsamples cutting out some very sparse examples

      labels = clstr.labels_
      cluster_sizes = np.bincount(labels[labels!=-1])
      biggest_cluster = np.argmax(cluster_sizes)
      num_removed = len(np_tbl) - cluster_sizes[biggest_cluster]
      if num_removed / biggest_cluster >= 0.2:
        print("Warning: {}'s biggest cluster is sparse: {} to {} pts after clustering".format(obj, len(np_tbl), biggest_cluster))

      filter_mask = [x == biggest for x in labels]
      df = df[filter_mask]
      raw_classes[kind][obj] = LightSource(df)

  with open(ROOT + "cached_data/raw_classes.pkl", "wb") as f:
    pickle.dump(raw_classes, f)


In [4]:
buckets = {"null": [], "nova": [], "pulsating_var": [], "transit": []}
buckets_test = {"null": [], "nova": [], "pulsating_var": [], "transit": []}
for kind in buckets_test:
  i = 0
  for objname in raw_classes[kind]:
    if i < test_cutoff * len(raw_classes[kind]) - 1: # cutoff to test set
      buckets[kind].append(raw_classes[kind][objname])
    else:
      buckets_test[kind].append(raw_classes[kind][objname])
    i += 1
  
print("Examples removed for testing:   Null: ", len(buckets_test["null"]), " Nova: ", len(buckets_test["nova"]), " Pulsating Var: ", len(buckets_test["pulsating_var"]), " Transit: ", len(buckets_test["transit"]))
with open(ROOT + "cached_data/train_buckets.pkl", "wb") as f: # Save buckets
  pickle.dump(buckets, f)

with open(ROOT + "cached_data/test_buckets.pkl", "wb") as f: # Save buckets
  pickle.dump(buckets_test, f)

In [5]:
amt_train = 0.7

num_of_examples = np.array([len(buckets["null"]), len(buckets["nova"]), len(buckets["pulsating_var"]), len(buckets["transit"])])
class_weights = num_of_examples / np.sum(num_of_examples)

print("Total examples of each class:", num_of_examples)
print("Weights:", class_weights)

class_weights_ = torch.tensor(class_weights)


buckets_train = {}
buckets_valid = {}

for name, data_dict_list in buckets.items(): # Split training items
    total_examples = len(data_dict_list)
    train_end = int(total_examples * amt_train)

    buckets_train[name] = data_dict_list[:train_end]
    buckets_valid[name] = data_dict_list[train_end:]

# Dataset and Augmentation

In [6]:
train = SourceSet(buckets_train, default_aug, equalize=True)
valid = SourceSet(buckets_valid)

with open(ROOT + "processed_datasets/data_train.pt", "wb") as f:
  torch.save(train, f)
with open(ROOT + "processed_datasets/data_valid.pt", "wb") as f:
  torch.save(valid, f)

len(train), len(valid)

Length =  260
Length =  265
Length =  263
Length =  249
Length =  272
Length =  276
Length =  226
Length =  217
Length =  158
Length =  442
Length =  447
Length =  404
Length =  316
Length =  241
Length =  246
Length =  251
Length =  1853
newlen  999
Length =  762
Length =  31525
newlen  999
Length =  31700
newlen  999
Length =  31582
newlen  999
Length =  31908
newlen  999
Length =  31917
newlen  999
Length =  32011
newlen  999
Length =  31839
newlen  999
Length =  107
Length =  71
Length =  207
Length =  34
Length =  39
Length =  159
Length =  141
Length =  392
Length =  28
Length =  165
Length =  20
Length =  49
Length =  38
Length =  146
Length =  84
Length =  46
Length =  148
Length =  177
Length =  31
Length =  33
Length =  31
Length =  494
Length =  335
Length =  298
Length =  238
Length =  246
Length =  228
Length =  228
Length =  248
Length =  217
Length =  264
Length =  250
Length =  271
Length =  268
Length =  307
Length =  286
Length =  3932
newlen  999
Length =  318
Length

(480, 48)

In [7]:
has_nan = False  # Initialize a flag to indicate the presence of NaN values

for sample in train:
    # Check if the sample contains NaN values
    if torch.isnan(sample[0]).any():
        has_nan = True
        break  # Exit the loop as soon as a NaN is found

if has_nan:
    print("The dataset contains NaN values.")
else:
    print("The dataset does not contain NaN values.")