# Visualize EEG Artifacts and Masked Autoencoder's Masking Uncertainty

-----

## Load Packages

In [None]:
# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
# Load some packages
import os
import gc
from copy import deepcopy
import hydra
from omegaconf import OmegaConf
import wandb
import pprint
import numpy as np
import torch
import torchvision
from tqdm.auto import tqdm
from functools import partial

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
from matplotlib.patches import FancyBboxPatch
from matplotlib.gridspec import GridSpec
import scienceplots
import mpl_interactions.ipyplot as iplt
from mpl_interactions import interactive_axvline
from mpl_interactions.controller import Controls
import mpl_interactions

# custom package
from run_train import check_device_env
from run_train import set_seed
from run_train import compose_dataset
from run_train import generate_model
from train.train_script import train_script
from datasets.caueeg_script import EegToTensor, EegDropChannels
from models.utils import count_parameters

In [None]:
# Other settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina' # cleaner text

plt.style.use('default') 
# ['Solarize_Light2', '_classic_test_patch', 'bmh', 'classic', 'dark_background', 'fast', 
#  'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn', 'seaborn-bright', 'seaborn-colorblind', 
#  'seaborn-dark', 'seaborn-dark-palette', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 
#  'seaborn-notebook', 'seaborn-paper', 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 
#  'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'tableau-colorblind10']

plt.rcParams['image.interpolation'] = 'bicubic'
plt.rcParams["font.family"] = 'Roboto Slab' # 'NanumGothic' # for Hangul in Windows
plt.style.use('classic') 
plt.style.use('default') 
plt.style.use('default') # default, ggplot, fivethirtyeight, bmh, dark_background, classic
plt.rcParams.update({'font.size': 14})
plt.rcParams.update({'font.family': 'Roboto Slab'})
plt.rcParams["savefig.dpi"] = 1200

---

## Load and modify the pretrained network

In [None]:
model_name = 'boej8vuk' # boej8vuk yap6fgxc p02vsovi | 3du3h4yl bco01cyz

use_wandb = True
device = 'cuda:0'
model_path = r"E:\CAUEEG\checkpoint"

In [None]:
print('PyTorch version:', torch.__version__)
device = torch.device(device if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available(): print('cuda is available.')
else: print('cuda is unavailable.')

In [None]:
# load pretrained configurations
path = os.path.join(model_path, model_name.split(',')[-1], 'checkpoint.pt')
try:
    ckpt = torch.load(path, map_location=device)
    config = ckpt['config']
except Exception as e:
    print(e)
    print(f'- checkpoint cannot be opened: {path}')
pprint.pprint(config)

In [None]:
# generate the model
config["device"] = device
model = generate_model(config).to(device)

# load model
model.load_state_dict(ckpt["model_state"])
model.requires_grad_(False)
model = model.eval()
model.art_net

---

## Generate the DataLoader

In [None]:
config['task']
config.pop('cwd', 0)
config['ddp'] = False
config['minibatch'] = 1
config['crop_multiple'] = 1
config['test_crop_multiple'] = 1
config['crop_timing_analysis'] = True
config['eval'] = True
config['device'] = device

config["task"] = 'abnormal'  # annotations were written with respect to the CAUEEG-Abnormal task data
train_loader, val_loader, test_loader, _ = compose_dataset(config)
signal_header = [channel.split('-')[0] for i, channel in enumerate(config["signal_header"])]

In [None]:
with torch.no_grad():
    for sample in train_loader:
        print(sample)
        serial = sample["serial"]
        ct = sample["crop_timing"][0]
        break

In [None]:
print("Previous Transform")
print(train_loader.dataset.transform)
print("---" * 10)

# skip the first transform (RandomCrop)
for loader in [train_loader, val_loader, test_loader]:
    loader.dataset.transform = torchvision.transforms.Compose([
        *loader.dataset.transform.transforms[1:]
    ])

print("Modified Transform")
print(val_loader.dataset.transform)

In [None]:
with torch.no_grad():
    for sample in train_loader:
        if sample["serial"] == serial:
            print(sample["signal"][:, :, ct:ct + config["crop_length"]])
            break

In [None]:
# target_serials = []
# for i in range(10):
#     target_serials.append(train_loader.dataset[i]["serial"])

target_serials = [f"{i:05d}" for i in range(2000)]

---

## Compute uncertainty

In [None]:
# interval = 16  # speed control
# results = {}

# with torch.no_grad():
#     for sample in tqdm(train_loader, desc="Data", leave=False):
#         serial = sample["serial"][0]
#         if serial in target_serials:
#             L = sample["signal"][0].shape[-1]
#             count = torch.zeros((L,))
#             score = torch.zeros((L,))

#             for t in tqdm(range(0, L - config["crop_length"], interval), desc="Crops", leave=False):
#                 s = deepcopy(sample)
#                 s["signal"] = s["signal"][:, :, t:t + config["crop_length"]]
#                 config["preprocess_test"](s)    
#                 out = model.forward_artifact(s["signal"], s["age"]).cpu()
                
#                 out = torch.nn.functional.interpolate(out.reshape(1, 1, 1, -1), 
#                                                       size=(1, config["crop_length"], ), mode="nearest")
#                 out = out.squeeze()
#                 count[t:t + config["crop_length"]] += 1
#                 score[t:t + config["crop_length"]] += out
                
#             results[serial] = score / (count + 1e-8)

In [None]:
# path = f'local/output/07_Visualize_MAE_Artifact_TrainingSets_{model_name}.pt'
# torch.save(results, os.path.join(path))

In [None]:
path = f'local/output/07_Visualize_MAE_Artifact_TrainingSets_{model_name}.pt'
results = torch.load(path, map_location='cpu')

---

## Visualize

In [None]:
# with plt.style.context(['ieee', 'science', 'default']):  # science, ieee, default, fivethirtyeight
    # plt.rcParams.update({'font.family': 'Roboto Slab'})

for serial in results.keys():
    fig, ax = plt.subplots(1, 1, figsize=(25, 5), constrained_layout=True)

    r = results[serial].numpy()
    ax.plot(r)
    sample_rate = config["sampling_rate"]

    x_ticks = np.arange(0, r.shape[0], sample_rate * 30)
    x_labels = [f"{round(tick / sample_rate)}" for tick in x_ticks]
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_labels)
    ax.set_xlim(0, r.shape[0])
    ax.set_ylim(0, 1.0)
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Estimated Uncertainty')
    fig.suptitle(serial, fontsize=13, fontweight='semibold')
    plt.show()
    fig.clear()
    plt.close(fig)

In [None]:
%matplotlib ipympl
target_serial = target_serials[0]
duration = 4000

def f1(signal, start, duration):
    return signal[start: start + duration]
    
def f2(result, start, duration):
    return np.tile(result[start: start + duration], (50, 1))

def f4(start, duration):
    t = start
    start_time = f"{int((t / sample_rate) // 60):02d}:{(t / sample_rate) % 60:02.1f}"
    t = start + duration
    end_time = f"{int((t / sample_rate) // 60):02d}:{(t / sample_rate) % 60:02.1f}"
    return start_time + " - " + end_time + " s"

def f5(signal, avg):
    if avg:
        signal = np.convolve(signal, np.ones(avg), 'same') / avg
    return signal

def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

with plt.style.context(['ieee', 'science', 'default']):  # science, ieee, default, fivethirtyeight
    # plt.rcParams.update({'font.family': 'Roboto Slab'})

    for sample in train_loader:
        serial = sample["serial"][0]
        if True:# serial == target_serial:
            signal = sample["signal"][0].cpu().numpy()
            sample_rate = config["sampling_rate"]
            C, L = signal.shape
            r = results[serial].numpy()
                
            fig = plt.figure(num=1, clear=True, figsize=(30, 15))
            fig.subplots_adjust(hspace=0)
            fig.tight_layout()
            gs = GridSpec(nrows=C + 6, ncols=1)
            ctrls = Controls(start=np.arange(0, L - duration), avg=np.arange(0, 400))
            display(ctrls)

            ax = fig.add_subplot(gs[:3])
            iplt.plot(partial(f5, signal=r), ax=ax, lw=0.6, controls=ctrls["avg"])
            mpl_interactions.interactive_axvline(x=ctrls["start"], ymin=0, ymax=1, ax=ax, 
                                                 color='purple', controls=ctrls["start"], ls="--")
            mpl_interactions.interactive_axvline(x=lambda start: start + duration, ymin=0, ymax=1, ax=ax, 
                                                 color='purple', controls=ctrls["start"], ls="--")
            x_ticks = np.arange(0, r.shape[0], sample_rate * 30)
            x_labels = [f"{round(tick / sample_rate)}" for tick in x_ticks]
            ax.set_xlim(0, r.shape[0])
            ax.set_xticks(x_ticks)
            ax.set_xticklabels(x_labels)
            ax.set_xlabel('Time (s)')
            ax.set_ylim(0, 1.0)
            ax.set_yticks([0])
            ax.set_yticklabels([])
            ax.set_ylabel("Artifact")
            
            ax = fig.add_subplot(gs[5])
            iplt.imshow(partial(f2, result=r, duration=duration), aspect="auto",
                        alpha=1.0, ax=ax, controls=ctrls["start"], vmin=0, vmax=1)
            ax.set_xticklabels([])
            ax.set_yticks([0])
            ax.set_yticklabels([])
            ax.set_ylabel("Pred")

            for c in range(C):
                ax = fig.add_subplot(gs[c + 6])
                iplt.plot(partial(f1, signal=signal[c], duration=duration), 
                          ax=ax, controls=ctrls["start"], lw=0.6)

                ax.set_xlim(0, duration)
                ax.set_ylabel(signal_header[c])
                mpl_interactions.interactive_xlabel(xlabel=partial(f4, duration=duration),
                                                    controls=ctrls["start"])
                ax.set_xticks(np.arange(round(duration / sample_rate) + 1) * sample_rate)
                ax.set_xticklabels([])
                # ax.tick_params(axis='x', width=0.1, length=0.1)
                ax.set_yticks([0])
                ax.set_yticklabels([])
                
            fig.suptitle(serial, fontsize=13, fontweight='semibold')
            break

    plt.show()