In [32]:
import os
import numpy as np
import pandas as pd
import scipy.io as sio
import matplotlib.pyplot as plt
import math
import json
import re
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

import random
np.random.seed(42)
random.seed(42)

# Global Variables

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BIAS = True
DATA_MODE = 5
Q_CONFIGS = {
    "Q4.12": dict(frac_bits=12, int_bits=4, dtype=np.int16),
    "Q10.10": dict(frac_bits=10, int_bits=10, dtype=np.int32),
    "Q9.14": dict(frac_bits=14, int_bits=10, dtype=np.int32),
}

TYPE =  "Q9.14" # "Q4.12", "Q10.10","Q9.14", or "FLOAT"

TIME_REPEAT = {
    "encoder":    [256, 128, 64, 32],
    "bottleneck": [32]*8,
    "decoder":    [64, 128, 256, 512],
    "out":        [512]
}

if TYPE == "FLOAT":
    MODE = "FLOAT"
    DTYPE = np.float32
else:
    MODE = "FIXED"
    cfg = Q_CONFIGS[TYPE]
    FRAC_BITS = cfg["frac_bits"]
    INT_BITS  = cfg["int_bits"]
    DTYPE     = cfg["dtype"]

# Data

## Import Data

In [34]:
data_path = os.path.abspath("../../data/eeg_denoise_net/")
print(data_path)

EEG = np.load(os.path.join(data_path, "EEG_all_epochs.npy"))
EOG = np.load(os.path.join(data_path, "EOG_all_epochs.npy"))
EMG = np.load(os.path.join(data_path, "EMG_all_epochs.npy"))

print("EEG shape:", EEG.shape)
print("EOG shape:", EOG.shape)
print("EMG shape:", EMG.shape)

c:\Users\Aryo\PersonalMade\Programming\GAN\repo\data\eeg_denoise_net
EEG shape: (4514, 512)
EOG shape: (3400, 512)
EMG shape: (5598, 512)


In [35]:
def rms(x: np.ndarray) -> float:
    """
    Root Mean Square (RMS) of a 1D signal.
    Implements Formula (3).
    """
    x = np.asarray(x)
    return np.sqrt(np.mean(x ** 2))

def snr_db(clean: np.ndarray, noise: np.ndarray) -> float:
    """
    Compute SNR in decibels.
    Implements Formula (2).
    """
    return 10 * np.log10(rms(clean) / rms(noise))

def compute_lambda(clean: np.ndarray, noise: np.ndarray, target_snr_db: float) -> float:
    """
    Compute lambda such that the mixed signal has the desired SNR (in dB).
    """
    return rms(clean) / rms(noise) * 10 ** (-target_snr_db / 10)

def mix_signals(clean: np.ndarray, noise: np.ndarray, target_snr_db: float):
    """
    Mix clean signal with noise at a target SNR (dB).
    
    Returns:
        mixed_signal
        lambda_used
    """
    lam = compute_lambda(clean, noise, target_snr_db)
    mixed = clean + lam * noise
    return mixed, lam

def noisy_eeg_eog(eeg: np.ndarray, eog: np.ndarray, target_snr_db: float):
    """
    EEG contaminated by ocular artifacts (EOG).
    """
    return mix_signals(eeg, eog, target_snr_db)

def noisy_eeg_emg(eeg: np.ndarray, emg: np.ndarray, target_snr_db: float):
    """
    EEG contaminated by myogenic artifacts (EMG).
    """
    return mix_signals(eeg, emg, target_snr_db)

def noisy_eeg_eog_emg(
    eeg: np.ndarray,
    eog: np.ndarray,
    emg: np.ndarray,
    target_snr_db: float,
    eog_weight: float = 1.0,
    emg_weight: float = 1.0
):
    """
    EEG contaminated by both ocular (EOG) and myogenic (EMG) artifacts.
    
    eog_weight / emg_weight allow control of relative artifact dominance.
    """
    combined_noise = eog_weight * eog + emg_weight * emg
    return mix_signals(eeg, combined_noise, target_snr_db)


In [36]:
def make_noisy_sample(
    noisy_signal,
    eeg_idx,
    eog_idx=None,
    emg_idx=None,
    snr_db=None,
    lambda_used=None
):
    return {
        "noisy": noisy_signal,      # np.ndarray (512,)
        "eeg_idx": eeg_idx,         # int
        "eog_idx": eog_idx,         # int or None
        "emg_idx": emg_idx,         # int or None
        "snr_db": snr_db,           # float
        "lambda": lambda_used       # float
    }

def sample_snr_uniform(low, high, rng):
    return rng.uniform(low, high)

def fixed_snr_list(snrs, rng):
    return rng.choice(snrs)

In [37]:
def mix_eeg_eog_paper(
    EEG,
    EOG,
    seed=42,
    mode="random"  # "exhaustive" | "random"
):
    rng = np.random.default_rng(seed)

    eeg_train_idx = np.arange(3000)
    eeg_test_idx  = np.arange(3000, 3400)

    snr_grid = np.array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2])

    train_samples = []
    test_samples = []

    if mode == "exhaustive":
        # ---- training: 10x uniform SNR ----
        for _ in range(10):
            perm = rng.permutation(eeg_train_idx)
            for i in perm:
                snr = sample_snr_uniform(-7, 2, rng)
                noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

                train_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=i,
                        eog_idx=i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

        # ---- testing: fixed SNR grid ----
        for snr in snr_grid:
            for i in eeg_test_idx:
                noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

                test_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=i,
                        eog_idx=i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

    elif mode == "random":
        # ---- training: single pass, random SNR from grid ----
        for i in eeg_train_idx:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

            train_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=i,
                    eog_idx=i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

        # ---- testing ----
        for i in eeg_test_idx:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

            test_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=i,
                    eog_idx=i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

    else:
        raise ValueError("mode must be 'exhaustive' or 'random'")

    return train_samples, test_samples

def mix_eeg_emg_paper(
    EEG,
    EMG,
    seed=42,
    mode="random"  # "exhaustive" | "random"
):
    rng = np.random.default_rng(seed)

    eeg_indices = rng.choice(len(EEG), size=len(EMG), replace=True)
    emg_indices = np.arange(len(EMG))

    pairs = list(zip(eeg_indices, emg_indices))
    rng.shuffle(pairs)

    train_pairs = pairs[:5000]
    test_pairs  = pairs[5000:5598]

    snr_grid = np.array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2])

    train_samples = []
    test_samples = []

    if mode == "exhaustive":
        # ---- training: 10x ----
        for _ in range(10):
            for eeg_i, emg_i in train_pairs:
                snr = sample_snr_uniform(-7, 2, rng)
                noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

                train_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=eeg_i,
                        emg_idx=emg_i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

        # ---- testing: fixed grid ----
        for snr in snr_grid:
            for eeg_i, emg_i in test_pairs:
                noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

                test_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=eeg_i,
                        emg_idx=emg_i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

    elif mode == "random":
        # ---- training ----
        for eeg_i, emg_i in train_pairs:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

            train_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=eeg_i,
                    emg_idx=emg_i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

        # ---- testing ----
        for eeg_i, emg_i in test_pairs:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

            test_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=eeg_i,
                    emg_idx=emg_i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

    else:
        raise ValueError("mode must be 'exhaustive' or 'random'")

    return train_samples, test_samples

def mix_custom(
    EEG,
    EOG=None,
    EMG=None,
    n_train=10000,
    n_test=2000,
    snr_range=(-7, 2),
    seed=42,
    eog_weight=1.0,
    emg_weight=1.0
):
    rng = np.random.default_rng(seed)

    samples = []

    def make_one():
        eeg_i = rng.integers(len(EEG))
        snr = rng.uniform(*snr_range)

        eog_i = None
        emg_i = None

        if EOG is not None and EMG is not None:
            eog_i = rng.integers(len(EOG))
            emg_i = rng.integers(len(EMG))
            noisy, lam = noisy_eeg_eog_emg(
                EEG[eeg_i],
                EOG[eog_i],
                EMG[emg_i],
                snr,
                eog_weight=eog_weight,
                emg_weight=emg_weight
            )

        elif EOG is not None:
            eog_i = rng.integers(len(EOG))
            noisy, lam = noisy_eeg_eog(EEG[eeg_i], EOG[eog_i], snr)

        elif EMG is not None:
            emg_i = rng.integers(len(EMG))
            noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

        else:
            raise ValueError("At least one of EOG or EMG must be provided.")

        return make_noisy_sample(
            noisy,
            eeg_idx=eeg_i,
            eog_idx=eog_i,
            emg_idx=emg_i,
            snr_db=snr,
            lambda_used=lam
        )

    for _ in range(n_train + n_test):
        samples.append(make_one())

    return samples[:n_train], samples[n_train:]


In [38]:
def normalize_noisy_sample(sample, EEG):
    """
    Normalize one noisy EEG sample according to EEGdenoiseNet protocol.
    
    Args:
        sample: dict produced by make_noisy_sample
        EEG: clean EEG array (for ground truth lookup)
    
    Returns:
        normalized_sample (new dict)
    """
    y = sample["noisy"]
    x = EEG[sample["eeg_idx"]]

    sigma_y = np.std(y)
    if sigma_y == 0:
        raise ValueError("Standard deviation of noisy signal is zero.")

    normalized_sample = sample.copy()
    normalized_sample.update({
        "noisy_norm": y / sigma_y,     # ŷ
        "clean_norm": x / sigma_y,     # x̂
        "sigma_y": sigma_y             # stored for rescaling later
    })

    return normalized_sample

def normalize_dataset(samples, EEG):
    """
    Normalize a list of noisy EEG samples.
    
    Args:
        samples: list of noisy sample dicts
        EEG: clean EEG array
    
    Returns:
        list of normalized sample dicts
    """
    return [normalize_noisy_sample(s, EEG) for s in samples]


## Data Generation

In [39]:
if DATA_MODE == 1: # EEG + EOG Paper
    # --- generate noisy datasets (raw, unnormalized) ---
    train_samples, test_samples = mix_eeg_eog_paper(
        EEG=EEG,
        EOG=EOG,
        seed=42
    )

    # --- normalize according to EEGdenoiseNet protocol ---
    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)
elif DATA_MODE == 2: # EEG + EMG Paper
    # --- generate noisy datasets (raw, unnormalized) ---
    train_samples, test_samples = mix_eeg_emg_paper(
        EEG=EEG,
        EMG=EMG,
        seed=42
    )

    # --- normalize ---
    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)
elif DATA_MODE == 3:
    train_samples, test_samples = mix_custom(
        EEG=EEG,
        EOG=EOG,
        n_train=10000,
        n_test=2000,
        snr_range=(-7, 2),
        seed=42
    )

    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)
elif DATA_MODE == 4:
    train_samples, test_samples = mix_custom(
        EEG=EEG,
        EMG=EMG,
        n_train=10000,
        n_test=2000,
        snr_range=(-7, 2),
        seed=42
    )

    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)

else:
    train_samples, test_samples = mix_custom(
        EEG=EEG,
        EOG=EOG,
        EMG=EMG,
        n_train=10000,
        n_test=2000,
        snr_range=(-7, 2),
        seed=42,
        eog_weight=1.0,
        emg_weight=1.0
    )

    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)


# Model

## Original Models

In [40]:
class ConvBlock1D(nn.Module):
    """
    Conv1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=15, s=2, p=7, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError("act must be 'lrelu' or 'relu'")

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class DeconvBlock1D(nn.Module):
    """
    ConvTranspose1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=4, s=2, p=1, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError("act must be 'relu' or 'lrelu'")

    def forward(self, x):
        return self.act(self.norm(self.deconv(x)))


class ResBlock1D(nn.Module):
    """
    Residual block: (Conv -> BN -> ReLU) x2 + skip
    Keeps same channel count and length.
    """
    def __init__(self, ch, k=7, p=3, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n1 = nn.BatchNorm1d(ch)
        self.c2 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n2 = nn.BatchNorm1d(ch)

    def forward(self, x):
        h = F.relu(self.n1(self.c1(x)))
        h = self.n2(self.c2(h))
        return F.relu(x + h)
    
class MultiScaleResBlock1D(nn.Module):
    """
    Multi-scale residual block: parallel conv branches (k=3,5,7) then fuse.
    Keeps same channel count and length.
    """
    def __init__(self, ch, bias=True):
        super().__init__()

        self.b3 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=3, padding=1, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b5 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=5, padding=2, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b7 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=7, padding=3, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )

        self.fuse = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=1, bias=bias),
            nn.BatchNorm1d(ch),
        )

    def forward(self, x):
        h = self.b3(x) + self.b5(x) + self.b7(x)
        h = self.fuse(h)
        return F.relu(x + h)



In [41]:
# NN Generator (U-Net-ish + Res bottleneck)
class GeneratorCNNWGAN(nn.Module):
    """
    CNN U-Net-ish generator for EEG denoising (WGAN).
    Input : (B, 1, 512) noisy_norm
    Output: (B, 1, 512) clean_norm_hat
    """
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        # Encoder
        self.e1 = ConvBlock1D(1, base_ch,       k=16, s=2, p=7, bias=bias, act="lrelu")      # 512 -> 256
        self.e2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.e3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.e4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32

        # Bottleneck
        bn_ch = base_ch * 8
        self.bottleneck = nn.Sequential(*[
            ResBlock1D(bn_ch, k=7, p=3, bias=bias) for _ in range(bottleneck_blocks)
        ])

        # Decoder (concat doubles channels)
        self.d1 = DeconvBlock1D(bn_ch, base_ch*4,   k=4, s=2, p=1, bias=bias, act="relu")     # 32 -> 64
        self.d2 = DeconvBlock1D(base_ch*8, base_ch*2, k=4, s=2, p=1,bias=bias, act="relu")   # 64 -> 128
        self.d3 = DeconvBlock1D(base_ch*4, base_ch,   k=4, s=2, p=1, bias=bias, act="relu")   # 128 -> 256
        self.d4 = DeconvBlock1D(base_ch*2, base_ch//2, k=4, s=2, p=1, bias=bias, act="relu")  # 256 -> 512

        # Head (linear output recommended for normalized signals)
        self.out = nn.Conv1d(base_ch//2, 1, kernel_size=7, stride=1, padding=3, bias=bias)

    def forward(self, y):
        # Encoder
        s1 = self.e1(y)   # (B, base, 256)
        s2 = self.e2(s1)  # (B, 2b, 128)
        s3 = self.e3(s2)  # (B, 4b, 64)
        s4 = self.e4(s3)  # (B, 8b, 32)

        # Bottleneck
        b = self.bottleneck(s4)

        # Decoder + skip connections
        d1 = self.d1(b)                  # (B, 4b, 64)
        d1 = torch.cat([d1, s3], dim=1)  # (B, 8b, 64)

        d2 = self.d2(d1)                 # (B, 2b, 128)
        d2 = torch.cat([d2, s2], dim=1)  # (B, 4b, 128)

        d3 = self.d3(d2)                 # (B, b, 256)
        d3 = torch.cat([d3, s1], dim=1)  # (B, 2b, 256)

        d4 = self.d4(d3)                 # (B, b/2, 512)

        return self.out(d4)              # (B, 1, 512)
    
# Patch Critic (shared by CNN/ResCNN)
class CriticPatch1D(nn.Module):
    """
    Conditional PatchGAN critic for WGAN:
      D(y, x) -> patch scores
    y,x: (B,1,512)
    output: (B,1,32)
    """
    def __init__(self, base_ch=32, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(2, base_ch, kernel_size=16, stride=2, padding=7, bias=bias)  # 512 -> 256
        self.c2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.c3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.c4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32
        self.out = nn.Conv1d(base_ch*8, 1, kernel_size=7, stride=1, padding=3, bias=bias)   # 32 -> 32

    def forward(self, y, x):
        h = torch.cat([y, x], dim=1)  # (B,2,512)
        h = F.leaky_relu(self.c1(h), 0.2, inplace=True)
        h = self.c2(h)
        h = self.c3(h)
        h = self.c4(h)
        return self.out(h)


## Fused Models

In [42]:
class FusedConvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, k, s, p, bias=bias)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError

    def forward(self, x):
        return self.act(self.conv(x))


class FusedDeconvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, k, s, p, bias=bias)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError

    def forward(self, x):
        return self.act(self.deconv(x))

class GeneratorCNNWGAN_Fused(nn.Module):
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        self.e1 = FusedConvBlock1D(1, base_ch, 16, 2, 7, bias, "lrelu")
        self.e2 = FusedConvBlock1D(base_ch, base_ch*2, 16, 2, 7, bias, "lrelu")
        self.e3 = FusedConvBlock1D(base_ch*2, base_ch*4, 16, 2, 7, bias, "lrelu")
        self.e4 = FusedConvBlock1D(base_ch*4, base_ch*8, 16, 2, 7, bias, "lrelu")

        self.bottleneck = nn.Sequential(*[
            nn.Sequential(
                nn.Conv1d(base_ch*8, base_ch*8, 7, 1, 3, bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv1d(base_ch*8, base_ch*8, 7, 1, 3, bias=bias),
            )
            for _ in range(bottleneck_blocks)
        ])

        self.d1 = FusedDeconvBlock1D(base_ch*8, base_ch*4, 4, 2, 1, bias)
        self.d2 = FusedDeconvBlock1D(base_ch*8, base_ch*2, 4, 2, 1, bias)
        self.d3 = FusedDeconvBlock1D(base_ch*4, base_ch, 4, 2, 1, bias)
        self.d4 = FusedDeconvBlock1D(base_ch*2, base_ch//2, 4, 2, 1, bias)

        self.out = nn.Conv1d(base_ch//2, 1, 7, 1, 3, bias=bias)

    def forward(self, y):
        s1 = self.e1(y)
        s2 = self.e2(s1)
        s3 = self.e3(s2)
        s4 = self.e4(s3)

        b = s4
        for blk in self.bottleneck:
            b = F.relu(b + blk(b))

        d1 = torch.cat([self.d1(b), s3], dim=1)
        d2 = torch.cat([self.d2(d1), s2], dim=1)
        d3 = torch.cat([self.d3(d2), s1], dim=1)
        d4 = self.d4(d3)

        return self.out(d4)



In [43]:
data_path = os.path.abspath(f"../../test/input_sample_{TYPE}/")
os.makedirs(data_path, exist_ok=True)

In [44]:
def write_tensor_mem(tensor, name, path, y_first=True):
    """
    tensor: torch.Tensor, shape [x, y]
    writes: name.mem
    format: HEX (Q10.10, int32)

    order:
      y_first=True  → y outer, x inner  (y → x)
      y_first=False → x outer, y inner  (x → y)
    """
    assert tensor.ndim == 2, f"{name} must be 2D"

    # Move to CPU → numpy float
    arr_f = tensor.detach().cpu().numpy().astype(np.float32)

    # Quantize to Q10.10
    arr_q = float_to_q(
        arr_f,
        frac_bits=FRAC_BITS,
        int_bits=INT_BITS,
        dtype=DTYPE      # np.int32
    )

    X, Y = arr_q.shape
    out_hex_path = os.path.join(path, f"{name}_hex.mem")
    out_raw_path = os.path.join(path, f"{name}_raw.mem")

    with open(out_hex_path, "w") as f_hex, open(out_raw_path, "w") as f_raw:
        if y_first:
            # y → x
            for x in range(X):
                for y in range(Y):
                    v = int(arr_q[x, y])
                    f_hex.write(f"{v & 0xFFFFFFFF:08X}\n")
                    f_raw.write(f"{v}\n")
        else:
            # x → y
            for y in range(Y):
                for x in range(X):
                    v = int(arr_q[x, y])
                    f_hex.write(f"{v & 0xFFFFFFFF:08X}\n")
                    f_raw.write(f"{v}\n")

    order_str = "y→x" if y_first else "x→y"
    print(
        f"Wrote {name}.mem  shape=({X},{Y})  "
        f"entries={X*Y}  order={order_str}"
    )

## Model Helper

In [45]:
def fuse_conv_bn_1d(
    conv_weight,        # (Cout, Cin, K)
    conv_bias,          # (Cout,) or None
    running_mean,       # (Cout,)
    running_var,        # (Cout,)
    bn_weight,          # (Cout,) or None (gamma)
    bn_bias,            # (Cout,) or None (beta)
    eps=1e-5
):
    Cout = conv_weight.shape[0]

    if bn_weight is None:
        bn_weight = torch.ones(Cout, device=conv_weight.device, dtype=conv_weight.dtype)
    if bn_bias is None:
        bn_bias = torch.zeros(Cout, device=conv_weight.device, dtype=conv_weight.dtype)
    if conv_bias is None:
        conv_bias = torch.zeros(Cout, device=conv_weight.device, dtype=conv_weight.dtype)

    denom = torch.sqrt(running_var + eps)          # (Cout,)
    scale = bn_weight / denom                      # (Cout,)

    # Fuse weight
    fused_weight = conv_weight * scale[:, None, None]

    # Fuse bias
    fused_bias = (conv_bias - running_mean) * scale + bn_bias

    return fused_weight, fused_bias

def fuse_deconv_bn_1d(
    deconv_weight,     # (Cin, Cout, K)
    deconv_bias,       # (Cout,) or None
    running_mean,      # (Cout,)
    running_var,       # (Cout,)
    bn_weight,         # (Cout,)
    bn_bias,           # (Cout,)
    eps
):
    Cin, Cout, K = deconv_weight.shape

    if deconv_bias is None:
        deconv_bias = torch.zeros(
            Cout,
            device=deconv_weight.device,
            dtype=deconv_weight.dtype
        )

    # BN scale
    scale = bn_weight / torch.sqrt(running_var + eps)  # (Cout,)

    # Fuse weights (scale on Cout dimension)
    fused_weight = deconv_weight * scale.view(1, Cout, 1)

    # Fuse bias
    fused_bias = (deconv_bias - running_mean) * scale + bn_bias

    return fused_weight, fused_bias


In [46]:
def fuse_generator(G):
    G_fused = GeneratorCNNWGAN_Fused(bias=True).to(device)
    G_fused.eval()

    with torch.no_grad():
        # Encoder
        for i in range(1, 5):
            e = getattr(G, f"e{i}")
            fe = getattr(G_fused, f"e{i}")

            w, b = fuse_conv_bn_1d(
                e.conv.weight, e.conv.bias,
                e.norm.running_mean, e.norm.running_var,
                e.norm.weight, e.norm.bias
            )
            fe.conv.weight.copy_(w)
            fe.conv.bias.copy_(b)

        # Bottleneck
        for i, blk in enumerate(G.bottleneck):
            fblk = G_fused.bottleneck[i]

            for j, (c, n) in enumerate([(blk.c1, blk.n1), (blk.c2, blk.n2)]):
                w, b = fuse_conv_bn_1d(
                    c.weight, c.bias,
                    n.running_mean, n.running_var,
                    n.weight, n.bias
                )
                fblk[j*2].weight.copy_(w)
                fblk[j*2].bias.copy_(b)

        # Decoder
        for i in range(1, 5):
            d = getattr(G, f"d{i}")
            fd = getattr(G_fused, f"d{i}")

            w, b = fuse_deconv_bn_1d(
                d.deconv.weight, d.deconv.bias,
                d.norm.running_mean, d.norm.running_var,
                d.norm.weight, d.norm.bias,
                d.norm.eps
            )
            fd.deconv.weight.copy_(w)
            fd.deconv.bias.copy_(b)

        # Output
        G_fused.out.weight.copy_(G.out.weight)
        G_fused.out.bias.copy_(G.out.bias)

    return G_fused


## Model Declaration

In [47]:
data_path = os.path.abspath(f"../models/main3_d{DATA_MODE}_{"b" if BIAS else "nb"}/")
print(data_path)

pattern = re.compile(r"^cnn_([DG])_\d{8}_\d{6}\.pth$")

cnn_G_path = None
cnn_D_path = None

for f in os.listdir(data_path):
    m = pattern.match(f)
    if m:
        full = os.path.join(data_path, f)
        if m.group(1) == "G":
            cnn_G_path = full
        else:
            cnn_D_path = full

c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d5_b


In [48]:
G = GeneratorCNNWGAN(bias=BIAS).to(device)
D = CriticPatch1D(bias=BIAS).to(device)

G.load_state_dict(torch.load(cnn_G_path, map_location=device))
D.load_state_dict(torch.load(cnn_D_path, map_location=device))

print(G.eval())
print(D.eval())

GeneratorCNNWGAN(
  (e1): ConvBlock1D(
    (conv): Conv1d(1, 32, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e2): ConvBlock1D(
    (conv): Conv1d(32, 64, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e3): ConvBlock1D(
    (conv): Conv1d(64, 128, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e4): ConvBlock1D(
    (conv): Conv1d(128, 256, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  

  G.load_state_dict(torch.load(cnn_G_path, map_location=device))
  D.load_state_dict(torch.load(cnn_D_path, map_location=device))


In [49]:
G_fused = fuse_generator(G)

# fused_G_path = os.path.join(
#     data_path,
#     "fused_" + os.path.basename(cnn_G_path)
# )

# torch.save(G_fused.state_dict(), fused_G_path)
# print("Saved:", fused_G_path)

# Utility

In [111]:
def float_to_q(x, frac_bits, int_bits, dtype):
    scale = 1 << frac_bits
    total_bits = int_bits + frac_bits
    min_val = -(1 << (total_bits - 1))
    max_val = (1 << (total_bits - 1)) - 1

    xq = np.round(x * scale)
    xq = np.clip(xq, min_val, max_val)
    return xq.astype(dtype)

def q_to_float(x, frac_bits):
    return x.astype(np.float32) / (1 << frac_bits)

def q_to_hex(arr):
    assert isinstance(arr, np.ndarray), "Input must be a numpy array"

    vec_hex = np.vectorize(lambda v: f"{int(v) & 0xFFFFFFFF:08X}\n")
    return vec_hex(arr)

# Input and Output

## Initial Inputs

In [97]:
seed_arr_ori = [1234, 420, 67, 69, 13523100, 13223051, 13223075, 42, 21, 20]
seed_arr = [x % len(train_samples_norm) for x in seed_arr_ori]
print(seed_arr)

[1234, 420, 67, 69, 3100, 3051, 3075, 42, 21, 20]


In [98]:
for x in train_samples_norm[1234].keys():
    print(x)

noisy
eeg_idx
eog_idx
emg_idx
snr_db
lambda
noisy_norm
clean_norm
sigma_y


In [99]:
raw_data = []

for seed in seed_arr_ori:
    seeded_data = {}

    mod_seed = seed % len(train_samples_norm)
    seeded_data["raw_seed"] = seed
    seeded_data["mod_seed"] = mod_seed

    inp = train_samples_norm[mod_seed]["noisy_norm"]
    
    if isinstance(inp, np.ndarray):
        inp = torch.from_numpy(inp)
    inp = inp.float().to(next(G.parameters()).device)
    inp = inp.unsqueeze(0)
    inp = inp.unsqueeze(0)

    seeded_data["input"] = inp

    raw_data.append(seeded_data)

In [100]:
print(raw_data[0]["input"].shape)

torch.Size([1, 1, 512])


## Input Through Models

In [101]:
for input in raw_data:
    G.eval()

    with torch.no_grad():
        input["e1"] = G.e1(input["input"])
        input["e2"] = G.e2(input["e1"])
        input["e3"] = G.e3(input["e2"])
        input["e4"] = G.e4(input["e3"])

        input["b0"] = G.bottleneck[0](input["e4"])
        input["b1"] = G.bottleneck[1](input["b0"])
        input["b2"] = G.bottleneck[2](input["b1"])
        input["b3"] = G.bottleneck[3](input["b2"])

        input["d1"] = G.d1(input["b3"])
        d2_input = torch.cat([input["d1"], input["e3"]], dim=1)
        input["d2_in"] = d2_input
        input["d2_out"] = G.d2(d2_input)
        d3_input = torch.cat([input["d2_out"], input["e2"]], dim=1)
        input["d3_in"] = d3_input
        input["d3_out"] = G.d3(d3_input)
        d4_input = torch.cat([input["d3_out"], input["e1"]], dim=1)
        input["d4_in"] = d4_input
        input["d4_out"] = G.d4(d4_input)

        input["out"] = G.out(input["d4_out"])

for sample in raw_data:
    G_fused.eval()

    with torch.no_grad():
        sample["e1_fused"] = G_fused.e1(sample["input"])
        sample["e2_fused"] = G_fused.e2(sample["e1_fused"])
        sample["e3_fused"] = G_fused.e3(sample["e2_fused"])
        sample["e4_fused"] = G_fused.e4(sample["e3_fused"])

        sample["b0_fused"] = G_fused.bottleneck[0](sample["e4_fused"])
        sample["b1_fused"] = G_fused.bottleneck[1](sample["b0_fused"])
        sample["b2_fused"] = G_fused.bottleneck[2](sample["b1_fused"])
        sample["b3_fused"] = G_fused.bottleneck[3](sample["b2_fused"])

        sample["d1_fused"] = G_fused.d1(sample["b3_fused"])
        d2_input_fused = torch.cat([sample["d1_fused"], sample["e3_fused"]],dim=1)
        sample["d2_in_fused"] = d2_input_fused
        sample["d2_out_fused"] = G_fused.d2(d2_input_fused)
        d3_input_fused = torch.cat([sample["d2_out_fused"], sample["e2_fused"]],dim=1)
        sample["d3_in_fused"] = d3_input_fused
        sample["d3_out_fused"] = G_fused.d3(d3_input_fused)
        d4_input_fused = torch.cat([sample["d3_out_fused"], sample["e1_fused"]],dim=1)
        sample["d4_in_fused"] = d4_input_fused
        sample["d4_out_fused"] = G_fused.d4(d4_input_fused)

        sample["out_fused"] = G_fused.out(sample["d4_out_fused"])

In [105]:
# for i, sample in enumerate(raw_data):
#     print(f"\n=== Sample {i} ===")
#     for k, v in sample.items():
#         if torch.is_tensor(v):
#             print(f"{k:10s}: {tuple(v.shape)}")
#     break

for sample in raw_data:
    for k, v in sample.items():
        if torch.is_tensor(v) and v.dim() > 0 and v.shape[0] == 1:
            sample[k] = v.squeeze(0)

sample = raw_data[0]
for k, v in sample.items():
    if torch.is_tensor(v):
        print(f"{k:15s}: {tuple(v.shape)}")
    

input          : (512,)
e1             : (32, 256)
e2             : (64, 128)
e3             : (128, 64)
e4             : (256, 32)
b0             : (256, 32)
b1             : (256, 32)
b2             : (256, 32)
b3             : (256, 32)
d1             : (128, 64)
d2_in          : (256, 64)
d2_out         : (64, 128)
d3_in          : (128, 128)
d3_out         : (32, 256)
d4_in          : (64, 256)
d4_out         : (16, 512)
out            : (512,)
e1_fused       : (32, 256)
e2_fused       : (64, 128)
e3_fused       : (128, 64)
e4_fused       : (256, 32)
b0_fused       : (256, 32)
b1_fused       : (256, 32)
b2_fused       : (256, 32)
b3_fused       : (256, 32)
d1_fused       : (128, 64)
d2_in_fused    : (256, 64)
d2_out_fused   : (64, 128)
d3_in_fused    : (128, 128)
d3_out_fused   : (32, 256)
d4_in_fused    : (64, 256)
d4_out_fused   : (16, 512)
out_fused      : (512,)


# Quantize and Hex

In [None]:
for sample in raw_data:
    tensor_keys = [
        k for k, v in sample.items()
        if torch.is_tensor(v)
    ]

    for k in tensor_keys:
        v = sample[k]

        v_f = v.detach().cpu().numpy().astype(np.float32)
        q = float_to_q(
            v_f,
            frac_bits=FRAC_BITS,
            int_bits=INT_BITS,
            dtype=DTYPE
        )

        sample[f"{k}_q"] = q
        sample[f"{k}_hex"] = q_to_hex(q)

# for i, sample in enumerate(raw_data):
#     print(f"\n=== Sample {i} (_q shapes) ===")
#     for k, v in sample.items():
#         if isinstance(v, np.ndarray):
#             print(f"{k:15s}: {v.shape}")
#     break   # remove if you want all samples


=== Sample 0 (_q shapes) ===
input_q        : (512,)
e1_q           : (32, 256)
e2_q           : (64, 128)
e3_q           : (128, 64)
e4_q           : (256, 32)
b0_q           : (256, 32)
b1_q           : (256, 32)
b2_q           : (256, 32)
b3_q           : (256, 32)
d1_q           : (128, 64)
d2_in_q        : (256, 64)
d2_out_q       : (64, 128)
d3_in_q        : (128, 128)
d3_out_q       : (32, 256)
d4_in_q        : (64, 256)
d4_out_q       : (16, 512)
out_q          : (512,)
e1_fused_q     : (32, 256)
e2_fused_q     : (64, 128)
e3_fused_q     : (128, 64)
e4_fused_q     : (256, 32)
b0_fused_q     : (256, 32)
b1_fused_q     : (256, 32)
b2_fused_q     : (256, 32)
b3_fused_q     : (256, 32)
d1_fused_q     : (128, 64)
d2_in_fused_q  : (256, 64)
d2_out_fused_q : (64, 128)
d3_in_fused_q  : (128, 128)
d3_out_fused_q : (32, 256)
d4_in_fused_q  : (64, 256)
d4_out_fused_q : (16, 512)
out_fused_q    : (512,)
input_hex      : (512,)
e1_hex         : (32, 256)
e2_hex         : (64, 128)
e3_hex  

# Exports

In [None]:
group_data_path = os.path.abspath(f"../../test/io_sample/")
os.makedirs(group_data_path, exist_ok=True)