In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=64,kernel_size=5, stride=1,requires_sum=True,generator=True):
        super(ConvBlock, self).__init__()
        self.requires_sum = requires_sum
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,stride=stride, padding='same'),
            nn.PReLU() if generator else nn.LeakyReLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,stride=stride, padding='same') if requires_sum else nn.Identity())
        
    def forward(self, x):
        if self.requires_sum is True:
            out = self.block(x)
            return x + out
        else:
            out = self.block(x)
            return x
        



class Generator(nn.Module):
    def __init__(self,in_channel=2, out_channel=64, blocks=4):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=11,stride=1,padding='same'),
            nn.PReLU()
        )

        self.blocks = nn.Sequential(*[ConvBlock() for _ in range(blocks)])
        self.conv = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=5,stride=1,padding='same')

        self.terminal = nn.Sequential(
            ConvBlock(in_channels=out_channel, out_channels=out_channel,requires_sum=False),
            ConvBlock(in_channels=out_channel, out_channels=out_channel,requires_sum=False),
            nn.Conv2d(in_channels=out_channel, out_channels=2,kernel_size=11,stride=1,padding='same')
        )

    def forward(self, z):
        z = self.initial(z)
        out = self.blocks(z)
        out = self.conv(out)
        out = out + z
        return self.terminal(out)



       
class Discriminator(nn.Module):
    def __init__(self,in_channels=3,out_channels=32):
        super(Discriminator, self).__init__()
        
        
        self.magnitude_path = nn.Sequential(
            ConvBlock(in_channels=1,out_channels=out_channels,stride=2,requires_sum=False,generator=False),
            ConvBlock(in_channels=out_channels,out_channels=out_channels,stride=2,requires_sum=False,generator=False),
            nn.Flatten(1,-1),
            nn.Linear(in_features=2048,out_features=)
            )
        
        

    
    def forward(self, complex, magnitude):
        pass

In [4]:
x = torch.randn(4,2,64,64)
residual = torch.randn(4,64,64,64)


gen = Generator(2, 64, 4)

gen(x).shape

torch.Size([4, 2, 64, 64])

In [6]:
x.flatten(1,-1).shape

torch.Size([4, 8192])

In [9]:
32 * 8 * 8

2048