In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, 
                 dilation=1, groups=1, bias=True):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            stride=stride, dilation=dilation,
            groups=groups, bias=bias
        )
        # CRITICAL: Proper causal padding calculation
        self.padding = (kernel_size - 1) * dilation
        
    def forward(self, x):
        # Apply left padding for causality
        x = F.pad(x, (self.padding, 0))
        # Apply convolution
        out = self.conv(x)
        return out

In [None]:
class CausalUpsample(nn.Module):
    """
    Causal upsampling by factor 2 using repeat_interleave.
    Fully native on DirectML — no CPU fallback.
    output[2t] = output[2t+1] = input[t]
    Strictly causal: output at t depends only on input at floor(t/2).
    """
    def forward(self, x):
        return x.repeat_interleave(2, dim=-1)

In [None]:
class CausalDNoizeBlock(nn.Module):
    """Causal depthwise-separable block with InstanceNorm and GLU gating"""
    def __init__(self, channels=96, kernel=3, dilation=1):
        super().__init__()
        self.dw = CausalConv1d(channels, channels, kernel,
                               groups=channels, dilation=dilation, bias=False)
        # GLU: double channels, split into content + gate
        self.pw = nn.Conv1d(channels, channels * 2, 1, bias=False)
        self.norm = nn.InstanceNorm1d(channels, affine=True, eps=1e-5)

    def forward(self, x):
        residual = x
        out = self.dw(x)
        out = self.pw(out)
        content, gate = out.chunk(2, dim=1)
        out = content * torch.sigmoid(gate)
        out = self.norm(out)
        if out.shape[-1] != residual.shape[-1]:
            out = out[..., :residual.shape[-1]]
        return out + residual


In [None]:
class CausalDNoizeConvTasNet(nn.Module):
    """
    Ultra-fast CPU inference speech enhancement
    Params: ~80k-150k
    MACs/sec @16kHz: ~25-45M (faster than GTCRN)
    """

    def __init__(self, channels=48, num_blocks=9, num_repeats=2, kernel=3):
        super().__init__()

        # Encoder — stride 2 downsampling
        self.enc = CausalConv1d(1, channels, kernel_size=5, stride=2)
        self.norm_enc = nn.InstanceNorm1d(channels, affine=True, eps=1e-5)
        self.act_enc = nn.GELU()

        # Processing blocks
        self.blocks = nn.Sequential(*[
            CausalDNoizeBlock(channels, kernel, dilation=2**(i % num_blocks))
            for i in range(num_repeats * num_blocks)
        ])

        # Bottleneck mask head
        self.mask_head = nn.Sequential(
            nn.Conv1d(channels, channels * 2, 1),
            nn.GELU(),
            nn.Conv1d(channels * 2, channels, 1)
        )

        # Causal decoder — CausalUpsample | Native Support for DirectML
        self.dec = nn.Sequential(
            CausalUpsample(),
            CausalConv1d(channels, channels, kernel_size=5, stride=1),
            nn.GELU(),
            CausalConv1d(channels, 1, kernel_size=3, stride=1)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.InstanceNorm1d):
                if m.weight is not None:
                    nn.init.ones_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        # Mask head — zero init for neutral starting mask
        for m in self.mask_head.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.zeros_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        # Output layer of decoder — small gain to prevent power explosion
        # Last Conv1d in dec is the CausalConv1d(channels, 1, ...) output layer
        dec_convs = [m for m in self.dec.modules() if isinstance(m, nn.Conv1d)]
        for conv in dec_convs:
            nn.init.xavier_uniform_(conv.weight, gain=0.01)
            if conv.bias is not None:
                nn.init.zeros_(conv.bias)

    def forward(self, x):
        input_length = x.shape[-1]

        enc = self.enc(x)
        enc = self.norm_enc(enc)
        enc = self.act_enc(enc)

        feat = self.blocks(enc)

        mask = torch.sigmoid(self.mask_head(feat))
        masked = feat * mask

        out = self.dec(masked)

        # Upsample(nearest) produces exactly 2×encoded_length
        # encoded_length = ceil(input_length / 2) so out may be 1 sample long/short
        # Trim or pad to match input exactly
        if out.shape[-1] > input_length:
            out = out[..., :input_length]
        elif out.shape[-1] < input_length:
            out = F.pad(out, (0, input_length - out.shape[-1]))

        out = torch.tanh(out) * 0.1
        out = x + out

        return out

In [None]:
# import torch_directml
# dml = torch_directml.device()

# model = CausalDNoizeConvTasNet(channels=48, num_blocks=9, num_repeats=2).to(dml)
# model.train()

# x = torch.randn(4, 1, 48000).to(dml)
# y = torch.randn(4, 1, 48000).to(dml)

# pred = model(x)
# loss = F.l1_loss(pred, y)
# loss.backward()

# print(f"OK | full model forward+backward | loss={loss.item():.4f}")
# print(f"   input:  {x.shape}  output: {pred.shape}")

# # Check no CPU fallback warnings appeared above
# # Then run the power debug
# del model, x, y, pred, loss
# # gc.collect()

In [None]:
# import torch_directml
# dml = torch_directml.device()

# x = torch.randn(4, 48, 24000, requires_grad=True).to(dml)
# up = CausalUpsample().to(dml)

# out = up(x)
# loss = out.mean()
# loss.backward()
# print(f"OK | CausalUpsample forward+backward | in={x.shape} out={out.shape}")
# print(f"grad shape: {x.grad.shape}")

In [None]:
# # Verify causal decoder produces correct shapes and power
# model = CausalDNoizeConvTasNet(channels=48, num_blocks=9, num_repeats=2)

# x = torch.randn(4, 1, 48000)
# y = model(x)

# assert x.shape == y.shape, f"Shape mismatch: {x.shape} vs {y.shape}"
# print(f"Shape check:  {x.shape} → {y.shape}  ✓")

# # Power sanity check
# est_zm = y.squeeze(1) - y.squeeze(1).mean(dim=-1, keepdim=True)
# tgt_zm = x.squeeze(1) - x.squeeze(1).mean(dim=-1, keepdim=True)
# print(f"pred power: {est_zm.pow(2).sum(dim=-1).tolist()}")
# print(f"targ power: {tgt_zm.pow(2).sum(dim=-1).tolist()}")
# print(f"pred range: [{y.min().item():.4f}, {y.max().item():.4f}]")

# # Parameter count
# total = sum(p.numel() for p in model.parameters())
# print(f"Parameters: {total:,}")
# # ```

# # The parameter count will increase slightly — the new decoder has two `CausalConv1d` layers replacing one `ConvTranspose1d`. With `channels=48`:
# # ```
# # Old: ConvTranspose1d(48→1, k=5)           =  48x1x5 = 240 params
# # New: CausalConv1d(48→48, k=5)             =  48x48x5 = 11,520 params
# #      CausalConv1d(48→1, k=3)              =  48x1x3 = 144 params
# #      Total new decoder                    = 11,664 params

In [None]:
# # Debug — check what dec_convs actually contains
# # model = CausalDNoizeConvTasNet(channels=48, num_blocks=9, num_repeats=2)
# dec_convs = [m for m in model.dec.modules() if isinstance(m, nn.Conv1d)]
# for i, c in enumerate(dec_convs):
#     print(f"dec_convs[{i}]: in={c.in_channels} out={c.out_channels} k={c.kernel_size} "
#           f"weight_norm={c.weight.abs().mean().item():.6f}")

In [None]:
# model = CausalDNoizeConvTasNet(channels=48, num_blocks=9, num_repeats=2)
# model.eval()

# # Simulate realistic audio amplitude (VoiceBank-DEMAND is normalized to ~0.1-0.3 RMS)
# x = torch.randn(4, 1, 48000) * 0.15  # realistic RMS ~0.15

# with torch.no_grad():
#     enc = model.enc(x)
#     enc = model.norm_enc(enc)
#     enc = model.act_enc(enc)
#     feat = model.blocks(enc)
#     mask = torch.sigmoid(model.mask_head(feat))
#     masked = feat * mask
#     dec_out = model.dec(masked)
#     if dec_out.shape[-1] > x.shape[-1]:
#         dec_out = dec_out[..., :x.shape[-1]]
#     correction = torch.tanh(dec_out) * 0.25
#     final = x + correction

# print(f"x range:          [{x.min():.4f}, {x.max():.4f}]")
# print(f"dec_out range:    [{dec_out.min():.4f}, {dec_out.max():.4f}]")
# print(f"correction range: [{correction.min():.4f}, {correction.max():.4f}]")
# print(f"final range:      [{final.min():.4f}, {final.max():.4f}]")
# print(f"pred power: {final.squeeze(1).pow(2).sum(dim=-1).tolist()}")
# print(f"targ power: {x.squeeze(1).pow(2).sum(dim=-1).tolist()}")
# # ```

# # Expected output with realistic audio:
# # ```
# # x range:          [-0.65, 0.68]
# # dec_out range:    [-0.02, 0.02]   ← small due to gain=0.01
# # correction range: [-0.02, 0.02]   ← tanh * 0.25 of small values
# # final range:      [-0.67, 0.70]   ← x dominates, small correction added
# # pred power ≈ targ power           ← ratio near 1.0