In [1]:
from repalette.lightning.datamodules import AdversarialDataModule, AdversarialRecolorDataModule
from repalette.lightning.systems import PreTrainSystem, AdversarialSystem, AdversarialMSESystem

from repalette.constants import S3_PRETRAINED_MODEL_CHECKPOINT_PATH, LIGHTNING_LOGS_DIR, MODEL_CHECKPOINTS_DIR

In [2]:
generator = AdversarialMSESystem.load_from_checkpoint("gan.ckpt").generator

In [3]:
adversarial_datamodule = AdversarialRecolorDataModule(multiplier=16, size=1., batch_size=8, num_workers=14)

In [4]:
import pytorch_lightning as pl
from repalette.lightning.callbacks import LogAdversarialToTensorboard, LogAdversarialMSEToTensorboard
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, GPUStatsMonitor
import uuid
import os

log_recoloring_to_tensorboard = LogAdversarialMSEToTensorboard(batches=3)

logger = TensorBoardLogger(
    LIGHTNING_LOGS_DIR,
    name="gan",
    log_graph=True,
#     default_hp_metric=False,
    )

adv_checkpoints = ModelCheckpoint(
        dirpath=os.path.join(MODEL_CHECKPOINTS_DIR, "gan-from-gan"),
        verbose=True,
        mode="min",
        save_top_k=-1,
    )

In [5]:
trainer = pl.Trainer(gpus=-1,
     logger=logger,
     benchmark=True,
    enable_pl_optimizer=True,
     callbacks=[
         log_recoloring_to_tensorboard,
     ],
    checkpoint_callback=adv_checkpoints
    )

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [6]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.metrics.regression import MeanSquaredError
from torch import nn as nn

from repalette.constants import (
    DEFAULT_DISCRIMINATOR_LR,
    DEFAULT_GENERATOR_LR,
    DEFAULT_ADVERSARIAL_BETA_1,
    DEFAULT_ADVERSARIAL_BETA_2,
    DEFAULT_GENERATOR_WEIGHT_DECAY,
    DEFAULT_ADVERSARIAL_LAMBDA_MSE_LOSS,
)
from repalette.models import PaletteNet, Discriminator

In [7]:
adversarial_system = AdversarialMSESystem(
    generator_learning_rate=DEFAULT_GENERATOR_LR / 2, generator=generator, k=5, discriminator_learning_rate=DEFAULT_DISCRIMINATOR_LR / 2, discriminator_weight_decay=0.01, generator_weight_decay=0.003, p=0.15
)

In [None]:
trainer.fit(adversarial_system, adversarial_datamodule)


  | Name          | Type          | Params
------------------------------------------------
0 | generator     | PaletteNet    | 13.9 M
1 | discriminator | Discriminator | 2.8 M 
2 | MSE           | MSELoss       | 0     
------------------------------------------------
16.7 M    Trainable params
0         Non-trainable params
16.7 M    Total params


Validation sanity check: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validating: |          | 0/? [00:00<?, ?it/s]

Epoch 0, step 74189: None was not in top -1


Validating: |          | 0/? [00:00<?, ?it/s]

Epoch 1, step 148379: None was not in top -1


In [11]:
print("test")

test
