In [1]:
import os

import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder

import cv2

from model_arch import TwoPicsGenerator, LitTwoPicsGenerator, VanillaGAN, IdentityBlock, LitGAN
from losses import LossForGenerator, LossForDiscriminator
from dataset import LitImageFolderDataModule

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

from demo_utils import interpolate_two_points, normalize_image

%matplotlib inline
%load_ext autoreload
%autoreload 2



In [2]:
HEIGHT = 32
WIDTH = 32
LOG_FOLDER_NAME = "lightning_logs"
EXPERIMENT_NAME = "vanilla_gan"
dataset_folder = r"data\mnist"

datamodule = LitImageFolderDataModule(dataset_folder)
datamodule.setup(HEIGHT, WIDTH, data_mean=[0.1318], data_std=[0.2812], num_workers=4)

In [70]:
emb_size = 3
model_gan = VanillaGAN(emb_size, datamodule.transforms, img_hw=(HEIGHT, WIDTH))

loss_G_fx = LossForGenerator()
loss_D_fx = LossForDiscriminator()

gan = LitGAN(model_gan, loss_G_fx, loss_D_fx, experiment_name=EXPERIMENT_NAME)

generator_intermediate_hws: tensor([[ 4,  4],
        [ 8,  8],
        [16, 16],
        [32, 32]], dtype=torch.int32)


In [71]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=3,
    verbose=True,
    monitor='epoch',
    mode='max',
    dirpath=os.path.join(LOG_FOLDER_NAME, gan.experiment_name)
)

logger = TensorBoardLogger(LOG_FOLDER_NAME, name=gan.experiment_name)
gan.enable_scheduler = False
gan.lr = 1e-5
gan.l2reg = 1e-9

# CKPT_NAME = "epoch=23-step=10752.ckpt"
# CKPT_PATH = os.path.join(LOG_FOLDER_NAME, gan.experiment_name, CKPT_NAME)

trainer = pl.Trainer(max_epochs=1000, callbacks=[checkpoint_callback], logger=logger)
trainer.fit(gan, train_dataloaders=datamodule.dataloader, 
            val_dataloaders=datamodule.dataloader)#, ckpt_path = CKPT_PATH)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                 | Params
---------------------------------------------------
0 | model     | VanillaGAN           | 28.2 K
1 | loss_G_fx | LossForGenerator     | 0     
2 | loss_D_fx | LossForDiscriminator | 0     
---------------------------------------------------
28.2 K    Trainable params
0         Non-trainable params
28.2 K    Total params
0.113     Total estimated model params size (MB)


CURRENT VERSION: 10


Training: 0it [00:00, ?it/s]

Epoch 0, global step 448: 'epoch' reached 0.00000 (best 0.00000), saving model to 'C:\\Users\\Admin\\Desktop\\pic_sampler\\lightning_logs\\vanilla_gan\\epoch=0-step=448.ckpt' as top 3
Epoch 1, global step 896: 'epoch' reached 1.00000 (best 1.00000), saving model to 'C:\\Users\\Admin\\Desktop\\pic_sampler\\lightning_logs\\vanilla_gan\\epoch=1-step=896.ckpt' as top 3
Epoch 2, global step 1344: 'epoch' reached 2.00000 (best 2.00000), saving model to 'C:\\Users\\Admin\\Desktop\\pic_sampler\\lightning_logs\\vanilla_gan\\epoch=2-step=1344.ckpt' as top 3
Epoch 3, global step 1792: 'epoch' reached 3.00000 (best 3.00000), saving model to 'C:\\Users\\Admin\\Desktop\\pic_sampler\\lightning_logs\\vanilla_gan\\epoch=3-step=1792.ckpt' as top 3
Epoch 4, global step 2240: 'epoch' reached 4.00000 (best 4.00000), saving model to 'C:\\Users\\Admin\\Desktop\\pic_sampler\\lightning_logs\\vanilla_gan\\epoch=4-step=2240.ckpt' as top 3
Epoch 5, global step 2688: 'epoch' reached 5.00000 (best 5.00000), saving 

In [None]:
gan.model.eval()
with torch.no_grad():
    z = torch.randn(1, emb_size)
    print(z.shape)
    fake = gan.model.generate(z)
    print(fake.shape)

img = fake.cpu().numpy()[0,0,...]
img = normalize_image(img)/255
print(img.max(), img.min())

cv2.imshow("image", cv2.resize(img, (500, 500))) 
  
# waits for user to press any key 
# (this is necessary to avoid Python kernel form crashing) 
cv2.waitKey(0) 
  
# closing all open windows 
cv2.destroyAllWindows() 

