In [None]:
import numpy as np
from scipy.io import loadmat, savemat
from scipy.signal import welch
from scipy.stats import pearsonr
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# =============================
# Dataset Definition
# =============================
class EcogDataset(Dataset):
    def __init__(self, X, Y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.Y = torch.tensor(Y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

# =============================
# Attention Layer
# =============================
class Attention(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.attn = nn.Linear(input_dim, 1)

    def forward(self, x):
        weights = torch.softmax(self.attn(x), dim=1)
        return (x * weights).sum(dim=1)

# =============================
# Model Definition
# =============================
class CNNBiLSTMAttn(nn.Module):
    def __init__(self, in_ch, seq_len, out_dim):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(in_ch, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3)
        )
        self.bilstm = nn.LSTM(128, 64, num_layers=2,
                             batch_first=True, bidirectional=True)
        self.attn  = Attention(128)
        self.fc    = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, out_dim)
        )

    def forward(self, x):
        # x: [B, T, C] → CNN wants [B, C, T]
        x = x.permute(0, 2, 1)
        x = self.cnn(x)
        x = x.permute(0, 2, 1)
        x, _ = self.bilstm(x)
        x = self.attn(x)
        return self.fc(x)

# =============================
# Feature Extraction
# =============================
def get_bandpower_feats(ecog, fs, win_len, step_len, bands):
    win_s = int(win_len/1000*fs)
    step_s = int(step_len/1000*fs)
    feats = []
    for i in range(0, ecog.shape[0]-win_s+1, step_s):
        w = ecog[i:i+win_s]
        row = []
        for ch in range(w.shape[1]):
            f, Pxx = welch(w[:,ch], fs=fs, nperseg=win_s)
            for low,high in bands:
                idx = (f>=low)&(f<=high)
                row.append(np.mean(Pxx[idx]))
        feats.append(row)
    return np.array(feats)

def make_windows(ecog, glove, win_len=1000, step_len=50, fs=1000):
    win = int(win_len/1000*fs)
    step= int(step_len/1000*fs)
    X, Y = [], []
    for i in range(0, ecog.shape[0]-win+1, step):
        X.append(ecog[i:i+win])
        Y.append(np.mean(glove[i:i+win],axis=0))
    return np.array(X), np.array(Y)

# =============================
# Training Function
# =============================
def train_subject_model(ecog, glove, bands, device):
    # 1) Create windows & labels
    X_time, Y = make_windows(ecog, glove, win_len=700, step_len=50, fs=1000)
    X_freq    = get_bandpower_feats(ecog, 1000, 700, 50, bands)
    X_delta   = np.diff(X_time, axis=1, prepend=X_time[:,:1,:])
    X_time_full = np.concatenate([X_time, X_delta], axis=2)  # [windows, T, 2*C]

    # === REMOVE TIME-DOMAIN STANDARD SCALER! ===
    # We no longer flatten & scale X_time_full to save ~8 GB of RAM.
    # Rely on BatchNorm inside the CNN to normalize.

    # 2) Scale frequency features only
    sc_freq  = StandardScaler()
    X_freq_s = sc_freq.fit_transform(X_freq)
    # Expand to match time dimension
    X_freq_e = np.repeat(X_freq_s[:,None,:], X_time_full.shape[1], axis=1)

    # 3) Final concatenation
    X_final = np.concatenate([X_time_full, X_freq_e], axis=2)  # [windows, T, 2C+F]

    # 4) Train/test split
    split = int(0.7 * len(X_final))
    X_tr, X_te = X_final[:split], X_final[split:]
    Y_tr, Y_te = Y[:split], Y[split:]

    # 5) DataLoaders with no workers & small batch
    bs = 16
    train_dl = DataLoader(EcogDataset(X_tr, Y_tr),
                          batch_size=bs, shuffle=True,
                          num_workers=0, pin_memory=False)
    test_dl  = DataLoader(EcogDataset(X_te, Y_te),
                          batch_size=bs, shuffle=False,
                          num_workers=0, pin_memory=False)

    # 6) Build model, optimizer, loss
    model = CNNBiLSTMAttn(X_tr.shape[2], X_tr.shape[1], Y_tr.shape[1]).to(device)
    opt   = torch.optim.Adam(model.parameters(), lr=1e-3)
    crit  = nn.MSELoss()

    # 7) Train loop
    for epoch in range(30):
        model.train()
        for xb, yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            pred   = model(xb)
            loss   = crit(pred, yb)
            opt.zero_grad()
            loss.backward()
            opt.step()

    # 8) Evaluate
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for xb, yb in test_dl:
            xb = xb.to(device)
            out = model(xb).cpu().numpy()
            preds.append(out)
            trues.append(yb.numpy())

    Y_pred = np.vstack(preds)
    Y_true = np.vstack(trues)
    r_vals = [pearsonr(Y_true[:,i], Y_pred[:,i])[0] for i in range(Y_true.shape[1])]
    mean_r = np.mean([r_vals[i] for i in [0,1,2,4]])
    return model, sc_freq, r_vals, mean_r

# =============================
# Main Execution
# =============================
if __name__ == "__main__":
    import warnings
    warnings.filterwarnings("ignore")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device:", device)

    data = loadmat('/kaggle/input/bcijamesdataset/raw_training_data.mat')
    ecog_all = data['train_ecog']
    glove_all = data['train_dg']
    bands = [(5,15),(20,30),(70,115)]

    all_mean = []
    for s in range(3):
        print(f"\n--- Subject {s+1} ---")
        _, _, r_vals, mr = train_subject_model(ecog_all[s,0],
                                               glove_all[s,0],
                                               bands, device)
        print("r:", np.round(r_vals,3), "mean r (no ring):", mr)
        all_mean.append(mr)

    print("\nFinal mean r across subjects:", np.mean(all_mean))

    # Save summary
    savemat('model_summary.mat', {'mean_r': np.mean(all_mean)})
