# Applying 3-class Model to Test and Real Data

In [67]:
### python packages
import os
from os import path
import numpy as np
import glob as glob
from random import random
import pandas as pd
import pickle
import time

### torch packages
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import Sampler

### sklearn packages
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.model_selection import KFold


### remove these later (for notebook version only)
'''
from tqdm import tqdm_notebook as tqdm
from bokeh.plotting import figure
from bokeh.io import output_notebook, show, export_png
from bokeh.layouts import row
output_notebook()
import matplotlib.pyplot as plt
import seaborn as sns
'''

torch.cuda.manual_seed(42)

def GetSimData(file):

    ### grab local and global views
    data_global = np.nan_to_num(np.load(file.replace('_info2.npy','_glob.npy'),encoding='latin1'))
    data_local = np.nan_to_num(np.load(file.replace('_info2.npy','_loc.npy'),encoding='latin1'))

    ### grab centroid views
    data_global_cen = data_global[:,1][np.newaxis,:]
    data_local_cen = data_local[:,1][np.newaxis,:]

    data_global = data_global[:,0][np.newaxis,:]
    data_local = data_local[:,0][np.newaxis,:]

    ### info file contains: [0]kic, [1]tce, [2]period, [3]epoch, [4]duration, [5]label)
    data_info = np.load(file,encoding='latin1')
    #np.load(self.flist_info[idx],encoding='latin1')

    if data_info[6]=='PL':
        label=1
    elif data_info[6]=='UNK':
        label=0
    else:
        label=2

    #collist=['TPERIOD','TDUR','DRRATIO','NTRANS','TSNR','TDEPTH','INDUR',
    #         'SESMES_LOG_RATIO','PRAD_LOG_RATIO','TDUR_LOG_RATIO','RADRATIO','IMPACT',
    #         'TESSMAG','RADIUS','PMTOTAL','LOGG','MH','TEFF']#from bls search, derived from transit model, from starpars
    stelpars=np.nan_to_num(np.hstack((data_info[7:13].astype(float),data_info[-18:-6].astype(float))))[np.newaxis,:]

    return (data_local.astype(float), data_global.astype(float), data_local_cen.astype(float), data_global_cen.astype(float), stelpars), label

def GetRealData(file):
    ### grab local and global views
    data_global = np.nan_to_num(np.load(file.replace('_info2.npy','_glob.npy'),encoding='latin1'))
    data_local = np.nan_to_num(np.load(file.replace('_info2.npy','_loc.npy'),encoding='latin1'))

    ### grab centroid views
    data_global_cen = data_global[:,1][np.newaxis,:]
    data_local_cen = data_local[:,1][np.newaxis,:]

    data_global = data_global[:,0][np.newaxis,:]
    data_local = data_local[:,0][np.newaxis,:]

    ### info file contains: [0]kic, [1]tce, [2]period, [3]epoch, [4]duration, [5]label)
    data_info = np.load(file,encoding='latin1')
    #np.load(self.flist_info[idx],encoding='latin1')

    #collist=['TPERIOD','TDUR','DRRATIO','NTRANS','TSNR','TDEPTH','INDUR',
    #         'SESMES_LOG_RATIO','PRAD_LOG_RATIO','TDUR_LOG_RATIO','RADRATIO','IMPACT',
    #         'TESSMAG','RADIUS','PMTOTAL','LOGG','MH','TEFF']#from bls search, derived from transit model, from starpars
    newixs=np.array([16, 24, 17, 30, 31, 21])
    stelpars=np.nan_to_num(np.hstack((data_info[newixs].astype(float),data_info[-18:-6].astype(float))))[np.newaxis,:]

    return (data_local.astype(float), data_global.astype(float), data_local_cen.astype(float), data_global_cen.astype(float), stelpars), np.nan


In [68]:
def GetRealData(file):
    ### grab local and global views
    data_global = np.nan_to_num(np.load(file.replace('_info2.npy','_glob.npy'),encoding='latin1'))
    data_local = np.nan_to_num(np.load(file.replace('_info2.npy','_loc.npy'),encoding='latin1'))

    ### grab centroid views
    data_global_cen = data_global[:,1][np.newaxis,:]
    data_local_cen = data_local[:,1][np.newaxis,:]

    data_global = data_global[:,0][np.newaxis,:]
    data_local = data_local[:,0][np.newaxis,:]

    ### info file contains: [0]kic, [1]tce, [2]period, [3]epoch, [4]duration, [5]label)
    data_info = np.load(file,encoding='latin1')
    #np.load(self.flist_info[idx],encoding='latin1')

    #collist=['TPERIOD','TDUR','DRRATIO','NTRANS','TSNR','TDEPTH','INDUR',
    #         'SESMES_LOG_RATIO','PRAD_LOG_RATIO','TDUR_LOG_RATIO','RADRATIO','IMPACT',
    #         'TESSMAG','RADIUS','PMTOTAL','LOGG','MH','TEFF']#from bls search, derived from transit model, from starpars
    newixs=np.array([16, 24, 17, 30, 31, 21])
    stelpars=np.nan_to_num(np.hstack((data_info[newixs].astype(float),data_info[-18:-6].astype(float))))[np.newaxis,:]

    return (data_local.astype(float), data_global.astype(float), data_local_cen.astype(float), data_global_cen.astype(float), stelpars), np.nan


In [63]:
class Model(nn.Module):

    '''
    
    PURPOSE: DEFINE EXTRANET MODEL ARCHITECTURE
    INPUT: GLOBAL + LOCAL LIGHT CURVES AND CENTROID CURVES, STELLAR PARAMETERS
    OUTPUT: BINARY CLASSIFIER
    
    '''
    
    def __init__(self):

        ### initialize model
        super(Model, self).__init__()

        ### define global convolutional lalyer
        self.fc_global = nn.Sequential(
            nn.Conv1d(2, 16, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.Conv1d(16, 16, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(5, stride=2),
            nn.Conv1d(16, 32, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.Conv1d(32, 32, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(5, stride=2),
            nn.Conv1d(32, 64, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.Conv1d(64, 64, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(5, stride=2),
            nn.Conv1d(64, 128, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.Conv1d(128, 128, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(5, stride=2),
            nn.Conv1d(128, 256, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.Conv1d(256, 256, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(5, stride=2),
        )

        ### define local convolutional lalyer
        self.fc_local = nn.Sequential(
            nn.Conv1d(2, 16, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.Conv1d(16, 16, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(7, stride=2),
            nn.Conv1d(16, 32, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.Conv1d(32, 32, 5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(7, stride=2),
        )

        ### define fully connected layer that combines both views
        self.final_layer = nn.Sequential(
            nn.Linear(7858, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            ### need output of 1 because using BCE for loss
            nn.Linear(256, 3),
            nn.Softmax(1))

    def forward(self, x_local, x_local_cen, x_global, x_global_cen, x_star):
        
        x_local_all = torch.cat([x_local, x_local_cen], dim=1)
        x_global_all = torch.cat([x_global, x_global_cen], dim=1)

        ### get outputs of global and local convolutional layers
        out_global = self.fc_global(x_global_all)
        out_local = self.fc_local(x_local_all)
        
        ### flattening outputs (multi-dim tensor) from convolutional layers into vector
        out_global = out_global.view(out_global.shape[0], -1)
        out_local = out_local.view(out_local.shape[0], -1)

        ### join two outputs together
        out = torch.cat([out_global, out_local, x_star.squeeze(1)], dim=1)
        out = self.final_layer(out)

        return out
    

In [64]:
def make_one_hot(labels, C=2):
    '''
    Converts an integer label torch.autograd.Variable to a one-hot Variable.
    
    Parameters
    ----------
    labels : torch.autograd.Variable of torch.cuda.LongTensor
        N x 1 x H x W, where N is batch size. 
        Each value is an integer representing correct classification.
    C : integer. 
        number of classes in labels.
    
    Returns
    -------
    target : torch.autograd.Variable of torch.cuda.FloatTensor
        N x C x H x W, where C is class number. One-hot encoded.
    '''
    if type(labels)==list or type(labels)==np.ndarray:
        target=np.zeros((len(labels),C))
        target[np.arange(len(labels)), np.array(labels).astype(int)] = 1
    else:
        one_hot = torch.cuda.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).zero_()
        target = one_hot.scatter_(1, labels.data, 1)

        target = Variable(target)
        
    return target


In [96]:
allmodels={}
for k in range(8):
    #INPUTS:
    foldname    = "/home/hosborn/TESS/final_runs/exonet_multiclass3_CV_globcents3b_k8"
    savename    = "exonet_CV_"+str(k)+".8_101_all_Big"
    savedicname = "exonet_multiclass3_CV_globcents3b_k8_dic"
    mod         = "Big"
    aug         = "all"
    fpath       = "101"
    kcount      = "0"
    #cont       = False

    #Assigning a GPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = "0"

    allmodels[k]=Model().cuda()
    lr = 2.05e-5

    optimizer = torch.optim.Adam(allmodels[k].parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    batch_size = 64

    allmodels[k].load_state_dict(torch.load(path.join(foldname,savename.replace('_comp','')+'_temp.pth')))

In [17]:
test_data = SimDataLoaderCrossVal(infofiles=)
test_data_loader = DataLoader(test_data, batch_size=batch_size, num_workers=4)

#nan_data = SimDataLoaderCrossVal(infofiles=glob.glob('/home/hosborn/TESS/processed_dv_101_centfixed2/nans/*info2.npy'))
#nan_data_loader = DataLoader(nan_data, batch_size=batch_size, num_workers=4)

real_data = RealDataLoaderCrossVal(infofiles=glob.glob('/home/hosborn/TESS/Processed_RealDat_2/*info2.npy'))
real_data_loader = DataLoader(real_data, batch_size=batch_size,num_workers=4)

real_nan_data = RealDataLoaderCrossVal(infofiles=glob.glob('/home/hosborn/TESS/Processed_RealDat_2/nans/*info2.npy'))
real_nan_data_loader = DataLoader(real_nan_data, batch_size=batch_size, num_workers=4)


In [109]:
dataloaders={'test':glob.glob('/home/hosborn/TESS/processed_dv_101_centfixed2/test/*info2.npy'),
             'real':glob.glob('/home/hosborn/TESS/Processed_RealDat_2/*info2.npy'),
             'real_nan':glob.glob('/home/hosborn/TESS/Processed_RealDat_2/nans/*info2.npy')}#,'nan':nan_data_loader
model_output={}

labels=['UNK','PL','EB','BEB']
nclasses=3

for data_loader in dataloaders:
    model_output[data_loader]=pd.DataFrame()
    
    for file in dataloaders[data_loader]:
        if data_loader=='test':
            x_data, ys = GetSimData(file)
        else:
            x_data, ys = GetRealData(file)

        ### get local view, global view, and label for training
        x_local, x_global, x_local_cent,x_global_cent, x_star = x_data

        x_local = Variable(torch.from_numpy(x_local)).type(torch.FloatTensor).cuda().unsqueeze(1)
        x_local_cent = Variable(torch.from_numpy(x_local_cent)).type(torch.FloatTensor).cuda().unsqueeze(1)
        x_global = Variable(torch.from_numpy(x_global)).type(torch.FloatTensor).cuda().unsqueeze(1)
        x_global_cent = Variable(torch.from_numpy(x_global_cent)).type(torch.FloatTensor).cuda().unsqueeze(1)
        x_star = Variable(torch.from_numpy(x_star)).type(torch.FloatTensor).cuda().unsqueeze(1)
        
        preds={}
        for k in allmodels:
            ### Looping through each model to predict:
            modout=allmodels[k](x_local, x_local_cent, x_global, x_global_cent, x_star).cpu().detach().numpy().ravel()
            
            preds.update({labels[n]+'_'+str(k):modout[n] for n in range(nclasses)})
        if data_loader is 'test':
            preds['gt']=ys
        model_output[data_loader] = model_output[data_loader].append(pd.Series(preds,name=file.split('/')[-1][:18]))
    for n in range(nclasses):
        allks=model_output[data_loader][[labels[n]+'_'+str(k) for k in range(8)]].values
        model_output[data_loader][labels[n]+'_med'] = np.nanmedian(allks,axis=1)
        model_output[data_loader][labels[n]+'_av'] = np.average(allks,axis=1)
    if data_loader is 'test':
        medvals=model_output['test'][[labels[nl]+'_med' for nl in range(nclasses)]].values
        avvals=model_output['test'][[labels[nl]+'_av' for nl in range(nclasses)]].values
        onehot=make_one_hot(model_output[data_loader]['gt'].values,nclasses)
        print("AP: average:",average_precision_score(onehot,avallks,average='micro'))
        for n in range(nclasses):
            print(labels[n],"MED accuracy:",np.sum((np.argmax(medvals,axis=1)==n)*onehot[:,n].astype(bool))/np.sum(np.argmax(medvals,axis=1)==n))
            print(labels[n],"MED recall:",np.sum((np.argmax(medvals,axis=1)==n)*onehot[:,n].astype(bool))/np.sum(onehot[:,n]))
            print(labels[n],"MED A.P.:", average_precision_score(onehot,medvals,average=None)[n])
            print(labels[n]," AV accuracy:",np.sum((np.argmax(avvals,axis=1)==n)*onehot[:,n].astype(bool))/np.sum(np.argmax(avvals,axis=1)==n))
            print(labels[n]," AV recall:",np.sum((np.argmax(avvals,axis=1)==n)*onehot[:,n].astype(bool))/np.sum(onehot[:,n]))
            print(labels[n]," AV A.P.:", average_precision_score(onehot,avvals,average=None)[n])


AP: average: 0.9711121403126206
UNK MED accuracy: 0.9467501957713391
UNK MED recall: 0.9482352941176471
UNK MED A.P.: 0.9686431917317233
UNK  AV accuracy: 0.9482758620689655
UNK  AV recall: 0.9490196078431372
UNK  AV A.P.: 0.9773092760006252
PL MED accuracy: 0.9007832898172323
PL MED recall: 0.8961038961038961
PL MED A.P.: 0.9397033318605802
PL  AV accuracy: 0.9036458333333334
PL  AV recall: 0.9012987012987013
PL  AV A.P.: 0.9457146336827876
EB MED accuracy: 0.9512893982808023
EB MED recall: 0.9512893982808023
EB MED A.P.: 0.9574337104175915
EB  AV accuracy: 0.9512893982808023
EB  AV recall: 0.9512893982808023
EB  AV A.P.: 0.9689583201448094


In [110]:
pickle.dump(model_output,open('3classmod.pickle','wb'))

In [107]:
medvals=model_output['test'][[labels[nl]+'_med' for nl in range(nclasses)]].values
avvals=model_output['test'][[labels[nl]+'_av' for nl in range(nclasses)]].values
onehot=make_one_hot(model_output['test']['gt'].values,nclasses)
print("AP: average:",average_precision_score(onehot,avallks,average='micro'))
for n in range(nclasses):
    print(labels[n]," AV accuracy:",np.sum((np.argmax(avvals,axis=1)==n)*onehot[:,n].astype(bool))/np.sum(np.argmax(avvals,axis=1)==n))
    print(labels[n]," AV recall:",np.sum((np.argmax(avvals,axis=1)==n)*onehot[:,n].astype(bool))/np.sum(onehot[:,n]))
    print(labels[n]," AV A.P.:", average_precision_score(onehot,avvals,average=None)[n])


AP: average: 0.9711121403126206
UNK  AV accuracy: 0.9482758620689655
UNK  AV recall: 0.9490196078431372
UNK  AV A.P.: 0.9773092760006252
PL  AV accuracy: 0.9036458333333334
PL  AV recall: 0.9012987012987013
PL  AV A.P.: 0.9457146336827876
EB  AV accuracy: 0.9512893982808023
EB  AV recall: 0.9512893982808023
EB  AV A.P.: 0.9689583201448094


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

%matplotlib inline
%config IPython.matplotlib.backend = "retina"
from matplotlib import rcParams
rcParams["savefig.dpi"] = 300
rcParams["figure.dpi"] = 300

plt.figure(figsize=(6,6))

labels=['UNK','PL','EB','BEB']
one_hot_gt=make_one_hot(model_output['test']['gt'].values,nclasses)
### plot values
for n in range(nclasses):
    P, R, _ =precision_recall_curve(one_hot_gt[:,n],medvals[:,n])
    steps=plt.step(R, P,label=labels[n]+'_Med',linewidth=2.5,alpha=0.75)
    P, R, _ =precision_recall_curve(one_hot_gt[:,n],avvals[:,n])
    plt.step(R, P,label=labels[n]+'_Av',linewidth=2.5,linestyle='--',color=steps[0].get_color(),alpha=0.75)
plt.legend()
plt.xlim(0.0,1.05)
plt.ylim(0.2,1.05)
plt.xlabel('Recall')
plt.ylabel('Precision')
savefile='3class_ensemble_testdat_avmedcomp'
plt.savefig(path.join('/home/hosborn/TESS/PaperFigures',savefile+'_PR.png'))
plt.savefig(path.join('/home/hosborn/TESS/PaperFigures',savefile+'_PR.pdf'))

In [113]:
model_output['test']['gt']

000358070970_00_02    1.0
000248988932_00_04    0.0
000269594275_00_04    2.0
000220476100_00_01    0.0
000260356202_01_99    2.0
000372908550_16_03    0.0
000266716027_00_04    0.0
000279726530_06_99    2.0
000237339895_00_02    0.0
000234520087_01_99    2.0
000183985164_01_03    2.0
000266966948_00_01    0.0
000360736543_00_01    1.0
000183495016_03_99    0.0
000349192028_05_01    2.0
000123292130_00_04    1.0
000267114952_02_03    2.0
000150358991_03_02    1.0
000260078169_06_99    1.0
000033835925_00_01    0.0
000261542169_01_99    2.0
000031942864_00_01    0.0
000259863095_00_01    2.0
000220526640_02_02    2.0
000220432563_00_99    0.0
000425938854_04_03    2.0
000050381881_08_02    2.0
000144000467_00_02    1.0
000277147381_02_01    2.0
000294784092_00_99    0.0
                     ... 
000401839949_02_04    2.0
000010803257_00_04    2.0
000219984873_02_03    1.0
000220436792_00_01    0.0
000421940749_00_02    0.0
000373523595_04_03    2.0
000346704136_00_01    1.0
000167087044

In [53]:

### grab local and global views
data_global = np.nan_to_num(np.load('/home/hosborn/TESS/Processed_RealDat_2/000001129033_00_04_glob.npy',encoding='latin1'))
data_local = np.nan_to_num(np.load('/home/hosborn/TESS/Processed_RealDat_2/000001129033_00_04_loc.npy',encoding='latin1'))

### grab centroid views
x_global_cent = data_global[:,1]
x_local_cent = data_local[:,1][np.newaxis,:]

x_global = data_global[:,0][np.newaxis,:]
x_local = data_local[:,0][np.newaxis,:]

### info file contains: [0]kic, [1]tce, [2]period, [3]epoch, [4]duration, [5]label)
data_info = np.load('/home/hosborn/TESS/Processed_RealDat_2/000001129033_00_04_info2.npy',encoding='latin1')
#np.load(self.flist_info[idx],encoding='latin1')

label=np.nan

#collist=['TPERIOD','TDUR','DRRATIO','NTRANS','TSNR','TDEPTH','INDUR',
#         'SESMES_LOG_RATIO','PRAD_LOG_RATIO','TDUR_LOG_RATIO','RADRATIO','IMPACT',
#         'TESSMAG','RADIUS','PMTOTAL','LOGG','MH','TEFF']#from bls search, derived from transit model, from starpars
newixs=np.array([16, 24, 17, 30, 31, 21])

x_star=np.nan_to_num(np.hstack((data_info[newixs].astype(float),data_info[-18:-6].astype(float))))[np.newaxis,:]

In [56]:
W77=[]
for k in allmodels:
    pred=allmodels[k](x_local, x_local_cent, x_global, x_global_cent, x_star)
    W77+=[pred.cpu().detach().numpy()]

In [57]:
W77

[array([[1.2458412e-14, 1.0000000e+00, 1.2592762e-23]], dtype=float32),
 array([[1.2458412e-14, 1.0000000e+00, 1.2592762e-23]], dtype=float32),
 array([[1.2458412e-14, 1.0000000e+00, 1.2592762e-23]], dtype=float32),
 array([[1.2458412e-14, 1.0000000e+00, 1.2592762e-23]], dtype=float32),
 array([[1.2458412e-14, 1.0000000e+00, 1.2592762e-23]], dtype=float32),
 array([[1.2458412e-14, 1.0000000e+00, 1.2592762e-23]], dtype=float32),
 array([[1.2458412e-14, 1.0000000e+00, 1.2592762e-23]], dtype=float32),
 array([[1.2458412e-14, 1.0000000e+00, 1.2592762e-23]], dtype=float32)]

In [23]:
pickle.dump(model_output,open('3classmod.pickle','wb'))

In [22]:
np.save('real_nan_idlist3.npy',real_nan_data.ids)
np.save('real_idlist3.npy',real_data.ids)

In [25]:
average_precision_score?

In [9]:
pred.cpu().detach().numpy()

array([[3.04026753e-16, 1.00000000e+00, 3.36392283e-23],
       [1.00000000e+00, 4.62603018e-32, 4.59011311e-19],
       [1.35944629e-05, 1.94485413e-20, 9.99986410e-01],
       [1.00000000e+00, 5.81212925e-11, 9.34402382e-12],
       [3.87693632e-07, 5.12899892e-22, 9.99999642e-01],
       [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [1.00000000e+00, 1.08231546e-09, 3.43416322e-13],
       [1.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [1.00000000e+00, 0.00000000e+00, 5.74532370e-44],
       [1.01939613e-05, 2.09248008e-17, 9.99989748e-01],
       [3.25027304e-06, 1.25134565e-25, 9.99996781e-01],
       [1.00000000e+00, 3.31933647e-10, 3.08426147e-15],
       [2.97330960e-08, 1.00000000e+00, 1.65461330e-14],
       [1.00000000e+00, 4.08566872e-16, 5.78681677e-12],
       [1.49439719e-13, 8.98340809e-26, 1.00000000e+00],
       [9.55515445e-10, 1.00000000e+00, 5.84007940e-14],
       [2.59327239e-12, 1.35251816e-29, 1.00000000e+00],
       [2.67890164e-06, 9.99997