In [None]:
from model.dcgan import Generator, Discriminator, initialize_weights
from utils.dataset import ImageFolder
from utils.gan_process import training_loops
from utils.transforms import _get_test_transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torch
import torch.nn as nn

## Define hyperparameter

In [None]:
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 64
IMAGE_SIZE = 300
CHANNELS_IMG = 3
Z_DIM = 200
NUM_EPOCHS = 100
FEATURES_DISC = 128
FEATURES_GEN = 128

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Init model

In [None]:
gen_model = Generator(1, Z_DIM, FEATURES_GEN, CHANNELS_IMG).to(device)

disc_model = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)

initialize_weights(gen_model)
initialize_weights(disc_model)

## Init data

In [None]:
dataset = ImageFolder(
    root_dir=r"",
    transform=_get_test_transforms(IMAGE_SIZE),
)
data_loader = DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
)

## Define optimizer

In [None]:
gen_opt = optim.Adam(gen_model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
disc_opt = optim.Adam(disc_model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

## Define loss

In [None]:
criterion = nn.BCELoss()

## Training processs

In [None]:
training_loops(
    gen_model=gen_model,
    disc_model=disc_model,
    noise_dim=Z_DIM,
    data_loader=data_loader,
    criterion=criterion,
    gen_optimizer=gen_opt,
    disc_optimizer=disc_opt,
    device=device,
    num_epochs=NUM_EPOCHS,
    save_path=r"gan_checkpoints",
)