# Import Library

In [286]:
import os
import numpy as np
import pandas as pd
import scipy.io as sio
import matplotlib.pyplot as plt
import math
import json
import re

import torch
import torch.nn as nn
import torch.nn.functional as F

import random
np.random.seed(42)
random.seed(42)

## Model Definition

In [287]:
class ConvBlock1D(nn.Module):
    """
    Conv1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=15, s=2, p=7, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError("act must be 'lrelu' or 'relu'")

    def forward(self, x):
        return self.act(self.norm(self.conv(x)))


class DeconvBlock1D(nn.Module):
    """
    ConvTranspose1d -> BatchNorm1d -> Activation
    """
    def __init__(self, in_ch, out_ch, k=4, s=2, p=1, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=bias)
        self.norm = nn.BatchNorm1d(out_ch)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError("act must be 'relu' or 'lrelu'")

    def forward(self, x):
        return self.act(self.norm(self.deconv(x)))


class ResBlock1D(nn.Module):
    """
    Residual block: (Conv -> BN -> ReLU) x2 + skip
    Keeps same channel count and length.
    """
    def __init__(self, ch, k=7, p=3, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n1 = nn.BatchNorm1d(ch)
        self.c2 = nn.Conv1d(ch, ch, kernel_size=k, stride=1, padding=p, bias=bias)
        self.n2 = nn.BatchNorm1d(ch)

    def forward(self, x):
        h = F.relu(self.n1(self.c1(x)))
        h = self.n2(self.c2(h))
        return F.relu(x + h)
    
class MultiScaleResBlock1D(nn.Module):
    """
    Multi-scale residual block: parallel conv branches (k=3,5,7) then fuse.
    Keeps same channel count and length.
    """
    def __init__(self, ch, bias=True):
        super().__init__()

        self.b3 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=3, padding=1, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b5 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=5, padding=2, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )
        self.b7 = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=7, padding=3, bias=bias),
            nn.BatchNorm1d(ch),
            nn.ReLU(inplace=True),
        )

        self.fuse = nn.Sequential(
            nn.Conv1d(ch, ch, kernel_size=1, bias=bias),
            nn.BatchNorm1d(ch),
        )

    def forward(self, x):
        h = self.b3(x) + self.b5(x) + self.b7(x)
        h = self.fuse(h)
        return F.relu(x + h)



In [288]:
# NN Generator (U-Net-ish + Res bottleneck)
class GeneratorCNNWGAN(nn.Module):
    """
    CNN U-Net-ish generator for EEG denoising (WGAN).
    Input : (B, 1, 512) noisy_norm
    Output: (B, 1, 512) clean_norm_hat
    """
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        # Encoder
        self.e1 = ConvBlock1D(1, base_ch,       k=16, s=2, p=7, bias=bias, act="lrelu")      # 512 -> 256
        self.e2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.e3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.e4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32

        # Bottleneck
        bn_ch = base_ch * 8
        self.bottleneck = nn.Sequential(*[
            ResBlock1D(bn_ch, k=7, p=3, bias=bias) for _ in range(bottleneck_blocks)
        ])

        # Decoder (concat doubles channels)
        self.d1 = DeconvBlock1D(bn_ch, base_ch*4,   k=4, s=2, p=1, bias=bias, act="relu")     # 32 -> 64
        self.d2 = DeconvBlock1D(base_ch*8, base_ch*2, k=4, s=2, p=1,bias=bias, act="relu")   # 64 -> 128
        self.d3 = DeconvBlock1D(base_ch*4, base_ch,   k=4, s=2, p=1, bias=bias, act="relu")   # 128 -> 256
        self.d4 = DeconvBlock1D(base_ch*2, base_ch//2, k=4, s=2, p=1, bias=bias, act="relu")  # 256 -> 512

        # Head (linear output recommended for normalized signals)
        self.out = nn.Conv1d(base_ch//2, 1, kernel_size=7, stride=1, padding=3, bias=bias)

    def forward(self, y):
        # Encoder
        s1 = self.e1(y)   # (B, base, 256)
        s2 = self.e2(s1)  # (B, 2b, 128)
        s3 = self.e3(s2)  # (B, 4b, 64)
        s4 = self.e4(s3)  # (B, 8b, 32)

        # Bottleneck
        b = self.bottleneck(s4)

        # Decoder + skip connections
        d1 = self.d1(b)                  # (B, 4b, 64)
        d1 = torch.cat([d1, s3], dim=1)  # (B, 8b, 64)

        d2 = self.d2(d1)                 # (B, 2b, 128)
        d2 = torch.cat([d2, s2], dim=1)  # (B, 4b, 128)

        d3 = self.d3(d2)                 # (B, b, 256)
        d3 = torch.cat([d3, s1], dim=1)  # (B, 2b, 256)

        d4 = self.d4(d3)                 # (B, b/2, 512)

        return self.out(d4)              # (B, 1, 512)

In [289]:
# Patch Critic (shared by CNN/ResCNN)
class CriticPatch1D(nn.Module):
    """
    Conditional PatchGAN critic for WGAN:
      D(y, x) -> patch scores
    y,x: (B,1,512)
    output: (B,1,32)
    """
    def __init__(self, base_ch=32, bias=True):
        super().__init__()
        self.c1 = nn.Conv1d(2, base_ch, kernel_size=16, stride=2, padding=7, bias=bias)  # 512 -> 256
        self.c2 = ConvBlock1D(base_ch, base_ch*2, k=16, s=2, p=7, bias=bias, act="lrelu")    # 256 -> 128
        self.c3 = ConvBlock1D(base_ch*2, base_ch*4, k=16, s=2, p=7, bias=bias, act="lrelu")  # 128 -> 64
        self.c4 = ConvBlock1D(base_ch*4, base_ch*8, k=16, s=2, p=7, bias=bias, act="lrelu")  # 64 -> 32
        self.out = nn.Conv1d(base_ch*8, 1, kernel_size=7, stride=1, padding=3, bias=bias)   # 32 -> 32

    def forward(self, y, x):
        h = torch.cat([y, x], dim=1)  # (B,2,512)
        h = F.leaky_relu(self.c1(h), 0.2, inplace=True)
        h = self.c2(h)
        h = self.c3(h)
        h = self.c4(h)
        return self.out(h)


# Model Import

In [290]:
BIAS = True
DATA_MODE = 1 # up to 5

In [291]:
data_path = os.path.abspath(f"../models/main3_d{DATA_MODE}_{"b" if BIAS else "nb"}/")
print(data_path)

pattern = re.compile(r"^cnn_([DG])_\d{8}_\d{6}\.pth$")

cnn_G_path = None
cnn_D_path = None

for f in os.listdir(data_path):
    m = pattern.match(f)
    if m:
        full = os.path.join(data_path, f)
        if m.group(1) == "G":
            cnn_G_path = full
        else:
            cnn_D_path = full

print("G:", cnn_G_path)
print("D:", cnn_D_path)

c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d1_b
G: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d1_b\cnn_G_20260113_074150.pth
D: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d1_b\cnn_D_20260113_074150.pth


In [292]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

G = GeneratorCNNWGAN(bias=BIAS).to(device)
D = CriticPatch1D(bias=BIAS).to(device)

G.load_state_dict(torch.load(cnn_G_path, map_location=device))
D.load_state_dict(torch.load(cnn_D_path, map_location=device))

print(G.eval())
print(D.eval())

GeneratorCNNWGAN(
  (e1): ConvBlock1D(
    (conv): Conv1d(1, 32, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e2): ConvBlock1D(
    (conv): Conv1d(32, 64, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e3): ConvBlock1D(
    (conv): Conv1d(64, 128, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (e4): ConvBlock1D(
    (conv): Conv1d(128, 256, kernel_size=(16,), stride=(2,), padding=(7,))
    (norm): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): LeakyReLU(negative_slope=0.2, inplace=True)
  

  G.load_state_dict(torch.load(cnn_G_path, map_location=device))
  D.load_state_dict(torch.load(cnn_D_path, map_location=device))


In [293]:
for name, module in G.named_modules():
    print(name, "->", module.__class__.__name__)

print("\n" + "=" * 80 +"\n")

for name, module in D.named_modules():
    print(name, "->", module.__class__.__name__)

 -> GeneratorCNNWGAN
e1 -> ConvBlock1D
e1.conv -> Conv1d
e1.norm -> BatchNorm1d
e1.act -> LeakyReLU
e2 -> ConvBlock1D
e2.conv -> Conv1d
e2.norm -> BatchNorm1d
e2.act -> LeakyReLU
e3 -> ConvBlock1D
e3.conv -> Conv1d
e3.norm -> BatchNorm1d
e3.act -> LeakyReLU
e4 -> ConvBlock1D
e4.conv -> Conv1d
e4.norm -> BatchNorm1d
e4.act -> LeakyReLU
bottleneck -> Sequential
bottleneck.0 -> ResBlock1D
bottleneck.0.c1 -> Conv1d
bottleneck.0.n1 -> BatchNorm1d
bottleneck.0.c2 -> Conv1d
bottleneck.0.n2 -> BatchNorm1d
bottleneck.1 -> ResBlock1D
bottleneck.1.c1 -> Conv1d
bottleneck.1.n1 -> BatchNorm1d
bottleneck.1.c2 -> Conv1d
bottleneck.1.n2 -> BatchNorm1d
bottleneck.2 -> ResBlock1D
bottleneck.2.c1 -> Conv1d
bottleneck.2.n1 -> BatchNorm1d
bottleneck.2.c2 -> Conv1d
bottleneck.2.n2 -> BatchNorm1d
bottleneck.3 -> ResBlock1D
bottleneck.3.c1 -> Conv1d
bottleneck.3.n1 -> BatchNorm1d
bottleneck.3.c2 -> Conv1d
bottleneck.3.n2 -> BatchNorm1d
d1 -> DeconvBlock1D
d1.deconv -> ConvTranspose1d
d1.norm -> BatchNorm1d
d1

In [294]:
for name, param in G.named_parameters():
    print(name, param.shape, param.requires_grad)

print("\n" + "=" * 80 +"\n")

for name, param in D.named_parameters():
    print(name, param.shape, param.requires_grad)


e1.conv.weight torch.Size([32, 1, 16]) True
e1.conv.bias torch.Size([32]) True
e1.norm.weight torch.Size([32]) True
e1.norm.bias torch.Size([32]) True
e2.conv.weight torch.Size([64, 32, 16]) True
e2.conv.bias torch.Size([64]) True
e2.norm.weight torch.Size([64]) True
e2.norm.bias torch.Size([64]) True
e3.conv.weight torch.Size([128, 64, 16]) True
e3.conv.bias torch.Size([128]) True
e3.norm.weight torch.Size([128]) True
e3.norm.bias torch.Size([128]) True
e4.conv.weight torch.Size([256, 128, 16]) True
e4.conv.bias torch.Size([256]) True
e4.norm.weight torch.Size([256]) True
e4.norm.bias torch.Size([256]) True
bottleneck.0.c1.weight torch.Size([256, 256, 7]) True
bottleneck.0.c1.bias torch.Size([256]) True
bottleneck.0.n1.weight torch.Size([256]) True
bottleneck.0.n1.bias torch.Size([256]) True
bottleneck.0.c2.weight torch.Size([256, 256, 7]) True
bottleneck.0.c2.bias torch.Size([256]) True
bottleneck.0.n2.weight torch.Size([256]) True
bottleneck.0.n2.bias torch.Size([256]) True
bottlene

In [295]:
for name, buf in G.named_buffers():
    print(name, buf.shape)

print("\n" + "=" * 80 +"\n")

for name, buf in D.named_buffers():
    print(name, buf.shape)

e1.norm.running_mean torch.Size([32])
e1.norm.running_var torch.Size([32])
e1.norm.num_batches_tracked torch.Size([])
e2.norm.running_mean torch.Size([64])
e2.norm.running_var torch.Size([64])
e2.norm.num_batches_tracked torch.Size([])
e3.norm.running_mean torch.Size([128])
e3.norm.running_var torch.Size([128])
e3.norm.num_batches_tracked torch.Size([])
e4.norm.running_mean torch.Size([256])
e4.norm.running_var torch.Size([256])
e4.norm.num_batches_tracked torch.Size([])
bottleneck.0.n1.running_mean torch.Size([256])
bottleneck.0.n1.running_var torch.Size([256])
bottleneck.0.n1.num_batches_tracked torch.Size([])
bottleneck.0.n2.running_mean torch.Size([256])
bottleneck.0.n2.running_var torch.Size([256])
bottleneck.0.n2.num_batches_tracked torch.Size([])
bottleneck.1.n1.running_mean torch.Size([256])
bottleneck.1.n1.running_var torch.Size([256])
bottleneck.1.n1.num_batches_tracked torch.Size([])
bottleneck.1.n2.running_mean torch.Size([256])
bottleneck.1.n2.running_var torch.Size([256])

In [296]:
for name, module in G.named_modules():
    if isinstance(module, torch.nn.BatchNorm1d):
        print(f"\nBN layer: {name}")
        print(" running_mean:", module.running_mean.shape)
        print(" running_var :", module.running_var.shape)
        print(" momentum    :", module.momentum)
        print(" eps         :", module.eps)
        print(" affine      :", module.affine)



BN layer: e1.norm
 running_mean: torch.Size([32])
 running_var : torch.Size([32])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: e2.norm
 running_mean: torch.Size([64])
 running_var : torch.Size([64])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: e3.norm
 running_mean: torch.Size([128])
 running_var : torch.Size([128])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: e4.norm
 running_mean: torch.Size([256])
 running_var : torch.Size([256])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: bottleneck.0.n1
 running_mean: torch.Size([256])
 running_var : torch.Size([256])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: bottleneck.0.n2
 running_mean: torch.Size([256])
 running_var : torch.Size([256])
 momentum    : 0.1
 eps         : 1e-05
 affine      : True

BN layer: bottleneck.1.n1
 running_mean: torch.Size([256])
 running_var : torch.Size([256])
 momentum    : 0.1
 eps  

# Convolution and BatchNorm Fusion

In [297]:
def fuse_conv_bn_1d(
    conv_weight,        # (Cout, Cin, K)
    conv_bias,          # (Cout,) or None
    running_mean,       # (Cout,)
    running_var,        # (Cout,)
    bn_weight,          # (Cout,) or None (gamma)
    bn_bias,            # (Cout,) or None (beta)
    eps=1e-5
):
    Cout = conv_weight.shape[0]

    if bn_weight is None:
        bn_weight = torch.ones(Cout, device=conv_weight.device, dtype=conv_weight.dtype)
    if bn_bias is None:
        bn_bias = torch.zeros(Cout, device=conv_weight.device, dtype=conv_weight.dtype)
    if conv_bias is None:
        conv_bias = torch.zeros(Cout, device=conv_weight.device, dtype=conv_weight.dtype)

    denom = torch.sqrt(running_var + eps)          # (Cout,)
    scale = bn_weight / denom                      # (Cout,)

    # Fuse weight
    fused_weight = conv_weight * scale[:, None, None]

    # Fuse bias
    fused_bias = (conv_bias - running_mean) * scale + bn_bias

    return fused_weight, fused_bias

def fuse_deconv_bn_1d(
    deconv_weight,     # (Cin, Cout, K)
    deconv_bias,       # (Cout,) or None
    running_mean,      # (Cout,)
    running_var,       # (Cout,)
    bn_weight,         # (Cout,)
    bn_bias,           # (Cout,)
    eps
):
    Cin, Cout, K = deconv_weight.shape

    if deconv_bias is None:
        deconv_bias = torch.zeros(
            Cout,
            device=deconv_weight.device,
            dtype=deconv_weight.dtype
        )

    # BN scale
    scale = bn_weight / torch.sqrt(running_var + eps)  # (Cout,)

    # Fuse weights (scale on Cout dimension)
    fused_weight = deconv_weight * scale.view(1, Cout, 1)

    # Fuse bias
    fused_bias = (deconv_bias - running_mean) * scale + bn_bias

    return fused_weight, fused_bias


In [298]:
class FusedConvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=True, act="lrelu"):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, k, s, p, bias=bias)

        if act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        elif act == "relu":
            self.act = nn.ReLU(inplace=True)
        else:
            raise ValueError

    def forward(self, x):
        return self.act(self.conv(x))


class FusedDeconvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=True, act="relu"):
        super().__init__()
        self.deconv = nn.ConvTranspose1d(in_ch, out_ch, k, s, p, bias=bias)

        if act == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act == "lrelu":
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError

    def forward(self, x):
        return self.act(self.deconv(x))


In [299]:
class GeneratorCNNWGAN_Fused(nn.Module):
    def __init__(self, base_ch=32, bottleneck_blocks=4, bias=True):
        super().__init__()

        self.e1 = FusedConvBlock1D(1, base_ch, 16, 2, 7, bias, "lrelu")
        self.e2 = FusedConvBlock1D(base_ch, base_ch*2, 16, 2, 7, bias, "lrelu")
        self.e3 = FusedConvBlock1D(base_ch*2, base_ch*4, 16, 2, 7, bias, "lrelu")
        self.e4 = FusedConvBlock1D(base_ch*4, base_ch*8, 16, 2, 7, bias, "lrelu")

        self.bottleneck = nn.Sequential(*[
            nn.Sequential(
                nn.Conv1d(base_ch*8, base_ch*8, 7, 1, 3, bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv1d(base_ch*8, base_ch*8, 7, 1, 3, bias=bias),
            )
            for _ in range(bottleneck_blocks)
        ])

        self.d1 = FusedDeconvBlock1D(base_ch*8, base_ch*4, 4, 2, 1, bias)
        self.d2 = FusedDeconvBlock1D(base_ch*8, base_ch*2, 4, 2, 1, bias)
        self.d3 = FusedDeconvBlock1D(base_ch*4, base_ch, 4, 2, 1, bias)
        self.d4 = FusedDeconvBlock1D(base_ch*2, base_ch//2, 4, 2, 1, bias)

        self.out = nn.Conv1d(base_ch//2, 1, 7, 1, 3, bias=bias)

    def forward(self, y):
        s1 = self.e1(y)
        s2 = self.e2(s1)
        s3 = self.e3(s2)
        s4 = self.e4(s3)

        b = s4
        for blk in self.bottleneck:
            b = F.relu(b + blk(b))

        d1 = torch.cat([self.d1(b), s3], dim=1)
        d2 = torch.cat([self.d2(d1), s2], dim=1)
        d3 = torch.cat([self.d3(d2), s1], dim=1)
        d4 = self.d4(d3)

        return self.out(d4)


In [300]:
def fuse_generator(G):
    G_fused = GeneratorCNNWGAN_Fused(bias=True).to(device)
    G_fused.eval()

    with torch.no_grad():
        # Encoder
        for i in range(1, 5):
            e = getattr(G, f"e{i}")
            fe = getattr(G_fused, f"e{i}")

            w, b = fuse_conv_bn_1d(
                e.conv.weight, e.conv.bias,
                e.norm.running_mean, e.norm.running_var,
                e.norm.weight, e.norm.bias
            )
            fe.conv.weight.copy_(w)
            fe.conv.bias.copy_(b)

        # Bottleneck
        for i, blk in enumerate(G.bottleneck):
            fblk = G_fused.bottleneck[i]

            for j, (c, n) in enumerate([(blk.c1, blk.n1), (blk.c2, blk.n2)]):
                w, b = fuse_conv_bn_1d(
                    c.weight, c.bias,
                    n.running_mean, n.running_var,
                    n.weight, n.bias
                )
                fblk[j*2].weight.copy_(w)
                fblk[j*2].bias.copy_(b)

        # Decoder
        for i in range(1, 5):
            d = getattr(G, f"d{i}")
            fd = getattr(G_fused, f"d{i}")

            w, b = fuse_deconv_bn_1d(
                d.deconv.weight, d.deconv.bias,
                d.norm.running_mean, d.norm.running_var,
                d.norm.weight, d.norm.bias,
                d.norm.eps
            )
            fd.deconv.weight.copy_(w)
            fd.deconv.bias.copy_(b)

        # Output
        G_fused.out.weight.copy_(G.out.weight)
        G_fused.out.bias.copy_(G.out.bias)

    return G_fused


In [301]:
G_fused = fuse_generator(G)

fused_G_path = os.path.join(
    data_path,
    "fused_" + os.path.basename(cnn_G_path)
)

torch.save(G_fused.state_dict(), fused_G_path)
print("Saved:", fused_G_path)


Saved: c:\Users\Aryo\PersonalMade\Programming\GAN\repo\src\models\main3_d1_b\fused_cnn_G_20260113_074150.pth


In [302]:
G_check = GeneratorCNNWGAN_Fused().to(device)
G_check.load_state_dict(torch.load(fused_G_path, map_location=device))
G_check.eval()

for name, module in G_check.named_modules():
    print(name, "->", module.__class__.__name__)


 -> GeneratorCNNWGAN_Fused
e1 -> FusedConvBlock1D
e1.conv -> Conv1d
e1.act -> LeakyReLU
e2 -> FusedConvBlock1D
e2.conv -> Conv1d
e2.act -> LeakyReLU
e3 -> FusedConvBlock1D
e3.conv -> Conv1d
e3.act -> LeakyReLU
e4 -> FusedConvBlock1D
e4.conv -> Conv1d
e4.act -> LeakyReLU
bottleneck -> Sequential
bottleneck.0 -> Sequential
bottleneck.0.0 -> Conv1d
bottleneck.0.1 -> ReLU
bottleneck.0.2 -> Conv1d
bottleneck.1 -> Sequential
bottleneck.1.0 -> Conv1d
bottleneck.1.1 -> ReLU
bottleneck.1.2 -> Conv1d
bottleneck.2 -> Sequential
bottleneck.2.0 -> Conv1d
bottleneck.2.1 -> ReLU
bottleneck.2.2 -> Conv1d
bottleneck.3 -> Sequential
bottleneck.3.0 -> Conv1d
bottleneck.3.1 -> ReLU
bottleneck.3.2 -> Conv1d
d1 -> FusedDeconvBlock1D
d1.deconv -> ConvTranspose1d
d1.act -> ReLU
d2 -> FusedDeconvBlock1D
d2.deconv -> ConvTranspose1d
d2.act -> ReLU
d3 -> FusedDeconvBlock1D
d3.deconv -> ConvTranspose1d
d3.act -> ReLU
d4 -> FusedDeconvBlock1D
d4.deconv -> ConvTranspose1d
d4.act -> ReLU
out -> Conv1d


  G_check.load_state_dict(torch.load(fused_G_path, map_location=device))


In [303]:
def count_params_and_constants(model):
    n_params = sum(p.numel() for p in model.parameters())
    n_buffers = sum(b.numel() for b in model.buffers())
    return n_params, n_buffers, n_params + n_buffers


def pretty_count(name, model):
    p, b, t = count_params_and_constants(model)
    print(f"{name}")
    print(f"  Learnable parameters : {p:,}")
    print(f"  Buffers (BN stats)   : {b:,}")
    print(f"  TOTAL constants      : {t:,}")
    print()


pretty_count("Original Generator", G)
pretty_count("Fused Generator", G_fused)

print("=" * 60)

p0, b0, t0 = count_params_and_constants(G)
p1, b1, t1 = count_params_and_constants(G_fused)

print("DIFFERENCE")
print(f"  Params removed  : {p0 - p1:,}")
print(f"  Buffers removed : {b0 - b1:,}")
print(f"  Total reduction : {t0 - t1:,}")


Original Generator
  Learnable parameters : 4,584,161
  Buffers (BN stats)   : 5,552
  TOTAL constants      : 4,589,713

Fused Generator
  Learnable parameters : 4,578,625
  Buffers (BN stats)   : 0
  TOTAL constants      : 4,578,625

DIFFERENCE
  Params removed  : 5,536
  Buffers removed : 5,552
  Total reduction : 11,088


# Prepare for PYNQ

In [304]:
weights = []
offsets = {}

cursor = 0
for name, param in G_fused.state_dict().items():
    arr = param.cpu().numpy().astype(np.float16)   
    # arr = param.cpu().numpy().astype(np.float32)
    size = arr.size

    offsets[name] = {
        "offset": int(cursor),
        "shape": list(arr.shape),
        "dtype": "float16"    
    }

    weights.append(arr.reshape(-1))
    cursor += size

weights_flat = np.concatenate(weights)

weight_path = os.path.join(
    os.path.dirname(fused_G_path),
    f"fused_cnn_G_d{DATA_MODE}b_weights.npy"
)
offset_path = os.path.join(
    os.path.dirname(fused_G_path),
    f"fused_cnn_G_d{DATA_MODE}b_offsets.json"
)
np.save(weight_path, weights_flat)
with open(offset_path, "w") as f:
    json.dump(offsets, f, indent=2)

print("Total floats:", weights_flat.size)

Total floats: 4578625


In [305]:
print("Min:", weights_flat.min())
print("Max:", weights_flat.max())
print("Avg:", weights_flat.mean())

Min: -2.121
Max: 2.738
Avg: -0.000262
