# Initialization

In [1]:
import common
import numpy as np
import torch
from rich import print
import pickle
import plotly.graph_objects as go
from IPython.display import display, clear_output
import random
from torch.utils.data import Dataset
import astropy.units as u
from astroquery.ipac.irsa import Irsa
import pandas as pd
import scipy as sp
import sklearn
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
import time
import os
import sys

ROOT = os.path.join("./")
sys.path.append(ROOT + "lib")

from plotly_style import update_layout
from flux_table import get_flux, LightSource
from fluxtable_to_tensor import fluxtable_to_tensor


torch.set_default_dtype(torch.float64)

neowise = 'neowiser_p1bs_psd'
min_qual_frame = 4
min_obj_qual = 3

test_cutoff = 0.9

QUERY = True # Toggle querying - Only do this if youve changed the data.csvs

# Spreadsheets to Data Tables


In [2]:
if QUERY:
  raw_classes = {
      "null": {

      },
      "nova": {

      },
      "pulsating_var": {

      },
      "transit": {

      }
  }

  for roots, dirs, files in os.walk(ROOT + "object_spreadsheets/"):
    for spreadsheet in files:
      kind = spreadsheet[:-4] # remove .csv
      bucket = raw_classes[kind]

      df = pd.read_csv(ROOT + "object_spreadsheets/" + spreadsheet)
      df.dropna(axis=0, how="any", inplace=True)

      rows = len(df)


      i = 0
      for objname, ra, dec, rad, qual in zip(df["Name"], df["RAJ2000"], df["DecJ2000"], df["query_rad"], df["qual"]):
        i += 1
        radius = float(rad)
        qual = int(qual)
        if qual < min_obj_qual:
          continue

        tbl = Irsa.query_region("{} {}".format(ra, dec), catalog=neowise, spatial="Cone",
                            radius=radius * u.arcsec)

        tbl = tbl.to_pandas()
        tbl = tbl.loc[tbl["qual_frame"]>= min_qual_frame]


        bucket[objname] = {}
        bucket[objname]["df"] = tbl
        bucket[objname]["np"] = tbl.to_numpy()
        bucket[objname]["cluster"] = {}
        bucket[objname]["flux"] = {}

        print("{}%".format(i*100 // rows), end="\r")
        print(kind + ": ", len(df.loc[df["qual"] >= min_obj_qual]))
        clear_output(wait=True)



In [3]:
plotting = False

if QUERY:
  eps = {
      "null": 1.25/3600,
      "nova": 2/3600,
      "transit": 1.5/3600,
      "pulsating_var": 1.5/3600
  }

  min_pts = {
      "null": 7,
      "nova": 4,
      "transit": 5,
      "pulsating_var": 5
  }

  min_cluster_pts = 20

  for kind in raw_classes:
    for obj in raw_classes[kind]:
      np_tbl = raw_classes[kind][obj]["np"]

      cluster_tbl = np_tbl[:, :2]

      clstr = DBSCAN(eps=eps[kind], min_samples=min_pts[kind]).fit(cluster_tbl) # high minsamples cutting out some very sparse examples

      labels = clstr.labels_


      raw_classes[kind][obj]["cluster"] = {}
      noisepct = 0
      for i in range(len(labels)): # For all the pts
        if labels[i] != -1:
          if labels[i] not in raw_classes[kind][obj]["cluster"]:
            raw_classes[kind][obj]["cluster"][labels[i]] = []

          raw_classes[kind][obj]["cluster"][labels[i]].append(np_tbl[i]) # Put it into the cluster dict with its label as key


      for cluster_label in raw_classes[kind][obj]["cluster"].copy(): #  ------- TODO: Check that non-reigon query tables are only returning ONE CLUSTER. --------
        raw_classes[kind][obj]["cluster"][cluster_label] = np.array(raw_classes[kind][obj]["cluster"][cluster_label])

        if len(raw_classes[kind][obj]["cluster"][cluster_label]) < min_cluster_pts:
          del raw_classes[kind][obj]["cluster"][cluster_label]


In [21]:
if QUERY:
  for kind in raw_classes:
    for objname in raw_classes[kind]:
      for cluster_label in raw_classes[kind][objname]["cluster"]:
        flux_table = get_flux(raw_classes[kind][objname]["cluster"][cluster_label], raw_classes[kind][objname]["df"].columns.tolist())
        raw_classes[kind][objname]["flux"][cluster_label] = flux_table


  tbl[tbl == 'null'] = 0


In [5]:
buckets = {"null": [], "nova": [], "pulsating_var": [], "transit": []}
buckets_test = {"null": [], "nova": [], "pulsating_var": [], "transit": []}

if not QUERY:
  with open(ROOT + "cached_data/train_buckets.pkl", "rb") as f: # Save buckets
    buckets = pickle.load(f)
else:
  for kind in buckets_test:
    i = 0
    for objname in raw_classes[kind]:
      if i < min(test_cutoff * len(raw_classes[kind]) - 1, len(raw_classes[kind]) - 1):
        for idx in raw_classes[kind][objname]["flux"]:
          buckets[kind].append(raw_classes[kind][objname]["flux"][idx]) # Populate buckets
          i+=1
      else:
        buckets_test[kind].append(raw_classes[kind][objname]["df"])
    
  print("Examples removed for testing:   Null: ", len(buckets_test["null"]), " Nova: ", len(buckets_test["nova"]), " Pulsating Var: ", len(buckets_test["pulsating_var"]), " Transit: ", len(buckets_test["transit"]))
  with open(ROOT + "cached_data/train_buckets.pkl", "wb") as f: # Save buckets
    pickle.dump(buckets, f)

  with open(ROOT + "cached_data/test_buckets.pkl", "wb") as f: # Save buckets
    pickle.dump(buckets_test, f)

In [14]:
amt_train = 0.7

class_weights = np.array([len(buckets["null"]), len(buckets["nova"]), len(buckets["pulsating_var"]), len(buckets["transit"])])


print("Total examples of each class:", class_weights)
total = np.sum(class_weights)
class_weights = total - class_weights
class_weights = class_weights / (3 * total)
print("Weights:", class_weights)

class_weights_ = torch.tensor(class_weights)


buckets_train = {}
buckets_valid = {}

for name, data_dict_list in buckets.items(): # Split training items
    total_examples = len(data_dict_list)
    train_end = int(total_examples * amt_train)

    buckets_train[name] = data_dict_list[:train_end]
    buckets_valid[name] = data_dict_list[train_end:]

# Dataset and Augmentation

In [17]:
from copy import copy

def swap(pct):
  def inner(data):
    new = torch.clone(data)
    n = int(pct*len(new)) // 2
    for _ in range(n):
      one, two = random.sample(range(len(new)), 2)

      temp = torch.clone(new[one])

      new[one] = new[two]
      new[two] = temp
    return new
  return inner


def flip_x(data):
  return torch.flip(data, (0,))

def flip_y(data):
  new = torch.clone(data)
  new[:, 0] = -1 * new[:, 0]
  new[:, 2] = -1 * new[:, 2]
  return new

def resample(std, pct=0.25):
  def inner(data):
    new = torch.clone(data)
    stddev = [0 for _ in range(len(new))]

    for i in random.sample(range(len(stddev)), int(pct*len(stddev))):
      stddev[i] = std

    stddev = torch.tensor(stddev)

    new[:, 0] = torch.normal(new[:, 0], stddev)
    new[:, 2] = torch.normal(new[:, 2], stddev)
    return new

  return inner

def rescale_x(data):
  s = random.random() *1.5

  new = torch.clone(data)

  new[:, -1] = s * new[:, -1]

  return new

def rescale_y(data):
  s = random.random() * 1.5

  new = torch.clone(data)

  new[:, 0] = s * new[:, 0]
  new[:, 2] = s * new[:, 2]

  return new


aug = [(rescale_x, 1), (rescale_y, 1), (resample(0.65, 0.35), 1), (flip_x, 1), (swap(0.1), 1), (flip_y, 1), (swap(0.2), 1), (resample(1, 0.25), 1)]

class FluxSet(Dataset):
  def __init__(self, classes, augmentation=None, choose=3, augmentation_frac=3, equalize=False):

    self.choose = choose
    self.augmentation_frac = augmentation_frac

    self.buckets = {"null": [], "nova": [], "pulsating_var": [], "transit": []}
    self.class_weights = class_weights_
    for kind in classes:
      flux_dicts = classes[kind]
      for flux in flux_dicts: # each flux dict


        timeseries_array = self.to_np(flux)

        self.buckets[kind].append(torch.tensor(timeseries_array))

    self.all = []

    if augmentation:
      for kind in self.buckets:
        new = []
        new += self.buckets[kind]
        for ex in self.buckets[kind]:
          for _ in range(self.augmentation_frac):
            new += self.apply_pipeline([ex], random.sample(augmentation, choose))

        self.buckets[kind] = new

      # Equalization
      if equalize:
        lens = [len(self.buckets[b]) for b in self.buckets]
        makeup = [max(lens) - v for v in lens]

        for i, count in enumerate(makeup):
          key = list(self.buckets.keys())[i]
          for ex in self.buckets[key]:
            if count <= 0:
              break
            self.buckets[key] += self.apply_pipeline([ex], random.sample(augmentation, choose))
            count -= 1
        # print(lens, [len(self.buckets[b]) for b in self.buckets])





    for kind in self.buckets: # Data, Label pairing
      for i, ex in enumerate(self.buckets[kind]):

        label = torch.zeros(4)
        label[list(self.buckets.keys()).index(kind)] = 1

        self.all.append((ex, label))
        self.buckets[kind][i] = ((ex, label))

  def apply_pipeline(self, examples, pipeline):
    p = copy(pipeline)

    if len(p) == 0:
      return examples

    new = []
    fn, times = p.pop(0)
    for ex in examples:
      for _ in range(times):
        new.append(fn(ex))
      del ex

    return self.apply_pipeline(new, p)

  def to_np(self, flux): # IMPORTANT! Defines order of data
      # Len(pts) x 5 matrix
      w1 = flux["norm"]["w1"]
      w1sig = flux["norm"]["w1sig"]
      w2 = flux["norm"]["w2"]
      w2sig = flux["norm"]["w2sig"]
      dt = flux["norm"]["dt"]
      day = flux["norm"]["day"]

      w1f = flux["norm"]["w1flux"]

      std_val = (flux["norm"]["w1std"] + flux["norm"]["w2std"]) / 2

      std = np.array([std_val for _ in w1])

      # Len(pts) x 5 matrix
      # IMPORTANT! Defines order of data
      return np.stack((w1f, std, day), axis=0).T




  def __getitem__(self, idx):
    # Commented out is random sampling
    # key = random.choice(list(self.buckets.keys()))
    # item = random.choice(self.buckets[key])
    # return item
    return self.all[idx]

  def __len__(self):
    return len(self.all)

In [18]:
train = FluxSet(buckets_train, aug, equalize=True)
valid = FluxSet(buckets_valid)

with open(ROOT + "processed_datasets/data_train.pt", "wb") as f:
  torch.save(train, f)
with open(ROOT + "processed_datasets/data_valid.pt", "wb") as f:
  torch.save(valid, f)

len(train)

train[0][0].dtype

torch.float64

In [20]:
has_nan = False  # Initialize a flag to indicate the presence of NaN values

for sample in train:
    # Check if the sample contains NaN values
    if torch.isnan(sample[0]).any():
        has_nan = True
        break  # Exit the loop as soon as a NaN is found

if has_nan:
    print("The dataset contains NaN values.")
else:
    print("The dataset does not contain NaN values.")

# Toy Dataset

In [None]:
def N(u, var):
  return np.random.normal(u, np.sqrt(var))

def sample(func, var, minx, maxx, pts, sparsity=0.5):
  rng = np.random.default_rng()
  res = (maxx-minx) / pts

  all_x = []
  for i in range(pts):
    if rng.random() <= sparsity:
      continue

    all_x.append(minx + i*res)



  all_y = [func(x) for x in all_x]

  valrange = max(all_y) - min(all_y)
  all_y = [N(y, var*valrange) for y in all_y]

  all_y = sp.stats.zscore(all_y)
  return (all_x, all_y)

def pos_encoding(unsorted):
  unsorted = np.array(unsorted)
  sorted = unsorted[unsorted[:, -1].argsort()] # sort ascending

  times = [0] + list(sorted[:, -1])
  pos = np.array([times[i] - times[i-1] for i in range(1, len(times))])

  sorted[:, -1] = pos # change abs time to positional encoding
  return torch.tensor(sorted)

In [None]:
def n_examples(fn_gen, var, minx, maxx, pts, sparsity, N):
  ex = []

  fn = fn_gen()

  for i in range(N):
    x, y = sample(fn, var, minx, maxx, pts, sparsity)
    ts = []
    for j in range(len(x)):
      xpt = x[j]
      ypt = y[j]
      ts.append(np.array([ypt - np.random.rand()*var*10, np.random.rand()*var*10, ypt + np.random.rand()*var*10, np.random.rand()*var*10, xpt/np.max(x), xpt]))

    ts = pos_encoding(ts)

    ts = torch.tensor(ts)

    ex.append(ts)

  return ex

class PseudoSet(Dataset):
  def __init__(self, classes):
    self.classes = classes

    self.all = []
    for i, class_ in enumerate(self.classes):
      label = torch.zeros(len(self.classes))
      label[i] = 1
      for ex in class_:
        self.all.append((ex, label))


  def __getitem__(self, idx):
    return self.all[idx]

  def __len__(self):
    return len(self.all)

In [None]:
num_ex = 100
var = 0.005

s, u = np.random.rand()*75, np.random.rand()*310
norm = lambda s, u, x: (1 / np.sqrt(s**2 * 2 * 3.141)) * np.exp(-0.5 * ((x - u) / s)**2)

def linefn_gen():
  m = np.random.rand()*100
  return lambda x: m * x

def logfn_gen():
  a = np.random.rand()*2
  b = np.random.rand()*75
  return lambda x: a * np.log(b * x)

def bellfn_gen():
  s = np.random.rand()*12
  u = (np.random.rand() - 0.5) * 6
  norm = lambda s, u, x: (1 / np.sqrt(s**2 * 2 * 3.141)) * np.exp(-0.5 * ((x - u) / s)**2)
  return lambda x: norm(s, u, x)


lines = n_examples(linefn_gen, 5*var, 0, 1000, 267, 0.65, num_ex)
logs = n_examples(logfn_gen, var, 0.1, 4, 267, 0.65, num_ex)
bells = n_examples(bellfn_gen, 0.01*var, -10, 10, 267, 0.65, num_ex)

def displayex(ex):
  fig = go.Figure()
  tr1 = go.Scatter(
      x=ex[:, -2],
      y=ex[:, 0],
      mode='markers',
      marker=dict(size=5, opacity=.75)
  )
  tr2 = go.Scatter(
      x=ex[:, -2],
      y=ex[:, 2],
      mode='markers',
      marker=dict(size=5, opacity=.75)
  )

  fig.add_trace(tr1)
  fig.add_trace(tr2)

  fig.show()

displayex(lines[15])
displayex(logs[15])
displayex(bells[15])


trainsplit = 0.7

splitidx = int(0.7 * num_ex)

train = PseudoSet((lines[:splitidx], logs[:splitidx], bells[:splitidx]))
valid = PseudoSet((lines[splitidx:], logs[splitidx:], bells[splitidx:]))

print(train[0][0].shape)

with open(ROOT + "datasets/toy_data_train.pt", "wb") as f:
  torch.save(train, f)
with open(ROOT + "datasets/toy_data_valid.pt", "wb") as f:
  torch.save(valid, f)



# Data Vis

In [None]:
ex = buckets['transit'][7] # One Example

l = np.argmin([len(x["raw"]["w1"]) for x in buckets['null']])
print(l)

def get_trace(data_dict, key):

  x_data = data_dict["day"]
  y_data = data_dict[key]

  sort_idxs = np.argsort(x_data)

  x_data = x_data[sort_idxs]
  y_data = y_data[sort_idxs]


  return go.Scatter(
    x=x_data,
    y=y_data,
    mode='markers',
    marker=dict(size=5, opacity=.75),
    name=key,
  )

data_dict = ex["norm"]


fig = go.Figure()

fig.add_trace(get_trace(data_dict, "w1flux"))



update_layout(fig, legend_out=True)

fig.layout.width = 1200
fig.layout.height = 0.65 * fig.layout.width


fig.show()


fig = go.Figure()

fig.add_trace(get_trace(data_dict, "w1"))


update_layout(fig, legend_out=True)

fig.layout.width = 1200
fig.layout.height = 0.65 * fig.layout.width


fig.show()

In [None]:
def plot_from_datadict(data_dict, use_norm=False):
  key = "norm" if use_norm else "raw"

  fig = go.Figure()

  fig.add_trace(get_trace(data_dict[key], "w1"))
  fig.add_trace(get_trace(data_dict[key], "w2"))
  fig.add_trace(get_trace(data_dict[key], "w1sig"))
  fig.add_trace(get_trace(data_dict[key], "w2sig"))


  update_layout(fig, legend_out=True)

  fig.layout.width = 800
  fig.layout.height = 0.65 * fig.layout.width


  return fig

def plot_from_tensor(data):
  fig = go.Figure()

  w1 = data[:, 0].numpy()
  std = data[:, 1].numpy()


  day = data[:, -1].numpy()


  fig.add_trace(go.Scatter(x=day, y=w1, marker=dict(size=5, opacity=0.7), name="w1mpro z-scored", mode='markers'))
  fig.add_trace(go.Scatter(x=day, y=std, marker=dict(size=5, opacity=0.7), name="w2mpro z-scored", mode='markers'))

  update_layout(fig, legend_out=True)

  fig.layout.width = 800
  fig.layout.height = 0.65 * fig.layout.width


  return fig

# Show all novae



In [None]:
d, l = list(train)[500]
plot_from_tensor(d).show()
print(l)

In [None]:
$$x$$