# Efb0 infer

## Import tools and CFG

In [71]:
import os, gc
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
import timm
from torch.utils.data import Dataset, DataLoader
import pywt, librosa

USE_WAVELET = None 

NAMES = ['LL','LP','RP','RR','LZ','RZ']

FEATS = [['Fp1','F7','T3','T5','O1'],
         ['Fp1','F3','C3','P3','O1'],
         ['Fp2','F8','T4','T6','O2'],
         ['Fp2','F4','C4','P4','O2'],
         ['Fp1','Fz','Cz','Pz','O1'],
         ['Fp2','Fz','Cz','Pz','O2'],
        ]

class CFG:
    base_dir = pathlib.Path("E:\HMS2024")
    path_test = base_dir / "code/model1/validation.csv"
    path_submission = base_dir / "sample_submission.csv"
    spec_dir = base_dir / "validation_spectrograms"
    model_name = "tf_efficientnet_b0_ns"

    model_weights = sorted(list(
        pathlib.Path("E:\HMS2024\code\model1\efb0").glob("*_v2.pt")
    ))
    
    print(model_weights)

    transform = transforms.Resize((768, 768), antialias=False)
    batch_size = 16
    label_columns = [
        "seizure_vote",
        "lpd_vote",
        "gpd_vote",
        "lrda_vote",
        "grda_vote",
        "other_vote",
    ]

[WindowsPath('E:/HMS2024/code/model1/efb0/tf_efficientnet_b0.ns_jft_in1k_0_v2.pt'), WindowsPath('E:/HMS2024/code/model1/efb0/tf_efficientnet_b0.ns_jft_in1k_1_v2.pt'), WindowsPath('E:/HMS2024/code/model1/efb0/tf_efficientnet_b0.ns_jft_in1k_2_v2.pt'), WindowsPath('E:/HMS2024/code/model1/efb0/tf_efficientnet_b0.ns_jft_in1k_3_v2.pt'), WindowsPath('E:/HMS2024/code/model1/efb0/tf_efficientnet_b0.ns_jft_in1k_4_v2.pt')]


## Data processing

In [72]:
# DENOISE FUNCTION
def maddest(d, axis=None):
    return np.mean(np.absolute(d - np.mean(d, axis)), axis)

def denoise(x, wavelet='haar', level=1):    
    coeff = pywt.wavedec(x, wavelet, mode="per")
    sigma = (1/0.6745) * maddest(coeff[-level])

    uthresh = sigma * np.sqrt(2*np.log(len(x)))
    coeff[1:] = (pywt.threshold(i, value=uthresh, mode='hard') for i in coeff[1:])

    ret=pywt.waverec(coeff, wavelet, mode='per')
    
    return ret

def spectrogram_from_eeg(parquet_path, display=False):
    
    # LOAD MIDDLE 50 SECONDS OF EEG SERIES
    eeg = pd.read_parquet(parquet_path)
    middle = (len(eeg)-10_000)//2
    eeg = eeg.iloc[middle:middle+10_000]
    
    # VARIABLE TO HOLD SPECTROGRAM
    img = np.zeros((192,768,6),dtype='float32')
    
    if display: plt.figure(figsize=(10,12))
    signals = []
    for k in range(6):
        COLS = FEATS[k]
        
        for kk in range(4):
        
            # COMPUTE PAIR DIFFERENCES
            x = eeg[COLS[kk]].values - eeg[COLS[kk+1]].values

            # FILL NANS
            m = np.nanmean(x)
            if np.isnan(x).mean()<1: x = np.nan_to_num(x,nan=m)
            else: x[:] = 0

            # DENOISE
            if USE_WAVELET:
                x = denoise(x, wavelet=USE_WAVELET)
            signals.append(x)

            # RAW SPECTROGRAM
            mel_spec = librosa.feature.melspectrogram(y=x, sr=200, hop_length=len(x)//768, 
                  n_fft=1024, n_mels=192, fmin=0, fmax=20, win_length=128)

            # LOG TRANSFORM
            width = (mel_spec.shape[1]//32)*32
            mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max).astype(np.float32)[:,:width]

            # STANDARDIZE TO -1 TO 1
            mel_spec_db = (mel_spec_db+40)/40 
            img[:,:,k] += mel_spec_db
                
        # AVERAGE THE 4 MONTAGE DIFFERENCES
        img[:,:,k] /= 4.0
        
        if display:
            plt.subplot(3,2,k+1)
            plt.imshow(img[:,:,k],aspect='auto',origin='lower')
            plt.title(f'EEG {eeg_id} - Spectrogram {NAMES[k]}')
            
    if display: 
        plt.show()
        plt.figure(figsize=(10,7))
        offset = 0
        for k in range(6):
            if k>0: offset -= signals[3-k].min()
            plt.plot(range(10_000),signals[k]+offset,label=NAMES[3-k])
            offset += signals[3-k].max()
        plt.legend()
        plt.title(f'EEG {eeg_id} Signals')
        plt.show()
        print(); print('#'*25); print()
        
    return img


def preprocess(x):
    m, s = x.mean(), x.std()
    x = (x - m) / (s + 1e-6)
    return x

def preprocess2(x):
    x = np.clip(x, np.exp(-6), np.exp(10))
    x = np.log(x)
    m, s = x.mean(), x.std()
    x = (x - m) / (s + 1e-6)
    return x

## Dataloader

In [73]:
class SpecDataset(Dataset):
    
    def __init__(self, df, transform=CFG.transform):
        self.df = df
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        # input
        x = all_eegs2[row.eeg_id]
        x = [x[:,:,i+0:i+1] for i in range(6)]
        x1 = np.concatenate(x,axis=0)[:,:,0]
        x1 = preprocess(x1)
        # input
        x = pd.read_parquet(row.path)
        x = x.fillna(-1).values[:, 1:].T
        x2 = preprocess2(x)
        x2 = torch.Tensor(x2[None, :])
        x2 = np.array(CFG.transform(x2))[0]
        x = np.concatenate([x1,x2])

        
        x = torch.Tensor(x[None, :])
        if self.transform:
            x = self.transform(x)
        # output
        y = np.array(row.loc[CFG.label_columns].values, 'float32')
        y = torch.Tensor(y)
        return x, y

## Loading data

In [74]:
submission = pd.read_csv("E:\HMS2024\code\model3/validation.csv")[['eeg_id', 'seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote', 'spectrogram_id', 'patient_id']]
submission['path'] = submission['spectrogram_id'].apply(lambda x: f"E:/HMS2024/validation_spectrograms/{x}.parquet")

PATH2 = 'model1/validation_eegs/'
DISPLAY = False
EEG_IDS2 = submission.eeg_id.unique()
all_eegs2 = {}

for i,eeg_id in enumerate(tqdm(EEG_IDS2)):
    img = spectrogram_from_eeg(f'{PATH2}{eeg_id}.parquet', i<DISPLAY)
    all_eegs2[eeg_id] = img
    
data_ds = SpecDataset(df=submission)
data_loader = DataLoader(dataset=data_ds)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model = timm.create_model(model_name=CFG.model_name, pretrained=False, num_classes=6, in_chans=1)
model.to(DEVICE)

prediction = pd.DataFrame(0.0, columns=CFG.label_columns, index=submission.index)

  0%|          | 0/8 [00:00<?, ?it/s]

## Infer

In [75]:
for i, path_weight in enumerate(CFG.model_weights):
    print(f"Model {i}: {path_weight}")
    model.load_state_dict(torch.load(path_weight))
    model.eval()
    with torch.no_grad():
        res = []
        for x, y in tqdm(data_loader):
            x = x.to(DEVICE)
            pred = model(x)
            pred = F.softmax(pred, dim=1)
            pred = pred.detach().cpu().numpy()
            res.append(pred)
        res = np.concatenate(res)
        res = pd.DataFrame(res, columns=CFG.label_columns, index=submission.iloc[:60,:].index)

        prediction = prediction + res
        print("\n")
        
        
prediction = prediction / len(CFG.model_weights)

Model 0: E:\HMS2024\code\model1\efb0\tf_efficientnet_b0.ns_jft_in1k_0_v2.pt


  0%|          | 0/60 [00:00<?, ?it/s]



Model 1: E:\HMS2024\code\model1\efb0\tf_efficientnet_b0.ns_jft_in1k_1_v2.pt


  0%|          | 0/60 [00:00<?, ?it/s]



Model 2: E:\HMS2024\code\model1\efb0\tf_efficientnet_b0.ns_jft_in1k_2_v2.pt


  0%|          | 0/60 [00:00<?, ?it/s]



Model 3: E:\HMS2024\code\model1\efb0\tf_efficientnet_b0.ns_jft_in1k_3_v2.pt


  0%|          | 0/60 [00:00<?, ?it/s]



Model 4: E:\HMS2024\code\model1\efb0\tf_efficientnet_b0.ns_jft_in1k_4_v2.pt


  0%|          | 0/60 [00:00<?, ?it/s]





## display

In [54]:
submission = submission
submission[CFG.label_columns] = prediction
submission = submission[["eeg_id"] + CFG.label_columns]
submission.to_csv("efficientnet_b0_v2_validation_result.csv", index=None)

result = pd.read_csv('efficientnet_b0_v2_validation_result.csv')
result.columns = ['eeg_id', 'seizure', 'lpd', 'gpd', 'lrda', 'grda', 'other']
train_df = pd.read_csv('E:\HMS2024/train.csv')
res = train_df[train_df['eeg_id'].isin(result['eeg_id'])]
b0result = result.merge(res[['eeg_id', 'expert_consensus', 'seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']], on='eeg_id', how='left')
b0result = b0result.drop_duplicates()

# Efb1 infer

## 导入工具包及config文件

In [76]:
import gc
import os
import random
import warnings
import numpy as np
import pandas as pd
from IPython.display import display

import timm
import torch
import torch.nn as nn  
import torch.optim as optim
import torch.nn.functional as F

import torchvision.transforms as transforms

from tqdm.notebook import tqdm


warnings.filterwarnings('ignore', category=Warning)

class Config:
    seed=42
    image_transform=transforms.Resize((512, 512))
    num_folds=5
    
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
set_seed(Config.seed)

models = []

## 加载数据模型完成推理

EfficientNetB1

In [77]:
# Load EfficientNetB1
for k in range(Config.num_folds):
    model_effnet_b2 = timm.create_model('efficientnet_b1', pretrained=False, num_classes=6, in_chans=1)
    model_effnet_b2.to('cuda')
    model_effnet_b2.load_state_dict(torch.load(f'./model3/efb1_v2/efficientnet_b1_fold{k}_v2.pth', map_location=torch.device('cpu')))
    models.append(model_effnet_b2)

submission = pd.read_csv("E:\HMS2024\code\model3/validation.csv")[['eeg_id', 'seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote', 'spectrogram_id', 'patient_id']]
submission = submission

submission['path'] = submission['spectrogram_id'].apply(lambda x: f"E:/HMS2024/validation_spectrograms/{x}.parquet")

paths = submission['path'].values
test_preds = []

for path in tqdm(paths):
    eps = 1e-6
    data = pd.read_parquet(path)
    data = data.fillna(-1).values[:, 1:].T
    data = np.clip(data, np.exp(-6), np.exp(10))
    data = np.log(data)
    
    data_mean = data.mean(axis=(0, 1))
    data_std = data.std(axis=(0, 1))
    data = (data - data_mean) / (data_std + eps)
    data_tensor = torch.unsqueeze(torch.Tensor(data), dim=0)
    data = Config.image_transform(data_tensor).to('cuda')

    test_pred = []
    
    for model in models:
        model.eval()
        with torch.no_grad():
            pred = F.softmax(model(data.unsqueeze(0)))[0]
            pred = pred.detach().cpu().numpy()
        test_pred.append(pred)
        
    weighted_pred = np.mean(test_pred[:Config.num_folds], axis=0)
    test_preds.append(weighted_pred)

test_preds = torch.tensor(test_preds).cpu().numpy()

  0%|          | 0/60 [00:00<?, ?it/s]

## display

In [78]:
l = [[attr1, attr2, attr3, attr4, attr5, attr6] for attr1, attr2, attr3, attr4, attr5, attr6 in test_preds]
eegid = submission['eeg_id']
sub = pd.DataFrame(l, columns=['seizure', 'lpd', 'gpd', 'lrda', 'grda', 'other'])
sub['eeg_id'] = eegid

# 查看合并后的DataFrame
sub.to_csv('efficientnet_b1_v2_validation_result.csv', index=False)

result = pd.read_csv('efficientnet_b1_v2_validation_result.csv')
train_df = pd.read_csv('E:\HMS2024/train.csv')
res = train_df[train_df['eeg_id'].isin(result['eeg_id'])]
b1result = result.merge(res[['eeg_id', 'expert_consensus', 'seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']], on='eeg_id', how='left')
b1result = b1result.drop_duplicates()

# ResNet34-BiGRU infer

In [None]:
import os
import gc
import sys
import math
import time
import random
import datetime as dt
import numpy as np
import pandas as pd
import wandb

from glob import glob
from pathlib import Path
from typing import Dict, List, Union
import scipy.signal as scisig
from scipy.signal import butter, lfilter, freqz
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import (
    ReduceLROnPlateau,
    OneCycleLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
)
from torch.optim.optimizer import Optimizer
from sklearn.model_selection import GroupKFold,KFold

import warnings
warnings.filterwarnings("ignore")

device = "cuda" if torch.cuda.is_available() else "cpu"
device

class CFG:
    VERSION = 88

    model_name = "resnet1d_gru"

    seed = 2024
    batch_size = 32
    num_workers = 0

    fixed_kernel_size = 5
    # kernels = [3, 5, 7, 9]
    # linear_layer_features = 424
    kernels = [3, 5, 7, 9, 11]
    #linear_layer_features = 448  # Full Signal = 10_000
    #linear_layer_features = 352  # Half Signal = 5_000
    linear_layer_features = 304   # 1/5  Signal = 2_000

    seq_length = 50  # Second's
    sampling_rate = 200  # Hz
    nsamples = seq_length * sampling_rate
    out_samples = nsamples // 5

    # bandpass_filter = {"low": 0.5, "high": 20, "order": 2}
    # rand_filter = {"probab": 0.1, "low": 10, "high": 20, "band": 1.0, "order": 2}
    freq_channels = []  # [(8.0, 12.0)]; [(0.5, 4.5)]
    filter_order = 2
    random_close_zone = 0.0  # 0.2
        
    target_cols = [
        "seizure_vote",
        "lpd_vote",
        "gpd_vote",
        "lrda_vote",
        "grda_vote",
        "other_vote",
    ]

    # target_preds = [x + "_pred" for x in target_cols]
    # label_to_num = {"Seizure": 0, "LPD": 1, "GPD": 2, "LRDA": 3, "GRDA": 4, "Other": 5}
    # num_to_label = {v: k for k, v in label_to_num.items()}

    map_features = [
        ("Fp1", "T3"),
        ("T3", "O1"),
        ("Fp1", "C3"),
        ("C3", "O1"),
        ("Fp2", "C4"),
        ("C4", "O2"),
        ("Fp2", "T4"),
        ("T4", "O2"),
        #('Fz', 'Cz'), ('Cz', 'Pz'),        
    ]

    eeg_features = ["Fp1", "T3", "C3", "O1", "Fp2", "C4", "T4", "O2"]  # 'Fz', 'Cz', 'Pz']
        # 'F3', 'P3', 'F7', 'T5', 'Fz', 'Cz', 'Pz', 'F4', 'P4', 'F8', 'T6', 'EKG']                    
    feature_to_index = {x: y for x, y in zip(eeg_features, range(len(eeg_features)))}
    simple_features = []  # 'Fz', 'Cz', 'Pz', 'EKG'

    # eeg_features = [row for row in feature_to_index]
    # eeg_feat_size = len(eeg_features)
    
    n_map_features = len(map_features)
    in_channels = n_map_features + n_map_features * len(freq_channels) + len(simple_features)
    target_size = len(target_cols)
    
    PATH = "."
    test_eeg = "E:/HMS2024/code/model1/validation_eegs/"
    test_csv = "./validation.csv"

koef_1 = 1.0
model_weights = [
    {
        'bandpass_filter':{'low':0.5, 'high':20, 'order':2}, 
        'file_data': 
        [
            #{'koef':koef_1, 'file_mask':"/kaggle/input/hms-resnet1d-gru-weights-v82/pop_1_weight_oof/*_best.pth"},
            {'koef':koef_1, 'file_mask':"E:\HMS2024\code\model2\pop_1_weight_oof_n9/*_best.pth"},  ## Set path
        ]
    },
]

def init_logger(log_file="./test.log"):
    from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler

    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return "%dm %ds" % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return "%s (remain %s)" % (asMinutes(s), asMinutes(rs))


def quantize_data(data, classes):
    mu_x = mu_law_encoding(data, classes)
    return mu_x  # quantized


def mu_law_encoding(data, mu):
    mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1)
    return mu_x


def mu_law_expansion(data, mu):
    s = np.sign(data) * (np.exp(np.abs(data) * np.log(mu + 1)) - 1) / mu
    return s


def butter_bandpass(lowcut, highcut, fs, order=5):
    return butter(order, [lowcut, highcut], fs=fs, btype="band")


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y


def butter_lowpass_filter(
    data, cutoff_freq=20, sampling_rate=CFG.sampling_rate, order=4
):
    nyquist = 0.5 * sampling_rate
    normal_cutoff = cutoff_freq / nyquist
    b, a = butter(order, normal_cutoff, btype="low", analog=False)
    filtered_data = lfilter(b, a, data, axis=0)
    return filtered_data


def denoise_filter(x):
    y = butter_bandpass_filter(x, CFG.lowcut, CFG.highcut, CFG.sampling_rate, order=6)
    y = (y + np.roll(y, -1) + np.roll(y, -2) + np.roll(y, -3)) / 4
    y = y[0:-1:4]
    return y

def eeg_from_parquet(
    parquet_path: str, display: bool = False, seq_length=CFG.seq_length) -> np.ndarray:

    eeg = pd.read_parquet(parquet_path, columns=CFG.eeg_features)
    rows = len(eeg)

    offset = (rows - CFG.nsamples) // 2
    eeg = eeg.iloc[offset : offset + CFG.nsamples]

    if display:
        plt.figure(figsize=(10, 5))
        offset = 0

    data = np.zeros((CFG.nsamples, len(CFG.eeg_features)))

    for index, feature in enumerate(CFG.eeg_features):
        x = eeg[feature].values.astype("float32") 

        mean = np.nanmean(x)
        nan_percentage = np.isnan(x).mean()  

        if nan_percentage < 1:  
            x = np.nan_to_num(x, nan=mean)
        else: 
            x[:] = 0
        data[:, index] = x

        if display:
            if index != 0:
                offset += x.max()
            plt.plot(range(CFG.nsamples), x - offset, label=feature)
            offset -= x.min()

    if display:
        plt.legend()
        name = parquet_path.split("/")[-1].split(".")[0]
        plt.yticks([])
        plt.title(f"EEG {name}", size=16)
        plt.show()
    return data

class EEGDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        batch_size: int,
        eegs: Dict[int, np.ndarray],
        mode: str = "train",
        downsample: int = None,
        bandpass_filter: Dict[str, Union[int, float]] = None,
        rand_filter: Dict[str, Union[int, float]] = None,
    ):
        self.df = df
        self.batch_size = batch_size
        self.mode = mode
        self.eegs = eegs
        self.downsample = downsample
        self.bandpass_filter = bandpass_filter
        self.rand_filter = rand_filter
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        X, y_prob = self.__data_generation(index)
        if self.downsample is not None:
            X = X[:: self.downsample, :]
        output = {
            "eeg": torch.tensor(X, dtype=torch.float32),
            "labels": torch.tensor(y_prob, dtype=torch.float32),
        }
        return output

    def __data_generation(self, index):

        X = np.zeros(
            (CFG.out_samples, CFG.in_channels), dtype="float32"
        )  # Size=(10000, 14)

        row = self.df.iloc[index]
        data = self.eegs[row.eeg_id]  # Size=(10000, 8)
        if CFG.nsamples != CFG.out_samples:
            if self.mode != "train":
                offset = (CFG.nsamples - CFG.out_samples) // 2
            else:
                #offset = random.randint(0, CFG.nsamples - CFG.out_samples)                
                offset = ((CFG.nsamples - CFG.out_samples) * random.randint(0, 1000)) // 1000
            data = data[offset:offset+CFG.out_samples,:]

        for i, (feat_a, feat_b) in enumerate(CFG.map_features):
            if self.mode == "train" and CFG.random_close_zone > 0 and random.uniform(0.0, 1.0) <= CFG.random_close_zone:
                continue
                
            diff_feat = (
                data[:, CFG.feature_to_index[feat_a]]
                - data[:, CFG.feature_to_index[feat_b]]
            )  # Size=(10000,)

            if not self.bandpass_filter is None:
                diff_feat = butter_bandpass_filter(
                    diff_feat,
                    self.bandpass_filter["low"],
                    self.bandpass_filter["high"],
                    CFG.sampling_rate,
                    order=self.bandpass_filter["order"],
                )
                    
            if (
                self.mode == "train"
                and not self.rand_filter is None
                and random.uniform(0.0, 1.0) <= self.rand_filter["probab"]
            ):
                lowcut = random.randint(
                    self.rand_filter["low"], self.rand_filter["high"]
                )
                highcut = lowcut + self.rand_filter["band"]
                diff_feat = butter_bandpass_filter(
                    diff_feat,
                    lowcut,
                    highcut,
                    CFG.sampling_rate,
                    order=self.rand_filter["order"],
                )

            X[:, i] = diff_feat

        n = CFG.n_map_features
        if len(CFG.freq_channels) > 0:
            for i in range(CFG.n_map_features):
                diff_feat = X[:, i]
                for j, (lowcut, highcut) in enumerate(CFG.freq_channels):
                    band_feat = butter_bandpass_filter(
                        diff_feat, lowcut, highcut, CFG.sampling_rate, order=CFG.filter_order,  # 6
                    )
                    X[:, n] = band_feat
                    n += 1

        for spml_feat in CFG.simple_features:
            feat_val = data[:, CFG.feature_to_index[spml_feat]]
            
            if not self.bandpass_filter is None:
                feat_val = butter_bandpass_filter(
                    feat_val,
                    self.bandpass_filter["low"],
                    self.bandpass_filter["high"],
                    CFG.sampling_rate,
                    order=self.bandpass_filter["order"],
                )

            if (
                self.mode == "train"
                and not self.rand_filter is None
                and random.uniform(0.0, 1.0) <= self.rand_filter["probab"]
            ):
                lowcut = random.randint(
                    self.rand_filter["low"], self.rand_filter["high"]
                )
                highcut = lowcut + self.rand_filter["band"]
                feat_val = butter_bandpass_filter(
                    feat_val,
                    lowcut,
                    highcut,
                    CFG.sampling_rate,
                    order=self.rand_filter["order"],
                )

            X[:, n] = feat_val
            n += 1
            
        X = np.clip(X, -1024, 1024)

        X = np.nan_to_num(X, nan=0) / 32.0

        X = butter_lowpass_filter(X, order=CFG.filter_order)  # 4

        y_prob = np.zeros(CFG.target_size, dtype="float32")  # Size=(6,)
        if self.mode != "test":
            y_prob = row[CFG.target_cols].values.astype(np.float32)

        return X, y_prob
    
class ResNet_1D_Block(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding,
        downsampling,
        dilation=1,
        groups=1,
        dropout=0.0,
    ):
        super(ResNet_1D_Block, self).__init__()

        self.bn1 = nn.BatchNorm1d(num_features=in_channels)
        # self.relu = nn.ReLU(inplace=False)
        # self.relu_1 = nn.PReLU()
        # self.relu_2 = nn.PReLU()
        self.relu_1 = nn.Hardswish()
        self.relu_2 = nn.Hardswish()

        self.dropout = nn.Dropout(p=dropout, inplace=False)
        self.conv1 = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=False,
        )

        self.bn2 = nn.BatchNorm1d(num_features=out_channels)
        self.conv2 = nn.Conv1d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=False,
        )

        self.maxpool = nn.MaxPool1d(
            kernel_size=2,
            stride=2,
            padding=0,
            dilation=dilation,
        )
        self.downsampling = downsampling

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu_1(out)
        out = self.dropout(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu_2(out)
        out = self.dropout(out)
        out = self.conv2(out)

        out = self.maxpool(out)
        identity = self.downsampling(x)

        out += identity
        return out


class EEGNet(nn.Module):
    def __init__(
        self,
        kernels,
        in_channels,
        fixed_kernel_size,
        num_classes,
        linear_layer_features,
        dilation=1,
        groups=1,
    ):
        super(EEGNet, self).__init__()
        self.kernels = kernels
        self.planes = 24
        self.parallel_conv = nn.ModuleList()
        self.in_channels = in_channels

        for i, kernel_size in enumerate(list(self.kernels)):
            sep_conv = nn.Conv1d(
                in_channels=in_channels,
                out_channels=self.planes,
                kernel_size=(kernel_size),
                stride=1,
                padding=0,
                dilation=dilation,
                groups=groups,
                bias=False,
            )
            self.parallel_conv.append(sep_conv)

        self.bn1 = nn.BatchNorm1d(num_features=self.planes)
        # self.relu = nn.ReLU(inplace=False)
        # self.relu_1 = nn.ReLU()
        # self.relu_2 = nn.ReLU()
        self.relu_1 = nn.SiLU()
        self.relu_2 = nn.SiLU()

        self.conv1 = nn.Conv1d(
            in_channels=self.planes,
            out_channels=self.planes,
            kernel_size=fixed_kernel_size,
            stride=2,
            padding=2,
            dilation=dilation,
            groups=groups,
            bias=False,
        )

        self.block = self._make_resnet_layer(
            kernel_size=fixed_kernel_size,
            stride=1,
            dilation=dilation,
            groups=groups,
            padding=fixed_kernel_size // 2,
        )
        self.bn2 = nn.BatchNorm1d(num_features=self.planes)
        self.avgpool = nn.AvgPool1d(kernel_size=6, stride=6, padding=2)

        self.rnn = nn.GRU(
            input_size=self.in_channels,
            hidden_size=128,
            num_layers=1,
            bidirectional=True,
            # dropout=0.2,
        )

        self.fc = nn.Linear(in_features=linear_layer_features, out_features=num_classes)

    def _make_resnet_layer(
        self,
        kernel_size,
        stride,
        dilation=1,
        groups=1,
        blocks=9,
        padding=0,
        dropout=0.0,
    ):
        layers = []
        downsample = None
        base_width = self.planes

        for i in range(blocks):
            downsampling = nn.Sequential(
                nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
            )
            layers.append(
                ResNet_1D_Block(
                    in_channels=self.planes,
                    out_channels=self.planes,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    downsampling=downsampling,
                    dilation=dilation,
                    groups=groups,
                    dropout=dropout,
                )
            )
        return nn.Sequential(*layers)

    def extract_features(self, x):
        x = x.permute(0, 2, 1)
        out_sep = []

        for i in range(len(self.kernels)):
            sep = self.parallel_conv[i](x)
            out_sep.append(sep)

        out = torch.cat(out_sep, dim=2)
        out = self.bn1(out)
        out = self.relu_1(out)
        out = self.conv1(out)

        out = self.block(out)
        out = self.bn2(out)
        out = self.relu_2(out)
        out = self.avgpool(out)

        out = out.reshape(out.shape[0], -1)
        rnn_out, _ = self.rnn(x.permute(0, 2, 1))
        new_rnn_h = rnn_out[:, -1, :]  # <~~

        new_out = torch.cat([out, new_rnn_h], dim=1)
        return new_out

    def forward(self, x):
        new_out = self.extract_features(x)
        result = self.fc(new_out)
        return result
    
def inference_function(test_loader, model, device):
    model.eval()  # set model in evaluation mode
    softmax = nn.Softmax(dim=1)
    prediction_dict = {}
    preds = []
    with tqdm(test_loader, unit="test_batch", desc="Inference") as tqdm_test_loader:
        for step, batch in enumerate(tqdm_test_loader):
            X = batch.pop("eeg").to(device)  # send inputs to `device`
            batch_size = X.size(0)
            with torch.no_grad():
                y_preds = model(X)  # forward propagation pass
            y_preds = softmax(y_preds)
            preds.append(y_preds.to("cpu").numpy())  # save predictions

    prediction_dict["predictions"] = np.concatenate(
        preds
    )  # np.array() of shape (fold_size, target_cols)
    return prediction_dict

test_df = pd.read_csv(CFG.test_csv)
# print(f"Test dataframe shape is: {test_df.shape}")

test_eeg_parquet_paths = glob(CFG.test_eeg + "*.parquet")
test_eeg_df = pd.read_parquet(test_eeg_parquet_paths[0])
test_eeg_features = test_eeg_df.columns
# print(f"There are {len(test_eeg_features)} raw eeg features")
# print(list(test_eeg_features))
del test_eeg_df
_ = gc.collect()

all_eegs = {}
eeg_ids = test_df.eeg_id.unique()
for i, eeg_id in tqdm(enumerate(eeg_ids)):
    # Save EEG to Python dictionary of numpy arrays
    eeg_path = CFG.test_eeg + str(eeg_id) + ".parquet"
    data = eeg_from_parquet(eeg_path)
    all_eegs[eeg_id] = data

In [None]:
koef_sum = 0
koef_count = 0
predictions = []
files = []
    
for model_block in model_weights:
    test_dataset = EEGDataset(
        df=test_df,
        batch_size=CFG.batch_size,
        mode="test",
        eegs=all_eegs,
        bandpass_filter=model_block['bandpass_filter']
    )

    if len(predictions) == 0:
        output = test_dataset[0]
        X = output["eeg"]
        print(f"X shape: {X.shape}")
                
    test_loader = DataLoader(
        test_dataset,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    model = EEGNet(
        kernels=CFG.kernels,
        in_channels=CFG.in_channels,
        fixed_kernel_size=CFG.fixed_kernel_size,
        num_classes=CFG.target_size,
        linear_layer_features=CFG.linear_layer_features,
    )

    for file_line in model_block['file_data']:
        koef = file_line['koef']
        for weight_model_file in glob(file_line['file_mask']):
            files.append(weight_model_file)
            checkpoint = torch.load(weight_model_file, map_location=device)
            model.load_state_dict(checkpoint["model"])
            model.to(device)
            prediction_dict = inference_function(test_loader, model, device)
            predict = prediction_dict["predictions"]
            predict *= koef
            koef_sum += koef
            koef_count += 1
            predictions.append(predict)
            torch.cuda.empty_cache()
            gc.collect()

predictions = np.array(predictions)
koef_sum /= koef_count
predictions /= koef_sum
predictions = np.mean(predictions, axis=0)

In [None]:
sub = pd.DataFrame({"eeg_id": test_df.eeg_id.values})
sub[CFG.target_cols] = predictions
sub.to_csv("res1d_validation_result_v1_n9.csv", index=None)