In [1]:
import torch
import torch.nn as nn
from model_arch import TwoPicsGenerator, LitTwoPicsGenerator
from dataset import LitTwoImageDataModule

import os

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

from demo_utils import interpolate_two_points

%matplotlib inline
%load_ext autoreload
%autoreload 2



In [2]:
HEIGHT = 100
WIDTH = 200
LOG_FOLDER_NAME = "lightning_logs"

datamodule = LitTwoImageDataModule("pic_1.png", "pic_2.png")
datamodule.setup(HEIGHT, WIDTH, data_mean=0.3259, data_std=0.4484, num_workers=4)

In [3]:
emb_size = 1
img_hw = (HEIGHT, WIDTH)

model = TwoPicsGenerator(emb_size, img_hw=img_hw)

In [4]:
loss_fx = nn.MSELoss()
pl_model = LitTwoPicsGenerator(model, loss_fx)

In [8]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=3,
    verbose=True,
    monitor='avg loss',
    mode='min',
    dirpath=os.path.join(LOG_FOLDER_NAME, pl_model.experiment_name)
)

logger = TensorBoardLogger(LOG_FOLDER_NAME, name=pl_model.experiment_name)
pl_model.enable_scheduler = True
pl_model.lr = 1e-2
pl_model.l2reg = 1e-2

# CKPT_NAME = "epoch=68-step=1056390.ckpt"
# CKPT_PATH = os.path.join(LOG_FOLDER_NAME, pl_model.experiment_name, CKPT_NAME)

trainer = pl.Trainer(max_epochs=68, callbacks=[checkpoint_callback], logger=logger)
trainer.fit(pl_model, train_dataloaders=datamodule.dataloader)#, ckpt_path = CKPT_PATH)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | model   | TwoPicsGenerator | 12.6 K
1 | loss_fx | MSELoss          | 0     
---------------------------------------------
12.6 K    Trainable params
0         Non-trainable params
12.6 K    Total params
0.050     Total estimated model params size (MB)


CURRENT VERSION: 3


Training: 0it [00:00, ?it/s]

Epoch 0, global step 2048: 'avg loss' reached 0.40708 (best 0.40708), saving model to 'C:\\Users\\Admin\\Desktop\\pic_sampler\\lightning_logs\\experiment\\epoch=0-step=2048-v3.ckpt' as top 3
Epoch 1, global step 4096: 'avg loss' reached 0.40798 (best 0.40708), saving model to 'C:\\Users\\Admin\\Desktop\\pic_sampler\\lightning_logs\\experiment\\epoch=1-step=4096-v2.ckpt' as top 3
Epoch 2, global step 6144: 'avg loss' reached 0.40727 (best 0.40708), saving model to 'C:\\Users\\Admin\\Desktop\\pic_sampler\\lightning_logs\\experiment\\epoch=2-step=6144.ckpt' as top 3
Epoch 3, global step 8192: 'avg loss' reached 0.40520 (best 0.40520), saving model to 'C:\\Users\\Admin\\Desktop\\pic_sampler\\lightning_logs\\experiment\\epoch=3-step=8192.ckpt' as top 3


In [19]:
torch.cuda.empty_cache()

z_1 = torch.tensor((-0.5,))
z_2 = torch.tensor((2.5,))

interpolate_two_points(z_1, z_2, 300, pl_model.model, export_imgs_to="")