### Import libraries

In [None]:
import yaml
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.setup_utils import (
    get_configs,
    init_configs,
    init_settings,
)
from datasets.maker import DatasetMaker
from models.maker import ModelMaker
from torch.utils.data import DataLoader
from glob import glob
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import time
import datetime
import scipy.stats as stats

import matplotlib.pyplot as plt
%matplotlib inline

### Configs

In [None]:
args = get_configs()
args = init_configs(args)
init_settings(args)

args.WEIGHT_PATH = "lightning_logs/version_0/checkpoints"

### Load data utils

In [None]:
class EvalDatasetMaker(DatasetMaker):
    def load_data(self, args, transform, target_subject = None):
        train_dataset = self.dataset(
            args=args,
            target_subject=target_subject,
            is_test=False,
            transform=transform,
        )
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False,
        )

        test_dataset = self.dataset(
            args=args,
            target_subject=target_subject,
            is_test=True,
            transform=transform,
        )
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=False,
        )

        return train_dataloader, test_dataloader

In [None]:
def load_data(target_subject, args):

    dataset = EvalDatasetMaker(args.dataset)
    train_dataloader, test_dataloader = dataset.load_data(
        args, transform=None, target_subject=target_subject
    )

    return train_dataloader, test_dataloader

In [None]:
def make_dataset(dataloader):
    datas, labels = [], []
    for data, label in dataloader:
        datas.append(data.numpy())
        labels.append(label.numpy())
    return np.concatenate(datas), np.concatenate(labels)

In [None]:
def classwise_divide(data, label):
    key_map = {0: 'left', 1: 'right', 2: 'foot', 3: 'tongue'}
    class_dict = {}
    for c in list(set(label)):

        c_idx = np.where(label == c)
        c_data = data[c_idx].squeeze()

        class_dict[key_map[c]] = c_data
    return class_dict

### Load Data

In [None]:
train_dataloader, test_dataloader = load_data(target_subject=0, args=args)

train_data, train_labels = make_dataset(train_dataloader)
test_data, test_labels = make_dataset(test_dataloader)

train_dict = classwise_divide(train_data, train_labels)
test_dict = classwise_divide(test_data, test_labels)

### Load Model

In [None]:
class EvalModelMaker(ModelMaker):
    def load_ckpt(self, model, path):
        checkpoint = torch.load(path, map_location='cpu')
        state_dict = checkpoint['state_dict']

        for k in list(state_dict.keys()):
            if k.startswith('model.'):
                state_dict[k[len('model.') :]] = state_dict[k]

            del state_dict[k]

        msg = model.load_state_dict(state_dict, strict=False)
        print(msg)

        return model

In [None]:
CKPT_DIR = f'{args.CKPT_PATH}/{args.WEIGHT_PATH}'
ckpt_list = sorted(glob(CKPT_DIR + '/*.ckpt'))
ckpt_path = ckpt_list[0]

In [None]:
model_maker = EvalModelMaker(args.model, args.litmodel)
encoder = model_maker.encoder(args)
model = model_maker.load_ckpt(encoder, ckpt_path)
model.to(args.device)
model.eval()

### Evaluation

In [None]:
def torch2np(x_torch):
    x_np = x_torch.detach().cpu().numpy() # ndarray
    return x_np

In [None]:
def pbs(model, criterion, data, labels, iters, alpha):
    attacked_datas = []
    adv_data = data.clone().detach()
    
    for i in range(iters):
        attacked_datas.append(adv_data.cpu().numpy())
        adv_data.requires_grad = True
        outputs = model(adv_data)
        
        cost = criterion(outputs, labels).to(data.device)
        grad = torch.autograd.grad(cost, adv_data, retain_graph=True, create_graph=True)[0]
        
        adv_data = adv_data.detach() + alpha * grad.sign()
        adv_data = torch.clamp(adv_data, min=0, max=1).detach()
    return adv_data, np.stack(attacked_datas, axis=1)

In [None]:
preds = []
attack_preds = []
batch_size = 1024
num_iters = 100
all_probs = []
start_time = time.time()

b, _, c, t = train_data.shape
significance_level = 0.05
stds = train_data.std()
f_critical_low = stats.f.ppf(significance_level / 2, b * c * t - 1, b * c * t - 1)
lower_bound = stds * np.sqrt(f_critical_low)
alpha = stds - lower_bound


for idx in tqdm(range(0, train_data.shape[0], batch_size)):
    
    data = torch.tensor(train_data[idx :idx + batch_size, ...], dtype=torch.float).to(args.device)
    label = train_labels[idx :idx + batch_size, ...]
    labels = torch.tensor(train_labels[idx :idx + batch_size, ...]).to(args.device)
    attacked_data, attacked_datas = pbs(model, nn.CrossEntropyLoss(), data, labels, iters=num_iters, alpha=alpha)

    for data_list, label in zip(attacked_datas, label):
        logit = model(torch.tensor(data_list, dtype=torch.float).to(args.device))
        probs = F.softmax(logit, dim=1)[:, label]
        all_probs.append(torch2np(probs))
    
    pred = torch.argmax(model(data), dim=1)
    preds.append(pred)
    
    attack_pred = torch.argmax(model(attacked_data), dim=1)
    attack_preds.append(attack_pred)

    torch.cuda.empty_cache()

end_time = time.time()

sec = (end_time - start_time)
result = datetime.timedelta(seconds=sec)
print(f'GPU time: {result}')

preds = torch.concat(preds).detach().cpu().numpy()
attack_preds = torch.concat(attack_preds).detach().cpu().numpy()
all_probs = np.stack(all_probs)
confidence_scores = all_probs.sum(1) / num_iters

In [None]:
plt.figure(figsize=(15, 15))
for i in range(len(all_probs)):
    plt.plot(all_probs[i], c='blue', alpha=0.1)
plt.ylim(0, 1)
plt.show()

In [None]:
accuracy_score(train_labels, preds), accuracy_score(train_labels, attack_preds), sum(train_labels == preds), sum(train_labels == attack_preds)

In [None]:
data = torch2np(data)
attacked_data = torch2np(attacked_data)

plt.figure(figsize=(20, 5))

plt.subplot(3, 1, 1)
plt.imshow(data[0, 0])

plt.subplot(3, 1, 2)
plt.imshow(attacked_data[0, 0])

plt.subplot(3, 1, 3)
plt.imshow(data[0, 0] - attacked_data[0, 0])

plt.show()

In [None]:
for subject_idx in range(9):
    score_index = []
    for i in range(288 * 2 * 9):
        if i >= subject_idx * 576 and i < (subject_idx + 1) * 576:
            continue

        
        score_index.append(i)
    score_index = np.array(score_index)

    np.save(f'./scores/pbs/bcic2a/S{subject_idx:02d}', confidence_scores[score_index])
    