In [1]:
import os
import numpy as np
import pandas as pd
import scipy.io as sio
import matplotlib.pyplot as plt
import math
import json
import re
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

import random
np.random.seed(42)
random.seed(42)

# Data

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 1234

# Python & NumPy
random.seed(SEED)
np.random.seed(SEED)

# PyTorch
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Determinism flags
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
encoder_input = torch.randn(1, 512, device=device)
bottleneck_input = torch.randn(256, 32, device=device)
decoder_first_input = torch.randn(256, 32, device=device)
decoder_last_input = torch.randn(64, 256, device=device)
out_input = torch.randn(16, 512, device=device)

# Model Definition

In [4]:
class ConvBlock1D(nn.Module):
    """
    Conv1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=15, s=2, p=7, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError("act must be 'lrelu' or 'relu'")

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class DeconvBlock1D(nn.Module):
    """
    ConvTranspose1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=4, s=2, p=1, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError("act must be 'relu' or 'lrelu'")

    def forward(self, x):
        return self.act(self.norm(self.deconv(x)))


class ResBlock1D(nn.Module):
    """
    Residual block: (Conv -> BN -> ReLU) x2 + skip
    Keeps same channel count and length.
    """
    def __init__(self, ch, k=7, p=3, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n1 = nn.BatchNorm1d(ch)
        self.c2 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n2 = nn.BatchNorm1d(ch)

    def forward(self, x):
        h = F.relu(self.n1(self.c1(x)))
        h = self.n2(self.c2(h))
        return F.relu(x + h)
    
class MultiScaleResBlock1D(nn.Module):
    """
    Multi-scale residual block: parallel conv branches (k=3,5,7) then fuse.
    Keeps same channel count and length.
    """
    def __init__(self, ch, bias=True):
        super().__init__()

        self.b3 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=3, padding=1, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b5 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=5, padding=2, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b7 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=7, padding=3, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )

        self.fuse = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=1, bias=bias),
            nn.BatchNorm1d(ch),
        )

    def forward(self, x):
        h = self.b3(x) + self.b5(x) + self.b7(x)
        h = self.fuse(h)
        return F.relu(x + h)



In [5]:
# NN Generator (U-Net-ish + Res bottleneck)
class GeneratorCNNWGAN(nn.Module):
    """
    CNN U-Net-ish generator for EEG denoising (WGAN).
    Input : (B, 1, 512) noisy_norm
    Output: (B, 1, 512) clean_norm_hat
    """
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        # Encoder
        self.e1 = ConvBlock1D(1, base_ch,       k=16, s=2, p=7, bias=bias, act="lrelu")      # 512 -> 256
        self.e2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.e3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.e4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32

        # Bottleneck
        bn_ch = base_ch * 8
        self.bottleneck = nn.Sequential(*[
            ResBlock1D(bn_ch, k=7, p=3, bias=bias) for _ in range(bottleneck_blocks)
        ])

        # Decoder (concat doubles channels)
        self.d1 = DeconvBlock1D(bn_ch, base_ch*4,   k=4, s=2, p=1, bias=bias, act="relu")     # 32 -> 64
        self.d2 = DeconvBlock1D(base_ch*8, base_ch*2, k=4, s=2, p=1,bias=bias, act="relu")   # 64 -> 128
        self.d3 = DeconvBlock1D(base_ch*4, base_ch,   k=4, s=2, p=1, bias=bias, act="relu")   # 128 -> 256
        self.d4 = DeconvBlock1D(base_ch*2, base_ch//2, k=4, s=2, p=1, bias=bias, act="relu")  # 256 -> 512

        # Head (linear output recommended for normalized signals)
        self.out = nn.Conv1d(base_ch//2, 1, kernel_size=7, stride=1, padding=3, bias=bias)

    def forward(self, y):
        # Encoder
        s1 = self.e1(y)   # (B, base, 256)
        s2 = self.e2(s1)  # (B, 2b, 128)
        s3 = self.e3(s2)  # (B, 4b, 64)
        s4 = self.e4(s3)  # (B, 8b, 32)

        # Bottleneck
        b = self.bottleneck(s4)

        # Decoder + skip connections
        d1 = self.d1(b)                  # (B, 4b, 64)
        d1 = torch.cat([d1, s3], dim=1)  # (B, 8b, 64)

        d2 = self.d2(d1)                 # (B, 2b, 128)
        d2 = torch.cat([d2, s2], dim=1)  # (B, 4b, 128)

        d3 = self.d3(d2)                 # (B, b, 256)
        d3 = torch.cat([d3, s1], dim=1)  # (B, 2b, 256)

        d4 = self.d4(d3)                 # (B, b/2, 512)

        return self.out(d4)              # (B, 1, 512)
    
# Patch Critic (shared by CNN/ResCNN)
class CriticPatch1D(nn.Module):
    """
    Conditional PatchGAN critic for WGAN:
      D(y, x) -> patch scores
    y,x: (B,1,512)
    output: (B,1,32)
    """
    def __init__(self, base_ch=32, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(2, base_ch, kernel_size=16, stride=2, padding=7, bias=bias)  # 512 -> 256
        self.c2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.c3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.c4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32
        self.out = nn.Conv1d(base_ch*8, 1, kernel_size=7, stride=1, padding=3, bias=bias)   # 32 -> 32

    def forward(self, y, x):
        h = torch.cat([y, x], dim=1)  # (B,2,512)
        h = F.leaky_relu(self.c1(h), 0.2, inplace=True)
        h = self.c2(h)
        h = self.c3(h)
        h = self.c4(h)
        return self.out(h)


# Model Import

In [6]:
BIAS = True
DATA_MODE = 5 # up to 5

In [7]:
data_path = os.path.abspath(f"../models/main3_d{DATA_MODE}_{"b" if BIAS else "nb"}/")
print(data_path)

pattern = re.compile(r"^cnn_([DG])_\d{8}_\d{6}\.pth$")

cnn_G_path = None
cnn_D_path = None

for f in os.listdir(data_path):
    m = pattern.match(f)
    if m:
        full = os.path.join(data_path, f)
        if m.group(1) == "G":
            cnn_G_path = full
        else:
            cnn_D_path = full

print("G:", cnn_G_path)
print("D:", cnn_D_path)

c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d5_b
G: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d5_b\cnn_G_20260114_024826.pth
D: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d5_b\cnn_D_20260114_024826.pth


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = GeneratorCNNWGAN(bias=BIAS).to(device)
D = CriticPatch1D(bias=BIAS).to(device)

G.load_state_dict(torch.load(cnn_G_path, map_location=device))
D.load_state_dict(torch.load(cnn_D_path, map_location=device))

print(G.eval())
print(D.eval())

GeneratorCNNWGAN(
  (e1): ConvBlock1D(
    (conv): Conv1d(1, 32, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e2): ConvBlock1D(
    (conv): Conv1d(32, 64, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e3): ConvBlock1D(
    (conv): Conv1d(64, 128, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e4): ConvBlock1D(
    (conv): Conv1d(128, 256, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  

  G.load_state_dict(torch.load(cnn_G_path, map_location=device))
  D.load_state_dict(torch.load(cnn_D_path, map_location=device))


In [9]:
Q_CONFIGS = {
    "Q4.12": dict(frac_bits=12, int_bits=4, dtype=np.int16),
    "Q10.10": dict(frac_bits=10, int_bits=10, dtype=np.int32),
    "Q9.14": dict(frac_bits=14, int_bits=10, dtype=np.int32),
}

def float_to_q(x, frac_bits, int_bits, dtype):
    scale = 1 << frac_bits
    total_bits = int_bits + frac_bits
    min_val = -(1 << (total_bits - 1))
    max_val = (1 << (total_bits - 1)) - 1

    xq = np.round(x * scale)
    xq = np.clip(xq, min_val, max_val)
    return xq.astype(dtype)

def q_to_float(x, frac_bits):
    return x.astype(np.float32) / (1 << frac_bits)

In [10]:
TYPE =  "Q9.14" # "Q4.12", "Q10.10","Q9.14", or "FLOAT"

TIME_REPEAT = {
    "encoder":    [256, 128, 64, 32],
    "bottleneck": [32]*8,
    "decoder":    [64, 128, 256, 512],
    "out":        [512]
}

if TYPE == "FLOAT":
    MODE = "FLOAT"
    DTYPE = np.float32
else:
    MODE = "FIXED"
    cfg = Q_CONFIGS[TYPE]
    FRAC_BITS = cfg["frac_bits"]
    INT_BITS  = cfg["int_bits"]
    DTYPE     = cfg["dtype"]

# Result

In [11]:
output_path = os.path.abspath(f"../models/sample_output_{TYPE}/")
os.makedirs(output_path, exist_ok=True)

decoder_last_input_batched = decoder_last_input.unsqueeze(0)

In [12]:
G.eval()

with torch.no_grad():
    out_deconv_ref = G.d4.deconv(decoder_last_input_batched)
    out_bn_ref     = G.d4.norm(out_deconv_ref)
    out_d4_ref     = G.d4.act(out_bn_ref)

out_2d = out_d4_ref.squeeze(0)
out_f = out_2d.detach().cpu().numpy().astype(np.float32)

assert out_f.ndim == 2, "Expected 2D output [x, y]"
X, Y = out_f.shape

print("===== DECODER LAST OUTPUT =====")
print(f"Shape      : ({X}, {Y})")
print(f"Total vals : {X * Y}")

===== DECODER LAST OUTPUT =====
Shape      : (16, 512)
Total vals : 8192


In [13]:
# ---- QUANTIZE ----
out_q = float_to_q(
    out_f,
    frac_bits=FRAC_BITS,
    int_bits=INT_BITS,
    dtype=DTYPE
)

mem_name = f"decoder_last_output_d{DATA_MODE}.mem"
mem_path = os.path.join(output_path, mem_name)

with open(mem_path, "w") as f:
    for y in range(Y):
        for x in range(X):
            v = int(out_q[x, y]) & 0xFFFFFFFF
            f.write(f"{v:08X}\n")

print(f"Saved: {mem_path}")

Saved: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\sample_output_Q9.14\decoder_last_output_d5.mem
