In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [15]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, 
                 dilation=1, groups=1, bias=True):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            stride=stride, dilation=dilation,
            groups=groups, bias=bias
        )
        # CRITICAL: Proper causal padding calculation
        self.padding = (kernel_size - 1) * dilation
        
    def forward(self, x):
        # Apply left padding for causality
        x = F.pad(x, (self.padding, 0))
        # Apply convolution
        out = self.conv(x)
        return out

In [16]:
class CausalDNoizeBlock(nn.Module):
    """Causal depthwise-separable block"""
    def __init__(self, channels=96, kernel=3):
        super().__init__()
        self.dw = CausalConv1d(channels, channels, kernel, groups=channels, 
                               bias=False)
        self.pw = nn.Conv1d(channels, channels, 1, bias=False)
        self.norm = nn.GroupNorm(8, channels)
        self.act = nn.GELU()

    def forward(self, x):
        residual = x
        
        out = self.dw(x)
        out = self.pw(out)
        out = self.norm(out)
        out = self.act(out)
        
        # Safety check: ensure sizes match
        if out.shape[-1] != residual.shape[-1]:
            out = out[..., :residual.shape[-1]]
        
        return out + residual

In [17]:
class CausalDNoizeConvTasNet(nn.Module):
    """
    Ultra-fast CPU inference speech enhancement
    Params: ~80k-150k
    MACs/sec @16kHz: ~25-45M (faster than GTCRN)
    """
    def __init__(self, channels=96, num_blocks=4, kernel=3):
        super().__init__()
        
        # Encoder with stride=2 (downsamples by 2x)
        self.enc = CausalConv1d(1, channels, kernel_size=5, stride=2)
        self.norm_enc = nn.GroupNorm(8, channels)
        self.act_enc = nn.GELU()

        # Processing blocks (same size)
        self.blocks = nn.Sequential(*[
            CausalDNoizeBlock(channels, kernel) 
            for _ in range(num_blocks)
        ])

        # Mask
        self.mask_head = nn.Conv1d(channels, channels, 1)
        
        # Decoder with stride=2 (upsamples by 2x)
        self.dec = nn.ConvTranspose1d(
            channels, 1, 
            kernel_size=5, 
            stride=2, 
            padding=2, 
            output_padding=1
        )

    def forward(self, x):
        input_length = x.shape[-1]
        
        # Encode (downsample)
        enc = self.enc(x)                # [B, 1, L] → [B, C, L/2]
        enc = self.norm_enc(enc)
        enc = self.act_enc(enc)
        
        # Process (same size)
        feat = self.blocks(enc)          # [B, C, L/2] → [B, C, L/2]
        
        # Mask
        mask = torch.sigmoid(self.mask_head(feat))
        masked = feat * mask
        
        # Decode (upsample)
        out = self.dec(masked)           # [B, C, L/2] → [B, 1, L]
        
        # Ensure exact input length match
        if out.shape[-1] != input_length:
            if out.shape[-1] > input_length:
                out = out[..., :input_length]
            else:
                pad_amount = input_length - out.shape[-1]
                out = F.pad(out, (0, pad_amount))
        
        return out

In [None]:
# # Test the model
# model = CausalDNoizeConvTasNet(channels=96, num_blocks=4)

# # Test with your actual audio length
# x = torch.randn(8, 1, 80000)  # Your batch size and length
# y = model(x)

# print(f"Input shape:  {x.shape}")
# print(f"Output shape: {y.shape}")
# assert x.shape == y.shape, "Shape mismatch!"
# print("✓ Shapes match!")

Input shape:  torch.Size([8, 1, 80000])
Output shape: torch.Size([8, 1, 80000])
✓ Shapes match!
