In [None]:
import os
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from scipy.signal import butter, filtfilt, iirnotch, resample

# ────────────────────────────────────────────────────────────────
# 1) PREPROCESSING HELPERS (no linked‐mastoid reference)
# ────────────────────────────────────────────────────────────────

def notch_filter(data, sfreq=1000.0, freqs=(50.0,), bandwidth=1.0):
    """
    Zero‐phase notch at `freq` Hz (bandwidth Hz wide).
    data: np.ndarray, shape (62, n_samples)
    """
    out = data.copy()
    for freq in freqs:
        Q = freq / bandwidth
        b, a = iirnotch(w0=freq, Q=Q, fs=sfreq)
        out = filtfilt(b, a, out, axis=1)
    return out

def bandpass_filter(data, sfreq=1000.0, low=0.5, high=80.0, order=5):
    """
    Zero‐phase Butterworth bandpass between low and high (Hz).
    data: np.ndarray, shape (62, n_samples)
    """
    nyq = sfreq / 2.0
    b, a = butter(order, [low/nyq, high/nyq], btype='band')
    return filtfilt(b, a, data, axis=1)

def crop_window(data, sfreq=1000.0, tmin=0.04, tmax=0.44):
    """
    Crop each trial to [tmin, tmax] seconds.
    data: np.ndarray, shape (62, n_samples)
    """
    start = int(np.round(tmin * sfreq))
    end   = int(np.round(tmax * sfreq))
    return data[:, start:end]

def downsample(data, sfreq_old=1000.0, sfreq_new=128.0):
    """
    Fourier‐based resampling from sfreq_old → sfreq_new.
    data: np.ndarray, shape (62, n_samples_old)
    """
    n_old = data.shape[1]
    n_new = int(np.round(n_old * (sfreq_new / sfreq_old)))
    return resample(data, n_new, axis=1)

def standardize_dataset(all_trials):
    """
    Compute per‐channel mean/std on training set, then z‐score.
    all_trials: np.ndarray, shape (n_trials, 62, n_times)
    Returns (standardized, mu, sigma)
    """
    mu    = all_trials.mean(axis=(0, 2), keepdims=True)  # → (1, 62, 1)
    sigma = all_trials.std( axis=(0, 2), keepdims=True)  # → (1, 62, 1)
    eps = 1e-7
    return (all_trials - mu) / (sigma + eps), mu, sigma

def car_reference(data):
    # data: (62, T)
    return data - data.mean(axis=0, keepdims=True)


def preprocess_single_trial(raw_data,
                            sfreq_raw=1000.0,
                            notch_freq=(50.0,100.0,150.0), notch_bw=1.0,
                            bp_low=0.5, bp_high=80.0, bp_order=5,
                            tmin=0.04, tmax=0.44,
                            sfreq_new=250.0):
    """
    Entire preprocessing chain for one trial—ASSUMES raw_data is already mastoid‐referenced:
      1. Notch 50 Hz
      2. Band‐pass 0.5–80 Hz
      3. Crop 40–440 ms
      4. Downsample to 250 Hz
    raw_data: np.ndarray, shape (62, n_samples_raw)
    """
    # 1) Notch filter
    data_notch = notch_filter(raw_data, sfreq=sfreq_raw, freqs=notch_freq, bandwidth=notch_bw)

    # 2) Bandpass 0.5–80 Hz
    data_bp = bandpass_filter(data_notch, sfreq=sfreq_raw, low=bp_low, high=bp_high, order=bp_order)

    data_cr = car_reference(data_bp)

    # 3) Crop to 40–440 ms
    data_crop = crop_window(data_cr, sfreq=sfreq_raw, tmin=tmin, tmax=tmax)

    # 4) Downsample to 128 Hz
    data_ds = downsample(data_crop, sfreq_old=sfreq_raw, sfreq_new=sfreq_new)

    data_ds_vis = data_ds[vis_idx,:]
    return data_ds, data_ds_vis

# ────────────────────────────────────────────────────────────────
# 2) CHANNEL NAMES (the 62 scalp electrodes are stored)
# ────────────────────────────────────────────────────────────────

channel_names_62 = [
    "FP1","FPZ","FP2","AF3","AF4",
    "F7","F5","F3","F1","FZ","F2","F4","F6","F8",
    "FT7","FC5","FC3","FC1","FCZ","FC2","FC4","FC6","FT8",
    "T7","C5","C3","C1","CZ","C2","C4","C6","T8",
    # M1, M2 are NOT present—assumed already referenced
    "TP7","CP5","CP3","CP1","CPZ","CP2","CP4","CP6","TP8",
    "P7","P5","P3","P1","PZ","P2","P4","P6","P8",
    "PO7","PO5","PO3","POZ","PO4","PO6","PO8",
    "CB1","O1","OZ","O2","CB2"
]

# indices of the four “visual” channels we want to keep
vis_chs = ["T7","T8","TP7","TP8","O1","OZ","O2","P7","P5","P3","P1","PZ","P2","P4","P6","P8","PO7","PO5","PO3","POZ","PO4","PO6","PO8"]
vis_idx = [channel_names_62.index(ch) for ch in vis_chs]

def segment_windows(epoch, win=64, hop=32):
    n_chans, n_times = epoch.shape
    windows = []
    for start in range(0, n_times - win + 1, hop):
        windows.append(epoch[:, start:start+win])
    return np.stack(windows, axis=0)  # → (n_windows, n_chans, win)

# ────────────────────────────────────────────────────────────────
# 3) LOAD & EXTRACT RAW TRIALS FROM .PTH
# ────────────────────────────────────────────────────────────────

pth_file = r"/content/drive/MyDrive/ImageNet_Images/EEG-ImageNet-full.pth"
d = torch.load(pth_file, weights_only=False)

trials = d["dataset"]   # list of 63 850 dicts
wnids  = np.array([trial["label"] for trial in trials])  # (63850,), dtype=object
images = np.array([trial["image"] for trial in trials])

print("Total trials:", len(trials))        # → 63850
print("Unique labels:", np.unique(wnids).shape[0])  # → 80

# 4) PREPROCESS EVERY TRIAL (skip mastoid step)
sfreq_raw = 1000.0
sfreq_new = 250.0

preprocessed_list, preprocessedVis_list, wnids_list, imgs_list, gran_list, sub_list = [], [], [], [],[],[]
for trial in trials:
    eeg_tensor = trial["eeg_data"]  # shape (62, 501)
    if isinstance(eeg_tensor, torch.Tensor):
        raw_np = eeg_tensor.cpu().numpy()
    else:
        raw_np = np.asarray(eeg_tensor)

    # Since the dataset has already discarded M1/M2, we do NOT call linked_mastoid_reference here.
    data_proc, data_proc_vis = preprocess_single_trial(
        raw_data=raw_np,
        sfreq_raw=sfreq_raw,
        notch_freq=(50.0,100.0,150.0),
        notch_bw=1.0,
        bp_low=0.5,
        bp_high=80.0,
        bp_order=5,
        tmin=0.04,
        tmax=0.44,
        sfreq_new=sfreq_new
    )
    windows = segment_windows(data_proc, win=64, hop=32)
    windows_vis = segment_windows(data_proc_vis, win=64, hop=32)
    assert len(windows) == len(windows_vis)
    n_w = len(windows)
    preprocessed_list.extend(windows)
    preprocessedVis_list.extend(windows_vis)
    wnids_list.extend([trial["label"]]*n_w)
    imgs_list.extend([trial["image"]]*n_w)
    gran_list.extend([trial["granularity"]]*n_w)
    sub_list.extend([trial["subject"]]*n_w)


# Stack into (n_trials, 62, n_times_new)
all_trials = np.stack(preprocessed_list, axis=0)
all_trials_vis = np.stack(preprocessedVis_list, axis=0)
wnids = np.array(wnids_list)                    # (N_total,)
imgs  = np.array(imgs_list)                     # (N_total,)
gran  = np.array(gran_list)                     # (N_total,)
sub  = np.array(sub_list, dtype=int)  # (N,)


# ────────────────────────────────────────────────────────────────
# 5) FOR EACH GRANULARITY: split & normalize & save
# ────────────────────────────────────────────────────────────────

FINE_GROUPS = {
  "fine0": {"n07753275","n12144580","n07772935","n07756951","n07740461","n07749192","n07758680","n07745940"},  # set of the 8 wnids in group0
  "fine1": {"n03384352","n02901620","n04389033","n04465666","n03690473","n03790512","n02701002","n03845190"},
  "fine2": {"n02107142","n02110185","n02111889","n02099601","n02112826","n02106166","n02099712","n02106550"},
  "fine3": {"n04249415","n03884397","n02672831","n03372029","n02992211","n04487394","n03838899","n03495258"},
  "fine4": {"n01494475","n01456756","n02643566","n02630281","n01484850","n02655020","n01496331","n01443537"},

}

coarse_wnids = [
    "n02510455","n02106662","n03584829","n02124075","n13054560","n03445777",
    "n04120489","n02504458","n02607072","n03775071","n04044716","n04086273",
    "n02690373","n02992529","n11939491","n03063599","n03272562","n03180011",
    "n03888257","n07753592","n03297495","n03100240","n02281787","n02906734",
    "n02492035","n03773504","n07873807","n03877472","n03590841","n03709823",
    "n02389026","n02951358","n03452741","n04069434","n03982430","n03792782",
    "n03792972","n03376595","n03197337","n03272010"
]

for ch in ("all_channels",):
  for g in ("fine0","fine1","fine2","fine3","fine4"):
      # a) filter

      OUT_ROOT = f"/content/drive/MyDrive/ImageNet_Images/preprocessed_splits/granularity/Time/{ch}/{g}"
      os.makedirs(OUT_ROOT, exist_ok=True)


      if   g == "all":
          idx = np.arange(len(gran))
      elif g == "coarse":
          idx = (gran == "coarse") & np.isin(wnids, coarse_wnids)
      else:
          idx = (gran == "fine") & np.isin(wnids, list(FINE_GROUPS[g]))

      if ch == "all_channels":
          specs_g = all_trials[idx]
      else:
          specs_g = all_trials_vis[idx]

      wnids_g = wnids[idx]
      imgs_g  = imgs [idx]
      sub_g  = sub [idx]
      # b) unique‑image split
      unique_imgs, first = np.unique(imgs_g, return_index=True)
      uniq_wnids          = wnids_g[first]
      train_imgs, temp_imgs, _, temp_wnids = train_test_split(
          unique_imgs, uniq_wnids,
          test_size=0.05, stratify=uniq_wnids, random_state=42)
      val_imgs, test_imgs, _, _ = train_test_split(
          temp_imgs, temp_wnids,
          train_size=0.4, stratify=temp_wnids, random_state=42)

      # build boolean masks
      is_train = np.isin(imgs_g, train_imgs)
      is_val   = np.isin(imgs_g, val_imgs)
      is_test  = np.isin(imgs_g, test_imgs)

      Xtr, ytr, Itr,strain, trainFnames = specs_g[is_train], wnids_g[is_train], imgs_g[is_train],sub_g[is_train],imgs_g[is_train].tolist()
      Xva, yva, Iva,sva, vFnames = specs_g[is_val  ], wnids_g[is_val  ], imgs_g[is_val  ],sub_g[is_val  ],imgs_g[is_val].tolist()
      Xte, yte, Ite,ste, testFnames = specs_g[is_test ], wnids_g[is_test ], imgs_g[is_test ],sub_g[is_test ],imgs_g[is_test].tolist()

      # c) standardize by Xtr only
      mu    = Xtr.mean(axis=(0, 2), keepdims=True)    # → (1, C, 1)
      sigma = Xtr.std( axis=(0, 2), keepdims=True) + 1e-7

      Xtr_z = (Xtr - mu)   / sigma
      Xva_z = (Xva - mu)   / sigma
      Xte_z = (Xte - mu)   / sigma



      def save_split(Xz, labels, images,subjects, split_name, imgFname):

          tensor = torch.from_numpy(Xz.astype(np.float32)).unsqueeze(1)
          fname  = os.path.join(OUT_ROOT, f"{split_name}_timeNewHop32.pt")
          torch.save((tensor, labels, images,subjects, imgFname), fname)

      save_split(Xtr_z, ytr, Itr,strain, "train",trainFnames)
      save_split(Xva_z, yva, Iva,sva, "val",vFnames)
      save_split(Xte_z, yte, Ite,ste, "test",testFnames)

      print(f"→ saved {g}: "
            f"train={Xtr_z.shape[0]}, val={Xva_z.shape[0]}, test={Xte_z.shape[0]}")


Total trials: 63850
Unique labels: 80
→ saved fine0: train=11876, val=250, test=374
→ saved fine1: train=11874, val=250, test=376
→ saved fine2: train=11688, val=246, test=366
→ saved fine3: train=11782, val=248, test=370
→ saved fine4: train=11590, val=244, test=366
