In [405]:
import torch
import torch.nn as nn
import math
from torch.nn import functional as F
from model import DownSample, UpSample

class Unet_model(nn.Module):
    def __init__(
        self, input_channel=1, out_channel=32, complex_input=True, log_flag=True, subband_flag=False):
        super().__init__()
        self.complex_input = complex_input
        self.log_flag = log_flag
        self.subband_flag = subband_flag
        
        self.down_1 = DownSample(
                input_channel=20,
                out_channel=64, 
                kernel_size=(4, 3), 
                stride=(4, 1),
                activation = nn.GELU(),
                normalization = nn.BatchNorm2d(64)
            )
        
        self.down_2 = DownSample(
                input_channel=64,
                out_channel=128, 
                kernel_size=(4, 3), 
                stride=(4, 1),
                activation = nn.GELU(),
                normalization = nn.BatchNorm2d(128)
            )
        
        self.lstm = nn.LSTM(input_size=256, 
                            hidden_size=256, 
                            num_layers=1, 
                            batch_first=True)
        
        
        self.up_1 = UpSample(
            input_channel=128+128, 
            out_channel=64, 
            kernel_size=(6, 3), 
            stride=(4, 1),
            activation = nn.GELU(),
            normalization = nn.BatchNorm2d(64)
        )    
        

        self.up_2 = UpSample(
            input_channel=64+64, 
            out_channel=32, 
            kernel_size=(4, 3), 
            stride=(4, 1),
            activation = nn.GELU(),
            normalization = nn.BatchNorm2d(32)
        )   
        
        self.final_conv = nn.Sequential(
            nn.Conv2d(32, 128, 1, 1, 0, bias=False),
            nn.GELU(),
            nn.Conv2d(128, 5, 1, 1, 0, bias=False)
        )
#         self.fc = nn.Sequential(
#             nn.Linear(40*32, 40*32),
#             nn.Sigmoid()
#         )
    
    def cac2cws(self, x):
        k = 5
        b,c,f,t = x.shape
        x = x.reshape(b,c,k,f//k,t)
        x = x.reshape(b,c*k,f//k,t)
        return x
    
    def cws2cac(self, x):
        k = 5
        b,c,f,t = x.shape
        x = x.reshape(b,c//k,k,f,t)
        x = x.reshape(b,c//k,f*k,t)
        return x
    
    def forward(self, x):
        B, C, F, T = x.shape
        
        if self.complex_input:
            x = torch.concat([x.real, x.imag], dim=1)
        elif self.log_flag:
            x = torch.log(x + 1e-5)
        
        if self.subband_flag:
            x = self.cac2cws(x)
        print(x.shape)
        x1 = self.down_1(x)
        x2 = self.down_2(x1)        
        _, C_c, F_c, _ = x2.shape
        x = x2.view(B, -1, T).permute(0, 2, 1) # (B, C, F, T) -> (B, C*F, T) -> (B, T, C*F)
        
        x = self.lstm(x)[0] + x
        
        x = x.permute(0, 2, 1).view(B, C_c, F_c, T)
        x = self.up_1(torch.concat([x, x2],dim=1))
        x = self.up_2(torch.concat([x, x1],dim=1))
        #x = self.final_conv(x)
        #x = self.cws2cac(x)
        x = x.view(B, -1, T).permute(0, 2, 1)
        #x = self.fc(x)
        
        return x#x.view(B, 1, F, T)

In [406]:
40*40

1600

In [407]:
real = torch.rand(1, 2, 200, 357)
imag = torch.rand(1, 2, 200, 357)
z = torch.complex(real, imag)

In [408]:
z.shape

torch.Size([1, 2, 200, 357])

In [413]:
sum(p.numel() for p in m.parameters() if p.requires_grad)

989664

In [410]:
m = Unet_model(subband_flag=True)

In [411]:
#%%timeit
m(z).shape

torch.Size([1, 20, 40, 357])


torch.Size([1, 357, 1280])

In [412]:
5 * 40

200

In [None]:
1280 -> 5, 40

In [399]:
m.final_conv(m(z)).shape

torch.Size([1, 20, 40, 357])


RuntimeError: Given groups=1, weight of size [128, 32, 1, 1], expected input[1, 1, 200, 357] to have 32 channels, but got 1 channels instead

In [290]:
Unet_model().cac2cws(z).shape

torch.Size([1, 20, 40, 357])

In [69]:
down_1 = DownSample(
                input_channel=10,
                out_channel=128, 
                kernel_size=(4, 3), 
                stride=(4, 1),
                activation = nn.GELU(),
                normalization = nn.BatchNorm2d(128)
            )

In [76]:
down_2 = DownSample(
                input_channel=128,
                out_channel=256, 
                kernel_size=(4, 3), 
                stride=(4, 1),
                activation = nn.GELU(),
                normalization = nn.BatchNorm2d(256)
            )

In [77]:
down_2(down_1(torch.rand((1, 10, 40, 100)))).shape

torch.Size([1, 256, 2, 100])

In [94]:
up_1 = UpSample(
            input_channel=256, 
            out_channel=128, 
            kernel_size=(4, 3), 
            stride=(4, 1),
            activation = nn.GELU(),
            normalization = nn.BatchNorm2d(128)
        )    

In [95]:
up_2 = UpSample(
            input_channel=128, 
            out_channel=32, 
            kernel_size=(4, 3), 
            stride=(4, 1),
            activation = nn.GELU(),
            normalization = nn.BatchNorm2d(32)
        )    

In [96]:
up_2(up_1(down_2(down_1(torch.rand((1, 10, 40, 100)))))).shape

torch.Size([1, 32, 32, 100])

In [97]:
32*32

1024