In [None]:
import os


import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import torch
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl 

GPUS = int(torch.cuda.is_available())
# torch.cuda.empty_cache() 
def print_cuda_summary():
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    f = r-a  # free inside reserved

    print("torch.cuda.get_device_properties(0).total_memory %fGB"%(t/1024/1024/1024))
    print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
    print("Free (res - alloc) %fGB"%(f/1024/1024/1024))

if GPUS:
    print_cuda_summary()

In [None]:
max_samples = 1000
val_ratio = 0.1
batch_size = 64


seed = 0
standardise_axes = (0, 1)  # per sample standardisation

num_workers = 0
pin_memory = True

if num_workers > 0:
    import cv2
    cv2.setNumThreads(0) 

### TESS DATA

In [None]:
if GPUS:
    path = "/state/partition1/mmorvan/data/TESS/lightcurves/0001"
else:
    path = "/Users/mario/data/TESS/lightcurves/0027"

train_path = os.path.join(path, 'processed_train')
test_path = os.path.join(path, 'processed_test')

In [None]:
# # split generation
# import os
# from datasets import TessDataset

# test_ratio = 0.2
    
# dataset = TessDataset(path, 
#                       processed=True, 
#                       save=False
#                       )

# #dataset.n_dim = 1
# # TRAIN/VAL SPLIT
# test_size = int(test_ratio * len(dataset))
# train_size = len(dataset) - test_size
# print(train_size, test_size)
# train_dataset, test_dataset = random_split(dataset, 
#                                           (train_size, test_size),
#                                            generator=torch.Generator().manual_seed(seed))

# for idx in train_dataset.indices:
#     dataset.save_item(idx, train_path)
# for idx in test_dataset.indices:
#     dataset.save_item(idx, test_path)
# #train_dataset.indices

# %%bash
# ls "/state/partition1/mmorvan/data/TESS/lightcurves/0001/processed_test" 

In [None]:
# from datasets import TessDataset

# test_dataset = TessDataset(test_path, load_processed=True)
# train_dataset = TessDataset(train_path, load_processed=True)
# len(test_dataset), len(train_dataset)

In [None]:

from datasets import TessDataset

from transforms import  Compose,StandardScaler, AddGaussianNoise, Mask, FillNans, RandomCrop, DownSample


    
transform_both_train = Compose([RandomCrop(800, exclude_missing_threshold=0.8),
                          DownSample(2),
                          Mask(0.3, block_len=None, value=None, exclude_mask=True),
                          StandardScaler(dim=0),
                          #FillNans(0),
                         ])

transform_both_test = Compose([RandomCrop(800, exclude_missing_threshold=0.8),
                               DownSample(2),
#                                Mask(0.3, block_len=None, value=None, exclude_mask=True),
                               StandardScaler(dim=0),
                          #FillNans(0),
                         ])


transform = None

if GPUS:
    path = "/state/partition1/mmorvan/data/TESS/lightcurves/0001"
else:
    path = "/Users/mario/data/TESS/lightcurves/0027"

dataset = TessDataset(train_path, 
                      load_processed=True, 
                      max_samples=max_samples,
                      transform=transform,
                      transform_both=transform_both_train,
                      use_cache=True,
                      )
test_dataset = TessDataset(test_path, 
                           load_processed=True, 
                           transform_both=transform_both_test,
                           use_cache=True,
                          )


#dataset.n_dim = 1
# TRAIN/VAL SPLIT
val_size = int(val_ratio * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, 
                                          (train_size, val_size),
                                           generator=torch.Generator().manual_seed(seed))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                          num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, 
                          num_workers=num_workers, pin_memory=pin_memory)


In [None]:
if GPUS:
    print_cuda_summary()

In [None]:
x, y, m, meta = dataset[0]

plt.plot(x)
plt.plot(y)

In [None]:
# for x,y,m, i in train_loader:
#     assert torch.isclose(x,y, equal_nan=True).all()

## Noise Study

In [None]:
# from utils import nanstd
# X,_,_,_ = next(iter(train_loader))
# B, L, D = X.shape


# window = 10
# n_windows = L // window
# X.view(B, n_windows, window).shape

# noise = nanstd(X.view(B, n_windows, window), -1, keepdim=True).nanmedian(1, keepdim=True).values

# # Samples highlighted by non-white-noise contribution to the variance
# plt.figure(figsize=(30,20))
# for i in range(len(X)):
#     plt.plot(X[i,:,0], alpha=1-noise[i,0,0].item()**1.5, lw=1/noise[i,0,0])
# plt.title('Batch highlighted by inverse noise')
# plt.legend()
# plt.show()

# # Distribution of noise estimates
# plt.hist(noise[:,0,0].numpy(), 50)
# plt.title('White Noise Level')
# plt.show()

# # After correction
# plt.figure(figsize=(30,20))
# plt.plot((X / noise)[:,:,0].T)
# plt.show()

# # Check of better noise estimate

# def rolling_std(x, width=10):
#     return pd.Series(x).rolling(width, center=True, min_periods=1).std().values

# better_noise = []
# for i in range(len(X)):
#     better_noise += [np.nanmedian(rolling_std(X[i,:,0].numpy(), width=10))]
# plt.scatter(noise, better_noise)

# # dataset_temp = TessDataset(train_path)


# # x, y, m, i = dataset_temp[np.random.randint(len(dataset_temp))]
# # plt.plot((x - np.nanmedian(x))/np.nanmedian(x))
# pass

In [None]:
# from utils.stats import estimate_noise

# X,_,_,_ = next(iter(train_loader))

# noise = estimate_noise(X, reduce='nanmedian')
# # plt.hist(noise[9].flatten().numpy(), 50)

# print(torch.isnan(X/noise).sum(), torch.isnan(X).sum())
# # plt.plot((X/np.sqrt(noise))[:,:,0].T)

### Dummy Data

In [None]:
# from datasets import DummyDataset

# dataset = DummyDataset(100)


In [None]:
# x, y, m, meta = dataset[0]

# plt.plot(x)
# plt.plot(y)

### Model

In [None]:
# ### TEST RUN
# lit_model = LitImputer(1, noise_scaling=True)
# trainer = pl.Trainer(max_epochs=10, gpus=GPUS)
# result = trainer.fit(lit_model, train_dataloaders=train_loader)

In [None]:
# lit_model = LitImputer(1, noise_scaling=True)
# from utils.stats import estimate_noise
# from models.loss import MaskedMSELoss
# criterion = MaskedMSELoss()

# for X, Y, M, Info in train_loader:
#     Y_pred = lit_model(X)
#     noise = torch.sqrt(estimate_noise(Y))
#     assert not noise.requires_grad
#     assert not torch.isclose(noise, torch.zeros_like(noise), equal_nan=True).any()
#     noise[torch.isnan(noise)] = 1.

#     loss = criterion(Y_pred, Y, M)
#     assert not torch.isnan(loss)
    
#     pred_proxy = Y_pred / noise
#     Y_proxy = Y/noise
# #     torch.isnan(Y_pred).sum(), torch.isnan(Y).sum(), torch.isnan(pred_proxy).sum(), torch.isnan(Y_proxy).sum()
#     assert not torch.isinf(pred_proxy).any()
#     loss_scaled = criterion(pred_proxy, Y_proxy, M)
    
#     assert not torch.isnan(loss_scaled)

In [None]:
# ### TEST volume
# from pytorch_lightning.loggers import NeptuneLogger

# logger = NeptuneLogger(project="denoising-transformer",
#                            name='test')
# torch.manual_seed(0)


# lit_model = LitImputer(n_dim=1, d_model=8, dim_feedforward=16, num_layers=1,
# #                        attention='linear', seq_len=400,
#                        random_ratio=1, zero_ratio=0., keep_ratio=0.,token_ratio=0, 
#                        #, token_ratio=0.8, 
#                        noise_scaling="true",
# #                        attention='linear', seq_len=400
#                       )

# trainer = pl.Trainer(max_epochs=10000, 
#                      gpus=GPUS,
#                      logger=logger,
#                      check_val_every_n_epoch=1)

# result = trainer.fit(lit_model, 
#                      train_dataloaders=train_loader,
#                      val_dataloaders=val_loader, 
#                      )


### Define Model

In [None]:
from models import LitImputer
torch.manual_seed(0)
if GPUS:
    print_cuda_summary()

In [None]:
lit_model = LitImputer(n_dim=1, d_model=64, dim_feedforward=128, lr=0.001,
#                        attention='linear', seq_len=400,
                       random_ratio=1, zero_ratio=0., keep_ratio=0.,token_ratio=0, 
                       train_unit = 'noise'
                       #, token_ratio=0.8, 
                       #noise_scaling="true",
#                        attention='linear', seq_len=400
                      )

In [None]:
# X, Y, M, I = next(iter(train_loader))

# X_masked, token_mask = lit_model.apply_mask(X, M)
# X_masked[token_mask.bool()] = np.nan
# # i = np.random.randint(len(X_masked))
# plt.figure(figsize=(12,6))
# #plt.scatter(range(400),X[i].detach(), s=50, alpha=0.6)
# plt.scatter(range(400), X_masked[i].detach(), s=5, color='red')
# plt.fill_between(range(len(M[i])), -3, 3, where=M[i].flatten(), alpha=0.3)
# plt.fill_between(range(len(M[i])), -3, 3, where=token_mask[i].flatten())


### Loading

In [None]:
%%bash
find . -mindepth 3 -maxdepth 3 -type d -name "DEN-245"

In [None]:
# # # ckpt_path = "/home/mmorvan/denoising-ts-transformer/.neptune/Tess-denoising17-02-2022_01-34-35/DEN-167/checkpoints/epoch=382-step=4992.ckpt"
# # ckpt_path = "./.neptune/Tess-denoising_17-02-2022_02-19-37/DEN-170/checkpoints/epoch=1032-step=15494.ckpt"
# ckpt_path = "./.neptune/Tess-denoising17-02-2022_11-11-26/DEN-186/checkpoints/epoch=601-step=15062.ckpt"
# ckpt_path = "./.neptune/Tess-denoising17-02-2022_11-41-36/DEN-189/checkpoints/epoch=636-step=15924.ckpt"
# # ckpt_path = "./.neptune/tess_denoising/DEN-219/checkpoints/epoch=2999-step=86999.ckpt"
# ckpt_path = './.neptune/tess_denoising/DEN-245/checkpoints/epoch=1382-step=20744.ckpt'
# lit_model = lit_model.load_from_checkpoint(ckpt_path)

### Training

In [None]:
import datetime
from pytorch_lightning.loggers import NeptuneLogger
logger = NeptuneLogger(project="denoising-transformer",
                        name='tess_denoising',
                       log_model_checkpoints=True,
                       tags=[str(len(dataset))+' samples',
                             #"continued_from_den-186"
                             #"large-training-set",
                             "train - " + lit_model.train_unit,
                             #"star-scaled"
                        #      "Noise2Noise",
                        #      "Imputation",
                        #      "cropped-800",
                        #      "downsampled-2",
                        #      'linformer'
                             #"replaced_normal",
                             #"Crop-400", 
#                              f"patch-{patch_size}",
#                              "TPT"
                            ])

In [None]:
# from pytorch_lightning import seed_everything
# seed_everything(1)

trainer = pl.Trainer(max_epochs=10000, 
                     logger=logger, 
                     gpus=GPUS,
                     check_val_every_n_epoch=1)

result = trainer.fit(lit_model, 
                     train_dataloaders=train_loader,
                     val_dataloaders=val_loader, 
                     )



In [None]:
result = trainer.fit(lit_model, 
                     train_dataloaders=train_loader,
                     val_dataloaders=val_loader, 
                     )

In [None]:
trainer.test(lit_model, 
             dataloaders=test_loader
                     )

In [None]:


# plt.figure(figsize=(13,6))
# plt.scatter(range(len(X[i])), X[i,:,j], label='input', color='blue', s=3, alpha=0.4)

# plt.plot(Y[i,:,j], label='target', color='green')

# plt.plot(pred.cpu().detach()[i,:,j], label='prediction', color='red')
# plt.legend()

In [None]:
from utils.stats import estimate_noise
lit_model.eval().cuda()
with torch.no_grad():
    X, Y, M, I = next(iter(train_loader))
    Y_pred = lit_model(X.cuda()).detach().cpu().numpy()
    noise = estimate_noise(Y)

In [None]:
# plot diagnostic
from utils.postprocessing import plot_pred_diagnostic



i = np.argmin(noise.numpy().squeeze())
#i = np.argmax(estimate_noise(Y).numpy().squeeze())
# i = np.random.randint(len(X))
#j = np.random.randint(n_dim)

x = X[i,:,0].detach().cpu().numpy()
y = Y[i,:,0].detach().cpu().numpy()
mask = M[i,:,0].detach().cpu().numpy()
info = {k:v[i].detach().cpu().item() for k,v in I.items()}
y_pred = Y_pred[i,:,0]

plot_pred_diagnostic(x, y, y_pred, mask=mask, targetid=info['targetid'], mu=info['mu'], sigma=info['sigma'])

In [None]:
# ### PLOT FOR ARTICLE
# res = y - y_pred

# f, ax = plt.subplots(1, 1, figsize=(17, 2))

# # PREDICTION
# ax.scatter(range(len(x)), y, label='input',
#                  color='black', s=20, alpha=0.8)
# if not np.isclose(x, y, equal_nan=True).all():
#     ax.scatter(range(len(x)), y, label='target',
#                      color='green', s=20, alpha=0.8)

# if mask is not None:
#     ymin, ymax = ax.get_ylim()
#     ax.fill_between(range(len(x)), [ymin]*len(x), [ymax]
#                           * len(x), where=mask, alpha=0.9, label='input mask')

# ax.set_xticks([])
# ax.set_yticks([])
# ax.plot(y_pred, label='pred', color='red', alpha=1, lw=2)


In [None]:
torch.cuda.empty_cache()
print_cuda_summary()

In [None]:
# Debugging

In [None]:
from utils.stats import estimate_noise
lit_model.eval().cuda()
# lit_model.train()
X, Y, M, I = next(iter(train_loader))
Y_pred = lit_model(X.cuda())
noise = estimate_noise(Y)

In [None]:
y_o = inverse_standardise_batch(Y, info['mu'], info['sigma'])
pred_o = inverse_standardise_batch(Y_pred, info['mu'], info['sigma'])

# Debugging nans
nans = torch.isnan(y_o)
m = M & ~nans
y_o[nans] = 0.

y_d = detrend(y_o, pred_o.cpu())
#loss = self.criterion(torch.ones_like(y_d), y_d, m)  # x or y

In [None]:
Y_pred

In [None]:
def inverse_standardise_batch(x, mu, sigma):
    return x * sigma + mu

def detrend(x, trend):
    return x / trend

Y_o = inverse_standardise_batch(Y, I['mu'], I['sigma'])
Y_pred_o = inverse_standardise_batch(Y_pred.cpu(), I['mu'], I['sigma'])

Y_d = detrend(Y_o, Y_pred_o).detach()


In [None]:
plt.plot(X[0])
plt.show()
plt.plot(Y[0])
plt.plot(Y_pred[0].cpu().detach().numpy())
plt.show()
plt.plot(Y_o[0])
plt.show()
plt.plot(Y_d[0])


In [None]:
from models.loss import MaskedMSELoss

loss = MaskedMSELoss()(torch.ones_like(Y_d), Y_d, mask=M)
loss, torch.isnan(Y_d).sum(), torch.isnan(Y).sum(), torch.isinf(Y_d).sum()

In [None]:
plt.plot(Y_d.detach().numpy()[:,:,0].T)

In [None]:
from transforms import StandardScaler

Y_pred_o = Y_pred.clone()

Y_o = Y.clone()
for k, i in enumerate(I['idx']):    
    scaler = StandardScaler(dim=0)
    scaler.fit(dataset[i][1]) ### What was used for preproc?? Careful to post transfos
    Y_pred_o[k] = Y_pred[k] * scaler.norms.item() + scaler.centers.item()
    Y_o[k] = scaler.inverse_transform(Y[k])

In [None]:
def inverse_standardise_batch(x, mu, sigma):
    return x * sigma + mu

def detrend(x, trend):
    return x / trend

Y_o = inverse_standardise_batch(Y, I['mu'], I['sigma'])
Y_pred_o = inverse_standardise_batch(Y_pred.cpu(), I['mu'], I['sigma'])

Y_d = detrend(Y_o, Y_pred_o).detach()

In [None]:
Y_o.shape, Y_pred_o.shape

In [None]:
from utils import nanstd

noise2 = nanstd(Y_d, 1).numpy()
plt.hist(noise2, 40)

plt.show()
plt.scatter(noise, noise2)
plt.yscale('log')

In [None]:
i = np.argmin(noise2)

In [None]:
plt.plot(Y[i])

In [None]:
X.shape, I['sigma'].shape

In [None]:
def inverse_standardise_batch(x, mu, sigma):
    return x * sigma + mu

def detrend(x, trend):
    return x / trend

In [None]:
plt.plot(dataset[i][1])

In [None]:
plt.plot(Y_pred_o[:,:,0].cpu().detach().T)

In [None]:
Y_d = (Y_o / Y_pred_o.cpu())  # detrended

plt.plot(Y_d[:,:,0].T.cpu().detach().numpy())

In [None]:
dataset[i][1].shape

In [None]:
scaler.centers.item()

In [None]:
lit_model.eval()#.cuda()
X, Y, M, info = next(iter(val_loader))
#X = X/2

pred = lit_model(X, torch.zeros_like(X, dtype=bool)).detach()
pred2 = pred.clone()
i = np.random.randint(len(X))
j = np.random.randint(n_dim)

res = pred.cpu().detach()[i,:,j]-Y[i,:,j]
res_m = res.clone()
# res_nom = res.clone()

res_m[~M[i,:,j]] = np.nan
# res_nom[M[i,:,j]] = np.nan

target = Y[i,:,j].clone()
target_m = target.clone()
# target_nom = target.clone()
target_m[~M[i,:,j]] = np.nan
# target_nom[M[i,:,j]] = np.nan


pred = pred.cpu().detach()[i,:,j].clone()
pred_m = pred.clone()
# pred_nom = pred.clone()

pred_m[~M[i,:,j]] = np.nan
# pred_nom[M[i,:,j]] = np.nan

plt.figure(figsize=(13,6))

plt.scatter(range(len(X[i])), X[i,:,j], label='input', color='blue', s=3, alpha=0.4)

plt.plot(target, label='target', color='green', alpha=0.7)
#plt.scatter(range(len(res)), target_nom, color='green', alpha=0.7)
plt.scatter(range(len
                  (res)), target_m, marker="s", color='green')
# plt.show()
# plt.figure(figsize=(13,6))

plt.plot(pred, label='pred', color='red', alpha=0.7)
#plt.scatter(range(len(res)), pred_nom, color='red', alpha=0.7)
plt.scatter(range(len(res)), pred_m, marker="s", color='red')
plt.show()
plt.figure(figsize=(13,6))
#plt.scatter(range(len(target)),  X[i,:,j], label='target', color='blue', alpha=0.7)


plt.scatter(range(len(target)), target, label='target', color='green', alpha=0.7, s=5)
plt.plot(pred, label='pred', color='red', alpha=0.7)
plt.show()
plt.figure(figsize=(13,6))

# plt.scatter(range(len(res)), res_nom)
plt.scatter(range(len(res)), res, color='blue')
plt.scatter(range(len(res)), res_m, marker="s", color='blue')

#plt.plot(M[i,:,j])

In [None]:
from models.loss import 

def median_filter(x, width=10):
    return pd.Series(x).rolling(width, center=True, min_periods=1).median().values

def mean_filter(x, width=10):
    return pd.Series(x).rolling(width, center=True, min_periods=1).mean().values


pred_mean = dict()
res_mean = dict()
res_median = dict()

pred_median = dict()
for w in (10, 25, 40, 55):
    pred_mean[w] = np.vstack([mean_filter(x[:,0], w) for x in X])
    res_mean[w] = pred_mean[w] - Y[i,:,0].detach().numpy()
    pred_median[w] = np.vstack([median_filter(x[:,0], w) for x in X])
    res_median[w] = pred_median[w] - Y[i,:,0].detach().numpy()

    # i = np.random.randint(len(Y))
    # plt.plot(X[i])
    # plt.plot(Y[i])
    # plt.plot(pred[i])

    mse_median = MaskedMSELoss()(torch.tensor(pred_median[w]), Y[:,:,0]).item()
    mse_mean = MaskedMSELoss()(torch.tensor(pred_mean[w]), Y[:,:,0]).item()
    print(f'\twindow = {w}')
    print(f'median filter : {mse_median:.4f}')
    print(f'mean filter : {mse_mean:.4f}')

mse_tst = MaskedMSELoss()(torch.tensor(pred2[:,:,0], device="cpu"), Y[:,:,0]).item()
print(f'\nTransformer : {mse_tst:.4f}')

In [None]:
plt.figure(figsize=(13,6))
plt.scatter(range(len(target)), target, label='target', color='green', alpha=0.7, s=5)
#plt.plot(pred, label='Transformer', color='red', alpha=0.7)
for w in pred_mean:
    plt.plot(pred_mean[w][i], label=f'mean filter-{w}', alpha=0.7)

plt.plot(pred, label='Transformer', lw=2)
plt.legend()
plt.show()

plt.figure(figsize=(13,6))
for w in [40]:
    plt.scatter(range(len(pred_mean[w][i])), 
                pred_mean[w][i]-Y[i,:,j].detach().numpy(), label=f'mean filter-{w}', alpha=0.7)
#     plt.scatter(range(len(res)), res, color='blue')

plt.scatter(range(len(res)), res, color='red')

plt.scatter(range(len(res)), res_m, marker="s", color='red', label='transfo')
#     plt.scatter(range(len(res)), res_m, marker="s", color='blue')
plt.legend()
plt.show()


plt.hist(res.numpy(), 50, range=(-3,3), label='transformer')
plt.hist(res_mean[w][i], 50, range=(-3,3), label=f'mean filter-{w}', alpha=0.7)
plt.legend()
plt.show()
pass

In [None]:
plt.figure(figsize=(13,6))
plt.scatter(range(len(target)), target, label='target', color='green', alpha=0.7, s=5)
#plt.plot(pred, label='Transformer', color='red', alpha=0.7)
for w in pred_median:
    plt.plot(pred_median[w][i], label=f'median filter-{w}', alpha=0.7)

plt.plot(pred, label='Transformer')
plt.legend()
plt.show()

plt.figure(figsize=(13,6))
for w in [40]:
    plt.scatter(range(len(pred_median[w][i])), 
                pred_median[w][i]-Y[i,:,j].detach().numpy(), label=f'median filter-{w}', alpha=0.7)
#     plt.scatter(range(len(res)), res, color='blue')


plt.scatter(range(len(res)), res_m, marker="s", color='red', label='transfo')
#     plt.scatter(range(len(res)), res_m, marker="s", color='blue')
plt.legend()
plt.show()


plt.hist(res.numpy(), 50, range=(-3,3), label='transformer')
plt.hist(res_median[w][i], 50, range=(-3,3), label=f'median filter-{w}', alpha=0.7)
plt.legend()
plt.show()
pass

In [None]:
pred_mean[w].shape

### Visualise attention

In [None]:
# # size study
list_seq_len = [100, 200,  400, 700, 1000, 2000]
# for L in list_seq_len:
#     torch.cuda.empty_cache()
#     lit_model = LitImputer(n_dim, d_model=64, dim_feedforward=128,  num_layers=3, eye=eye, lr=0.001,
#                        normal_ratio=0.2, keep_ratio=0., token_ratio=0.8, attention='linear', seq_len=L
#                       )
    
#     dataset.transform_both = Compose([RandomCrop(L),
#                                       StandardScaler(dim=0),
#                                      ])
#     train_loader = DataLoader(dataset, batch_size=64, shuffle=True) 

#     trainer = pl.Trainer(max_epochs=3, 
#                          gpus=1, profiler='simple')

#     result = trainer.fit(lit_model, 
#                          train_dataloaders=train_loader,                     
#                          )

# result_full = [0.16705 , 0.24147, 0.51497, np.nan, np.nan, np.nan] #B256
result_full = [0.2 , 0.27, 0.54, 1.2, np.nan, np.nan] #B64
# result_prob = [ 0.30687, 0.33286, 0.38871, 0.54279, 0.68638, 1.1525 ] #B256
result_prob = [0.47, 0.48, 0.53, 0.64, 0.81, 1.3] #B64

###result_lin = [0.25, 0.31, 0.46, 0.65, 0.87, 1.53] #B64
result_lin = [0.21, 0.24, 0.32, 0.42, 0.52, 0.87]
plt.plot(list_seq_len, result_full, marker='o', label='Full attention B64')#, marker='-o')
plt.plot(list_seq_len, result_prob, marker='o', label='Prob attention B64')
plt.plot(list_seq_len, result_lin, marker='o', label='Linformer attention B64')
plt.legend()
plt.xlabel('Sequence length')
plt.ylabel('Training epoch time')
print('batch size', train_loader.batch_size)