In [2]:
import os
import numpy as np
import pandas as pd
import scipy.io as sio
import matplotlib.pyplot as plt
import math
import json
import re
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

import random
np.random.seed(42)
random.seed(42)

# Data

In [3]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 1234

# Python & NumPy
random.seed(SEED)
np.random.seed(SEED)

# PyTorch
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Determinism flags
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
encoder_input = torch.randn(1, 512, device=device)
bottleneck_input = torch.randn(256, 32, device=device)
decoder_first_input = torch.randn(256, 32, device=device)
decoder_last_input = torch.randn(64, 256, device=device)
out_input = torch.randn(16, 512, device=device)

# Model Definition

In [5]:
class ConvBlock1D(nn.Module):
    """
    Conv1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=15, s=2, p=7, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError("act must be 'lrelu' or 'relu'")

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class DeconvBlock1D(nn.Module):
    """
    ConvTranspose1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=4, s=2, p=1, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError("act must be 'relu' or 'lrelu'")

    def forward(self, x):
        return self.act(self.norm(self.deconv(x)))


class ResBlock1D(nn.Module):
    """
    Residual block: (Conv -> BN -> ReLU) x2 + skip
    Keeps same channel count and length.
    """
    def __init__(self, ch, k=7, p=3, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n1 = nn.BatchNorm1d(ch)
        self.c2 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n2 = nn.BatchNorm1d(ch)

    def forward(self, x):
        h = F.relu(self.n1(self.c1(x)))
        h = self.n2(self.c2(h))
        return F.relu(x + h)
    
class MultiScaleResBlock1D(nn.Module):
    """
    Multi-scale residual block: parallel conv branches (k=3,5,7) then fuse.
    Keeps same channel count and length.
    """
    def __init__(self, ch, bias=True):
        super().__init__()

        self.b3 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=3, padding=1, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b5 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=5, padding=2, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b7 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=7, padding=3, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )

        self.fuse = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=1, bias=bias),
            nn.BatchNorm1d(ch),
        )

    def forward(self, x):
        h = self.b3(x) + self.b5(x) + self.b7(x)
        h = self.fuse(h)
        return F.relu(x + h)



In [6]:
# NN Generator (U-Net-ish + Res bottleneck)
class GeneratorCNNWGAN(nn.Module):
    """
    CNN U-Net-ish generator for EEG denoising (WGAN).
    Input : (B, 1, 512) noisy_norm
    Output: (B, 1, 512) clean_norm_hat
    """
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        # Encoder
        self.e1 = ConvBlock1D(1, base_ch,       k=16, s=2, p=7, bias=bias, act="lrelu")      # 512 -> 256
        self.e2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.e3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.e4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32

        # Bottleneck
        bn_ch = base_ch * 8
        self.bottleneck = nn.Sequential(*[
            ResBlock1D(bn_ch, k=7, p=3, bias=bias) for _ in range(bottleneck_blocks)
        ])

        # Decoder (concat doubles channels)
        self.d1 = DeconvBlock1D(bn_ch, base_ch*4,   k=4, s=2, p=1, bias=bias, act="relu")     # 32 -> 64
        self.d2 = DeconvBlock1D(base_ch*8, base_ch*2, k=4, s=2, p=1,bias=bias, act="relu")   # 64 -> 128
        self.d3 = DeconvBlock1D(base_ch*4, base_ch,   k=4, s=2, p=1, bias=bias, act="relu")   # 128 -> 256
        self.d4 = DeconvBlock1D(base_ch*2, base_ch//2, k=4, s=2, p=1, bias=bias, act="relu")  # 256 -> 512

        # Head (linear output recommended for normalized signals)
        self.out = nn.Conv1d(base_ch//2, 1, kernel_size=7, stride=1, padding=3, bias=bias)

    def forward(self, y):
        # Encoder
        s1 = self.e1(y)   # (B, base, 256)
        s2 = self.e2(s1)  # (B, 2b, 128)
        s3 = self.e3(s2)  # (B, 4b, 64)
        s4 = self.e4(s3)  # (B, 8b, 32)

        # Bottleneck
        b = self.bottleneck(s4)

        # Decoder + skip connections
        d1 = self.d1(b)                  # (B, 4b, 64)
        d1 = torch.cat([d1, s3], dim=1)  # (B, 8b, 64)

        d2 = self.d2(d1)                 # (B, 2b, 128)
        d2 = torch.cat([d2, s2], dim=1)  # (B, 4b, 128)

        d3 = self.d3(d2)                 # (B, b, 256)
        d3 = torch.cat([d3, s1], dim=1)  # (B, 2b, 256)

        d4 = self.d4(d3)                 # (B, b/2, 512)

        return self.out(d4)              # (B, 1, 512)
    
# Patch Critic (shared by CNN/ResCNN)
class CriticPatch1D(nn.Module):
    """
    Conditional PatchGAN critic for WGAN:
      D(y, x) -> patch scores
    y,x: (B,1,512)
    output: (B,1,32)
    """
    def __init__(self, base_ch=32, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(2, base_ch, kernel_size=16, stride=2, padding=7, bias=bias)  # 512 -> 256
        self.c2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.c3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.c4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32
        self.out = nn.Conv1d(base_ch*8, 1, kernel_size=7, stride=1, padding=3, bias=bias)   # 32 -> 32

    def forward(self, y, x):
        h = torch.cat([y, x], dim=1)  # (B,2,512)
        h = F.leaky_relu(self.c1(h), 0.2, inplace=True)
        h = self.c2(h)
        h = self.c3(h)
        h = self.c4(h)
        return self.out(h)


In [7]:
class FusedConvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, k, s, p, bias=bias)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError

    def forward(self, x):
        return self.act(self.conv(x))


class FusedDeconvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, k, s, p, bias=bias)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError

    def forward(self, x):
        return self.act(self.deconv(x))


In [8]:
class GeneratorCNNWGAN_Fused(nn.Module):
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        self.e1 = FusedConvBlock1D(1, base_ch, 16, 2, 7, bias, "lrelu")
        self.e2 = FusedConvBlock1D(base_ch, base_ch*2, 16, 2, 7, bias, "lrelu")
        self.e3 = FusedConvBlock1D(base_ch*2, base_ch*4, 16, 2, 7, bias, "lrelu")
        self.e4 = FusedConvBlock1D(base_ch*4, base_ch*8, 16, 2, 7, bias, "lrelu")

        self.bottleneck = nn.Sequential(*[
            nn.Sequential(
                nn.Conv1d(base_ch*8, base_ch*8, 7, 1, 3, bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv1d(base_ch*8, base_ch*8, 7, 1, 3, bias=bias),
            )
            for _ in range(bottleneck_blocks)
        ])

        self.d1 = FusedDeconvBlock1D(base_ch*8, base_ch*4, 4, 2, 1, bias)
        self.d2 = FusedDeconvBlock1D(base_ch*8, base_ch*2, 4, 2, 1, bias)
        self.d3 = FusedDeconvBlock1D(base_ch*4, base_ch, 4, 2, 1, bias)
        self.d4 = FusedDeconvBlock1D(base_ch*2, base_ch//2, 4, 2, 1, bias)

        self.out = nn.Conv1d(base_ch//2, 1, 7, 1, 3, bias=bias)

    def forward(self, y):
        s1 = self.e1(y)
        s2 = self.e2(s1)
        s3 = self.e3(s2)
        s4 = self.e4(s3)

        b = s4
        for blk in self.bottleneck:
            b = F.relu(b + blk(b))

        d1 = torch.cat([self.d1(b), s3], dim=1)
        d2 = torch.cat([self.d2(d1), s2], dim=1)
        d3 = torch.cat([self.d3(d2), s1], dim=1)
        d4 = self.d4(d3)

        return self.out(d4)


In [9]:
def fuse_conv_bn_1d(
    conv_weight,        # (Cout, Cin, K)
    conv_bias,          # (Cout,) or None
    running_mean,       # (Cout,)
    running_var,        # (Cout,)
    bn_weight,          # (Cout,) or None (gamma)
    bn_bias,            # (Cout,) or None (beta)
    eps=1e-5
):
    Cout = conv_weight.shape[0]

    if bn_weight is None:
        bn_weight = torch.ones(Cout, device=conv_weight.device, dtype=conv_weight.dtype)
    if bn_bias is None:
        bn_bias = torch.zeros(Cout, device=conv_weight.device, dtype=conv_weight.dtype)
    if conv_bias is None:
        conv_bias = torch.zeros(Cout, device=conv_weight.device, dtype=conv_weight.dtype)

    denom = torch.sqrt(running_var + eps)          # (Cout,)
    scale = bn_weight / denom                      # (Cout,)

    # Fuse weight
    fused_weight = conv_weight * scale[:, None, None]

    # Fuse bias
    fused_bias = (conv_bias - running_mean) * scale + bn_bias

    return fused_weight, fused_bias

def fuse_deconv_bn_1d(
    deconv_weight,     # (Cin, Cout, K)
    deconv_bias,       # (Cout,) or None
    running_mean,      # (Cout,)
    running_var,       # (Cout,)
    bn_weight,         # (Cout,)
    bn_bias,           # (Cout,)
    eps
):
    Cin, Cout, K = deconv_weight.shape

    if deconv_bias is None:
        deconv_bias = torch.zeros(
            Cout,
            device=deconv_weight.device,
            dtype=deconv_weight.dtype
        )

    # BN scale
    scale = bn_weight / torch.sqrt(running_var + eps)  # (Cout,)

    # Fuse weights (scale on Cout dimension)
    fused_weight = deconv_weight * scale.view(1, Cout, 1)

    # Fuse bias
    fused_bias = (deconv_bias - running_mean) * scale + bn_bias

    return fused_weight, fused_bias


In [10]:
def fuse_generator(G):
    G_fused = GeneratorCNNWGAN_Fused(bias=True).to(device)
    G_fused.eval()

    with torch.no_grad():
        # Encoder
        for i in range(1, 5):
            e = getattr(G, f"e{i}")
            fe = getattr(G_fused, f"e{i}")

            w, b = fuse_conv_bn_1d(
                e.conv.weight, e.conv.bias,
                e.norm.running_mean, e.norm.running_var,
                e.norm.weight, e.norm.bias
            )
            fe.conv.weight.copy_(w)
            fe.conv.bias.copy_(b)

        # Bottleneck
        for i, blk in enumerate(G.bottleneck):
            fblk = G_fused.bottleneck[i]

            for j, (c, n) in enumerate([(blk.c1, blk.n1), (blk.c2, blk.n2)]):
                w, b = fuse_conv_bn_1d(
                    c.weight, c.bias,
                    n.running_mean, n.running_var,
                    n.weight, n.bias
                )
                fblk[j*2].weight.copy_(w)
                fblk[j*2].bias.copy_(b)

        # Decoder
        for i in range(1, 5):
            d = getattr(G, f"d{i}")
            fd = getattr(G_fused, f"d{i}")

            w, b = fuse_deconv_bn_1d(
                d.deconv.weight, d.deconv.bias,
                d.norm.running_mean, d.norm.running_var,
                d.norm.weight, d.norm.bias,
                d.norm.eps
            )
            fd.deconv.weight.copy_(w)
            fd.deconv.bias.copy_(b)

        # Output
        G_fused.out.weight.copy_(G.out.weight)
        G_fused.out.bias.copy_(G.out.bias)

    return G_fused


# Model Import

In [11]:
BIAS = True
DATA_MODE = 5 # up to 5

In [12]:
data_path = os.path.abspath(f"../models/main3_d{DATA_MODE}_{"b" if BIAS else "nb"}/")
print(data_path)

pattern = re.compile(r"^cnn_([DG])_\d{8}_\d{6}\.pth$")

cnn_G_path = None
cnn_D_path = None

for f in os.listdir(data_path):
    m = pattern.match(f)
    if m:
        full = os.path.join(data_path, f)
        if m.group(1) == "G":
            cnn_G_path = full
        else:
            cnn_D_path = full

print("G:", cnn_G_path)
print("D:", cnn_D_path)

c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d5_b
G: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d5_b\cnn_G_20260114_024826.pth
D: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d5_b\cnn_D_20260114_024826.pth


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = GeneratorCNNWGAN(bias=BIAS).to(device)
D = CriticPatch1D(bias=BIAS).to(device)

G.load_state_dict(torch.load(cnn_G_path, map_location=device))
D.load_state_dict(torch.load(cnn_D_path, map_location=device))

print(G.eval())
print(D.eval())

GeneratorCNNWGAN(
  (e1): ConvBlock1D(
    (conv): Conv1d(1, 32, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e2): ConvBlock1D(
    (conv): Conv1d(32, 64, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e3): ConvBlock1D(
    (conv): Conv1d(64, 128, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e4): ConvBlock1D(
    (conv): Conv1d(128, 256, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  

  G.load_state_dict(torch.load(cnn_G_path, map_location=device))
  D.load_state_dict(torch.load(cnn_D_path, map_location=device))


In [14]:
G_fused = fuse_generator(G)

In [15]:
Q_CONFIGS = {
    "Q4.12": dict(frac_bits=12, int_bits=4, dtype=np.int16),
    "Q10.10": dict(frac_bits=10, int_bits=10, dtype=np.int32),
    "Q9.14": dict(frac_bits=14, int_bits=10, dtype=np.int32),
}

def float_to_q(x, frac_bits, int_bits, dtype):
    scale = 1 << frac_bits
    total_bits = int_bits + frac_bits
    min_val = -(1 << (total_bits - 1))
    max_val = (1 << (total_bits - 1)) - 1

    xq = np.round(x * scale)
    xq = np.clip(xq, min_val, max_val)
    return xq.astype(dtype)

def q_to_float(x, frac_bits):
    return x.astype(np.float32) / (1 << frac_bits)

In [16]:
TYPE =  "Q9.14" # "Q4.12", "Q10.10","Q9.14", or "FLOAT"

TIME_REPEAT = {
    "encoder":    [256, 128, 64, 32],
    "bottleneck": [32]*8,
    "decoder":    [64, 128, 256, 512],
    "out":        [512]
}

if TYPE == "FLOAT":
    MODE = "FLOAT"
    DTYPE = np.float32
else:
    MODE = "FIXED"
    cfg = Q_CONFIGS[TYPE]
    FRAC_BITS = cfg["frac_bits"]
    INT_BITS  = cfg["int_bits"]
    DTYPE     = cfg["dtype"]

# Result

In [17]:
output_path = os.path.abspath(f"../models/sample_output_{TYPE}/")
os.makedirs(output_path, exist_ok=True)

In [18]:
# Encoder input
G.eval()
encoder_input_batched = encoder_input.unsqueeze(0)

with torch.no_grad():
    out_all = G(encoder_input_batched)

out_2d_all = out_all.squeeze(0)
out_f_all = out_2d_all.detach().cpu().numpy().astype(np.float32)

assert out_f_all.ndim == 2, "Expected 2D output [x, y]"
X_all, Y_all = out_f_all.shape

print("===== DECODER LAST OUTPUT =====")
print(f"Shape      : ({X_all}, {Y_all})")
print(f"Total vals : {X_all * Y_all}")

===== DECODER LAST OUTPUT =====
Shape      : (1, 512)
Total vals : 512


In [19]:
# Decoder last output 
G.eval()
decoder_last_input_batched = decoder_last_input.unsqueeze(0)

with torch.no_grad():
    out_deconv_ref = G.d4.deconv(decoder_last_input_batched)
    out_bn_ref     = G.d4.norm(out_deconv_ref)
    out_d4_ref     = G.d4.act(out_bn_ref)

out_2d_decoder_last = out_d4_ref.squeeze(0)
out_f_decoder_last = out_2d_decoder_last.detach().cpu().numpy().astype(np.float32)

assert out_f_decoder_last.ndim == 2, "Expected 2D output [x, y]"
X_decoder_last, Y_decoder_last = out_f_decoder_last.shape

print("===== DECODER LAST OUTPUT =====")
print(f"Shape      : ({X_decoder_last}, {Y_decoder_last})")
print(f"Total vals : {X_decoder_last * Y_decoder_last}")

===== DECODER LAST OUTPUT =====
Shape      : (16, 512)
Total vals : 8192


In [20]:
# ---- QUANTIZE ----
out_q_all = float_to_q(
    out_f_all,
    frac_bits=FRAC_BITS,
    int_bits=INT_BITS,
    dtype=DTYPE
)

mem_name_all_hex = f"encoder_output_d{DATA_MODE}_format.mem"
mem_name_all_raw = f"encoder_output_d{DATA_MODE}_raw.mem"

mem_path_all_hex = os.path.join(output_path, mem_name_all_hex)
mem_path_all_raw = os.path.join(output_path, mem_name_all_raw)

with open(mem_path_all_hex, "w") as f_hex, open(mem_path_all_raw, "w") as f_raw:
    for y in range(Y_all):
        for x in range(X_all):
            v = int(out_q_all[x, y])

            # HEX (two's complement, 32-bit)
            f_hex.write(f"{v & 0xFFFFFFFF:08X}\n")

            # RAW signed decimal
            f_raw.write(f"{v}\n")

print("Saved outputs:")
print(f"  HEX : {mem_name_all_hex}")
print(f"  RAW : {mem_name_all_raw}")
print(f"  Shape: ({X_all}, {Y_all}), total entries = {X_all * Y_all}")

Saved outputs:
  HEX : encoder_output_d5_format.mem
  RAW : encoder_output_d5_raw.mem
  Shape: (1, 512), total entries = 512


In [21]:
out_q_decoder_last = float_to_q(
    out_f_decoder_last,
    frac_bits=FRAC_BITS,
    int_bits=INT_BITS,
    dtype=DTYPE
)

mem_name_decoder_last_hex = f"decoder_last_output_d{DATA_MODE}_format.mem"
mem_name_decoder_last_raw = f"decoder_last_output_d{DATA_MODE}_raw.mem"

mem_path_decoder_last_hex = os.path.join(output_path, mem_name_decoder_last_hex)
mem_path_decoder_last_raw = os.path.join(output_path, mem_name_decoder_last_raw)

with open(mem_path_decoder_last_hex, "w") as f_hex, open(mem_path_decoder_last_raw, "w") as f_raw:
    for x in range(X_decoder_last):
        for y in range(Y_decoder_last):
            v = int(out_q_decoder_last[x, y])

            # HEX (two's complement, 32-bit)
            f_hex.write(f"{v & 0xFFFFFFFF:08X}\n")

            # RAW signed decimal
            f_raw.write(f"{v}\n")

print("Saved outputs:")
print(f"  HEX : {mem_name_decoder_last_hex}")
print(f"  RAW : {mem_name_decoder_last_raw}")
print(f"  Shape: ({X_decoder_last}, {Y_decoder_last}), total entries = {X_decoder_last * Y_decoder_last}")

Saved outputs:
  HEX : decoder_last_output_d5_format.mem
  RAW : decoder_last_output_d5_raw.mem
  Shape: (16, 512), total entries = 8192


In [28]:
print(decoder_last_input)

tensor([[ 0.6341, -1.2250,  0.0230,  ...,  0.2211,  0.2900, -0.1713],
        [-0.4107, -1.2087,  1.7604,  ..., -0.4288, -0.4538,  0.9026],
        [-0.3861,  0.3538,  0.8009,  ...,  0.5962, -1.7025, -0.3186],
        ...,
        [ 0.4007, -0.3718, -1.2671,  ...,  1.2490,  0.5307, -1.2224],
        [-0.6826,  0.6642,  1.0837,  ...,  0.5124, -0.3639,  0.7698],
        [-0.7566, -0.0156, -0.9006,  ..., -1.7253,  0.0457,  0.4532]],
       device='cuda:0')


In [29]:
print(G_fused.d4.deconv.weight)
print(G_fused.d4.deconv.bias)

Parameter containing:
tensor([[[ 0.0228,  0.1653,  0.0051, -0.1294],
         [-0.0901, -0.0611,  0.1014,  0.0291],
         [ 0.0008, -0.0930,  0.0647, -0.0078],
         ...,
         [-0.0497, -0.1010, -0.0566,  0.0056],
         [-0.0846,  0.0546, -0.0636,  0.0020],
         [-0.0536,  0.0287, -0.0322,  0.0185]],

        [[ 0.0769, -0.0828,  0.1180,  0.0688],
         [-0.0002, -0.0019,  0.0623, -0.1138],
         [-0.0240, -0.0500, -0.0507, -0.0429],
         ...,
         [-0.1126,  0.0289, -0.1446,  0.1001],
         [ 0.0340, -0.0484,  0.0286,  0.1241],
         [ 0.0700, -0.0342,  0.0891,  0.0096]],

        [[-0.0995,  0.0514, -0.1144,  0.0107],
         [ 0.0730, -0.0616, -0.0276, -0.0601],
         [ 0.0722, -0.0767,  0.0243, -0.0774],
         ...,
         [-0.0025,  0.0232, -0.0602, -0.0437],
         [-0.1055, -0.0083, -0.1348, -0.0881],
         [ 0.0371,  0.0029,  0.0282, -0.1540]],

        ...,

        [[ 0.0081,  0.1087, -0.0645, -0.0016],
         [ 0.0173,  0.0

In [30]:
def deconv1d_output_at(
    x,          # input tensor, shape (Cin, Tin)
    W,          # deconv weight, shape (Cin, Cout, K)
    b,          # bias, shape (Cout,)
    oc,         # output channel index (0-based)
    t_out,      # output time index (0-based)
    FRAC_BITS,
    INT_BITS,
    DTYPE,
    stride=2,
    padding=1,
):
    """
    Compute y[oc, t_out] for ConvTranspose1d using default stride/padding.

    This is a GOLDEN reference for HW:
    - No tiling
    - No PE logic
    - Pure math
    """

    Cin, Tin = x.shape
    _, Cout, K = W.shape

    assert 0 <= oc < Cout

    # --- FLOAT ACCUMULATION ---
    y_float = b[oc].item()

    for k_pos in range(K):
        # Transposed-conv inverse mapping
        t_in = t_out + padding - k_pos

        # Must align with stride
        if t_in % stride != 0:
            continue

        t_in //= stride

        # Bounds check
        if t_in < 0 or t_in >= Tin:
            continue

        for cin in range(Cin):
            y_float += (
                x[cin, t_in].item()
                * W[cin, oc, k_pos].item()
            )

    # --- FIXED-POINT ---
    y_fixed = float_to_q(
        y_float,
        frac_bits=FRAC_BITS,
        int_bits=INT_BITS,
        dtype=DTYPE
    )

    return y_float, y_fixed

In [39]:
y_float, y_fixed = deconv1d_output_at(
    x=decoder_last_input,
    W=G_fused.d4.deconv.weight,
    b=G_fused.d4.deconv.bias,
    oc=0,
    t_out=2,
    FRAC_BITS=FRAC_BITS,
    INT_BITS=INT_BITS,
    DTYPE=DTYPE,
)

print("y_float[0,0] =", y_float if y_float > 0 else 0)
print("y_fixed[0,0] =", y_fixed if y_fixed > 0 else 0)

y_float[0,0] = 0
y_fixed[0,0] = 0
