In [None]:
import torch
import torch.nn as nn
from data.dataloader import dataloader
from .GANs.generator import Generator
from .GANs.discriminator import Discriminator
from .GANs.methods import weight_init, training

In [None]:
print(torch.__version__)
print(torch.cuda.is_available())

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
latent_dim = 100

noise_z = torch.randn(batch_size, latent_dim, 1, 1, device=device)
train_loader, test_loader = dataloader('MNIST', batch_size=batch_size)

In [None]:
generator = Generator(latent_dim=latent_dim, img_channels=1, biais=False)
discriminator = Discriminator(biais=False)

generator.apply(weight_init)
discriminator.apply(weight_init)

In [None]:
from torchinfo import summary

summary(generator, input_size=(batch_size, latent_dim, 1, 1))
summary(discriminator, input_size=(batch_size, 1, 28, 28))

In [None]:
criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

training(
    generator=generator,
    discriminator=discriminator,
    train_loader=train_loader,
    epochs=5,
    optimizerG=optimizerG,
    optimizerD=optimizerD,
    criterion=criterion,
    device=device,
    use_amp=True
)