In [1]:
import pandas as pd

data = pd.read_csv("../data/design-seeds.csv")

In [2]:
multiplier = 24

In [3]:
from repalette.utils.data import PairRecolorDataset

In [4]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(data, test_size=0.2)
val, test = train_test_split(test, test_size=0.5)

In [5]:
from repalette.utils.data import ShuffleDataLoader

train_dataset = PairRecolorDataset(train, multiplier)
train_dataloader = ShuffleDataLoader(train_dataset, shuffle=False, num_workers=8, batch_size=8)

In [6]:
len(train_dataset)

4416

In [7]:
val_dataset = PairRecolorDataset(val, multiplier)
val_dataloader = ShuffleDataLoader(val_dataset, shuffle=False, num_workers=8, batch_size=8)

In [8]:
len(val_dataset)

552

In [9]:
test_dataset = PairRecolorDataset(test, multiplier)
test_dataloader = ShuffleDataLoader(test_dataset, shuffle=False, num_workers=8, batch_size=8)

In [10]:
from repalette.models import PaletteNet

In [11]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

In [12]:
from repalette.models import PaletteNet

In [13]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from repalette.constants import DEFAULT_LR, DEFAULT_BETAS, PL_LOGS_DIR, MODELS_DIR

checkpoint_callback = ModelCheckpoint(
    filepath=MODELS_DIR,
    monitor="Val/Loss",
    verbose=True,
    mode='min',
    save_top_k=2
)

early_stop_callback = EarlyStopping(
   monitor="Val/Loss",
   min_delta=0.00,
   patience=20,
   verbose=False,
   mode='min'
)

hparams={
    'lr': DEFAULT_LR, 'betas': DEFAULT_BETAS,
}
model = PaletteNet(
    train_dataloader, val_dataloader, test_dataloader,
    hparams=hparams,
)
logger = TensorBoardLogger(PL_LOGS_DIR, name='PaletteNet')
trainer = Trainer(
    gpus=1, # use gpu!
    logger=logger,
    callbacks=[early_stop_callback],
    checkpoint_callback=checkpoint_callback,
#     auto_lr_find="learning_rate"
)

# trainer.tune(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [14]:
# train_iter = iter(model.train_dataloader())

In [15]:
# batch = next(train_iter)

In [None]:
trainer.fit(model)


  | Name               | Type              | Params
---------------------------------------------------------
0 | feature_extractor  | FeatureExtractor  | 11 M  
1 | recoloring_decoder | RecoloringDecoder | 2 M   
2 | loss_fn            | MSELoss           | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

In [19]:
from skimage.color import lab2rgb
from torchvision.utils import make_grid

In [22]:
(original_img, _), (target_img, target_palette) = next(
    iter(val_dataloader)
)

original_img = original_img#.to(self.device)
target_img = target_img#.to(self.device)
target_palette = target_palette#.to(self.device)
model.eval()
model.to("cpu")

def lab_batch_to_rgb_image_grid(lab_batch):
    grid = make_grid(torch.stack([torch.from_numpy(lab2rgb(
        lab_image.permute(1, 2, 0).cpu())
    ).permute(2, 0, 1) for lab_image in lab_batch]), pad_value=1)
    return grid
    

with torch.no_grad():
    target_palette = nn.Flatten()(target_palette)
    recolored_img = model(original_img, target_palette)

In [103]:
original_grid = lab_batch_to_rgb_image_grid(original_img)

In [104]:
model.logger.experiment.add_image("original", original_grid, model.current_epoch)

In [105]:
target_grid = lab_batch_to_rgb_image_grid(target_img)

In [106]:
model.logger.experiment.add_image("target", target_grid, model.current_epoch)

In [23]:
target_palette_img = target_palette.view(-1, 3, 6, 1)
target_palette_grid = lab_batch_to_rgb_image_grid(target_palette_img)
model.logger.experiment.add_image("target_palette", target_palette_grid, model.current_epoch)

In [117]:
pad_value

In [24]:
target_palette_grid

tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000],
         [1.0000, 1.0000, 0.8877, 1.0000, 1.0000, 0.9245, 1.0000, 1.0000,
          0.9576, 1.0000, 1.0000, 0.9749, 1.0000, 1.0000, 0.9787, 1.0000,
          1.0000, 0.9830, 1.0000, 1.0000, 0.9868, 1.0000, 1.0000, 0.9621,
          1.0000, 1.0000],
         [1.0000, 1.0000, 0.1773, 1.0000, 1.0000, 0.1619, 1.0000, 1.0000,
          0.1459, 1.0000, 1.0000, 0.1355, 1.0000, 1.0000, 0.1345, 1.0000,
          1.0000, 0.1329, 1.0000, 1.0000, 0.1309, 1.0000, 1.0000, 0.1422,
          1.0000, 1.0000],
    

In [111]:
recolored_with_lum = torch.cat((original_img[:, 0:1, :, :], recolored_img), axis=1)

In [112]:
recolored_grid = lab_batch_to_rgb_image_grid(recolored_with_lum)

In [113]:
model.logger.experiment.add_image("recolored", recolored_grid, model.current_epoch)

In [79]:
with torch.no_grad():
    target_palette = nn.Flatten()(target_palette)
    recolored_img = model(original_img, target_palette)
    recolored = [torch.from_numpy(lab2rgb(
        torch.cat(
            (
                original_img_[0, :, :].unsqueeze(0),
                recolored_img_,
            ),
            axis=0,
        ).permute(1, 2, 0).cpu()
    ) for original_img_, recolored_img_ in zip(original_img, recolored_img))]

original_grid = make_grid(original_img)
target_grid = make_grid(target_img)
recolored_grid = make_grid(recolored)

torch.Size([432, 288])

In [96]:
recolored_img

tensor([[[[ 0.4306,  1.3108,  0.8856,  ...,  1.6597,  0.5155, -0.0754],
          [ 0.3227,  1.0049,  0.5212,  ...,  1.1652,  0.1366, -0.5990],
          [-0.2797,  0.2197, -0.0592,  ...,  1.1876,  0.4007, -0.5782],
          ...,
          [-0.3868,  0.4632,  0.7211,  ...,  1.3233,  0.5609, -0.0682],
          [ 1.0449,  2.3526,  2.4854,  ...,  2.2538,  1.8026,  0.9876],
          [ 1.0683,  2.3379,  2.2803,  ...,  1.8070,  1.3086,  0.9760]],

         [[ 0.6974,  0.9982,  1.1844,  ...,  0.3353,  0.5361,  0.4916],
          [ 0.6713,  1.0601,  1.1453,  ...,  0.4740,  0.6447,  0.8008],
          [ 0.4956,  0.6756,  0.7979,  ...,  0.3243,  0.4295,  0.7632],
          ...,
          [ 0.2500,  0.3411,  0.7670,  ...,  0.3854,  0.6659,  0.7979],
          [ 0.4134,  0.6028,  1.0891,  ...,  0.1348,  0.4815,  0.5672],
          [ 0.4168,  0.5767,  0.9490,  ..., -0.0307,  0.1862,  0.3163]]]])

In [16]:
torch.cat((np.zeros(2), np.zeros(3))).shape

TypeError: expected Tensor as element 0 in argument 0, but got numpy.ndarray