In [None]:
import os


import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import torch
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl 

GPUS = int(torch.cuda.is_available())
# torch.cuda.empty_cache() 
def print_cuda_summary():
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    f = r-a  # free inside reserved

    print("torch.cuda.get_device_properties(0).total_memory %fGB"%(t/1024/1024/1024))
    print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
    print("Free (res - alloc) %fGB"%(f/1024/1024/1024))

if GPUS:
    print_cuda_summary()

In [None]:
max_samples = 1000
val_ratio = 0.1
batch_size = 64


seed = 0
standardise_axes = (0, 1)  # per sample standardisation

num_workers = 0
pin_memory = True

if num_workers > 0:
    import cv2
    cv2.setNumThreads(0) 

In [None]:
torch.cuda.empty_cache() 

### TESS DATA

In [None]:
if GPUS:
    path = "/state/partition1/mmorvan/data/TESS/lightcurves/0001"
else:
    path = "/Users/mario/data/TESS/lightcurves/0027"

train_path = os.path.join(path, 'processed_train')
test_path = os.path.join(path, 'processed_test')

In [None]:
# # split generation
# import os
# from datasets import TessDataset

# test_ratio = 0.2
    
# dataset = TessDataset(path, 
#                       processed=True, 
#                       save=False
#                       )

# #dataset.n_dim = 1
# # TRAIN/VAL SPLIT
# test_size = int(test_ratio * len(dataset))
# train_size = len(dataset) - test_size
# print(train_size, test_size)
# train_dataset, test_dataset = random_split(dataset, 
#                                           (train_size, test_size),
#                                            generator=torch.Generator().manual_seed(seed))

# for idx in train_dataset.indices:
#     dataset.save_item(idx, train_path)
# for idx in test_dataset.indices:
#     dataset.save_item(idx, test_path)
# #train_dataset.indices

# %%bash
# ls "/state/partition1/mmorvan/data/TESS/lightcurves/0001/processed_test" 

In [None]:
# from datasets import TessDataset

# test_dataset = TessDataset(test_path, load_processed=True)
# train_dataset = TessDataset(train_path, load_processed=True)
# len(test_dataset), len(train_dataset)

In [None]:

from datasets.kepler_tess import TessDataset, Subset, split_indices

from transforms import  Compose,StandardScaler, AddGaussianNoise, Mask, FillNans, RandomCrop, DownSample


    
transform_both_train = Compose([RandomCrop(800, exclude_missing_threshold=0.8),
                          DownSample(2),
                          Mask(0.3, block_len=None, value=None, exclude_mask=True),
                          StandardScaler(dim=0),
                          #FillNans(0),
                         ])

transform_both_2 = Compose([RandomCrop(800, exclude_missing_threshold=0.8),
                               DownSample(2),
#                                Mask(0.3, block_len=None, value=None, exclude_mask=True),
                               StandardScaler(dim=0),
                          #FillNans(0),
                         ])



transform = None

if GPUS:
    path = "/state/partition1/mmorvan/data/TESS/lightcurves/0001"
else:
    path = "/Users/mario/data/TESS/lightcurves/0027"

dataset = TessDataset(train_path, 
                      load_processed=True, 
                      max_samples=max_samples,
                      transform=transform,
                      transform_both=transform_both_train,
                      use_cache=True,
                      )
test_dataset = TessDataset(test_path, 
                           load_processed=True, 
                           use_cache=True,
                          )


#dataset.n_dim = 1
# TRAIN/VAL SPLIT
val_size = int(val_ratio * len(dataset))
train_size = len(dataset) - val_size
train_indices, val_indices = split_indices((train_size, val_size),
                                           generator=torch.Generator().manual_seed(seed))
train_dataset = Subset(dataset, train_indices)

val_dataset1 = Subset(dataset, val_indices)
val_dataset2 = Subset(dataset, val_indices, replace_transform_both=transform_both_2)

test_dataset1 = Subset(test_dataset1, range(len(test_dataset1)), replace_transform_both=transform_both_train)
test_dataset2 = Subset(test_dataset1, range(len(test_dataset1)), replace_transform_both=transform_both_2)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                          num_workers=num_workers, pin_memory=pin_memory)
val_loader1 = DataLoader(val_dataset1, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)
val_loader2 = DataLoader(val_dataset2, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)
test_loader1 = DataLoader(test_dataset1, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)
test_loader2 = DataLoader(test_dataset2, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)



In [None]:
from torch import nn

class CollatePred(object):
    # May need to pad outside to ensure the totatlity is contained in output :-) 
    def __init__(self, window, step=1, standardise=None):
        self.window = window
        self.step = step
        self.scaler = StandardScaler(dim=1)
        

        
    def __call__(self, batch):
        x_out_list = []
        y_out_list = []
        m_out_list = []
        info_out_list = []

        seq_len = batch[0][0].shape[0]
        d = (seq_len - window) / step + 1
        if int(d) == d:
            padding = 0
        else:
            padding = step
            pad = nn.ConstantPad1d((0, padding), value=np.nan)
                
        for i in range(len(batch)):
            x, y, m, info = batch[i]
            
            x_out = torch.tensor(x)
            y_out = torch.tensor(y)
            m_out = torch.tensor(m)
            
            if padding:
                x_out = pad(x_out.T).T
                y_out = pad(y_out.T).T
                m_out = pad(m_out.T).T
            
            x_out = x_out.unfold(0, size=self.window, step=self.step).transpose(1,2)
            y_out = y_out.unfold(0, size=self.window, step=self.step).transpose(1,2)
            m_out = m_out.unfold(0, size=self.window, step=self.step).transpose(1,2)
            
            info_out = {k: torch.tensor([v]*len(x_out)) for k,v in info.items()}
            
            x_out = self.scaler.fit_transform(x_out)
            y_out = self.scaler.transform(y_out)
            info_out['left_crop'] = torch.arange(0, seq_len, step)
            info_out['mu'] = self.scaler.centers
            info_out['sigma'] = self.scaler.norms
            
            x_out_list += [x_out]
            y_out_list += [y_out]
            m_out_list += [m_out]
            info_out_list += [info_out]
        return (torch.cat(x_out_list),
                torch.cat(y_out_list),
                torch.cat(m_out_list),
                {k: torch.cat([info_out_list[i][k] for i in range(len(batch))]) for k in info_out})

loader_pred = DataLoader(test_dataset_predict, 
                        batch_size=1, 
                        shuffle=False, 
                        collate_fn=CollatePred(400, step=350),
                        num_workers=num_workers, pin_memory=pin_memory)
            
X, Y, M, I = next(iter(loader_pred))  
X.shape, Y.shape, M.shape
# X.dtype, Y.dtype, M.dtype

In [None]:
x.T.shape

In [None]:
Y[:,25:-25].flatten().shape

In [None]:
## Stitching back
# X_o = inverse_standardise_batch(X, I['mu'], I['sigma'])

# plt.figure(figsize=(25,5))
# for i in range(len(X)):
#     x = X_o[i]
#     plt.plot(range(I['left_crop'][i], I['left_crop'][i]+400),
#              x)

In [None]:
x, y, m, meta = dataset[0]

plt.plot(x)
plt.plot(y)

## Noise Study

In [None]:
# from utils import nanstd
# X,_,_,_ = next(iter(train_loader))
# B, L, D = X.shape


# window = 10
# n_windows = L // window
# X.view(B, n_windows, window).shape

# noise = nanstd(X.view(B, n_windows, window), -1, keepdim=True).nanmedian(1, keepdim=True).values

# # Samples highlighted by non-white-noise contribution to the variance
# plt.figure(figsize=(30,20))
# for i in range(len(X)):
#     plt.plot(X[i,:,0], alpha=1-noise[i,0,0].item()**1.5, lw=1/noise[i,0,0])
# plt.title('Batch highlighted by inverse noise')
# plt.legend()
# plt.show()

# # Distribution of noise estimates
# plt.hist(noise[:,0,0].numpy(), 50)
# plt.title('White Noise Level')
# plt.show()

# # After correction
# plt.figure(figsize=(30,20))
# plt.plot((X / noise)[:,:,0].T)
# plt.show()

# # Check of better noise estimate

# def rolling_std(x, width=10):
#     return pd.Series(x).rolling(width, center=True, min_periods=1).std().values

# better_noise = []
# for i in range(len(X)):
#     better_noise += [np.nanmedian(rolling_std(X[i,:,0].numpy(), width=10))]
# plt.scatter(noise, better_noise)

# # dataset_temp = TessDataset(train_path)


# # x, y, m, i = dataset_temp[np.random.randint(len(dataset_temp))]
# # plt.plot((x - np.nanmedian(x))/np.nanmedian(x))
# pass

In [None]:
# from utils.stats import estimate_noise

# X,_,_,_ = next(iter(train_loader))

# noise = estimate_noise(X, reduce='nanmedian')
# # plt.hist(noise[9].flatten().numpy(), 50)

# print(torch.isnan(X/noise).sum(), torch.isnan(X).sum())
# # plt.plot((X/np.sqrt(noise))[:,:,0].T)

### Dummy Data

In [None]:
# from datasets import DummyDataset

# dataset = DummyDataset(100)


In [None]:
# x, y, m, meta = dataset[0]

# plt.plot(x)
# plt.plot(y)

### Model

In [None]:
# ### TEST RUN
# lit_model = LitImputer(1, noise_scaling=True)
# trainer = pl.Trainer(max_epochs=10, gpus=GPUS)
# result = trainer.fit(lit_model, train_dataloaders=train_loader)

In [None]:
# lit_model = LitImputer(1, noise_scaling=True)
# from utils.stats import estimate_noise
# from models.loss import MaskedMSELoss
# criterion = MaskedMSELoss()

# for X, Y, M, Info in train_loader:
#     Y_pred = lit_model(X)
#     noise = torch.sqrt(estimate_noise(Y))
#     assert not noise.requires_grad
#     assert not torch.isclose(noise, torch.zeros_like(noise), equal_nan=True).any()
#     noise[torch.isnan(noise)] = 1.

#     loss = criterion(Y_pred, Y, M)
#     assert not torch.isnan(loss)
    
#     pred_proxy = Y_pred / noise
#     Y_proxy = Y/noise
# #     torch.isnan(Y_pred).sum(), torch.isnan(Y).sum(), torch.isnan(pred_proxy).sum(), torch.isnan(Y_proxy).sum()
#     assert not torch.isinf(pred_proxy).any()
#     loss_scaled = criterion(pred_proxy, Y_proxy, M)
    
#     assert not torch.isnan(loss_scaled)

In [None]:
# ### TEST volume
# from pytorch_lightning.loggers import NeptuneLogger

# logger = NeptuneLogger(project="denoising-transformer",
#                            name='test')
# torch.manual_seed(0)


# lit_model = LitImputer(n_dim=1, d_model=8, dim_feedforward=16, num_layers=1,
# #                        attention='linear', seq_len=400,
#                        random_ratio=1, zero_ratio=0., keep_ratio=0.,token_ratio=0, 
#                        #, token_ratio=0.8, 
#                        noise_scaling="true",
# #                        attention='linear', seq_len=400
#                       )

# trainer = pl.Trainer(max_epochs=10000, 
#                      gpus=GPUS,
#                      logger=logger,
#                      check_val_every_n_epoch=1)

# result = trainer.fit(lit_model, 
#                      train_dataloaders=train_loader,
#                      val_dataloaders=val_loader, 
#                      )


### Define Model

In [None]:
from models import LitImputer
torch.manual_seed(0)
if GPUS:
    print_cuda_summary()

In [None]:
lit_model = LitImputer(n_dim=1, d_model=64, dim_feedforward=128, lr=0.001,
#                        attention='linear', seq_len=400,
                       train_unit = 'noise', train_loss='mae'
                       #, token_ratio=0.8, 
                       #noise_scaling="true",
#                        attention='linear', seq_len=400
                      )

In [None]:
# X, Y, M, I = next(iter(train_loader))

# X_masked, token_mask = lit_model.apply_mask(X, M)
# X_masked[token_mask.bool()] = np.nan
# # i = np.random.randint(len(X_masked))
# plt.figure(figsize=(12,6))
# #plt.scatter(range(400),X[i].detach(), s=50, alpha=0.6)
# plt.scatter(range(400), X_masked[i].detach(), s=5, color='red')
# plt.fill_between(range(len(M[i])), -3, 3, where=M[i].flatten(), alpha=0.3)
# plt.fill_between(range(len(M[i])), -3, 3, where=token_mask[i].flatten())


### Training

In [None]:
import datetime
from pytorch_lightning.loggers import NeptuneLogger
logger = NeptuneLogger(project="denoising-transformer",
                        name='tess_denoising',
                       tags=[str(len(dataset))+' samples',
                             #"continued_from_den-186"
                             #"large-training-set",
                             "test fix Masked Loss",
                             "train - " + lit_model.train_unit,
                             #"star-scaled"
                        #      "Noise2Noise",
                        #      "Imputation",
                        #      "cropped-800",
                        #      "downsampled-2",
                        #      'linformer'
                             #"replaced_normal",
                             #"Crop-400", 
#                              f"patch-{patch_size}",
#                              "TPT"
                            ])

In [None]:
# from pytorch_lightning import seed_everything
# seed_everything(1)

trainer = pl.Trainer(max_epochs=5, 
                     logger=logger, 
                     gpus=GPUS,
                     check_val_every_n_epoch=1)

result = trainer.fit(lit_model, 
                     train_dataloaders=train_loader,
                     val_dataloaders=[val_loader1, val_loader2], 
                     )



In [None]:
result = trainer.fit(lit_model, 
                     train_dataloaders=train_loader,
                     val_dataloaders=val_loader, 
                     )

In [None]:
trainer.test(lit_model, 
             dataloaders=test_loader
                     )

### Loading

In [None]:
%%bash
find . -mindepth 3 -maxdepth 3 -type d -name "DEN-293"

In [None]:
# # # ckpt_path = "/home/mmorvan/denoising-ts-transformer/.neptune/Tess-denoising17-02-2022_01-34-35/DEN-167/checkpoints/epoch=382-step=4992.ckpt"
# # ckpt_path = "./.neptune/Tess-denoising_17-02-2022_02-19-37/DEN-170/checkpoints/epoch=1032-step=15494.ckpt"
# ckpt_path = "./.neptune/Tess-denoising17-02-2022_11-11-26/DEN-186/checkpoints/epoch=601-step=15062.ckpt"
# ckpt_path = "./.neptune/Tess-denoising17-02-2022_11-41-36/DEN-189/checkpoints/epoch=636-step=15924.ckpt"
# # ckpt_path = "./.neptune/tess_denoising/DEN-219/checkpoints/epoch=2999-step=86999.ckpt"
# ckpt_path = './.neptune/tess_denoising/DEN-245/checkpoints/epoch=1382-step=20744.ckpt'
#ckpt_path = "./.neptune/tess_denoising/DEN-258/checkpoints/epoch=2999-step=44999.ckpt"
# ckpt_path = './.neptune/tess_denoising/DEN-264/checkpoints/epoch=2999-step=44999.ckpt'
ckpt_path = "./.neptune/tess_denoising/DEN-272/checkpoints/epoch=4999-step=74999.ckpt"  # MAE
#ckpt_path = "./.neptune/tess_denoising/DEN-289/checkpoints/epoch=4999-step=74999.ckpt"
#ckpt_path = "./.neptune/tess_denoising/DEN-294/checkpoints/epoch=4999-step=74999.ckpt"
#ckpt_path = "./.neptune/tess_denoising/DEN-295/checkpoints/epoch=1696-step=25454.ckpt"
#ckpt_path = "./.neptune/tess_denoising/DEN-296/checkpoints/epoch=4300-step=64514.ckpt"
ckpt_path = "./.neptune/tess_denoising/DEN-299/checkpoints/epoch=4999-step=74999.ckpt"
lit_model = lit_model.load_from_checkpoint(ckpt_path)

# Analysis

In [None]:
print_cuda_summary()
torch.cuda.empty_cache()

In [None]:
from utils.stats import estimate_noise
from utils.postprocessing import compute_rollout_attention
lit_model.eval().cuda()

# del X, Y, M, I, AR
with torch.no_grad():
    X, Y, M, I = next(iter(loader_pred))
    Y_pred = lit_model(X.cuda()).detach().cpu().numpy()
    noise = estimate_noise(Y)
    attention_maps = lit_model.get_attention_maps(X.cuda(), mask=M.cuda())
    AR = compute_rollout_attention(attention_maps)
(noise <= 0.5).sum()

In [None]:
### # plot diagnostic
from utils.postprocessing import plot_pred_diagnostic

#### All indices in their order
indices = range(len(X))
step = 1

##### Ordered selection wrt to noise contribution
# n_plots = 8 
# indices = np.argsort(noise.squeeze().detach().cpu())
# step = len(indices)//n_plots
# step = 2
# step = 1


indices = indices[::step]

# i = np.argmin(noise.numpy().squeeze())
# #i = np.argmax(estimate_noise(Y).numpy().squeeze())
# i = np.random.randint(len(X))

for i in indices:
    x = X[i,:,0].detach().cpu().numpy()
    y = Y[i,:,0].detach().cpu().numpy()
    mask = M[i,:,0].detach().cpu().numpy()
    info = {k:v[i].detach().cpu().item() for k,v in I.items()}
    y_pred = Y_pred[i,:,0]
    ar = AR[i].sum(1).detach().cpu().numpy()
    plot_pred_diagnostic(x, y, y_pred, mask=mask, ar=ar, targetid=info['targetid'], mu=info['mu'], sigma=info['sigma'])
    plt.show()


In [None]:
# ### Attention visu dev

# # attention_maps = lit_model.get_attention_maps(X.cuda(), mask=M.cuda())
# # ar = compute_rollout_attention(attention_maps)



# plt.show()
# plt.plot(ar)#.sum(0))

# plt.show()
# for l in range(len(attention_maps)):
#     plt.plot(attention_maps[l][i].mean(0).detach().cpu())

# chunk = range(290, 295)
# attention_maps[0][i, chunk].shape
# plt.show()
# for l in range(len(attention_maps)):
#     plt.plot(attention_maps[l][i, :, chunk].mean(-1).cpu().detach())
# plt.show()
# for l in range(len(attention_maps)):
#     plt.plot(attention_maps[l][i, chunk].mean(0).cpu().detach())

### Plotting examples of attention
Appendix

In [None]:
# Plot only attention - pred

ncols = 4
nrows = 4
n = ncols * nrows
indices = np.argsort(noise.squeeze().detach().cpu())
step = len(indices)//n
step = 2
indices = indices[::step]

fig, ax = plt.subplots(nrows, ncols, figsize=(10,5), sharex=True)
fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axes
#fig.tight_layout(w_pad=-1)
#plt.tight_layout(pad=0., w_pad=0.5, h_pad=1.0)
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
plt.grid(False)
plt.ylabel('standardised flux')
plt.xlabel('time steps')

k = 0
for row in range(nrows):
    for col in range(ncols):
        i = indices[k]
        
        x = X[i,:,0].detach().cpu().numpy()
        y = Y[i,:,0].detach().cpu().numpy()
        mask = M[i,:,0].detach().cpu().numpy()
        info = {k:v[i].detach().cpu().item() for k,v in I.items()}
        y_pred = Y_pred[i,:,0]
        ar = AR[i].sum(1).detach().cpu().numpy()
        
#         ax[row, col].set_title('Prediction')
#         ax[row, col].set_ylabel('stand. flux')

        ma, Ma = np.min(ar), np.max(ar)
        alpha = (ar-ma)/(Ma-ma)/1.002 + ma + 1e-5
        s = (((ar-ma)/(Ma-ma)+ma)) * 20 + 1
        
        ax[row, col].scatter(range(len(x)), y, label='input',
                         color='black', s=s, alpha=alpha)

        ax[row, col].plot(y_pred, label='pred', color='red', lw=1, alpha=0.9)
        ax[row, col].set_yticks([])
        k += 1

# hide tick and tick label of the big axes
#plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)


#ax[2,0].set_ylabel('standardised flux')
#ax[0,0].set_xlabel('time steps')
        
import datetime
date = datetime.datetime.now()
plt.savefig(f'experiments/outputs/attention_preds_{date}.pdf', format="pdf", bbox_inches="tight")
plt.show()

# Baselines and metrics

### Batch eval

In [None]:
from utils.stats import estimate_noise
from utils.postprocessing import compute_rollout_attention
lit_model.eval().cuda()

# del X, Y, M, I, AR
with torch.no_grad():
    X, Y, M, I = next(iter(loader_pred))
    Y_pred = lit_model(X.cuda()).cpu()
    noise = estimate_noise(Y)
    attention_maps = lit_model.get_attention_maps(X.cuda(), mask=M.cuda())
    AR = compute_rollout_attention(attention_maps)
(noise <= 0.5).sum()

In [None]:
from utils.postprocessing import inverse_standardise_batch, detrend
from utils.stats import naniqr, compute_dw

Y_o = inverse_standardise_batch(Y, I['mu'], I['sigma'])
Y_pred_o = inverse_standardise_batch(Y_pred, I['mu'], I['sigma'])
Y_d = detrend(Y_o, Y_pred_o).numpy()

iqr_dtst = naniqr(Y_d, dim=1)
dw_dtst = compute_dw(Y_d-1, reduce='none')
iqr_dtst.mean(), np.median(iqr_dtst), np.abs(2-dw_dtst).mean()

In [None]:
# compute_dw(Y_d[0][~np.isnan(Y_d[0])]-1, axis=0), compute_dw(Y_d[0]-1, axis=0), 

In [None]:
import wotan

def predict_batch_wotan(y, cadence=1/24, method='biweight', **kwargs):
    # y in flux units!
    y = y.detach().cpu().squeeze()
    batch_size, len_seq = y.shape
    
    time = np.arange(len_seq) * cadence
    list_flat = []
    list_trend = []
    for i in range(batch_size):
        flattened_y, trend_y = wotan.flatten(time, y[i], method=method, return_trend=True, **kwargs)
        list_flat += [flattened_y]
        list_trend += [trend_y]
        
    return np.stack(list_flat), np.stack(list_trend)

for window in [0.2, 0.3, 0.4, 0.5, 0.6, 0.7]:
    print('\t', window)
    Y_wotan_d, _ = predict_batch_wotan(Y_o, window_length=window)
    iqr_wotan = naniqr(Y_wotan_d, dim=1)
    dw_wotan = compute_dw(Y_wotan_d-1, reduce='none')
    print(iqr_wotan.mean(), np.median(iqr_wotan),  np.abs(2-dw_wotan).mean())

In [None]:
plt.hist((iqr_wotan - iqr_dtst.squeeze()), 100)
plt.show()
plt.hist((np.abs(2-dw_wotan) - np.abs(2-dw_dtst.squeeze())), 100)
pass

In [None]:
Y_wotan_d, trend_wotan = predict_batch_wotan(Y_o, window_length=0.2)

i = np.random.randint(len(Y))
plt.plot(Y_o[i])
plt.plot(trend_wotan[i])

plot_pred_diagnostic(Y_o[i], Y_o[i], trend_wotan[i])

In [None]:
plt.hist(iqr_dtst.flatten(), 50, range=(0, 0.05))
plt.hist(iqr_wotan.flatten(), 50, range=(0, 0.05))

pass

In [None]:


for i in indices:
    x = X[i,:,0].detach().cpu().numpy()
    y = Y[i,:,0].detach().cpu().numpy()
    mask = M[i,:,0].detach().cpu().numpy()
    info = {k:v[i].detach().cpu().item() for k,v in I.items()}
    y_pred = Y_pred[i,:,0]

    y_o = inverse_standardise_batch(y, info['mu'], info['sigma'])
    y_pred_o = inverse_standardise_batch(y_pred, info['mu'], info['sigma'])
    y_d = detrend(y_o, y_pred_o)
    
    time = np.arange(len(y_d)) / 24
    y_d_biweight = wotan.flatten(time, y_o, method='biweight', return_trend=False)
    y_d_medfilt = wotan.flatten(time, y_o, window_length=49, method='medfilt', return_trend=False)
    
    print(f'IQR(biweight) : {naniqr(y_d_biweight):.5f}')
    print(f'IQR(medfilt) : {naniqr(y_d_medfilt):.5f}')
    print(f'IQR(dtst) : {naniqr(y_d):.5f}')
    print()

### Full-LC predictions

In [None]:
with torch.no_grad():
    X, Y, M, I = next(iter(loader_pred))
    Y_pred = lit_model(X.cuda()).cpu()
    Y_o = inverse_standardise_batch(Y, I['mu'], I['sigma'])
    Y_pred_o = inverse_standardise_batch(Y_pred, I['mu'], I['sigma'])
    Y_d = detrend(Y_o, Y_pred_o).numpy()


In [None]:
# access original TS non transformed 
idx = I['idx'][0]
Y_intact = test_dataset.get_pretransformed_sample(idx)


In [None]:
def fold_back(x, skip=0, seq_len=None):
    # Assumes no skip at the start
    if skip == 0:
        out = x.flatten()
    else:
        out = [x[0,:skip].flatten(), x[:,skip:-skip].flatten()]
        if isinstance(x, torch.Tensor):
            out = torch.cat(out)
        elif isinstance(x, np.ndarray):
            out = np.concatenate(out)
    if seq_len is not None:
        out = out[:seq_len]
    return out

In [None]:
# check same as inverse transfo 
skip = 25
seq_len = len(Y_intact)
Y_of = fold_back(Y_o, skip, seq_len=seq_len)
assert np.isclose(Y_intact.flatten(), Y_of.flatten(), equal_nan=True).all()

In [None]:
# Show full pred with wotan pred too

Y_pred_of = fold_back(Y_pred_o, skip, seq_len=seq_len)

Y_wotan_d, trend_wotan_of = wotan.flatten(np.arange(seq_len) / 48,
                                          Y_of,
                                          window_length=0.5,
                                          return_trend = True
                                          )
plt.figure(figsize=(25,10))
plt.plot(Y_of)
plt.plot(Y_pred_of)
plt.plot(trend_wotan_of)
# plt.plot(Y_of/Y_pred_of)
#plt.xlim(3000,7000)

In [None]:
# metrics
#naniqr(Y_wotan_d), compute_dw(Y_wotan_d-1, 0)
#naniqr(Y_of/Y_pred_of), compute_dw((Y_of- Y_pred_of).numpy(), 0),

In [None]:
# Produce full predictions
from tqdm import tqdm
from datetime import datetime

d_pred = {'dtst': [],
          'biweight_0.5': [],
          'medfilt': [],
            }
d_time = {'dtst': [],
          'biweight_0.5': [],
          'medfilt': [],
            }

k = 0
for X, Y, M, I in tqdm(loader_pred):
    k+=1
#     if k> 20:
#         break
    # access original TS non transformed 
    idx = I['idx'][0]
    Y_intact = test_dataset.get_pretransformed_sample(idx).squeeze()
    seq_len = len(Y_intact)
    time = np.arange(seq_len) / 48  # Get the actual time vector maybe
    
    with torch.no_grad():
        
        t0 = datetime.now()
        Y_pred = lit_model(X.cuda()).cpu()
        Y_pred_o = inverse_standardise_batch(Y_pred, I['mu'], I['sigma'])
        Y_pred_of = fold_back(Y_pred_o, skip=25, seq_len=seq_len)
        d_pred['dtst'] += [Y_pred_of]
        d_time['dtst'] += [datetime.now()-t0]
        
        t0 = datetime.now()
        d_pred['biweight_0.5'] += [wotan.flatten(time, Y_intact, window_length=0.5, 
                                                 method='biweight', return_trend = True)[1]]
        d_time['biweight_0.5'] += [datetime.now()-t0]

        t0 = datetime.now()
        d_pred['medfilt'] += [wotan.flatten(time, Y_intact, window_length=49, 
                                            method='medfilt', return_trend = True)[1]]
        d_time['medfilt'] += [datetime.now() - t0]

for model in d_pred:
    if len(d_pred[model]):
        d_pred[model] = np.vstack(d_pred[model])
        
for model in d_time:
    if len(d_time[model]):
        d_time[model] = np.mean(d_time[model])

In [None]:
d_time

In [None]:
# Compute metrics

target_test = np.vstack([test_dataset.get_pretransformed_sample(idx).squeeze() for idx in range(len(test_dataset))])
iqr = dict()
dw = dict()

for model in ['dtst', 'biweight_0.5', 'medfilt']:
    pred_d = target_test / d_pred[model]
    iqr[model] = naniqr(pred_d, dim=1, reduction='mean')
    dw[model] = np.abs(compute_dw(pred_d-1, axis=1, reduce='none')-2).mean()

In [None]:
# Quick median filter
dim = 1
window_size = 50
step= 1
from utils.stats import nanstd
torch.nanmedian(X.unfold(dim, window_size, step),-1).values.shape

In [None]:
def median_filter(x, width=35):
    return pd.Series(x).rolling(width, center=True, min_periods=1).median().values

median_filter()

In [None]:
from models.loss import 

def median_filter(x, width=35):
    return pd.Series(x).rolling(width, center=True, min_periods=1).median().values

def mean_filter(x, width=10):
    return pd.Series(x).rolling(width, center=True, min_periods=1).mean().values


pred_mean = dict()
res_mean = dict()
res_median = dict()

pred_median = dict()
for w in (10, 25, 40, 55):
    pred_mean[w] = np.vstack([mean_filter(x[:,0], w) for x in X])
    res_mean[w] = pred_mean[w] - Y[i,:,0].detach().numpy()
    pred_median[w] = np.vstack([median_filter(x[:,0], w) for x in X])
    res_median[w] = pred_median[w] - Y[i,:,0].detach().numpy()

    # i = np.random.randint(len(Y))
    # plt.plot(X[i])
    # plt.plot(Y[i])
    # plt.plot(pred[i])

    mse_median = MaskedMSELoss()(torch.tensor(pred_median[w]), Y[:,:,0]).item()
    mse_mean = MaskedMSELoss()(torch.tensor(pred_mean[w]), Y[:,:,0]).item()
    print(f'\twindow = {w}')
    print(f'median filter : {mse_median:.4f}')
    print(f'mean filter : {mse_mean:.4f}')

mse_tst = MaskedMSELoss()(torch.tensor(pred2[:,:,0], device="cpu"), Y[:,:,0]).item()
print(f'\nTransformer : {mse_tst:.4f}')

In [None]:
plt.figure(figsize=(13,6))
plt.scatter(range(len(target)), target, label='target', color='green', alpha=0.7, s=5)
#plt.plot(pred, label='Transformer', color='red', alpha=0.7)
for w in pred_mean:
    plt.plot(pred_mean[w][i], label=f'mean filter-{w}', alpha=0.7)

plt.plot(pred, label='Transformer', lw=2)
plt.legend()
plt.show()

plt.figure(figsize=(13,6))
for w in [40]:
    plt.scatter(range(len(pred_mean[w][i])), 
                pred_mean[w][i]-Y[i,:,j].detach().numpy(), label=f'mean filter-{w}', alpha=0.7)
#     plt.scatter(range(len(res)), res, color='blue')

plt.scatter(range(len(res)), res, color='red')

plt.scatter(range(len(res)), res_m, marker="s", color='red', label='transfo')
#     plt.scatter(range(len(res)), res_m, marker="s", color='blue')
plt.legend()
plt.show()


plt.hist(res.numpy(), 50, range=(-3,3), label='transformer')
plt.hist(res_mean[w][i], 50, range=(-3,3), label=f'mean filter-{w}', alpha=0.7)
plt.legend()
plt.show()
pass

In [None]:
plt.figure(figsize=(13,6))
plt.scatter(range(len(target)), target, label='target', color='green', alpha=0.7, s=5)
#plt.plot(pred, label='Transformer', color='red', alpha=0.7)
for w in pred_median:
    plt.plot(pred_median[w][i], label=f'median filter-{w}', alpha=0.7)

plt.plot(pred, label='Transformer')
plt.legend()
plt.show()

plt.figure(figsize=(13,6))
for w in [40]:
    plt.scatter(range(len(pred_median[w][i])), 
                pred_median[w][i]-Y[i,:,j].detach().numpy(), label=f'median filter-{w}', alpha=0.7)
#     plt.scatter(range(len(res)), res, color='blue')


plt.scatter(range(len(res)), res_m, marker="s", color='red', label='transfo')
#     plt.scatter(range(len(res)), res_m, marker="s", color='blue')
plt.legend()
plt.show()


plt.hist(res.numpy(), 50, range=(-3,3), label='transformer')
plt.hist(res_median[w][i], 50, range=(-3,3), label=f'median filter-{w}', alpha=0.7)
plt.legend()
plt.show()
pass

In [None]:
pred_mean[w].shape

### Visualise attention

In [None]:
# # size study
list_seq_len = [100, 200,  400, 700, 1000, 2000]
# for L in list_seq_len:
#     torch.cuda.empty_cache()
#     lit_model = LitImputer(n_dim, d_model=64, dim_feedforward=128,  num_layers=3, eye=eye, lr=0.001,
#                        normal_ratio=0.2, keep_ratio=0., token_ratio=0.8, attention='linear', seq_len=L
#                       )
    
#     dataset.transform_both = Compose([RandomCrop(L),
#                                       StandardScaler(dim=0),
#                                      ])
#     train_loader = DataLoader(dataset, batch_size=64, shuffle=True) 

#     trainer = pl.Trainer(max_epochs=3, 
#                          gpus=1, profiler='simple')

#     result = trainer.fit(lit_model, 
#                          train_dataloaders=train_loader,                     
#                          )

# result_full = [0.16705 , 0.24147, 0.51497, np.nan, np.nan, np.nan] #B256
result_full = [0.2 , 0.27, 0.54, 1.2, np.nan, np.nan] #B64
# result_prob = [ 0.30687, 0.33286, 0.38871, 0.54279, 0.68638, 1.1525 ] #B256
result_prob = [0.47, 0.48, 0.53, 0.64, 0.81, 1.3] #B64

###result_lin = [0.25, 0.31, 0.46, 0.65, 0.87, 1.53] #B64
result_lin = [0.21, 0.24, 0.32, 0.42, 0.52, 0.87]
plt.plot(list_seq_len, result_full, marker='o', label='Full attention B64')#, marker='-o')
plt.plot(list_seq_len, result_prob, marker='o', label='Prob attention B64')
plt.plot(list_seq_len, result_lin, marker='o', label='Linformer attention B64')
plt.legend()
plt.xlabel('Sequence length')
plt.ylabel('Training epoch time')
print('batch size', train_loader.batch_size)