In [1]:
import torch
import torch.nn as nn
import math
from torch.nn import functional as F

In [2]:
class DownSample(nn.Module):
    def __init__(self, 
                 input_channel,
                 out_channel,
                 kernel_size,
                 stride,
                 padding,
                 activation
                ):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channel, out_channel, 
                      kernel_size=kernel_size, stride=stride, padding=padding
                     ), 
            nn.InstanceNorm2d(out_channel), 
            activation
        )
    def forward(self, x):
        return self.conv(x)

In [3]:
class UpSample(nn.Module):
    def __init__(self, 
                 input_channel,
                 out_channel,
                 kernel_size,
                 stride,
                 activation
                ):
        super().__init__()
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(input_channel, out_channel, 
                      kernel_size=kernel_size, stride=stride
                     ), 
            nn.InstanceNorm2d(out_channel), 
            activation
        )
    def forward(self, x):
        return self.conv(x)

In [4]:
UpSample(1, 2, 3, 1, nn.GELU())(torch.rand(1,1,200,200)).shape

torch.Size([1, 2, 202, 202])

In [5]:
class Unet_model(nn.Module):
    def __init__(
        self, input_channel=1, out_channel=32, log_flag=False):
        super().__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.log_flag = log_flag
        
        self.down_1 = DownSample(
                input_channel=1,
                out_channel=32, 
                kernel_size=(8, 1), 
                stride=(4, 1),
                padding = (0, 0),
                activation = nn.GELU(),
            )
        
        self.down_2 = DownSample(
                input_channel=32,
                out_channel=64, 
                kernel_size=(8, 1), 
                stride=(4, 1),
                padding = (0, 0),
                activation = nn.GELU()
            )
        
        self.down_3 = DownSample(
                input_channel=64,
                out_channel=128, 
                kernel_size=(8, 1), 
                stride=(4, 1),
                padding = (0, 0),
                activation = nn.GELU()
            )
             
        self.lstm = nn.LSTM(input_size=128, 
                            hidden_size=128, 
                            num_layers=1, 
                            batch_first=True)
        
        
#         self.fc_b = nn.Sequential(
#             nn.Linear(128, 128),
#             nn.InstanceNorm1d(128), 
#             nn.GELU()
#         )
        
        self.up_1 = UpSample(
            input_channel=128,
            out_channel=64, 
            kernel_size=(11, 1), 
            stride=(4, 1),
            activation = nn.GELU()
        )
        
        self.up_2 = UpSample(
            input_channel=64,
            out_channel=32, 
            kernel_size=(9, 1), 
            stride=(4, 1),
            activation = nn.GELU()
        )    
        
        self.up_3 = UpSample(
            input_channel=32,
            out_channel=16, 
            kernel_size=(8, 1), 
            stride=(4, 1),
            activation = nn.GELU()
        )
        
        self.conv = nn.Sequential(
             nn.Conv2d(
                 in_channels=16, 
                 out_channels=1, 
                 kernel_size=1, 
                 stride=1, 
                 padding=0
                     ),
            nn.Sigmoid()
            
        )
        
    def forward(self, x):
        B, C, F, T = x.shape
        
        x = x.abs().mean(dim=1).view(B, 1, F, T)  # (B, F, T) -> (B, 1, F, T)
        if self.log_flag:
            x = torch.log(x + 1e-5)


        x1 = self.down_1(x)
        x2 = self.down_2(x1)
        x3 = self.down_3(x2)
        
        _, C_c, F_c, _ = x3.shape
        x = x3.view(B, -1, T).permute(0, 2, 1) # (B, C, F, T) -> (B, C*F, T) -> (B, T, C*F)
        x = self.lstm(x)[0] + x
        x = x.permute(0, 2, 1).view(B, C_c, F_c, T)
        
        x = self.up_1(x  + x3)
        x = self.up_2(x  + x2)
        x = self.up_3(x  + x1)
        return self.conv(x)

In [6]:
unet = Unet_model()#(torch.rand(1, 1, 200, 375)).shape

In [7]:
unet(torch.rand(1, 1, 200, 375))

tensor([[[[0.5143, 0.4415, 0.4996,  ..., 0.5754, 0.5614, 0.5253],
          [0.4868, 0.4478, 0.4431,  ..., 0.5357, 0.5049, 0.5173],
          [0.5588, 0.5326, 0.4751,  ..., 0.5113, 0.5297, 0.5057],
          ...,
          [0.5550, 0.5480, 0.5402,  ..., 0.5816, 0.4968, 0.5821],
          [0.4998, 0.5027, 0.5096,  ..., 0.4832, 0.5342, 0.4980],
          [0.5163, 0.4946, 0.4933,  ..., 0.4789, 0.5332, 0.5021]]]],
       grad_fn=<SigmoidBackward0>)

In [8]:
%%timeit
unet(torch.rand(1, 1, 200, 400))

46.9 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [15]:
print(f"RT:{3000 // 50}")

RT:60


In [16]:
(49-8+2)//4 + 1

11

In [17]:
(11 - 1)*4 + (8-1) + 1

48

In [18]:
(49-8+2)//4 + 1

11