In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, 
                 dilation=1, groups=1, bias=True):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            stride=stride, dilation=dilation,
            groups=groups, bias=bias
        )
        # CRITICAL: Proper causal padding calculation
        self.padding = (kernel_size - 1) * dilation
        
    def forward(self, x):
        # Apply left padding for causality
        x = F.pad(x, (self.padding, 0))
        # Apply convolution
        out = self.conv(x)
        return out

In [None]:
class CausalDNoizeBlock(nn.Module):
    """Causal depthwise-separable block with InstanceNorm and GLU gating"""
    def __init__(self, channels=96, kernel=3, dilation=1):
        super().__init__()
        self.dw = CausalConv1d(channels, channels, kernel,
                               groups=channels, dilation=dilation, bias=False)
        # GLU: double channels, split into content + gate
        self.pw = nn.Conv1d(channels, channels * 2, 1, bias=False)
        self.norm = nn.InstanceNorm1d(channels, affine=True, eps=1e-5)

    def forward(self, x):
        residual = x
        out = self.dw(x)
        out = self.pw(out)
        content, gate = out.chunk(2, dim=1)
        out = content * torch.sigmoid(gate)
        out = self.norm(out)
        if out.shape[-1] != residual.shape[-1]:
            out = out[..., :residual.shape[-1]]
        return out + residual

In [None]:
class CausalDNoizeConvTasNet(nn.Module):
    """
    Ultra-fast CPU inference speech enhancement
    Params: ~80k-150k
    MACs/sec @16kHz: ~25-45M (faster than GTCRN)
    """
    def __init__(self, channels=48, num_blocks=10, num_repeats=2, kernel=3):
        super().__init__()

        self.enc = CausalConv1d(1, channels, kernel_size=5, stride=2)
        self.norm_enc = nn.InstanceNorm1d(channels, affine=True, eps=1e-5)
        self.act_enc = nn.GELU()

        self.blocks = nn.Sequential(*[
            CausalDNoizeBlock(channels, kernel, dilation=2**(i % num_blocks))
            for i in range(num_repeats * num_blocks)
        ])

        # Bottleneck mask head
        self.mask_head = nn.Sequential(
            nn.Conv1d(channels, channels * 2, 1),
            nn.GELU(),
            nn.Conv1d(channels * 2, channels, 1)
        )

        self.dec = nn.ConvTranspose1d(
            channels, 1,
            kernel_size=5,
            stride=2,
            padding=2,
            output_padding=1
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.InstanceNorm1d):
                if m.weight is not None:
                    nn.init.ones_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        # Mask head — iterate modules since it's Sequential, not single layer
        for m in self.mask_head.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.zeros_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        # Decoder small init to prevent output power explosion
        nn.init.xavier_uniform_(self.dec.weight, gain=0.01)
        nn.init.zeros_(self.dec.bias)

    def forward(self, x):
        input_length = x.shape[-1]

        enc = self.enc(x)
        enc = self.norm_enc(enc)
        enc = self.act_enc(enc)

        feat = self.blocks(enc)

        mask = torch.sigmoid(self.mask_head(feat))
        masked = feat * mask

        out = self.dec(masked)

        if out.shape[-1] != input_length:
            if out.shape[-1] > input_length:
                out = out[..., :input_length]
            else:
                out = F.pad(out, (0, input_length - out.shape[-1]))

        out = torch.tanh(out) * 0.5
        out = x + out

        # Soft power normalization — nudge output toward input RMS
        # Uses a blend rather than hard normalization to preserve learned corrections
        eps = 1e-8
        input_rms  = x.pow(2).mean(dim=-1, keepdim=True).sqrt()
        output_rms = out.pow(2).mean(dim=-1, keepdim=True).sqrt()
        
        # Blend 30% normalized, 70% original — soft correction not hard override
        scale = input_rms / (output_rms + eps)
        out   = out * (0.7 + 0.3 * scale)

        return out

In [None]:
# # Test the model
# model = CausalDNoizeConvTasNet(channels=96, num_blocks=4)

# # Test with your actual audio length
# x = torch.randn(8, 1, 80000)  # Your batch size and length
# y = model(x)

# print(f"Input shape:  {x.shape}")
# print(f"Output shape: {y.shape}")
# assert x.shape == y.shape, "Shape mismatch!"
# print("✓ Shapes match!")

Input shape:  torch.Size([8, 1, 80000])
Output shape: torch.Size([8, 1, 80000])
✓ Shapes match!
