In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DNoizeConvTasBlock(nn.Module):
    """Very fast depthwise-separable conv block"""
    def __init__(self, channels=96, kernel=3):
        super().__init__()
        self.dw = nn.Conv1d(channels, channels, kernel, groups=channels, padding=kernel//2, bias=False)
        self.pw = nn.Conv1d(channels, channels, 1, bias=False)
        self.norm = nn.GroupNorm(8, channels)
        self.act = nn.GELU()

    def forward(self, x):
        residual = x
        x = self.dw(x)
        x = self.pw(x)
        x = self.norm(x)
        x = self.act(x)
        return x + residual

In [None]:
class DnoizeConvTasNet(nn.Module):
    """
    Ultra-fast CPU inference speech enhancement
    Params: ~80k-150k
    MACs/sec @16kHz: ~25-45M (faster than GTCRN)
    """
    def __init__(self, channels=96, num_blocks=4, kernel=3):
        super().__init__()
        self.enc = nn.Conv1d(1, channels, kernel_size=5, stride=2, padding=2, bias=False)
        self.norm_enc = nn.GroupNorm(8, channels)
        self.act_enc = nn.GELU()

        self.blocks = nn.Sequential(*[
            DNoizeConvTasBlock(channels, kernel) for _ in range(num_blocks)
        ])

        self.mask_head = nn.Conv1d(channels, channels, 1)
        self.dec = nn.ConvTranspose1d(channels, 1, kernel_size=5, stride=2, padding=2, output_padding=1)

    def forward(self, x):
        # x: [B, 1, T] waveform @16kHz
        enc = self.enc(x)
        enc = self.norm_enc(enc)
        enc = self.act_enc(enc)

        # Backbone (very fast)
        feat = self.blocks(enc)

        # Mask
        mask = torch.sigmoid(self.mask_head(feat))

        # Apply & decode
        masked = feat * mask
        out = self.dec(masked)

        # Trim to input length
        out = out[..., :x.shape[-1]]

        return out