UFORMER: A UNET BASED DILATED COMPLEX & REAL DUAL-PATH CONFORMER NETWORK FOR SIMULTANEOUS SPEECH ENHANCEMENT AND DEREVERBERATION  
https://arxiv.org/pdf/2111.06015.pdf

In [24]:
import torch
import torch.nn as nn
from torch import Tensor

## Encoder-Decoder Attention  

$$ \mathbf{G}_i = \sigma(\mathbf{W}_i^E * \mathbf{E}_i + \mathbf{W}_i^D * \mathbf{D}_i)\tag{10}$$  

$$ \hat{\mathbf{D}}_i = \sigma(\mathbf{W}^A_i * \mathbf{G}_i) \odot \mathbf{D}_i  \tag{11} $$

이전 레이어 출력과 skip connection을 각각 conv2d 돌린걸 합친걸 sigmoid 때려서 (10) 생성   
(10) 에다가 conv2d 때리고 sigmoid 때린걸 마스크로 해서 이전 레이어의 출력과 곱해줌.      
이거를 이전 레이어 출력과 concat해서 다음 레이어에 보내줌      

=>  

3개의 conv2d가 필요하다.  

```
For encoder-decoder attention, the kernel size of three Conv2ds is (2, 3)
```

In [54]:
class EncoderDecoderAttention(nn.Module):
    def __init__(self,channels) -> None :
        super(EncoderDecoderAttention,self).__init__()
        
        # Encoder Kernel  
        self.w_e = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(2,3),stride=1, padding="same", dilation=1)
        # Decoder Kernel  
        self.w_d = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(2,3),stride=1, padding="same", dilation=1)
        # Attention Kernel  
        self.w_a = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(2,3),stride=1, padding="same", dilation=1)
        
        self.sigma = nn.Sigmoid()
        
    def forward(self, d:Tensor, e:Tensor) -> Tensor:
        ## (10) extracting high dimensional feature
        g = self.sigma(self.w_e(e) + self.w_d(d))
        
        ## (11) multipling attention masking
        d_hat = torch.mul(self.sigma(self.w_a(g)),d)
        return d_hat
        
        

In [55]:
B = 2
C = 3
T = 4
F = 5
e = torch.rand(B,C,T,F)
d = torch.rand(B,C,T,F)
m = EncoderDecoderAttention(C)

y = m(d,e)

print(y.shape)

torch.Size([2, 3, 4, 5])


## Dilated Conformer

In [None]:
class DialatedConformer(nn.Module):
    def __init__(self) -> None :
        super(DialatedConformer,self,dim_in).__init__()
        
        dim_emb_1 = 1024
        num_head_1 = 8
        self.FF1 = nn.Linear(dim_in,dim_emb_1)
        self.TA = nn.MultiheadAttention(
            embed_dim=dim_emb_1,
            num_heads = num_head_1,
            batch_first=True
        )
        self.FA = nn.MultiheadAttention(
            embed_dim=dim_emb_1,
            num_heads = num_head_1,
            batch_first=True
        )
        self.DC
        self.FF2
        self.LN
    def forward(self, x:Tensor) -> Tensor : 
