In [None]:
import torch.nn as nn

class Discriminator2(nn.Module):
    def __init__(self, img_channels=3):
        self.img_channels = img_channels
        super(Discriminator2, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.out = nn.Sequential(
            nn.Linear(128 * 12 * 12, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), self.img_channels, 48, 48)
        x = self.main(x)
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

discriminator2 = Discriminator2()
output = discriminator2(torch.randn(64, 3*48*48))
output.shape

In [None]:
class Generator2(nn.Module):
    def __init__(self, z_dim=100, img_channels=3):
        super(Generator2, self).__init__()
        self.init_size = 48 // 4  # Tamaño inicial después de la capa lineal
        self.input_size = z_dim
        self.img_channels = img_channels
        self.l1 = nn.Sequential(nn.Linear(self.input_size, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, self.img_channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        img = img.view(img.size(0), self.img_channels*48*48)
        return img
    
generator2 = Generator2(z_dim=100, img_channels=3)
output = generator2(torch.randn(64, 100))
output.shape