# Import Library

In [1]:
import os
import numpy as np
import pandas as pd
import scipy.io as sio
import matplotlib.pyplot as plt
import math
import json
import re
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

import random
np.random.seed(42)
random.seed(42)

# Load Data

In [2]:
ENV = "local" # or "local"

In [3]:
if ENV == "kaggle":
    EEG = np.load("/kaggle/input/eeg-denoise-net/EEG_all_epochs.npy")
    EOG = np.load("/kaggle/input/eeg-denoise-net/EOG_all_epochs.npy")
    EMG = np.load("/kaggle/input/eeg-denoise-net/EMG_all_epochs.npy")
else:
    data_path = os.path.abspath("../../data/eeg_denoise_net/")
    print(data_path)
    
    EEG = np.load(os.path.join(data_path, "EEG_all_epochs.npy"))
    EOG = np.load(os.path.join(data_path, "EOG_all_epochs.npy"))
    EMG = np.load(os.path.join(data_path, "EMG_all_epochs.npy"))

print("EEG shape:", EEG.shape)
print("EOG shape:", EOG.shape)
print("EMG shape:", EMG.shape)

c:\Users\Aryo\PersonalMade\Programming\GAN\repo\data\eeg_denoise_net
EEG shape: (4514, 512)
EOG shape: (3400, 512)
EMG shape: (5598, 512)


# Data Setup

In [4]:
def rms(x: np.ndarray) -> float:
    """
    Root Mean Square (RMS) of a 1D signal.
    Implements Formula (3).
    """
    x = np.asarray(x)
    return np.sqrt(np.mean(x ** 2))

def snr_db(clean: np.ndarray, noise: np.ndarray) -> float:
    """
    Compute SNR in decibels.
    Implements Formula (2).
    """
    return 10 * np.log10(rms(clean) / rms(noise))

def compute_lambda(clean: np.ndarray, noise: np.ndarray, target_snr_db: float) -> float:
    """
    Compute lambda such that the mixed signal has the desired SNR (in dB).
    """
    return rms(clean) / rms(noise) * 10 ** (-target_snr_db / 10)

def mix_signals(clean: np.ndarray, noise: np.ndarray, target_snr_db: float):
    """
    Mix clean signal with noise at a target SNR (dB).
    
    Returns:
        mixed_signal
        lambda_used
    """
    lam = compute_lambda(clean, noise, target_snr_db)
    mixed = clean + lam * noise
    return mixed, lam

def noisy_eeg_eog(eeg: np.ndarray, eog: np.ndarray, target_snr_db: float):
    """
    EEG contaminated by ocular artifacts (EOG).
    """
    return mix_signals(eeg, eog, target_snr_db)

def noisy_eeg_emg(eeg: np.ndarray, emg: np.ndarray, target_snr_db: float):
    """
    EEG contaminated by myogenic artifacts (EMG).
    """
    return mix_signals(eeg, emg, target_snr_db)

def noisy_eeg_eog_emg(
    eeg: np.ndarray,
    eog: np.ndarray,
    emg: np.ndarray,
    target_snr_db: float,
    eog_weight: float = 1.0,
    emg_weight: float = 1.0
):
    """
    EEG contaminated by both ocular (EOG) and myogenic (EMG) artifacts.
    
    eog_weight / emg_weight allow control of relative artifact dominance.
    """
    combined_noise = eog_weight * eog + emg_weight * emg
    return mix_signals(eeg, combined_noise, target_snr_db)

def make_noisy_sample(
    noisy_signal,
    eeg_idx,
    eog_idx=None,
    emg_idx=None,
    snr_db=None,
    lambda_used=None
):
    return {
        "noisy": noisy_signal,      # np.ndarray (512,)
        "eeg_idx": eeg_idx,         # int
        "eog_idx": eog_idx,         # int or None
        "emg_idx": emg_idx,         # int or None
        "snr_db": snr_db,           # float
        "lambda": lambda_used       # float
    }

def sample_snr_uniform(low, high, rng):
    return rng.uniform(low, high)

def fixed_snr_list(snrs, rng):
    return rng.choice(snrs)



In [5]:
def mix_eeg_eog_paper(
    EEG,
    EOG,
    seed=42,
    mode="random"  # "exhaustive" | "random"
):
    rng = np.random.default_rng(seed)

    eeg_train_idx = np.arange(3000)
    eeg_test_idx  = np.arange(3000, 3400)

    snr_grid = np.array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2])

    train_samples = []
    test_samples = []

    if mode == "exhaustive":
        # ---- training: 10x uniform SNR ----
        for _ in range(10):
            perm = rng.permutation(eeg_train_idx)
            for i in perm:
                snr = sample_snr_uniform(-7, 2, rng)
                noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

                train_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=i,
                        eog_idx=i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

        # ---- testing: fixed SNR grid ----
        for snr in snr_grid:
            for i in eeg_test_idx:
                noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

                test_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=i,
                        eog_idx=i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

    elif mode == "random":
        # ---- training: single pass, random SNR from grid ----
        for i in eeg_train_idx:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

            train_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=i,
                    eog_idx=i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

        # ---- testing ----
        for i in eeg_test_idx:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

            test_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=i,
                    eog_idx=i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

    else:
        raise ValueError("mode must be 'exhaustive' or 'random'")

    return train_samples, test_samples


In [6]:
def mix_eeg_emg_paper(
    EEG,
    EMG,
    seed=42,
    mode="random"  # "exhaustive" | "random"
):
    rng = np.random.default_rng(seed)

    eeg_indices = rng.choice(len(EEG), size=len(EMG), replace=True)
    emg_indices = np.arange(len(EMG))

    pairs = list(zip(eeg_indices, emg_indices))
    rng.shuffle(pairs)

    train_pairs = pairs[:5000]
    test_pairs  = pairs[5000:5598]

    snr_grid = np.array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2])

    train_samples = []
    test_samples = []

    if mode == "exhaustive":
        # ---- training: 10x ----
        for _ in range(10):
            for eeg_i, emg_i in train_pairs:
                snr = sample_snr_uniform(-7, 2, rng)
                noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

                train_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=eeg_i,
                        emg_idx=emg_i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

        # ---- testing: fixed grid ----
        for snr in snr_grid:
            for eeg_i, emg_i in test_pairs:
                noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

                test_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=eeg_i,
                        emg_idx=emg_i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

    elif mode == "random":
        # ---- training ----
        for eeg_i, emg_i in train_pairs:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

            train_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=eeg_i,
                    emg_idx=emg_i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

        # ---- testing ----
        for eeg_i, emg_i in test_pairs:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

            test_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=eeg_i,
                    emg_idx=emg_i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

    else:
        raise ValueError("mode must be 'exhaustive' or 'random'")

    return train_samples, test_samples


In [7]:
def mix_custom(
    EEG,
    EOG=None,
    EMG=None,
    n_train=10000,
    n_test=2000,
    snr_range=(-7, 2),
    seed=42,
    eog_weight=1.0,
    emg_weight=1.0
):
    rng = np.random.default_rng(seed)

    samples = []

    def make_one():
        eeg_i = rng.integers(len(EEG))
        snr = rng.uniform(*snr_range)

        eog_i = None
        emg_i = None

        if EOG is not None and EMG is not None:
            eog_i = rng.integers(len(EOG))
            emg_i = rng.integers(len(EMG))
            noisy, lam = noisy_eeg_eog_emg(
                EEG[eeg_i],
                EOG[eog_i],
                EMG[emg_i],
                snr,
                eog_weight=eog_weight,
                emg_weight=emg_weight
            )

        elif EOG is not None:
            eog_i = rng.integers(len(EOG))
            noisy, lam = noisy_eeg_eog(EEG[eeg_i], EOG[eog_i], snr)

        elif EMG is not None:
            emg_i = rng.integers(len(EMG))
            noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

        else:
            raise ValueError("At least one of EOG or EMG must be provided.")

        return make_noisy_sample(
            noisy,
            eeg_idx=eeg_i,
            eog_idx=eog_i,
            emg_idx=emg_i,
            snr_db=snr,
            lambda_used=lam
        )

    for _ in range(n_train + n_test):
        samples.append(make_one())

    return samples[:n_train], samples[n_train:]

In [8]:
def normalize_noisy_sample(sample, EEG):
    """
    Normalize one noisy EEG sample according to EEGdenoiseNet protocol.
    
    Args:
        sample: dict produced by make_noisy_sample
        EEG: clean EEG array (for ground truth lookup)
    
    Returns:
        normalized_sample (new dict)
    """
    y = sample["noisy"]
    x = EEG[sample["eeg_idx"]]

    sigma_y = np.std(y)
    if sigma_y == 0:
        raise ValueError("Standard deviation of noisy signal is zero.")

    normalized_sample = sample.copy()
    normalized_sample.update({
        "noisy_norm": y / sigma_y,     # ŷ
        "clean_norm": x / sigma_y,     # x̂
        "sigma_y": sigma_y             # stored for rescaling later
    })

    return normalized_sample

def normalize_dataset(samples, EEG):
    """
    Normalize a list of noisy EEG samples.
    
    Args:
        samples: list of noisy sample dicts
        EEG: clean EEG array
    
    Returns:
        list of normalized sample dicts
    """
    return [normalize_noisy_sample(s, EEG) for s in samples]


# Data Config

In [9]:
# 1 = EEG + EOG Paper
# 2 = EEG + EMG Paper
# 3 = Custom EEG + EOG
# 4 = Custom EEG + EMG
# else Custom EEG + EOG + EMG
DATA_MODE = 2

In [10]:
if DATA_MODE == 1: # EEG + EOG Paper
    # --- generate noisy datasets (raw, unnormalized) ---
    train_samples, test_samples = mix_eeg_eog_paper(
        EEG=EEG,
        EOG=EOG,
        seed=42
    )

    # --- normalize according to EEGdenoiseNet protocol ---
    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)
elif DATA_MODE == 2: # EEG + EMG Paper
    # --- generate noisy datasets (raw, unnormalized) ---
    train_samples, test_samples = mix_eeg_emg_paper(
        EEG=EEG,
        EMG=EMG,
        seed=42
    )

    # --- normalize ---
    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)
elif DATA_MODE == 3:
    train_samples, test_samples = mix_custom(
        EEG=EEG,
        EOG=EOG,
        n_train=10000,
        n_test=2000,
        snr_range=(-7, 2),
        seed=42
    )

    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)
elif DATA_MODE == 4:
    train_samples, test_samples = mix_custom(
        EEG=EEG,
        EMG=EMG,
        n_train=10000,
        n_test=2000,
        snr_range=(-7, 2),
        seed=42
    )

    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)

else:
    train_samples, test_samples = mix_custom(
        EEG=EEG,
        EOG=EOG,
        EMG=EMG,
        n_train=10000,
        n_test=2000,
        snr_range=(-7, 2),
        seed=42,
        eog_weight=1.0,
        emg_weight=1.0
    )

    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)


# Model Definition

In [11]:
class ConvBlock1D(nn.Module):
    """
    Conv1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=15, s=2, p=7, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError("act must be 'lrelu' or 'relu'")

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class DeconvBlock1D(nn.Module):
    """
    ConvTranspose1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=4, s=2, p=1, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError("act must be 'relu' or 'lrelu'")

    def forward(self, x):
        return self.act(self.norm(self.deconv(x)))


class ResBlock1D(nn.Module):
    """
    Residual block: (Conv -> BN -> ReLU) x2 + skip
    Keeps same channel count and length.
    """
    def __init__(self, ch, k=7, p=3, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n1 = nn.BatchNorm1d(ch)
        self.c2 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n2 = nn.BatchNorm1d(ch)

    def forward(self, x):
        h = F.relu(self.n1(self.c1(x)))
        h = self.n2(self.c2(h))
        return F.relu(x + h)
    
class MultiScaleResBlock1D(nn.Module):
    """
    Multi-scale residual block: parallel conv branches (k=3,5,7) then fuse.
    Keeps same channel count and length.
    """
    def __init__(self, ch, bias=True):
        super().__init__()

        self.b3 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=3, padding=1, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b5 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=5, padding=2, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b7 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=7, padding=3, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )

        self.fuse = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=1, bias=bias),
            nn.BatchNorm1d(ch),
        )

    def forward(self, x):
        h = self.b3(x) + self.b5(x) + self.b7(x)
        h = self.fuse(h)
        return F.relu(x + h)



In [12]:
# NN Generator (U-Net-ish + Res bottleneck)
class GeneratorCNNWGAN(nn.Module):
    """
    CNN U-Net-ish generator for EEG denoising (WGAN).
    Input : (B, 1, 512) noisy_norm
    Output: (B, 1, 512) clean_norm_hat
    """
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        # Encoder
        self.e1 = ConvBlock1D(1, base_ch,       k=16, s=2, p=7, bias=bias, act="lrelu")      # 512 -> 256
        self.e2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.e3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.e4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32

        # Bottleneck
        bn_ch = base_ch * 8
        self.bottleneck = nn.Sequential(*[
            ResBlock1D(bn_ch, k=7, p=3, bias=bias) for _ in range(bottleneck_blocks)
        ])

        # Decoder (concat doubles channels)
        self.d1 = DeconvBlock1D(bn_ch, base_ch*4,   k=4, s=2, p=1, bias=bias, act="relu")     # 32 -> 64
        self.d2 = DeconvBlock1D(base_ch*8, base_ch*2, k=4, s=2, p=1,bias=bias, act="relu")   # 64 -> 128
        self.d3 = DeconvBlock1D(base_ch*4, base_ch,   k=4, s=2, p=1, bias=bias, act="relu")   # 128 -> 256
        self.d4 = DeconvBlock1D(base_ch*2, base_ch//2, k=4, s=2, p=1, bias=bias, act="relu")  # 256 -> 512

        # Head (linear output recommended for normalized signals)
        self.out = nn.Conv1d(base_ch//2, 1, kernel_size=7, stride=1, padding=3, bias=bias)

    def forward(self, y):
        # Encoder
        s1 = self.e1(y)   # (B, base, 256)
        s2 = self.e2(s1)  # (B, 2b, 128)
        s3 = self.e3(s2)  # (B, 4b, 64)
        s4 = self.e4(s3)  # (B, 8b, 32)

        # Bottleneck
        b = self.bottleneck(s4)

        # Decoder + skip connections
        d1 = self.d1(b)                  # (B, 4b, 64)
        d1 = torch.cat([d1, s3], dim=1)  # (B, 8b, 64)

        d2 = self.d2(d1)                 # (B, 2b, 128)
        d2 = torch.cat([d2, s2], dim=1)  # (B, 4b, 128)

        d3 = self.d3(d2)                 # (B, b, 256)
        d3 = torch.cat([d3, s1], dim=1)  # (B, 2b, 256)

        d4 = self.d4(d3)                 # (B, b/2, 512)

        return self.out(d4)              # (B, 1, 512)

In [13]:
# Patch Critic (shared by CNN/ResCNN)
class CriticPatch1D(nn.Module):
    """
    Conditional PatchGAN critic for WGAN:
      D(y, x) -> patch scores
    y,x: (B,1,512)
    output: (B,1,32)
    """
    def __init__(self, base_ch=32, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(2, base_ch, kernel_size=16, stride=2, padding=7, bias=bias)  # 512 -> 256
        self.c2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.c3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.c4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32
        self.out = nn.Conv1d(base_ch*8, 1, kernel_size=7, stride=1, padding=3, bias=bias)   # 32 -> 32

    def forward(self, y, x):
        h = torch.cat([y, x], dim=1)  # (B,2,512)
        h = F.leaky_relu(self.c1(h), 0.2, inplace=True)
        h = self.c2(h)
        h = self.c3(h)
        h = self.c4(h)
        return self.out(h)


# Model Import

In [14]:
BIAS = True

In [15]:
data_path = os.path.abspath(f"../models/main3_d{DATA_MODE}_{"b" if BIAS else "nb"}/")
print(data_path)

pattern = re.compile(r"^cnn_([DG])_\d{8}_\d{6}\.pth$")

cnn_G_path = None
cnn_D_path = None

for f in os.listdir(data_path):
    m = pattern.match(f)
    if m:
        full = os.path.join(data_path, f)
        if m.group(1) == "G":
            cnn_G_path = full
        else:
            cnn_D_path = full

print("G:", cnn_G_path)
print("D:", cnn_D_path)

c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d2_b
G: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d2_b\cnn_G_20260113_175122.pth
D: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d2_b\cnn_D_20260113_175122.pth


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = GeneratorCNNWGAN(bias=BIAS).to(device)
D = CriticPatch1D(bias=BIAS).to(device)

G.load_state_dict(torch.load(cnn_G_path, map_location=device))
D.load_state_dict(torch.load(cnn_D_path, map_location=device))

print(G.eval())
print(D.eval())


GeneratorCNNWGAN(
  (e1): ConvBlock1D(
    (conv): Conv1d(1, 32, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e2): ConvBlock1D(
    (conv): Conv1d(32, 64, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e3): ConvBlock1D(
    (conv): Conv1d(64, 128, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e4): ConvBlock1D(
    (conv): Conv1d(128, 256, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  

  G.load_state_dict(torch.load(cnn_G_path, map_location=device))
  D.load_state_dict(torch.load(cnn_D_path, map_location=device))


In [17]:
for name, module in G.named_modules():
    print(name, "->", module.__class__.__name__)

print("\n" + "=" * 80 +"\n")

for name, module in D.named_modules():
    print(name, "->", module.__class__.__name__)

print("\n" + "=" * 80 +"\n")

for name, param in G.named_parameters():
    print(name, param.shape, param.requires_grad)

print("\n" + "=" * 80 +"\n")

for name, param in D.named_parameters():
    print(name, param.shape, param.requires_grad)

print("\n" + "=" * 80 +"\n")

for name, buf in G.named_buffers():
    print(name, buf.shape)

print("\n" + "=" * 80 +"\n")

for name, buf in D.named_buffers():
    print(name, buf.shape)


 -> GeneratorCNNWGAN
e1 -> ConvBlock1D
e1.conv -> Conv1d
e1.norm -> BatchNorm1d
e1.act -> LeakyReLU
e2 -> ConvBlock1D
e2.conv -> Conv1d
e2.norm -> BatchNorm1d
e2.act -> LeakyReLU
e3 -> ConvBlock1D
e3.conv -> Conv1d
e3.norm -> BatchNorm1d
e3.act -> LeakyReLU
e4 -> ConvBlock1D
e4.conv -> Conv1d
e4.norm -> BatchNorm1d
e4.act -> LeakyReLU
bottleneck -> Sequential
bottleneck.0 -> ResBlock1D
bottleneck.0.c1 -> Conv1d
bottleneck.0.n1 -> BatchNorm1d
bottleneck.0.c2 -> Conv1d
bottleneck.0.n2 -> BatchNorm1d
bottleneck.1 -> ResBlock1D
bottleneck.1.c1 -> Conv1d
bottleneck.1.n1 -> BatchNorm1d
bottleneck.1.c2 -> Conv1d
bottleneck.1.n2 -> BatchNorm1d
bottleneck.2 -> ResBlock1D
bottleneck.2.c1 -> Conv1d
bottleneck.2.n1 -> BatchNorm1d
bottleneck.2.c2 -> Conv1d
bottleneck.2.n2 -> BatchNorm1d
bottleneck.3 -> ResBlock1D
bottleneck.3.c1 -> Conv1d
bottleneck.3.n1 -> BatchNorm1d
bottleneck.3.c2 -> Conv1d
bottleneck.3.n2 -> BatchNorm1d
d1 -> DeconvBlock1D
d1.deconv -> ConvTranspose1d
d1.norm -> BatchNorm1d
d1

# Fuse Weigths

In [18]:
def fuse_conv_bn_1d(
    conv_weight,        # (Cout, Cin, K)
    conv_bias,          # (Cout,) or None
    running_mean,       # (Cout,)
    running_var,        # (Cout,)
    bn_weight,          # (Cout,) or None (gamma)
    bn_bias,            # (Cout,) or None (beta)
    eps=1e-5
):
    Cout = conv_weight.shape[0]

    if bn_weight is None:
        bn_weight = torch.ones(Cout, device=conv_weight.device, dtype=conv_weight.dtype)
    if bn_bias is None:
        bn_bias = torch.zeros(Cout, device=conv_weight.device, dtype=conv_weight.dtype)
    if conv_bias is None:
        conv_bias = torch.zeros(Cout, device=conv_weight.device, dtype=conv_weight.dtype)

    denom = torch.sqrt(running_var + eps)          # (Cout,)
    scale = bn_weight / denom                      # (Cout,)

    # Fuse weight
    fused_weight = conv_weight * scale[:, None, None]

    # Fuse bias
    fused_bias = (conv_bias - running_mean) * scale + bn_bias

    return fused_weight, fused_bias

def fuse_deconv_bn_1d(
    deconv_weight,     # (Cin, Cout, K)
    deconv_bias,       # (Cout,) or None
    running_mean,      # (Cout,)
    running_var,       # (Cout,)
    bn_weight,         # (Cout,)
    bn_bias,           # (Cout,)
    eps
):
    Cin, Cout, K = deconv_weight.shape

    if deconv_bias is None:
        deconv_bias = torch.zeros(
            Cout,
            device=deconv_weight.device,
            dtype=deconv_weight.dtype
        )

    # BN scale
    scale = bn_weight / torch.sqrt(running_var + eps)  # (Cout,)

    # Fuse weights (scale on Cout dimension)
    fused_weight = deconv_weight * scale.view(1, Cout, 1)

    # Fuse bias
    fused_bias = (deconv_bias - running_mean) * scale + bn_bias

    return fused_weight, fused_bias


In [19]:
class FusedConvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, k, s, p, bias=bias)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError

    def forward(self, x):
        return self.act(self.conv(x))


class FusedDeconvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, k, s, p, bias=bias)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError

    def forward(self, x):
        return self.act(self.deconv(x))


In [20]:
class GeneratorCNNWGAN_Fused(nn.Module):
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        self.e1 = FusedConvBlock1D(1, base_ch, 16, 2, 7, bias, "lrelu")
        self.e2 = FusedConvBlock1D(base_ch, base_ch*2, 16, 2, 7, bias, "lrelu")
        self.e3 = FusedConvBlock1D(base_ch*2, base_ch*4, 16, 2, 7, bias, "lrelu")
        self.e4 = FusedConvBlock1D(base_ch*4, base_ch*8, 16, 2, 7, bias, "lrelu")

        self.bottleneck = nn.Sequential(*[
            nn.Sequential(
                nn.Conv1d(base_ch*8, base_ch*8, 7, 1, 3, bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv1d(base_ch*8, base_ch*8, 7, 1, 3, bias=bias),
            )
            for _ in range(bottleneck_blocks)
        ])

        self.d1 = FusedDeconvBlock1D(base_ch*8, base_ch*4, 4, 2, 1, bias)
        self.d2 = FusedDeconvBlock1D(base_ch*8, base_ch*2, 4, 2, 1, bias)
        self.d3 = FusedDeconvBlock1D(base_ch*4, base_ch, 4, 2, 1, bias)
        self.d4 = FusedDeconvBlock1D(base_ch*2, base_ch//2, 4, 2, 1, bias)

        self.out = nn.Conv1d(base_ch//2, 1, 7, 1, 3, bias=bias)

    def forward(self, y):
        s1 = self.e1(y)
        s2 = self.e2(s1)
        s3 = self.e3(s2)
        s4 = self.e4(s3)

        b = s4
        for blk in self.bottleneck:
            b = F.relu(b + blk(b))

        d1 = torch.cat([self.d1(b), s3], dim=1)
        d2 = torch.cat([self.d2(d1), s2], dim=1)
        d3 = torch.cat([self.d3(d2), s1], dim=1)
        d4 = self.d4(d3)

        return self.out(d4)


In [21]:
def fuse_generator(G):
    G_fused = GeneratorCNNWGAN_Fused(bias=True).to(device)
    G_fused.eval()

    with torch.no_grad():
        # Encoder
        for i in range(1, 5):
            e = getattr(G, f"e{i}")
            fe = getattr(G_fused, f"e{i}")

            w, b = fuse_conv_bn_1d(
                e.conv.weight, e.conv.bias,
                e.norm.running_mean, e.norm.running_var,
                e.norm.weight, e.norm.bias
            )
            fe.conv.weight.copy_(w)
            fe.conv.bias.copy_(b)

        # Bottleneck
        for i, blk in enumerate(G.bottleneck):
            fblk = G_fused.bottleneck[i]

            for j, (c, n) in enumerate([(blk.c1, blk.n1), (blk.c2, blk.n2)]):
                w, b = fuse_conv_bn_1d(
                    c.weight, c.bias,
                    n.running_mean, n.running_var,
                    n.weight, n.bias
                )
                fblk[j*2].weight.copy_(w)
                fblk[j*2].bias.copy_(b)

        # Decoder
        for i in range(1, 5):
            d = getattr(G, f"d{i}")
            fd = getattr(G_fused, f"d{i}")

            w, b = fuse_deconv_bn_1d(
                d.deconv.weight, d.deconv.bias,
                d.norm.running_mean, d.norm.running_var,
                d.norm.weight, d.norm.bias,
                d.norm.eps
            )
            fd.deconv.weight.copy_(w)
            fd.deconv.bias.copy_(b)

        # Output
        G_fused.out.weight.copy_(G.out.weight)
        G_fused.out.bias.copy_(G.out.bias)

    return G_fused


In [22]:
G_fused = fuse_generator(G)

for name, module in G_fused.named_modules():
    print(name, "->", module.__class__.__name__)

 -> GeneratorCNNWGAN_Fused
e1 -> FusedConvBlock1D
e1.conv -> Conv1d
e1.act -> LeakyReLU
e2 -> FusedConvBlock1D
e2.conv -> Conv1d
e2.act -> LeakyReLU
e3 -> FusedConvBlock1D
e3.conv -> Conv1d
e3.act -> LeakyReLU
e4 -> FusedConvBlock1D
e4.conv -> Conv1d
e4.act -> LeakyReLU
bottleneck -> Sequential
bottleneck.0 -> Sequential
bottleneck.0.0 -> Conv1d
bottleneck.0.1 -> ReLU
bottleneck.0.2 -> Conv1d
bottleneck.1 -> Sequential
bottleneck.1.0 -> Conv1d
bottleneck.1.1 -> ReLU
bottleneck.1.2 -> Conv1d
bottleneck.2 -> Sequential
bottleneck.2.0 -> Conv1d
bottleneck.2.1 -> ReLU
bottleneck.2.2 -> Conv1d
bottleneck.3 -> Sequential
bottleneck.3.0 -> Conv1d
bottleneck.3.1 -> ReLU
bottleneck.3.2 -> Conv1d
d1 -> FusedDeconvBlock1D
d1.deconv -> ConvTranspose1d
d1.act -> ReLU
d2 -> FusedDeconvBlock1D
d2.deconv -> ConvTranspose1d
d2.act -> ReLU
d3 -> FusedDeconvBlock1D
d3.deconv -> ConvTranspose1d
d3.act -> ReLU
d4 -> FusedDeconvBlock1D
d4.deconv -> ConvTranspose1d
d4.act -> ReLU
out -> Conv1d


# Test Data

In [23]:
sample = test_samples_norm[0]
sample_scale = max(abs(sample["noisy_norm"].min()), abs(sample["noisy_norm"].max()))
x = torch.tensor(sample["noisy_norm"], dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

print("Len:", len(test_samples_norm[0]["noisy_norm"]))
print("Max:", sample["noisy_norm"].max())
print("Min:", sample["noisy_norm"].min())
print("Scale:", sample_scale)

Len: 512
Max: 5.067047251915556
Min: -4.0315014453972156
Scale: 5.067047251915556


In [24]:
all_vals = np.concatenate(
    [s["noisy_norm"] for s in test_samples_norm]
)

print("GLOBAL min:", all_vals.min())
print("GLOBAL max:", all_vals.max())


GLOBAL min: -10.46939142750513
GLOBAL max: 14.82895621443052


# Data and Weight Modification

In [25]:
# All 32 bit float (torch)
print(x.dtype)
print(G_fused.e1.conv.weight.dtype)
print(G_fused.e1.conv.bias.dtype)
print(G_fused.d1.deconv.weight.dtype)
print(G_fused.d1.deconv.bias.dtype)

torch.float32
torch.float32
torch.float32
torch.float32
torch.float32


In [26]:
def scan_weight_bias_ranges(model):
    w_min, w_max = float("inf"), float("-inf")
    b_min, b_max = float("inf"), float("-inf")

    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
            if module.weight is not None:
                w_min = min(w_min, module.weight.min().item())
                w_max = max(w_max, module.weight.max().item())

            if module.bias is not None:
                b_min = min(b_min, module.bias.min().item())
                b_max = max(b_max, module.bias.max().item())

    return w_min, w_max, b_min, b_max

w_min, w_max, b_min, b_max = scan_weight_bias_ranges(G_fused)

print("=== G_fused parameter ranges ===")
print(f"Weights: min = {w_min:.6e}, max = {w_max:.6e}")
print(f"Biases : min = {b_min:.6e}, max = {b_max:.6e}")

=== G_fused parameter ranges ===
Weights: min = -5.299317e-01, max = 5.558594e-01
Biases : min = -2.609087e+00, max = 2.154806e+00


In [None]:
def float_to_q1_15(x):
    # Q1.15 (int16)
    # Raw int range   : [-32768, 32767]
    # Real value range: [-1.0, +0.999969482421875]
    scale = 1 << 15
    x = np.round(x * scale)
    x = np.clip(x, -32768, 32767) 
    return x.astype(np.int16)

def float_to_q4_12(x):
    # Q4.12 (int16)
    # Raw int range   : [-32768, 32767]
    # Real value range: [-8.0, +7.999755859375]
    scale = 1 << 12
    x = np.round(x * scale)
    x = np.clip(x, -32768, 32767) 
    return x.astype(np.int16)

def float_to_q10_10(x):
    # Q10.10 (int32, 20-bit signed)
    # Raw int range   : [-524288, 524287]
    # Real value range: [-512.0, +511.9990234375]
    scale = 1 << 10
    x = np.round(x * scale)
    x = np.clip(x, -524288,  524287)
    return x.astype(np.int32)

def float_to_q(x, frac_bits, int_bits, dtype):
    scale = 1 << frac_bits
    total_bits = int_bits + frac_bits
    min_val = -(1 << (total_bits - 1))
    max_val = (1 << (total_bits - 1)) - 1

    xq = np.round(x * scale)
    xq = np.clip(xq, min_val, max_val)
    return xq.astype(dtype)

def q_to_float(x, frac_bits):
    return x.astype(np.float32) / (1 << frac_bits)

# Modification Helper

In [28]:
Q_CONFIGS = {
    "Q4.12": dict(frac_bits=12, int_bits=4, dtype=np.int16),
    "Q10.10": dict(frac_bits=10, int_bits=10, dtype=np.int32),
}

This:
- walks Conv1d + ConvTranspose1d
- quantizes weights and bias
- writes them back as float32 dequantized
- prints fixed-point min/max per layer

In [29]:
# For each layer
# FP32 weights
#     ↓ quantize (round + clip)
# INT representation (Q format)
#     ↓ dequantize
# FP32 tensor with quantization error

def quantize_model_fixed_point(model, qcfg, verbose=True):
    model_q = deepcopy(model)

    frac = qcfg["frac_bits"]
    itg  = qcfg["int_bits"]
    dt   = qcfg["dtype"]

    scale = 1 << frac
    
    print(f"\n=== Quantizing model to Q{itg}.{frac} ===")

    for name, module in model_q.named_modules():
        if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):

            # ---- weights ----
            w = module.weight.detach().cpu().numpy()
            w_q = float_to_q(w, frac, itg, dt)
            w_f = q_to_float(w_q, frac)

            module.weight.data = torch.from_numpy(w_f).to(module.weight.device)

            # ---- bias ----
            if module.bias is not None:
                b = module.bias.detach().cpu().numpy()
                b_q = float_to_q(b, frac, itg, dt)
                b_f = q_to_float(b_q, frac)
                module.bias.data = torch.from_numpy(b_f).to(module.bias.device)

            if verbose:
                wq_min, wq_max = w_q.min(), w_q.max()
                print(f"[{name}]")
                print(
                    f"  W_q min/max : {wq_min:>8} {wq_max:>8} "
                    f"(≈ {wq_min/scale:+.6f}, {wq_max/scale:+.6f})"
                )

                if module.bias is not None:
                    bq_min, bq_max = b_q.min(), b_q.max()
                    print(
                        f"  B_q min/max : {bq_min:>8} {bq_max:>8} "
                        f"(≈ {bq_min/scale:+.6f}, {bq_max/scale:+.6f})"
                    )


    return model_q


In [30]:
def quantize_input_fixed_point(x, qcfg):
    frac = qcfg["frac_bits"]
    itg  = qcfg["int_bits"]
    dt   = qcfg["dtype"]

    x_np = x.detach().cpu().numpy()
    x_q = float_to_q(x_np, frac, itg, dt)
    x_f = q_to_float(x_q, frac)

    print("\n=== Input quantization ===")
    print(f"x_q min/max: {x_q.min()} {x_q.max()}")

    return torch.from_numpy(x_f).to(x.device)


# Reference Evaluation

In [31]:
def rrmse(x_hat, x):
    return torch.sqrt(torch.mean((x_hat - x) ** 2)) / torch.sqrt(torch.mean(x ** 2))


def corrcoef(x_hat, x):
    x_hat = x_hat - x_hat.mean()
    x = x - x.mean()
    return torch.sum(x_hat * x) / (
        torch.sqrt(torch.sum(x_hat ** 2)) * torch.sqrt(torch.sum(x ** 2)) + 1e-8
    )


In [32]:
G.eval()
G_fused.eval()

with torch.no_grad():
    y_ref = G(x)
    y_fused = G_fused(x)

    assert y_ref.shape == y_fused.shape

    max_abs_diff = torch.max(torch.abs(y_ref - y_fused)).item()
    mean_abs_diff = torch.mean(torch.abs(y_ref - y_fused)).item()
    rrmse_val = rrmse(y_fused, y_ref).item()
    corr_val = corrcoef(y_fused, y_ref).item()

    print("=== G vs G_fused ===")
    print(f"Max abs diff : {max_abs_diff:.6e}")
    print(f"Mean abs diff: {mean_abs_diff:.6e}")
    print(f"RRMSE        : {rrmse_val:.6e}")
    print(f"Corr coef    : {corr_val:.6f}")

=== G vs G_fused ===
Max abs diff : 1.086950e-03
Mean abs diff: 2.378545e-04
RRMSE        : 5.670033e-04
Corr coef    : 1.000000


# Evaluation

In [33]:
def run_and_compare(G_fp32, G_q, x_fp32, x_q):
    G_fp32.eval()
    G_q.eval()

    with torch.no_grad():
        y_ref = G_fp32(x_fp32)
        y_q   = G_q(x_q)

        max_abs = torch.max(torch.abs(y_ref - y_q)).item()
        mean_abs = torch.mean(torch.abs(y_ref - y_q)).item()
        rrmse_v = rrmse(y_q, y_ref).item()
        corr_v  = corrcoef(y_q, y_ref).item()

    print("\n=== Fixed-point vs FP32 ===")
    print(f"Max abs diff : {max_abs:.6e}")
    print(f"Mean abs diff: {mean_abs:.6e}")
    print(f"RRMSE        : {rrmse_v:.6e}")
    print(f"Corr coef    : {corr_v:.6f}")

    return y_q

def print_qcfg_limits(qcfg):
    itg  = qcfg["int_bits"]   # includes sign
    frac = qcfg["frac_bits"]
    dt   = qcfg["dtype"]

    total_bits = np.dtype(dt).itemsize * 8
    scale = 1 << frac

    # signed integer limits from dtype
    int_min = -(1 << (total_bits - 1))
    int_max =  (1 << (total_bits - 1)) - 1

    real_min = int_min / scale
    real_max = int_max / scale

    # sanity: m = total_bits - frac
    # assert itg == total_bits - frac, \
    #     f"Mismatch: Q{itg}.{frac} but dtype implies Q{total_bits-frac}.{frac}"

    print("=== Fixed-point format ===")
    print(f"Q{itg}.{frac} ({dt.__name__})")
    print(f"Total bits         : {total_bits}")
    print(f"Scale factor       : 2^{frac} = {scale}")
    print(f"Integer range      : [{int_min}, {int_max}]")
    print(f"Real value range   : [{real_min:.6f}, {real_max:.6f}]")
    print()

def attach_activation_range_hooks(model, qcfg):
    hooks = []

    itg  = qcfg["int_bits"]   # includes sign
    frac = qcfg["frac_bits"]

    real_min = -(1 << (itg - 1))
    real_max =  (1 << (itg - 1)) - (1 / (1 << frac))

    def hook_fn(name):
        def fn(module, inp, out):
            out_min = out.min().item()
            out_max = out.max().item()

            print(f"[{name}] output min/max: "
                  f"{out_min:+.6f}, {out_max:+.6f}", end="")

            if out_min < real_min or out_max > real_max:
                print("  ⚠️  OUT OF RANGE")
            else:
                print("  ✓")

        return fn

    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
            hooks.append(module.register_forward_hook(hook_fn(name)))

    return hooks



In [34]:
# "Q4.12" or "Q10.10"
qcfg = Q_CONFIGS["Q10.10"]

print_qcfg_limits(qcfg)

G_q = quantize_model_fixed_point(G_fused, qcfg, verbose=False)
x_q = quantize_input_fixed_point(x, qcfg)
hooks = attach_activation_range_hooks(G_q, qcfg)

y_q = run_and_compare(G_fused, G_q, x, x_q)

for h in hooks:
    h.remove()

=== Fixed-point format ===
Q10.10 (int32)
Total bits         : 32
Scale factor       : 2^10 = 1024
Integer range      : [-2147483648, 2147483647]
Real value range   : [-2097152.000000, 2097151.999023]


=== Quantizing model to Q10.10 ===

=== Input quantization ===
x_q min/max: -4128 5189
[e1.conv] output min/max: -5.064058, +5.926579  ✓
[e2.conv] output min/max: -5.975633, +4.874141  ✓
[e3.conv] output min/max: -4.102404, +3.927611  ✓
[e4.conv] output min/max: -3.861659, +3.760465  ✓
[bottleneck.0.0] output min/max: -3.759048, +3.915655  ✓
[bottleneck.0.2] output min/max: -3.849937, +3.924349  ✓
[bottleneck.1.0] output min/max: -3.774267, +4.282374  ✓
[bottleneck.1.2] output min/max: -4.456840, +3.892466  ✓
[bottleneck.2.0] output min/max: -4.465076, +3.977906  ✓
[bottleneck.2.2] output min/max: -3.880110, +4.034831  ✓
[bottleneck.3.0] output min/max: -3.649748, +3.954165  ✓
[bottleneck.3.2] output min/max: -4.985993, +3.991021  ✓
[d1.deconv] output min/max: -3.993756, +4.326971  ✓
[d