In [2]:
from typing import Tuple, Optional, Union
import math
import torch as th
import torch
import torch.nn as nn
import torchaudio as ta
from torchaudio.transforms import PSD
import torch.nn.functional as F

In [65]:
class mask_estimator(nn.Module):
    def __init__(self, log_flag=False):
        
        super().__init__()
        self.log_flag = log_flag
        
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 32, 
                      kernel_size=(8, 1), stride=(4, 1)
                     ), 
            nn.InstanceNorm2d(32), 
            nn.GELU()
        )
        
        self.conv_2 = nn.Sequential(
            nn.Conv2d(32, 64, 
                      kernel_size=(8, 1), stride=(4, 1)
                     ), 
            nn.InstanceNorm2d(64), 
            nn.GELU()
        )
        
        self.conv_3 = nn.Sequential(
            nn.Conv2d(64, 128, 
                      kernel_size=(8, 1), stride=(4, 1)
                     ), 
            nn.InstanceNorm2d(128), 
            nn.GELU()
        )
        
        self.lstm = nn.LSTM(input_size=128, 
                            hidden_size=128, 
                            num_layers=1, 
                            batch_first=True)
        
        
        self.fc_b = nn.Sequential(
            nn.Linear(128, 128),
            nn.InstanceNorm1d(128), 
            nn.GELU()
        )
        
        self.fc_1 = nn.Sequential(
            nn.Linear(128, 256),
            nn.InstanceNorm1d(256), 
            nn.GELU()
        )
        
        self.fc_2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.InstanceNorm1d(512), 
            nn.GELU()
        )
        
        self.fc_3 = nn.Sequential(
            nn.Linear(512, 200),
            nn.InstanceNorm1d(200), 
            nn.GELU()
        )
    def forward(self, spec): # (B, C, F, T)
        B, C, F, T = spec.shape
        
        mag = spec.abs().mean(dim=1)  # (B, F, T)
        if self.log_flag:
            mag = torch.log(mag + 1e-5) 
        
        mag = mag.view(B, 1, F, T) # (B, C, F, T)
        
        x = self.conv_1(mag)
        x = self.conv_2(x)
        x = self.conv_3(x)
        
        _, _, F_c, _ = x.shape
        x = x.view(B, -1, T).permute(0, 2, 1) # (B, C, F, T) -> (B, C*F, T) -> (B, T, C*F)
        x = self.lstm(x)[0] + self.fc_b(x) + x
        
        x = self.fc_1(x)
        x = self.fc_2(x)
        x = self.fc_3(x)
        return  x.permute(0,2,1)

In [66]:
mask_m = mask_estimator()

In [67]:
spect = torch.rand(1, 1, 200, 357)

In [68]:
mask_m(spect)

tensor([[[ 0.1317, -0.1699, -0.1594,  ...,  0.7900, -0.0456, -0.1331],
         [ 0.6650,  1.2864,  0.6785,  ...,  0.3314,  0.8197, -0.0612],
         [ 0.6642, -0.1651, -0.1318,  ...,  0.1767,  0.3496,  0.1475],
         ...,
         [ 0.3783,  0.6874,  0.7025,  ...,  0.7915,  1.0150,  0.8384],
         [ 0.1408,  0.0096,  0.6543,  ...,  0.4268,  1.2694,  0.1685],
         [-0.1182, -0.1667, -0.1699,  ...,  0.3841, -0.1414,  0.3926]]],
       grad_fn=<PermuteBackward0>)

In [69]:
sum(p.numel() for p in mask_m.parameters())

498216

In [71]:
%%timeit
mask_m(spect)

13.7 ms ± 503 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
