In [None]:
import os


import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import torch
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl 

GPUS = int(torch.cuda.is_available())
# torch.cuda.empty_cache() 
def print_cuda_summary():
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    f = r-a  # free inside reserved

    print("torch.cuda.get_device_properties(0).total_memory %fGB"%(t/1024/1024/1024))
    print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
    print("Free (res - alloc) %fGB"%(f/1024/1024/1024))

if GPUS:
    print_cuda_summary()

In [None]:
max_samples = 1000
val_ratio = 0.1
batch_size = 64


seed = 0
standardise_axes = (0, 1)  # per sample standardisation

num_workers = 0
pin_memory = True

if num_workers > 0:
    import cv2
    cv2.setNumThreads(0) 

In [None]:
torch.cuda.empty_cache() 

### TESS DATA

In [None]:
if GPUS:
    path = "/state/partition1/mmorvan/data/TESS/lightcurves/0001"
else:
    path = "/Users/mario/data/TESS/lightcurves/0027"

train_path = os.path.join(path, 'processed_train')
test_path = os.path.join(path, 'processed_test')

In [None]:
from datasets.kepler_tess import TessDataset, Subset, split_indices
from transforms import Compose,StandardScaler, AddGaussianNoise, Mask, FillNans, RandomCrop, DownSample
from utils.loading import CollatePred


    
transform_both_train = Compose([RandomCrop(800, exclude_missing_threshold=0.8),
                          DownSample(2),
                          Mask(0.3, block_len=None, value=None, exclude_mask=True),
                          StandardScaler(dim=0),
                          #FillNans(0),
                         ])

transform_both_2 = Compose([RandomCrop(800, exclude_missing_threshold=0.8),
                               DownSample(2),
#                                Mask(0.3, block_len=None, value=None, exclude_mask=True),
                               StandardScaler(dim=0),
                          #FillNans(0),
                         ])

transform = None

if GPUS:
    path = "/state/partition1/mmorvan/data/TESS/lightcurves/0001"
else:
    path = "/Users/mario/data/TESS/lightcurves/0027"

dataset = TessDataset(train_path, 
                      load_processed=True, 
                      max_samples=max_samples,
                      transform=transform,
                      transform_both=transform_both_train,
                      use_cache=True,
                      )
test_dataset = TessDataset(test_path, 
                           load_processed=True, 
                           use_cache=True,
                          )


#dataset.n_dim = 1
# TRAIN/VAL SPLIT
val_size = int(val_ratio * len(dataset))
train_size = len(dataset) - val_size
train_indices, val_indices = split_indices((train_size, val_size),
                                           generator=torch.Generator().manual_seed(seed))
train_dataset = Subset(dataset, train_indices)

val_dataset1 = Subset(dataset, val_indices)
val_dataset2 = Subset(dataset, val_indices, replace_transform_both=transform_both_2)

test_dataset1 = Subset(test_dataset, replace_transform_both=transform_both_train)
test_dataset2 = Subset(test_dataset, replace_transform_both=transform_both_2)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                          num_workers=num_workers, pin_memory=pin_memory)
val_loader1 = DataLoader(val_dataset1, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)
val_loader2 = DataLoader(val_dataset2, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)
test_loader1 = DataLoader(test_dataset1, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)
test_loader2 = DataLoader(test_dataset2, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)

loader_pred = DataLoader(test_dataset, 
                         batch_size=1, 
                         shuffle=False, 
                         collate_fn=CollatePred(400, step=350),
                         num_workers=num_workers, pin_memory=pin_memory)

In [None]:
x, y, m, meta = dataset[0]

plt.plot(x)
plt.plot(y)

In [None]:
## Stitching back
# X_o = inverse_standardise_batch(X, I['mu'], I['sigma'])

# plt.figure(figsize=(25,5))
# for i in range(len(X)):
#     x = X_o[i]
#     plt.plot(range(I['left_crop'][i], I['left_crop'][i]+400),
#              x)

### Dummy Data

In [None]:
# from datasets import DummyDataset

# dataset = DummyDataset(100)
# x, y, m, meta = dataset[0]

# plt.plot(x)
# plt.plot(y)

# Model

In [None]:
# ### TEST RUN
# lit_model = LitImputer(1, noise_scaling=True)
# trainer = pl.Trainer(max_epochs=10, gpus=GPUS)
# result = trainer.fit(lit_model, train_dataloaders=train_loader)

### Define Model

In [None]:
from models import LitImputer
pl.seed_everything(0)
if GPUS:
    print_cuda_summary()

In [None]:
lit_model = LitImputer(n_dim=1, d_model=64, dim_feedforward=128, lr=0.001,
                       train_unit = 'noise', train_loss='mae')

### Training

In [None]:
import datetime
from pytorch_lightning.loggers import NeptuneLogger
logger = NeptuneLogger(project="denoising-transformer",
                        name='tess_denoising',
                       tags=[str(len(dataset))+' samples',
                             #"continued_from_den-186"
                             "test fix Masked Loss",
                             "train - " + lit_model.train_unit,
                            ])

In [None]:
trainer = pl.Trainer(max_epochs=5, 
                     logger=logger, 
                     gpus=GPUS,
                     check_val_every_n_epoch=1)

In [None]:
result = trainer.fit(lit_model, 
                     train_dataloaders=train_loader,
                     val_dataloaders=[val_loader1, val_loader2], 
                     )

In [None]:
trainer.test(lit_model, dataloaders=[test_loader1, val_loader2])

### Loading

In [None]:
import glob
num_study = 299

ckpt_paths = glob.glob(f"./.neptune/tess_denoising/DEN-{num_study}/checkpoints/*.ckpt")
assert len(ckpt_paths) == 1
ckpt_path = ckpt_paths[0]
print(f'successfully found ckpt file for study {num_study}: ', ckpt_path)

# ckpt_path = "./.neptune/Tess-denoising17-02-2022_01-34-35/DEN-167/checkpoints/epoch=382-step=4992.ckpt"
# ckpt_path = "./.neptune/Tess-denoising_17-02-2022_02-19-37/DEN-170/checkpoints/epoch=1032-step=15494.ckpt"
# ckpt_path = "./.neptune/Tess-denoising17-02-2022_11-11-26/DEN-186/checkpoints/epoch=601-step=15062.ckpt"
# ckpt_path = "./.neptune/Tess-denoising17-02-2022_11-41-36/DEN-189/checkpoints/epoch=636-step=15924.ckpt"
# ckpt_path = "./.neptune/tess_denoising/DEN-219/checkpoints/epoch=2999-step=86999.ckpt"
# ckpt_path = './.neptune/tess_denoising/DEN-245/checkpoints/epoch=1382-step=20744.ckpt'
# ckpt_path = "./.neptune/tess_denoising/DEN-258/checkpoints/epoch=2999-step=44999.ckpt"
# ckpt_path = './.neptune/tess_denoising/DEN-264/checkpoints/epoch=2999-step=44999.ckpt'
# ckpt_path = "./.neptune/tess_denoising/DEN-272/checkpoints/epoch=4999-step=74999.ckpt"  # MAE
# ckpt_path = "./.neptune/tess_denoising/DEN-289/checkpoints/epoch=4999-step=74999.ckpt"
# ckpt_path = "./.neptune/tess_denoising/DEN-294/checkpoints/epoch=4999-step=74999.ckpt"
# ckpt_path = "./.neptune/tess_denoising/DEN-295/checkpoints/epoch=1696-step=25454.ckpt"
# ckpt_path = "./.neptune/tess_denoising/DEN-296/checkpoints/epoch=4300-step=64514.ckpt"
# ckpt_path = "./.neptune/tess_denoising/DEN-299/checkpoints/epoch=4999-step=74999.ckpt"   # MAE + Geom mask

In [None]:
lit_model = lit_model.load_from_checkpoint(ckpt_path)

# Analysis

In [None]:
print_cuda_summary()
torch.cuda.empty_cache()

from utils.stats import estimate_noise
from utils.postprocessing import compute_rollout_attention
lit_model.eval().cuda()

if 'X' in globals() and 'Y' in globals() and 'M' in globals() and 'I' in globals():
    del X, Y, M, I, AR

with torch.no_grad():
    X, Y, M, I = next(iter(loader_pred))
    Y_pred = lit_model(X.cuda()).detach().cpu().numpy()
    noise = estimate_noise(Y)
    attention_maps = lit_model.get_attention_maps(X.cuda(), mask=M.cuda())
    AR = compute_rollout_attention(attention_maps)
print('N stars with white noise contribution smaller than 0.5:',(noise <= 0.5).sum().item())

In [None]:
### # plot diagnostic
from utils.postprocessing import plot_pred_diagnostic

#### All indices in their order
indices = range(len(X))
step = 1
indices = indices[::step]

##### Ordered selection wrt to noise contribution
# n_plots = 8 
# indices = np.argsort(noise.squeeze().detach().cpu())
# step = len(indices)//n_plots
# step = 2
# indices = indices[::step]


## Just a random sample
# indices = [np.random.randint(len(X))]

for i in indices:
    x = X[i,:,0].detach().cpu().numpy()
    y = Y[i,:,0].detach().cpu().numpy()
    mask = M[i,:,0].detach().cpu().numpy()
    info = {k:v[i].detach().cpu().item() for k,v in I.items()}
    y_pred = Y_pred[i,:,0]
    ar = AR[i].sum(1).detach().cpu().numpy()
    plot_pred_diagnostic(x, y, y_pred, mask=mask, ar=ar, targetid=info['targetid'], mu=info['mu'], sigma=info['sigma'])
    plt.show()


### Plotting examples of attention
Appendix

In [None]:
# Plot only attention - pred

ncols = 4
nrows = 4
n = ncols * nrows
indices = np.argsort(noise.squeeze().detach().cpu())
step = len(indices)//n
step = 2
indices = indices[::step]

fig, ax = plt.subplots(nrows, ncols, figsize=(10,5), sharex=True)
fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
plt.grid(False)
plt.ylabel('standardised flux')
plt.xlabel('time steps')

k = 0
for row in range(nrows):
    for col in range(ncols):
        i = indices[k]
        x = X[i,:,0].detach().cpu().numpy()
        y = Y[i,:,0].detach().cpu().numpy()
        mask = M[i,:,0].detach().cpu().numpy()
        info = {k:v[i].detach().cpu().item() for k,v in I.items()}
        y_pred = Y_pred[i,:,0]
        ar = AR[i].sum(1).detach().cpu().numpy()
#         ax[row, col].set_title('Prediction')
#         ax[row, col].set_ylabel('stand. flux')
        ma, Ma = np.min(ar), np.max(ar)
        alpha = (ar-ma)/(Ma-ma)/1.002 + ma + 1e-5
        s = (((ar-ma)/(Ma-ma)+ma)) * 20 + 1
        
        ax[row, col].scatter(range(len(x)), y, label='input',
                         color='black', s=s, alpha=alpha)

        ax[row, col].plot(y_pred, label='pred', color='red', lw=1, alpha=0.9)
        ax[row, col].set_yticks([])
        k += 1

# hide tick and tick label of the big axes
#plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
        
import datetime
date = datetime.datetime.now()
plt.savefig(f'experiments/outputs/attention_preds_{date}.pdf', format="pdf", bbox_inches="tight")
plt.show()

# Baselines and metrics

### Batch eval

In [None]:
print_cuda_summary()
torch.cuda.empty_cache()

from utils.stats import estimate_noise
from utils.postprocessing import compute_rollout_attention

lit_model.eval().cuda()

if 'X' in globals() and 'Y' in globals() and 'M' in globals() and 'I' in globals():
    del X, Y, M, I, AR

with torch.no_grad():
    X, Y, M, I = next(iter(test_loader1))
    Y_pred = lit_model(X.cuda()).cpu()
    noise = estimate_noise(Y)
    attention_maps = lit_model.get_attention_maps(X.cuda(), mask=M.cuda())
    AR = compute_rollout_attention(attention_maps)
print('N stars with white noise contribution smaller than 0.5:',(noise <= 0.5).sum().item())

In [None]:
from utils.postprocessing import inverse_standardise_batch, detrend
from utils.stats import naniqr, compute_dw

Y_o = inverse_standardise_batch(Y, I['mu'], I['sigma'])
Y_pred_o = inverse_standardise_batch(Y_pred, I['mu'], I['sigma'])
Y_d = detrend(Y_o, Y_pred_o).numpy()

iqr_dtst = naniqr(Y_d, dim=1)
dw_dtst = compute_dw(Y_d-1, reduction='none')
np.nanmean(iqr_dtst), np.nanmedian(iqr_dtst), np.nanmean(np.abs(2-dw_dtst))

In [None]:
import wotan
from models import predict_batch_wotan

for window in [0.2, 0.3, 0.4, 0.5, 0.6, 0.7]:
    print('\t', window)
    Y_wotan_d, _ = predict_batch_wotan(Y_o.squeeze()[:-2], window_length=window)
    iqr_wotan = naniqr(Y_wotan_d, dim=1)
    dw_wotan = compute_dw(Y_wotan_d-1, reduction='none')
    print(iqr_wotan.mean(), np.median(iqr_wotan),  np.abs(2-dw_wotan).mean())

In [None]:
# Plotting IQR scores
# plt.hist(iqr_dtst.flatten(), 50, range=(0, 0.05))
# plt.hist(iqr_wotan.flatten(), 50, range=(0, 0.05))

# pass

In [None]:
# ### plotting baseline
# Y_wotan_d, trend_wotan = predict_batch_wotan(Y_o, window_length=0.2)

# i = np.random.randint(len(Y))
# plt.plot(Y_o[i])
# plt.plot(trend_wotan[i])

# plot_pred_diagnostic(Y_o[i], Y_o[i], trend_wotan[i])

In [None]:
# a range of predictions

for i in indices:
    x = X[i,:,0].detach().cpu().numpy()
    y = Y[i,:,0].detach().cpu().numpy()
    mask = M[i,:,0].detach().cpu().numpy()
    info = {k:v[i].detach().cpu().item() for k,v in I.items()}
    y_pred = Y_pred[i,:,0]

    y_o = inverse_standardise_batch(y, info['mu'], info['sigma'])
    y_pred_o = inverse_standardise_batch(y_pred, info['mu'], info['sigma'])
    y_d = detrend(y_o, y_pred_o)
    
    time = np.arange(len(y_d)) / 24
    y_d_biweight = wotan.flatten(time, y_o, method='biweight', return_trend=False)
    y_d_medfilt = wotan.flatten(time, y_o, window_length=49, method='medfilt', return_trend=False)
    
    print(f'IQR(biweight) : {naniqr(y_d_biweight):.5f}')
    print(f'IQR(medfilt) : {naniqr(y_d_medfilt):.5f}')
    print(f'IQR(dtst) : {naniqr(y_d):.5f}')
    print()

### Full-LC predictions

In [None]:
with torch.no_grad():
    X, Y, M, I = next(iter(loader_pred))
    Y_pred = lit_model(X.cuda()).cpu()
    Y_o = inverse_standardise_batch(Y, I['mu'], I['sigma'])
    Y_pred_o = inverse_standardise_batch(Y_pred, I['mu'], I['sigma'])
    Y_d = detrend(Y_o, Y_pred_o).numpy()

# access original TS non transformed 
idx = I['idx'][0]
Y_intact = test_dataset.get_pretransformed_sample(idx)

from utils.postprocessing import fold_back
# check same as inverse transfo 
skip = 25
seq_len = len(Y_intact)
Y_of = fold_back(Y_o, skip, seq_len=seq_len)
assert np.isclose(Y_intact.flatten(), Y_of.flatten(), equal_nan=True).all()

In [None]:
# Show full pred with wotan pred too

Y_pred_of = fold_back(Y_pred_o, skip, seq_len=seq_len)

Y_wotan_d, trend_wotan_of = wotan.flatten(np.arange(seq_len) / 48,
                                          Y_of,
                                          window_length=0.5,
                                          return_trend = True
                                          )
plt.figure(figsize=(25,10))
plt.plot(Y_of)
plt.plot(Y_pred_of)
plt.plot(trend_wotan_of)
# plt.plot(Y_of/Y_pred_of)
#plt.xlim(3000,7000)

### Full test dataset

In [None]:
from tqdm import tqdm
from datetime import datetime
from utils.postprocessing import eval_full_inputs

iqr, dw = eval_full_inputs(lit_model, loader_pred, test_dataset, 25, 'cuda')
print(iqr, dw)

In [None]:
# Produce full predictions for baselines
from tqdm import tqdm
from datetime import datetime

d_pred = {#'dtst': [],
          'biweight_0.5': [],
          'medfilt': [],
            }
d_time = {#'dtst': [],
          'biweight_0.5': [],
          'medfilt': [],
            }

k = 0
for X, Y, M, I in tqdm(loader_pred):
    k+=1
#     if k> 20:
#         break
    # access original TS non transformed 
    idx = I['idx'][0]
    Y_intact = test_dataset.get_pretransformed_sample(idx).squeeze()
    seq_len = len(Y_intact)
    time = np.arange(seq_len) / 48  # Get the actual time vector maybe
    
    with torch.no_grad():
        
#         t0 = datetime.now()
#         Y_pred = lit_model(X.cuda()).cpu()
#         Y_pred_o = inverse_standardise_batch(Y_pred, I['mu'], I['sigma'])
#         Y_pred_of = fold_back(Y_pred_o, skip=25, seq_len=seq_len)
#         d_pred['dtst'] += [Y_pred_of]
#         d_time['dtst'] += [datetime.now()-t0]
        
        t0 = datetime.now()
        d_pred['biweight_0.5'] += [wotan.flatten(time, Y_intact, window_length=0.5, 
                                                 method='biweight', return_trend = True)[1]]
        d_time['biweight_0.5'] += [datetime.now()-t0]

        t0 = datetime.now()
        d_pred['medfilt'] += [wotan.flatten(time, Y_intact, window_length=49, 
                                            method='medfilt', return_trend = True)[1]]
        d_time['medfilt'] += [datetime.now() - t0]

for model in d_pred:
    if len(d_pred[model]):
        d_pred[model] = np.vstack(d_pred[model])
        
for model in d_time:
    if len(d_time[model]):
        d_time[model] = np.mean(d_time[model])

In [None]:
# Compute metrics

target_test = np.vstack([test_dataset.get_pretransformed_sample(idx).squeeze() for idx in range(len(test_dataset))])
iqr = dict()
dw = dict()

for model in ['dtst', 'biweight_0.5', 'medfilt']:
    pred_d = target_test / d_pred[model]
    iqr[model] = naniqr(pred_d, dim=1, reduction='mean')
    dw[model] = np.abs(compute_dw(pred_d-1, axis=1, reduce='none')-2).mean()

#### Computation efficiency of various attention mechanisms

In [None]:
# # # size study
# list_seq_len = [100, 200,  400, 700, 1000, 2000]
# # for L in list_seq_len:
# #     torch.cuda.empty_cache()
# #     lit_model = LitImputer(n_dim, d_model=64, dim_feedforward=128,  num_layers=3, eye=eye, lr=0.001,
# #                        normal_ratio=0.2, keep_ratio=0., token_ratio=0.8, attention='linear', seq_len=L
# #                       )
    
# #     dataset.transform_both = Compose([RandomCrop(L),
# #                                       StandardScaler(dim=0),
# #                                      ])
# #     train_loader = DataLoader(dataset, batch_size=64, shuffle=True) 

# #     trainer = pl.Trainer(max_epochs=3, 
# #                          gpus=1, profiler='simple')

# #     result = trainer.fit(lit_model, 
# #                          train_dataloaders=train_loader,                     
# #                          )

# # result_full = [0.16705 , 0.24147, 0.51497, np.nan, np.nan, np.nan] #B256
# result_full = [0.2 , 0.27, 0.54, 1.2, np.nan, np.nan] #B64
# # result_prob = [ 0.30687, 0.33286, 0.38871, 0.54279, 0.68638, 1.1525 ] #B256
# result_prob = [0.47, 0.48, 0.53, 0.64, 0.81, 1.3] #B64

# ###result_lin = [0.25, 0.31, 0.46, 0.65, 0.87, 1.53] #B64
# result_lin = [0.21, 0.24, 0.32, 0.42, 0.52, 0.87]
# plt.plot(list_seq_len, result_full, marker='o', label='Full attention B64')#, marker='-o')
# plt.plot(list_seq_len, result_prob, marker='o', label='Prob attention B64')
# plt.plot(list_seq_len, result_lin, marker='o', label='Linformer attention B64')
# plt.legend()
# plt.xlabel('Sequence length')
# plt.ylabel('Training epoch time')
# print('batch size', train_loader.batch_size)

### Look at noise levels and computations

In [None]:
# from utils import nanstd
# X,_,_,_ = next(iter(train_loader))
# B, L, D = X.shape


# window = 10
# n_windows = L // window
# X.view(B, n_windows, window).shape

# noise = nanstd(X.view(B, n_windows, window), -1, keepdim=True).nanmedian(1, keepdim=True).values

# # Samples highlighted by non-white-noise contribution to the variance
# plt.figure(figsize=(30,20))
# for i in range(len(X)):
#     plt.plot(X[i,:,0], alpha=1-noise[i,0,0].item()**1.5, lw=1/noise[i,0,0])
# plt.title('Batch highlighted by inverse noise')
# plt.legend()
# plt.show()

# # Distribution of noise estimates
# plt.hist(noise[:,0,0].numpy(), 50)
# plt.title('White Noise Level')
# plt.show()

# # After correction
# plt.figure(figsize=(30,20))
# plt.plot((X / noise)[:,:,0].T)
# plt.show()

# # Check of better noise estimate

# def rolling_std(x, width=10):
#     return pd.Series(x).rolling(width, center=True, min_periods=1).std().values

# better_noise = []
# for i in range(len(X)):
#     better_noise += [np.nanmedian(rolling_std(X[i,:,0].numpy(), width=10))]
# plt.scatter(noise, better_noise)

# # dataset_temp = TessDataset(train_path)


# # x, y, m, i = dataset_temp[np.random.randint(len(dataset_temp))]
# # plt.plot((x - np.nanmedian(x))/np.nanmedian(x))
# pass

#### Attention visu dev


In [None]:

# # attention_maps = lit_model.get_attention_maps(X.cuda(), mask=M.cuda())
# # ar = compute_rollout_attention(attention_maps)



# plt.show()
# plt.plot(ar)#.sum(0))

# plt.show()
# for l in range(len(attention_maps)):
#     plt.plot(attention_maps[l][i].mean(0).detach().cpu())

# chunk = range(290, 295)
# attention_maps[0][i, chunk].shape
# plt.show()
# for l in range(len(attention_maps)):
#     plt.plot(attention_maps[l][i, :, chunk].mean(-1).cpu().detach())
# plt.show()
# for l in range(len(attention_maps)):
#     plt.plot(attention_maps[l][i, chunk].mean(0).cpu().detach())

In [None]:
# from utils.stats import estimate_noise

# X,_,_,_ = next(iter(train_loader))

# noise = estimate_noise(X, reduce='nanmedian')
# # plt.hist(noise[9].flatten().numpy(), 50)

# print(torch.isnan(X/noise).sum(), torch.isnan(X).sum())
# # plt.plot((X/np.sqrt(noise))[:,:,0].T)