In [2]:
import sys
import os
sys.path.insert(0, os.path.abspath('../..'))
import torch
import numpy as np
import utils
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from dataset import DataSet
import pandas as pd
from sklearn.decomposition import PCA
%load_ext autoreload
np.random.seed(2) ## keep same shuffled as was trained on 
torch.manual_seed(2)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<torch._C.Generator at 0x107c75d90>

## load a checkpoint & dataset the network trained on

In [8]:
lcs = utils.get_data('../../datasets/ZTF_g',split=.02)


validated 3408 files out of 3408 for band='g'
created union_tp attribute of length 3500
dataset created w/ shape (2559, 1, 1974, 3)
train size: 51, valid size: 50, test size: 2508


In [9]:
net, optimizer, args, epoch, loss, train_loss, test_loss = utils.load_checkpoint('../checkpoints/final/ZTF_g0.668415367603302.h5', lcs.data_obj)



=> loading checkpoint '../checkpoints/final/ZTF_g0.668415367603302.h5'
Namespace(n_union_tp=3500, data_folder='datasets/ZTF_g', checkpoint='datasets/ZTF_g0.7037358283996582.h5', start_col=1, inc_errors=False, print_at=1, embed_time=128, enc_num_heads=16, latent_dim=64, mixing='concat', num_ref_points=16, rec_hidden=128, width=512, save_at=30, patience=100, early_stopping=False, niters=20000, frac=0.5, batch_size=2, mse_weight=5.0, dropout=0.0, num_resamples=0, lr=1e-06, scheduler=False, warmup=4000, kl_zero=False, kl_annealing=True, net='hetvae', device='mps', const_var=False, var_per_dim=False, std=0.1, seed=2, save=True, k_iwae=1)


In [44]:
from tqdm import tqdm
def evaluate_hetvae(
    net,
    dim,
    dataloader,
    frac=0.5,
    k_iwae=1,
    device='mps',
    forecast=False,
    qz_mean=False
):
    train_n = 0
    train_loss,avg_loglik, mse, mae = 0, 0, 0,0
    mean_mae, mean_mse = 0, 0
    individual_nlls = []
    indy_nlls = []
    mses= []
    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch_len = batch.shape[0]
            # forecasting if this mask is set to first section of points only, not random selection
            subsampled_mask = utils.make_masks(batch, frac=frac, forecast=forecast)
            ######################
            errorbars = torch.swapaxes(batch[:,:,:,2], 2,1)
            weights = errorbars.clone()
            weights[weights!=0] = 1 / weights[weights!=0]
            errorbars[errorbars!=0] = torch.log(errorbars[errorbars!=0])
            logerr = errorbars.to(device)
            weights = weights.to(device)
            ######################
            batch = batch.to(device)
            subsampled_mask = subsampled_mask.to(device)
            recon_mask = torch.logical_xor(subsampled_mask, batch[:,:,:,1])

            context_y = torch.cat((
              batch[:,:,:,1] * subsampled_mask, subsampled_mask
            ), 1).transpose(2,1)
            recon_context_y = torch.cat((
              batch[:,:,:,1] * recon_mask, recon_mask
            ), 1).transpose(2,1)
            
            loss_info = net.compute_unsupervised_loss(
              batch[:, 0, :,0],
              context_y,
              batch[:, 0, :,0],
              recon_context_y,
              logerr,
              weights,
              num_samples=k_iwae,
              qz_mean=qz_mean
            )
            
            individual_nlls.append(loss_info.loglik_per_ex.cpu().numpy())
            train_loss += loss_info.composite_loss.item() * batch_len
            avg_loglik += loss_info.loglik * batch_len
            mse += loss_info.mse * batch_len
            train_n += batch_len
            

    avg_nll =  -avg_loglik / train_n
    avg_mse = -avg_loglik / train_n

    return avg_nll, avg_mse, -1 * np.concatenate(individual_nlls, axis=1)[0]


In [45]:
lcs_rm = utils.get_data('../../datasets/ZTF_rm_g',split=1.0)

validated 10 files out of 10 for band='g'
created union_tp attribute of length 3500
dataset created w/ shape (8, 1, 1312, 3)
train size: 8, valid size: 0, test size: 0


In [55]:
nll, mse, nlls = evaluate_hetvae(net, 1, lcs_rm.data_obj['train_loader'], frac=0.5, device='mps', qz_mean=True)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.61it/s]


In [57]:
nlls

array([1.1668259 , 2.5462494 , 1.9417218 , 0.16350092, 2.396925  ,
       1.8516396 , 1.3945901 , 2.5341253 ], dtype=float32)

In [58]:
#eval wrt means and 10 different frac things for anomaly scores 

Unnamed: 0,g
NGC5548,../../datasets/ZTF_rm_g/g/NGC5548_DR_gband.csv
3C120,../../datasets/ZTF_rm_g/g/3C120_DR_gband.csv
Mrk142,../../datasets/ZTF_rm_g/g/Mrk142_DR_gband.csv
Mrk876,../../datasets/ZTF_rm_g/g/Mrk876_DR_gband.csv
NGC2617,../../datasets/ZTF_rm_g/g/NGC2617_DR_gband.csv
H2106-099,../../datasets/ZTF_rm_g/g/H2106-099_DR_gband.csv
Mrk817,../../datasets/ZTF_rm_g/g/Mrk817_DR_gband.csv
MCG+08-11-011,../../datasets/ZTF_rm_g/g/MCG+08-11-011_DR_gba...
