In [4]:
import torch
import torchcomplex.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [47]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64,kernel_size=5, stride=1,requires_sum=True,generator=True):
        super(ConvBlock, self).__init__()
        self.requires_sum = requires_sum
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,stride=stride, padding= 2 if generator else 2),
            nn.zReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size,stride=stride, padding= 2 ))
        
    def forward(self, x):
        if self.requires_sum is True:
            out = self.block(x)
            return x + out
        else:
            out = self.block(x)
            return out
        
class Generator(nn.Module):
    def __init__(self,in_channel=2, out_channel=64, blocks=4):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=11,stride=1,padding=5),
            nn.zReLU())
        

        self.blocks = nn.Sequential(*[ConvBlock() for _ in range(blocks)])
        self.conv = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=5,stride=1,padding=2)

        self.terminal = nn.Sequential(
            ConvBlock(in_channels=out_channel, out_channels=out_channel,requires_sum=False),
            ConvBlock(in_channels=out_channel, out_channels=out_channel,requires_sum=False),
            nn.Conv2d(in_channels=out_channel, out_channels=2,kernel_size=11,stride=1,padding=5)
        )

    def forward(self, z):
        z = self.initial(z)
        print('clear 01')
        out = self.blocks(z)
        print('clear 02')
        out = self.conv(out)
        out = out + z
        return self.terminal(out)

In [64]:
x = torch.randn(4,2,4,4,dtype=torch.complex64).to('cpu')
gen = Generator(2, 64, 1)
gen.to('cpu')
gen(x).shape

t = gen(x)



torch.istft(t[0],n_fft=4)

clear 01
clear 02
clear 01
clear 02



tensor([[-4.1264e-03, -2.0781e-03,  9.3113e-05],
        [ 3.4261e-03,  1.8364e-03,  1.2095e-03]], grad_fn=<DivBackward0>)

In [33]:
block = nn.Sequential(
        nn.Conv2d(in_channels=2, out_channels=8, kernel_size=5,stride=1, padding=2),
        nn.zReLU(),
        nn.Conv2d(in_channels=8, out_channels=8, kernel_size=5,stride=1, padding=2))

In [34]:
x = torch.randn(1,2,8,8,dtype=torch.complex64)

In [35]:
block(x).shape

torch.Size([1, 8, 8, 8])