# Import Library

In [1]:
import os
import numpy as np
import pandas as pd
import scipy.io as sio
import matplotlib.pyplot as plt
import math
import json
import re

import torch
import torch.nn as nn
import torch.nn.functional as F

import random
np.random.seed(42)
random.seed(42)

# Data Import

In [2]:
ENV = "local" # or "local"

In [3]:
if ENV == "kaggle":
    EEG = np.load("/kaggle/input/eeg-denoise-net/EEG_all_epochs.npy")
    EOG = np.load("/kaggle/input/eeg-denoise-net/EOG_all_epochs.npy")
    EMG = np.load("/kaggle/input/eeg-denoise-net/EMG_all_epochs.npy")
else:
    data_path = os.path.abspath("../../data/eeg_denoise_net/")
    print(data_path)
    
    EEG = np.load(os.path.join(data_path, "EEG_all_epochs.npy"))
    EOG = np.load(os.path.join(data_path, "EOG_all_epochs.npy"))
    EMG = np.load(os.path.join(data_path, "EMG_all_epochs.npy"))

print("EEG shape:", EEG.shape)
print("EOG shape:", EOG.shape)
print("EMG shape:", EMG.shape)

c:\Users\Aryo\PersonalMade\Programming\GAN\repo\data\eeg_denoise_net
EEG shape: (4514, 512)
EOG shape: (3400, 512)
EMG shape: (5598, 512)


# Data Setup

In [4]:
def rms(x: np.ndarray) -> float:
    """
    Root Mean Square (RMS) of a 1D signal.
    Implements Formula (3).
    """
    x = np.asarray(x)
    return np.sqrt(np.mean(x ** 2))

def snr_db(clean: np.ndarray, noise: np.ndarray) -> float:
    """
    Compute SNR in decibels.
    Implements Formula (2).
    """
    return 10 * np.log10(rms(clean) / rms(noise))

def compute_lambda(clean: np.ndarray, noise: np.ndarray, target_snr_db: float) -> float:
    """
    Compute lambda such that the mixed signal has the desired SNR (in dB).
    """
    return rms(clean) / rms(noise) * 10 ** (-target_snr_db / 10)

In [5]:
def mix_signals(clean: np.ndarray, noise: np.ndarray, target_snr_db: float):
    """
    Mix clean signal with noise at a target SNR (dB).
    
    Returns:
        mixed_signal
        lambda_used
    """
    lam = compute_lambda(clean, noise, target_snr_db)
    mixed = clean + lam * noise
    return mixed, lam

def noisy_eeg_eog(eeg: np.ndarray, eog: np.ndarray, target_snr_db: float):
    """
    EEG contaminated by ocular artifacts (EOG).
    """
    return mix_signals(eeg, eog, target_snr_db)

def noisy_eeg_emg(eeg: np.ndarray, emg: np.ndarray, target_snr_db: float):
    """
    EEG contaminated by myogenic artifacts (EMG).
    """
    return mix_signals(eeg, emg, target_snr_db)

def noisy_eeg_eog_emg(
    eeg: np.ndarray,
    eog: np.ndarray,
    emg: np.ndarray,
    target_snr_db: float,
    eog_weight: float = 1.0,
    emg_weight: float = 1.0
):
    """
    EEG contaminated by both ocular (EOG) and myogenic (EMG) artifacts.
    
    eog_weight / emg_weight allow control of relative artifact dominance.
    """
    combined_noise = eog_weight * eog + emg_weight * emg
    return mix_signals(eeg, combined_noise, target_snr_db)


In [6]:
def make_noisy_sample(
    noisy_signal,
    eeg_idx,
    eog_idx=None,
    emg_idx=None,
    snr_db=None,
    lambda_used=None
):
    return {
        "noisy": noisy_signal,      # np.ndarray (512,)
        "eeg_idx": eeg_idx,         # int
        "eog_idx": eog_idx,         # int or None
        "emg_idx": emg_idx,         # int or None
        "snr_db": snr_db,           # float
        "lambda": lambda_used       # float
    }

def sample_snr_uniform(low, high, rng):
    return rng.uniform(low, high)

def fixed_snr_list(snrs, rng):
    return rng.choice(snrs)


# Mixers

In [7]:
def mix_eeg_eog_paper(
    EEG,
    EOG,
    seed=42,
    mode="random"  # "exhaustive" | "random"
):
    rng = np.random.default_rng(seed)

    eeg_train_idx = np.arange(3000)
    eeg_test_idx  = np.arange(3000, 3400)

    snr_grid = np.array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2])

    train_samples = []
    test_samples = []

    if mode == "exhaustive":
        # ---- training: 10x uniform SNR ----
        for _ in range(10):
            perm = rng.permutation(eeg_train_idx)
            for i in perm:
                snr = sample_snr_uniform(-7, 2, rng)
                noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

                train_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=i,
                        eog_idx=i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

        # ---- testing: fixed SNR grid ----
        for snr in snr_grid:
            for i in eeg_test_idx:
                noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

                test_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=i,
                        eog_idx=i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

    elif mode == "random":
        # ---- training: single pass, random SNR from grid ----
        for i in eeg_train_idx:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

            train_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=i,
                    eog_idx=i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

        # ---- testing ----
        for i in eeg_test_idx:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_eog(EEG[i], EOG[i], snr)

            test_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=i,
                    eog_idx=i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

    else:
        raise ValueError("mode must be 'exhaustive' or 'random'")

    return train_samples, test_samples


In [8]:
def mix_eeg_emg_paper(
    EEG,
    EMG,
    seed=42,
    mode="random"  # "exhaustive" | "random"
):
    rng = np.random.default_rng(seed)

    eeg_indices = rng.choice(len(EEG), size=len(EMG), replace=True)
    emg_indices = np.arange(len(EMG))

    pairs = list(zip(eeg_indices, emg_indices))
    rng.shuffle(pairs)

    train_pairs = pairs[:5000]
    test_pairs  = pairs[5000:5598]

    snr_grid = np.array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2])

    train_samples = []
    test_samples = []

    if mode == "exhaustive":
        # ---- training: 10x ----
        for _ in range(10):
            for eeg_i, emg_i in train_pairs:
                snr = sample_snr_uniform(-7, 2, rng)
                noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

                train_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=eeg_i,
                        emg_idx=emg_i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

        # ---- testing: fixed grid ----
        for snr in snr_grid:
            for eeg_i, emg_i in test_pairs:
                noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

                test_samples.append(
                    make_noisy_sample(
                        noisy,
                        eeg_idx=eeg_i,
                        emg_idx=emg_i,
                        snr_db=snr,
                        lambda_used=lam
                    )
                )

    elif mode == "random":
        # ---- training ----
        for eeg_i, emg_i in train_pairs:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

            train_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=eeg_i,
                    emg_idx=emg_i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

        # ---- testing ----
        for eeg_i, emg_i in test_pairs:
            snr = rng.choice(snr_grid)
            noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

            test_samples.append(
                make_noisy_sample(
                    noisy,
                    eeg_idx=eeg_i,
                    emg_idx=emg_i,
                    snr_db=snr,
                    lambda_used=lam
                )
            )

    else:
        raise ValueError("mode must be 'exhaustive' or 'random'")

    return train_samples, test_samples


In [9]:
def mix_custom(
    EEG,
    EOG=None,
    EMG=None,
    n_train=10000,
    n_test=2000,
    snr_range=(-7, 2),
    seed=42,
    eog_weight=1.0,
    emg_weight=1.0
):
    rng = np.random.default_rng(seed)

    samples = []

    def make_one():
        eeg_i = rng.integers(len(EEG))
        snr = rng.uniform(*snr_range)

        eog_i = None
        emg_i = None

        if EOG is not None and EMG is not None:
            eog_i = rng.integers(len(EOG))
            emg_i = rng.integers(len(EMG))
            noisy, lam = noisy_eeg_eog_emg(
                EEG[eeg_i],
                EOG[eog_i],
                EMG[emg_i],
                snr,
                eog_weight=eog_weight,
                emg_weight=emg_weight
            )

        elif EOG is not None:
            eog_i = rng.integers(len(EOG))
            noisy, lam = noisy_eeg_eog(EEG[eeg_i], EOG[eog_i], snr)

        elif EMG is not None:
            emg_i = rng.integers(len(EMG))
            noisy, lam = noisy_eeg_emg(EEG[eeg_i], EMG[emg_i], snr)

        else:
            raise ValueError("At least one of EOG or EMG must be provided.")

        return make_noisy_sample(
            noisy,
            eeg_idx=eeg_i,
            eog_idx=eog_i,
            emg_idx=emg_i,
            snr_db=snr,
            lambda_used=lam
        )

    for _ in range(n_train + n_test):
        samples.append(make_one())

    return samples[:n_train], samples[n_train:]

In [10]:
def normalize_noisy_sample(sample, EEG):
    """
    Normalize one noisy EEG sample according to EEGdenoiseNet protocol.
    
    Args:
        sample: dict produced by make_noisy_sample
        EEG: clean EEG array (for ground truth lookup)
    
    Returns:
        normalized_sample (new dict)
    """
    y = sample["noisy"]
    x = EEG[sample["eeg_idx"]]

    sigma_y = np.std(y)
    if sigma_y == 0:
        raise ValueError("Standard deviation of noisy signal is zero.")

    normalized_sample = sample.copy()
    normalized_sample.update({
        "noisy_norm": y / sigma_y,     # ŷ
        "clean_norm": x / sigma_y,     # x̂
        "sigma_y": sigma_y             # stored for rescaling later
    })

    return normalized_sample

def normalize_dataset(samples, EEG):
    """
    Normalize a list of noisy EEG samples.
    
    Args:
        samples: list of noisy sample dicts
        EEG: clean EEG array
    
    Returns:
        list of normalized sample dicts
    """
    return [normalize_noisy_sample(s, EEG) for s in samples]


# Data Config

In [11]:
# 1 = EEG + EOG Paper
# 2 = EEG + EMG Paper
# 3 = Custom EEG + EOG
# 4 = Custom EEG + EMG
# else Custom EEG + EOG + EMG
DATA_MODE = 1 

In [12]:
if DATA_MODE == 1: # EEG + EOG Paper
    # --- generate noisy datasets (raw, unnormalized) ---
    train_samples, test_samples = mix_eeg_eog_paper(
        EEG=EEG,
        EOG=EOG,
        seed=42
    )

    # --- normalize according to EEGdenoiseNet protocol ---
    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)
elif DATA_MODE == 2: # EEG + EMG Paper
    # --- generate noisy datasets (raw, unnormalized) ---
    train_samples, test_samples = mix_eeg_emg_paper(
        EEG=EEG,
        EMG=EMG,
        seed=42
    )

    # --- normalize ---
    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)
elif DATA_MODE == 3:
    train_samples, test_samples = mix_custom(
        EEG=EEG,
        EOG=EOG,
        n_train=10000,
        n_test=2000,
        snr_range=(-7, 2),
        seed=42
    )

    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)
elif DATA_MODE == 4:
    train_samples, test_samples = mix_custom(
        EEG=EEG,
        EMG=EMG,
        n_train=10000,
        n_test=2000,
        snr_range=(-7, 2),
        seed=42
    )

    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)

else:
    train_samples, test_samples = mix_custom(
        EEG=EEG,
        EOG=EOG,
        EMG=EMG,
        n_train=10000,
        n_test=2000,
        snr_range=(-7, 2),
        seed=42,
        eog_weight=1.0,
        emg_weight=1.0
    )

    train_samples_norm = normalize_dataset(train_samples, EEG)
    test_samples_norm  = normalize_dataset(test_samples, EEG)


## Model Definition

In [13]:
class ConvBlock1D(nn.Module):
    """
    Conv1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=15, s=2, p=7, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError("act must be 'lrelu' or 'relu'")

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class DeconvBlock1D(nn.Module):
    """
    ConvTranspose1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=4, s=2, p=1, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError("act must be 'relu' or 'lrelu'")

    def forward(self, x):
        return self.act(self.norm(self.deconv(x)))


class ResBlock1D(nn.Module):
    """
    Residual block: (Conv -> BN -> ReLU) x2 + skip
    Keeps same channel count and length.
    """
    def __init__(self, ch, k=7, p=3, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n1 = nn.BatchNorm1d(ch)
        self.c2 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n2 = nn.BatchNorm1d(ch)

    def forward(self, x):
        h = F.relu(self.n1(self.c1(x)))
        h = self.n2(self.c2(h))
        return F.relu(x + h)
    
class MultiScaleResBlock1D(nn.Module):
    """
    Multi-scale residual block: parallel conv branches (k=3,5,7) then fuse.
    Keeps same channel count and length.
    """
    def __init__(self, ch, bias=True):
        super().__init__()

        self.b3 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=3, padding=1, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b5 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=5, padding=2, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b7 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=7, padding=3, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )

        self.fuse = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=1, bias=bias),
            nn.BatchNorm1d(ch),
        )

    def forward(self, x):
        h = self.b3(x) + self.b5(x) + self.b7(x)
        h = self.fuse(h)
        return F.relu(x + h)



In [14]:
# NN Generator (U-Net-ish + Res bottleneck)
class GeneratorCNNWGAN(nn.Module):
    """
    CNN U-Net-ish generator for EEG denoising (WGAN).
    Input : (B, 1, 512) noisy_norm
    Output: (B, 1, 512) clean_norm_hat
    """
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        # Encoder
        self.e1 = ConvBlock1D(1, base_ch,       k=16, s=2, p=7, bias=bias, act="lrelu")      # 512 -> 256
        self.e2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.e3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.e4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32

        # Bottleneck
        bn_ch = base_ch * 8
        self.bottleneck = nn.Sequential(*[
            ResBlock1D(bn_ch, k=7, p=3, bias=bias) for _ in range(bottleneck_blocks)
        ])

        # Decoder (concat doubles channels)
        self.d1 = DeconvBlock1D(bn_ch, base_ch*4,   k=4, s=2, p=1, bias=bias, act="relu")     # 32 -> 64
        self.d2 = DeconvBlock1D(base_ch*8, base_ch*2, k=4, s=2, p=1,bias=bias, act="relu")   # 64 -> 128
        self.d3 = DeconvBlock1D(base_ch*4, base_ch,   k=4, s=2, p=1, bias=bias, act="relu")   # 128 -> 256
        self.d4 = DeconvBlock1D(base_ch*2, base_ch//2, k=4, s=2, p=1, bias=bias, act="relu")  # 256 -> 512

        # Head (linear output recommended for normalized signals)
        self.out = nn.Conv1d(base_ch//2, 1, kernel_size=7, stride=1, padding=3, bias=bias)

    def forward(self, y):
        # Encoder
        s1 = self.e1(y)   # (B, base, 256)
        s2 = self.e2(s1)  # (B, 2b, 128)
        s3 = self.e3(s2)  # (B, 4b, 64)
        s4 = self.e4(s3)  # (B, 8b, 32)

        # Bottleneck
        b = self.bottleneck(s4)

        # Decoder + skip connections
        d1 = self.d1(b)                  # (B, 4b, 64)
        d1 = torch.cat([d1, s3], dim=1)  # (B, 8b, 64)

        d2 = self.d2(d1)                 # (B, 2b, 128)
        d2 = torch.cat([d2, s2], dim=1)  # (B, 4b, 128)

        d3 = self.d3(d2)                 # (B, b, 256)
        d3 = torch.cat([d3, s1], dim=1)  # (B, 2b, 256)

        d4 = self.d4(d3)                 # (B, b/2, 512)

        return self.out(d4)              # (B, 1, 512)

In [15]:
# Patch Critic (shared by CNN/ResCNN)
class CriticPatch1D(nn.Module):
    """
    Conditional PatchGAN critic for WGAN:
      D(y, x) -> patch scores
    y,x: (B,1,512)
    output: (B,1,32)
    """
    def __init__(self, base_ch=32, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(2, base_ch, kernel_size=16, stride=2, padding=7, bias=bias)  # 512 -> 256
        self.c2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.c3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.c4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32
        self.out = nn.Conv1d(base_ch*8, 1, kernel_size=7, stride=1, padding=3, bias=bias)   # 32 -> 32

    def forward(self, y, x):
        h = torch.cat([y, x], dim=1)  # (B,2,512)
        h = F.leaky_relu(self.c1(h), 0.2, inplace=True)
        h = self.c2(h)
        h = self.c3(h)
        h = self.c4(h)
        return self.out(h)


# Model Import

In [16]:
BIAS = True

In [17]:
data_path = os.path.abspath(f"../models/main3_d{DATA_MODE}_{"b" if BIAS else "nb"}/")
print(data_path)

pattern = re.compile(r"^cnn_([DG])_\d{8}_\d{6}\.pth$")

cnn_G_path = None
cnn_D_path = None

for f in os.listdir(data_path):
    m = pattern.match(f)
    if m:
        full = os.path.join(data_path, f)
        if m.group(1) == "G":
            cnn_G_path = full
        else:
            cnn_D_path = full

print("G:", cnn_G_path)
print("D:", cnn_D_path)

c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d1_b
G: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d1_b\cnn_G_20260113_074150.pth
D: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d1_b\cnn_D_20260113_074150.pth


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = GeneratorCNNWGAN(bias=BIAS).to(device)
D = CriticPatch1D(bias=BIAS).to(device)

G.load_state_dict(torch.load(cnn_G_path, map_location=device))
D.load_state_dict(torch.load(cnn_D_path, map_location=device))

print(G.eval())
print(D.eval())


GeneratorCNNWGAN(
  (e1): ConvBlock1D(
    (conv): Conv1d(1, 32, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e2): ConvBlock1D(
    (conv): Conv1d(32, 64, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e3): ConvBlock1D(
    (conv): Conv1d(64, 128, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e4): ConvBlock1D(
    (conv): Conv1d(128, 256, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  

  G.load_state_dict(torch.load(cnn_G_path, map_location=device))
  D.load_state_dict(torch.load(cnn_D_path, map_location=device))


In [19]:
for name, module in G.named_modules():
    print(name, "->", module.__class__.__name__)

print("\n" + "=" * 80 +"\n")

for name, module in D.named_modules():
    print(name, "->", module.__class__.__name__)

 -> GeneratorCNNWGAN
e1 -> ConvBlock1D
e1.conv -> Conv1d
e1.norm -> BatchNorm1d
e1.act -> LeakyReLU
e2 -> ConvBlock1D
e2.conv -> Conv1d
e2.norm -> BatchNorm1d
e2.act -> LeakyReLU
e3 -> ConvBlock1D
e3.conv -> Conv1d
e3.norm -> BatchNorm1d
e3.act -> LeakyReLU
e4 -> ConvBlock1D
e4.conv -> Conv1d
e4.norm -> BatchNorm1d
e4.act -> LeakyReLU
bottleneck -> Sequential
bottleneck.0 -> ResBlock1D
bottleneck.0.c1 -> Conv1d
bottleneck.0.n1 -> BatchNorm1d
bottleneck.0.c2 -> Conv1d
bottleneck.0.n2 -> BatchNorm1d
bottleneck.1 -> ResBlock1D
bottleneck.1.c1 -> Conv1d
bottleneck.1.n1 -> BatchNorm1d
bottleneck.1.c2 -> Conv1d
bottleneck.1.n2 -> BatchNorm1d
bottleneck.2 -> ResBlock1D
bottleneck.2.c1 -> Conv1d
bottleneck.2.n1 -> BatchNorm1d
bottleneck.2.c2 -> Conv1d
bottleneck.2.n2 -> BatchNorm1d
bottleneck.3 -> ResBlock1D
bottleneck.3.c1 -> Conv1d
bottleneck.3.n1 -> BatchNorm1d
bottleneck.3.c2 -> Conv1d
bottleneck.3.n2 -> BatchNorm1d
d1 -> DeconvBlock1D
d1.deconv -> ConvTranspose1d
d1.norm -> BatchNorm1d
d1

In [20]:
for name, param in G.named_parameters():
    print(name, param.shape, param.requires_grad)

print("\n" + "=" * 80 +"\n")

for name, param in D.named_parameters():
    print(name, param.shape, param.requires_grad)


e1.conv.weight torch.Size([32, 1, 16]) True
e1.conv.bias torch.Size([32]) True
e1.norm.weight torch.Size([32]) True
e1.norm.bias torch.Size([32]) True
e2.conv.weight torch.Size([64, 32, 16]) True
e2.conv.bias torch.Size([64]) True
e2.norm.weight torch.Size([64]) True
e2.norm.bias torch.Size([64]) True
e3.conv.weight torch.Size([128, 64, 16]) True
e3.conv.bias torch.Size([128]) True
e3.norm.weight torch.Size([128]) True
e3.norm.bias torch.Size([128]) True
e4.conv.weight torch.Size([256, 128, 16]) True
e4.conv.bias torch.Size([256]) True
e4.norm.weight torch.Size([256]) True
e4.norm.bias torch.Size([256]) True
bottleneck.0.c1.weight torch.Size([256, 256, 7]) True
bottleneck.0.c1.bias torch.Size([256]) True
bottleneck.0.n1.weight torch.Size([256]) True
bottleneck.0.n1.bias torch.Size([256]) True
bottleneck.0.c2.weight torch.Size([256, 256, 7]) True
bottleneck.0.c2.bias torch.Size([256]) True
bottleneck.0.n2.weight torch.Size([256]) True
bottleneck.0.n2.bias torch.Size([256]) True
bottlene

In [21]:
for name, buf in G.named_buffers():
    print(name, buf.shape)

print("\n" + "=" * 80 +"\n")

for name, buf in D.named_buffers():
    print(name, buf.shape)

e1.norm.running_mean torch.Size([32])
e1.norm.running_var torch.Size([32])
e1.norm.num_batches_tracked torch.Size([])
e2.norm.running_mean torch.Size([64])
e2.norm.running_var torch.Size([64])
e2.norm.num_batches_tracked torch.Size([])
e3.norm.running_mean torch.Size([128])
e3.norm.running_var torch.Size([128])
e3.norm.num_batches_tracked torch.Size([])
e4.norm.running_mean torch.Size([256])
e4.norm.running_var torch.Size([256])
e4.norm.num_batches_tracked torch.Size([])
bottleneck.0.n1.running_mean torch.Size([256])
bottleneck.0.n1.running_var torch.Size([256])
bottleneck.0.n1.num_batches_tracked torch.Size([])
bottleneck.0.n2.running_mean torch.Size([256])
bottleneck.0.n2.running_var torch.Size([256])
bottleneck.0.n2.num_batches_tracked torch.Size([])
bottleneck.1.n1.running_mean torch.Size([256])
bottleneck.1.n1.running_var torch.Size([256])
bottleneck.1.n1.num_batches_tracked torch.Size([])
bottleneck.1.n2.running_mean torch.Size([256])
bottleneck.1.n2.running_var torch.Size([256])

In [22]:
for name, module in G.named_modules():
    if isinstance(module, torch.nn.BatchNorm1d):
        print(f"\nBN layer: {name}")
        print(" running_mean:", module.running_mean.shape)
        print(" running_var :", module.running_var.shape)
        print(" momentum    :", module.momentum)
        print(" eps         :", module.eps)
        print(" affine      :", module.affine)



BN layer: e1.norm
 running_mean: torch.Size([32])
 running_var : torch.Size([32])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: e2.norm
 running_mean: torch.Size([64])
 running_var : torch.Size([64])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: e3.norm
 running_mean: torch.Size([128])
 running_var : torch.Size([128])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: e4.norm
 running_mean: torch.Size([256])
 running_var : torch.Size([256])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: bottleneck.0.n1
 running_mean: torch.Size([256])
 running_var : torch.Size([256])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: bottleneck.0.n2
 running_mean: torch.Size([256])
 running_var : torch.Size([256])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: bottleneck.1.n1
 running_mean: torch.Size([256])
 running_var : torch.Size([256])
 momentum    : 0.1
 eps  

# Test Data

In [23]:
def rrmse(x_hat, x):
    return torch.sqrt(torch.mean((x_hat - x) ** 2)) / torch.sqrt(torch.mean(x ** 2))


def corrcoef(x_hat, x):
    x_hat = x_hat - x_hat.mean()
    x = x - x.mean()
    return torch.sum(x_hat * x) / (
        torch.sqrt(torch.sum(x_hat ** 2)) * torch.sqrt(torch.sum(x ** 2)) + 1e-8
    )


In [24]:
print(len(test_samples_norm[0]["noisy_norm"]))

sample = test_samples_norm[0]

512


## Generator Inference

In [25]:
y = torch.tensor(sample["noisy_norm"], dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, T)
x = torch.tensor(sample["clean_norm"], dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

G.eval()

with torch.no_grad():
    x_hat_ref = G(y)

## Manual Inference

In [26]:
def free(*tensors):
    for t in tensors:
        del t
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [27]:
def manual_transconv(
    x,              # (N, Cin, T_in)
    weight_value,   # (Cin, Cout, K)
    bias_value=None,# (Cout,) or None
    s=2,
    p=1,
    output_padding=0
):
    """
    Manual ConvTranspose1d (inference / eval semantics)

    Matches PyTorch nn.ConvTranspose1d:
    - weight shape: (Cin, Cout, K)
    """

    N, Cin, T_in = x.shape
    Cin_w, Cout, K = weight_value.shape
    assert Cin == Cin_w

    # PyTorch output length formula
    T_out = (T_in - 1) * s - 2 * p + K + output_padding

    out = torch.zeros(
        (N, Cout, T_out),
        device=x.device,
        dtype=x.dtype
    )

    for n in range(N):
        for ci in range(Cin):
            for ti in range(T_in):
                base = ti * s - p

                for k in range(K):
                    t_out = base + k
                    if 0 <= t_out < T_out:
                        for co in range(Cout):
                            out[n, co, t_out] += (
                                x[n, ci, ti]
                                * weight_value[ci, co, k]
                            )

    if bias_value is not None:
        out += bias_value[None, :, None]

    return out

In [28]:
def manual_batchnorm1d(
    x,              # (N, C, T)
    running_mean,   # (C,)
    running_var,    # (C,)
    weight=None,    # (C,) or None
    bias=None,      # (C,) or None
    eps=1e-5
):
    N, C, T = x.shape

    out = torch.zeros_like(x)

    for n in range(N):
        for c in range(C):
            mean = running_mean[c]
            var = running_var[c]
            denom = torch.sqrt(var + eps)

            for t in range(T):
                y = (x[n, c, t] - mean) / denom

                if weight is not None:
                    y = y * weight[c]

                if bias is not None:
                    y = y + bias[c]

                out[n, c, t] = y

    return out

def manual_batchnorm1d_fast(x, bn):
    mean = bn.running_mean[None, :, None]
    var = bn.running_var[None, :, None]

    y = (x - mean) / torch.sqrt(var + bn.eps)

    if bn.affine:
        y = y * bn.weight[None, :, None] + bn.bias[None, :, None]

    return y


In [29]:
def manual_relu(x):
    out = torch.zeros_like(x)

    mask_pos = x >= 0
    mask_neg = x < 0

    out[mask_pos] = x[mask_pos]
    out[mask_neg] = 0

    return out

def manual_leaky_relu(x, negative_slope=0.2):
    out = torch.zeros_like(x)

    mask_pos = x >= 0
    mask_neg = x < 0

    out[mask_pos] = x[mask_pos]
    out[mask_neg] = negative_slope * x[mask_neg]

    return out


In [30]:
def fuse_deconv_bn_1d(
    deconv_weight,     # (Cin, Cout, K)
    deconv_bias,       # (Cout,) or None
    running_mean,      # (Cout,)
    running_var,       # (Cout,)
    bn_weight,         # (Cout,)
    bn_bias,           # (Cout,)
    eps
):
    Cin, Cout, K = deconv_weight.shape

    if deconv_bias is None:
        deconv_bias = torch.zeros(
            Cout,
            device=deconv_weight.device,
            dtype=deconv_weight.dtype
        )

    # BN scale
    scale = bn_weight / torch.sqrt(running_var + eps)  # (Cout,)

    # Fuse weights (scale on Cout dimension)
    fused_weight = deconv_weight * scale.view(1, Cout, 1)

    # Fuse bias
    fused_bias = (deconv_bias - running_mean) * scale + bn_bias

    return fused_weight, fused_bias


In [32]:
# --- MANUAL D4 BLOCK (FUSED DECONV + BN) ---
G.eval()

print("G.training:", G.training)
print("d4.training:", G.d4.training)
print("d4.deconv.training:", G.d4.deconv.training)
print("d4.norm.training:", G.d4.norm.training)
print("d4.act.training:", G.d4.act.training)

# -------------------------------------------------
# Create VALID random input for d4
# -------------------------------------------------
B = 1
device = next(G.parameters()).device

x = torch.randn(B, 64, 256, device=device)

# -------------------------------------------------
# 1) Fuse deconv + BN
# -------------------------------------------------
W_fused, b_fused = fuse_deconv_bn_1d(
    G.d4.deconv.weight,
    G.d4.deconv.bias,
    G.d4.norm.running_mean,
    G.d4.norm.running_var,
    G.d4.norm.weight,
    G.d4.norm.bias,
    eps=G.d4.norm.eps
)

# -------------------------------------------------
# 2) Fused Transposed Conv
# -------------------------------------------------
out_fused_manual = manual_transconv(
    x,
    W_fused,
    b_fused,
    s=G.d4.deconv.stride[0],
    p=G.d4.deconv.padding[0],
    output_padding=G.d4.deconv.output_padding[0]
)

# -------------------------------------------------
# 3) ReLU
# -------------------------------------------------
out_d4_manual = manual_relu(out_fused_manual)

# -------------------------------------------------
# PyTorch reference
# -------------------------------------------------
with torch.no_grad():
    out_deconv_ref = G.d4.deconv(x)
    out_bn_ref     = G.d4.norm(out_deconv_ref)
    out_d4_ref     = G.d4.act(out_bn_ref)

# -------------------------------------------------
# Verification
# -------------------------------------------------
print("Fused Conv vs (Conv+BN) max abs diff:",
      (out_bn_ref - out_fused_manual).abs().max().item())

diff = (out_d4_ref - out_d4_manual).abs()
print("E1 fused max abs diff:", diff.max().item())
print("E1 fused mean abs diff:", diff.mean().item())
print("Final RMSE:", torch.sqrt(torch.mean(diff ** 2)).item())


G.training: False
d4.training: False
d4.deconv.training: False
d4.norm.training: False
d4.act.training: False
Fused Conv vs (Conv+BN) max abs diff: 3.7915029525756836
E1 fused max abs diff: 0.0008767843246459961
E1 fused mean abs diff: 0.0001021545467665419
Final RMSE: 0.00016998803766909987
