# Apply model to sources

This notebook ...

# Input

In [1]:
import pandas as pd
import numpy as np
from ENID.acquisition import *
from ENID.interpolation import *
from alerce.core import Alerce
from time import gmtime, strftime
import matplotlib.pyplot as plt
import torch.nn as nn
import george

In [2]:
datasources = pd.read_csv('sourceone.csv',header=None) # Give your sources in a .csv file with the ZTF-names. An example with blue continuum sources is attached

In [3]:
# Model wanted
ModelWeights = 'Models/Simple GRU/Weights'

# Name of the saved files
Version = strftime("%Y%m%d%H%M%S",gmtime()) #Default used UTC time and date.

# Number of points before
timediff = 14
noofpoint = 2

# Plotting?
plotbool = 'n'

# Define model

In [4]:
class simple_ENID(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers):
        super(simple_ENID, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.GRU = nn.GRU(input_dim, hidden_dim, n_layers, batch_first=True, bidirectional=False) 
        self.Dense = nn.Linear(hidden_dim, output_dim)
        self.predict = nn.Softmax(dim=1)
        
    def forward(self, x):
        
        x_, h_ = self.GRU(x)
        x_ = self.Dense(x_[:,-1])
        x_ = self.predict(x_)
        
        return x_

# Apply model

In [5]:
##### Retrieve light curves part

# Get ready by making it a list anc checking that they are not floats
ztf_raw = list(datasources[0])
ztf_names = [x for x in ztf_raw if type(x) != float]

# Find the sources on ALeRCE
print('Number of sources :', len(ztf_raw), '\n')
alerce_found, alerce_missing = source_search_alerce(ztf_names)

# Tell how many were found
print('Found ', alerce_found.shape[0], 'objects')
print('Missing ', len(alerce_missing), 'objects')

if len(alerce_missing) > 0:
    for mis in alerce_missing:
        print('Missing ', alerce_missing['oid'][mis])

# Get the object id for each source
sources = list(alerce_found['oid'])

# Retrieve the lightcurves for the sources
object_dictionary = {"Name": [], "Data": [], "Label": []}

for i in range(len(sources)):
    
    
    lightcurve = lc_compile(sources[i])
    if len(lightcurve['R_mag']) > 1 and len(lightcurve['G_mag']) > 1:
        object_dictionary['Data'].append(lightcurve)
        object_dictionary['Name'].append(ztf_names[i])
        object_dictionary['Label'].append(['SN II', 1])
    else:
        print('Warning : Not enought detections. Ignoring entry.')
        

    
    
save_file = open("pickle_lightcurves.pickle", mode='wb')
pickle.dump(object_dictionary, save_file)
save_file.close()    
    

Number of sources : 1 

Found  1 objects
Missing  0 objects

Importing lightcurve and metadata for  [1mZTF19abzwbxy[0m


In [None]:
##### Adding in non-detections & interpolating the light curves
#processed_data = preprocessing('pickle_lightcurves.pickle',Version,timediff,noofpoint)



In [7]:
print(datadict['Data'][0]['R_mag_wn'])
print(datadict['Data'][0]['R_mjd_wn'])
print(datadict['Data'][0]['R_err_wn'])

[19.863    16.580412 18.9924   16.364395 16.240137 16.66374  16.995054
 17.247751 18.422678 18.880629]
[58743.2079745 58747.2773032 58750.2082639 59335.4405671 59338.4472106
 59340.441169  59342.4675463 59345.4266088 59348.4435648 59354.4437384]
[-1.          0.0389301   0.198369    0.03921192  0.03599869  0.04902661
  0.04449942  0.04719749  0.07475115  0.15144286]


In [6]:
file = open('pickle_lightcurves.pickle', "rb")
datadict = pickle.load(file)
dummy_dict, source_idx = initialise(datadict,timediff,noofpoint)

In [None]:
lightcurve_R, error_R = interpolate(datadict['Data'][0]['R_mag_wn'], 
                                               datadict['Data'][0]['R_err_wn'], 
                                               datadict['Data'][0]['R_mjd_wn'])

In [None]:
for row in range(len(source_idx)):

    index = source_idx[row]

    try:

        # Interpolate the data
        interpolated = source_interpolate(datadict['Data'][index])
        dummy_dict = array_update(dummy_dict, interpolated, row)

    except:

        print('\nERROR : Interpolation failed. Discarding entry.\n')
        dummy_dict['failed'].append([index, datadict['Name'][index], datadict['Label'][index]])

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53]

In [83]:
datafile = 'Data/'+Version+'/data_lc.npy'
labelsfile = 'Data/'+Version+'/labels.npy'
failsfile = 'Data/'+Version+'/fails.npy'

data = np.load(datafile)
labels = np.load(labelsfile)
fails = np.load(failsfile)

print('Data Dimensions :', data.shape)
print('Label Dimensions :', labels.shape)
print('Fails Dimensions :', fails.shape)

ValueError: Object arrays cannot be loaded when allow_pickle=False

In [15]:
##### Plot lightcuves AND interpolated light curves

if plotbool == 1 or "Yes" or 'y' or 'Y':
    for scs in processed_data:
        fig, ax = plt.subplots(1,2)
        fig.suptitle("NAME")
        # Real data
        ax[0,0].plt.scatter(datadict['Data'][index]['R_mjd'], datadict['Data'][index]['R_mag'], 60, color=[(232/256, 63/256, 72/256)], label='R-band')
        ax[0,0].plt.scatter(datadict['Data'][index]['G_mjd'], datadict['Data'][index]['G_mag'], 60, color=[(31/256, 208/256, 130/256)], label='G-band')
        ax[0,0].plt.xlabel('Time [MJD]', fontsize=20)
        ax[0,0].plt.ylabel('Apparent Magnitude', fontsize=20)
        ax[0,0].plt.gca().invert_yaxis()
        ax[0,0].plt.grid()
        ax[0,0].plt.legend()
        ax[0,0].plt.title(datadict['Label'][index][0], fontsize=20)
        # Interpolated data
        ax[0,1].plt.scatter(datadict['Data'][index]['R_mjd'], datadict['Data'][index]['R_mag'], 60, color=[(232/256, 63/256, 72/256)], label='R-band')
        ax[0,1].plt.scatter(datadict['Data'][index]['G_mjd'], datadict['Data'][index]['G_mag'], 60, color=[(31/256, 208/256, 130/256)], label='G-band')
        ax[0,1].plt.xlabel('Normalized time', fontsize=20)
        ax[0,1].plt.ylabel('Normalized Flux', fontsize=20)
        ax[0,1].plt.gca().invert_yaxis()
        ax[0,1].plt.grid()
        ax[0,1].plt.legend()
        ax[0,1].plt.title(datadict['Label'][index][0], fontsize=20)
        plt.show()
        
        
    
    
    

NameError: name 'processed_data' is not defined

In [None]:
##### Classify it



X_train = torch()

# Model Building
input_dim = X_train.shape[2]
num_classes = Y_train.shape[1]
hidden_dim = 64
n_layers = 1

Net = simple_ENID(input_dim=input_dim, hidden_dim =hidden_dim, output_dim=num_classes, n_layers=n_layers)
print(Net)

Net.load_state_dict(torch.load(ModelWeights))




#???

' ZTF19abzwbxy'

# Function

In [None]:
def ApplyENID(datasources,Version,ModelWeights,timediff,noofpoints,plotbool):

        
    
    