In [1]:
from anything_vae import (
    ResnetBlock2D,
    SelfAttention,
    Downsample2D,
    Upsample2D,
    DownEncoderBlock2D,
    UpDecoderBlock2D,
    UNetMidBlock2D,
    # Encoder,
    # Decoder,
    # AutoencoderKL,
    # VGGPerceptualLoss
)

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models import vgg16
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from torchvision import transforms, models as torchvision_models
from pytorch_lightning import LightningModule, Trainer, loggers, callbacks
# import pytorch_lightning as pl
from torchmetrics import MeanSquaredError
from PIL import Image

import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import kornia

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        # Change input channels from 1 (L channel) to 3 (LAB image)
        self.conv_in = nn.Conv2d(3, 128, kernel_size=3, padding=1)

        self.down_blocks = nn.ModuleList([
            DownEncoderBlock2D(128, 128, num_res_blocks=2, downsample=True),
            DownEncoderBlock2D(128, 256, num_res_blocks=2, downsample=True),
            DownEncoderBlock2D(256, 512, num_res_blocks=2, downsample=True),
            DownEncoderBlock2D(512, 512, num_res_blocks=2, downsample=False),
        ])

        self.mid_block = UNetMidBlock2D(512)

        self.conv_norm_out = nn.GroupNorm(32, 512, eps=1e-6, affine=True)
        self.conv_act = nn.SiLU()
        # The output channels remain the same
        self.conv_out = nn.Conv2d(512, 8, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv_in(x)  # x: [batch_size, 3, H, W]
        for block in self.down_blocks:
            x = block(x)
        x = self.mid_block(x)
        x = self.conv_norm_out(x)
        x = self.conv_act(x)
        x = self.conv_out(x)  # Output h: [batch_size, 8, H', W']
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        # Input channels remain the same (latent vector z)
        self.conv_in = nn.Conv2d(4, 512, kernel_size=3, padding=1)

        self.mid_block = UNetMidBlock2D(512)

        self.up_blocks = nn.ModuleList([
            UpDecoderBlock2D(512, 512, num_res_blocks=3, upsample=True),
            UpDecoderBlock2D(512, 512, num_res_blocks=3, upsample=True),
            UpDecoderBlock2D(512, 256, num_res_blocks=3, upsample=True),
            UpDecoderBlock2D(256, 128, num_res_blocks=3, upsample=False),
        ])

        self.conv_norm_out = nn.GroupNorm(32, 128, eps=1e-6, affine=True)
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(128, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv_in(x)  # x: [batch_size, 4, H', W']
        x = self.mid_block(x)
        for block in self.up_blocks:
            x = block(x)
        x = self.conv_norm_out(x)
        x = self.conv_act(x)
        x = self.conv_out(x)  # Output before activation: [batch_size, 3, H, W]
        x = torch.tanh(x)     # Ensure outputs are in [-1, 1]
        return x


In [3]:
class ColorizationDataset(Dataset):
    def __init__(self, data_folder, data_csv):
        self.data_folder = data_folder
        self.data_path = os.path.join(data_folder, data_csv)
        self.images = pd.read_csv(self.data_path)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        sketch_path, colored_path = self.images.iloc[idx]
        
        # Load and process the sketch image as grayscale
        sketch_image = self.transform(self.__loadImage(sketch_path).convert('L'))
        # sketch_image shape: [1, H, W]

        # Convert grayscale to RGB by repeating the channel or using gray_to_rgb
        sketch_image_rgb = sketch_image.repeat(3, 1, 1)
        # Alternatively:
        # sketch_image_rgb = kornia.color.gray_to_rgb(sketch_image.unsqueeze(0)).squeeze(0)
        # sketch_image_rgb shape: [3, H, W]

        # Apply rgb_to_lab conversion
        sketch_lab = kornia.color.rgb_to_lab(sketch_image_rgb.unsqueeze(0)).squeeze(0)
        sketch_lab_normalized = normalize_lab(sketch_lab)
        
        # Load and process the colored image
        colored_image = self.transform(self.__loadImage(colored_path).convert('RGB'))
        colored_lab = kornia.color.rgb_to_lab(colored_image.unsqueeze(0)).squeeze(0)
        colored_lab_normalized = normalize_lab(colored_lab)
        
        return sketch_lab_normalized, colored_lab_normalized  # Both are [3, H, W]
    
    def __loadImage(self, image_path):
        return Image.open(os.path.join(self.data_folder, image_path))


In [4]:
class VGGPerceptualLoss(LightningModule):
    def __init__(self, vgg_model):
        super().__init__()
        self.vgg = vgg_model
        self.criterion = nn.MSELoss()
        self.features = list(self.vgg.features[:16])
        self.features = nn.Sequential(*self.features).eval()
        
        for params in self.features.parameters():
            params.requires_grad = False

    def forward(self, x, y):
        return self.criterion(self.features(x),self.features(y))

In [5]:
def normalize_lab(lab_image):
    """Normalize LAB image channels."""
    # lab_image shape: [3, H, W]
    
    # Indexing without the batch dimension
    L_channel = lab_image[:1, :, :]        # L channel: [1, H, W]
    ab_channels = lab_image[1:, :, :]      # ab channels: [2, H, W]

    # Normalize the channels
    L_normalized = (L_channel / 50.0) - 1.0         # L from [0, 100] to [-1, 1]
    ab_normalized = (ab_channels + 128.0) / 127.5 - 1.0  # ab from [-128, 127] to [-1, 1]

    # Concatenate normalized channels back together
    normalized_lab = torch.cat([L_normalized, ab_normalized], dim=0)  # Shape: [3, H, W]
    return normalized_lab


def denormalize_lab(normalized_lab_image):
    """Denormalize LAB image channels back to original ranges."""
    L_normalized = normalized_lab_image[:, :1, :, :]
    ab_normalized = normalized_lab_image[:, 1:, :, :]

    L_channel = (L_normalized + 1.0) * 50.0  # L back to [0, 100]
    ab_channels = (ab_normalized + 1.0) * 127.5 - 128.0  # ab back to [-128, 127]

    return torch.cat([L_channel, ab_channels], dim=1)

def reconstruct_lab(L_channel, ab_channel):
    lab = torch.cat([L_channel, ab_channel], dim=1)
    rgb = kornia.color.lab_to_rgb(lab)
    return torch.clamp(rgb, 0.0, 1.0)

def reconstruct_lab_and_rgb(*images):
    """
    Reconstruct LAB and RGB images from normalized LAB tensors.

    Args:
        images: A sequence of normalized LAB images.

    Returns:
        A tuple of RGB images corresponding to the input LAB images.
    """
    rgb_images = []
    for image in images:
        # Denormalize LAB image
        lab_image = denormalize_lab(image)
        
        # Convert LAB to RGB
        rgb_image = kornia.color.lab_to_rgb(lab_image)
        
        # Clamp RGB values to [0, 1]
        rgb_image = torch.clamp(rgb_image, 0.0, 1.0)
        rgb_images.append(rgb_image)
    return tuple(rgb_images)


def visualize_model_output(inputs, outputs, targets, logger, global_step, tag='Visualization'):
    """Log model outputs as RGB images using the logger."""
    inputs_rgb, outputs_rgb, targets_rgb = reconstruct_lab_and_rgb(inputs, outputs, targets)
    
    if inputs_rgb.dim() == 4:  # Batch of images
        combined_images = torch.cat((inputs_rgb, targets_rgb, outputs_rgb), dim=3)
        grid = torchvision.utils.make_grid(combined_images, nrow=1)
    else:  # Single image
        combined_image = torch.cat((inputs_rgb, targets_rgb, outputs_rgb), dim=2)
        grid = torchvision.utils.make_grid(combined_image.unsqueeze(0), nrow=1)
    
    logger.experiment.add_image(tag, grid, global_step)


In [6]:
from pytorch_lightning import LightningModule
import torch
import torch.nn as nn
import torchvision
from torchvision.models import vgg16

class Colorizer(LightningModule):
    def __init__(self):
        super(Colorizer, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.quant_conv = nn.Conv2d(8, 8, kernel_size=1)  # Output 8 channels from quant_conv
        self.post_quant_conv = nn.Conv2d(4, 4, kernel_size=1)  # Expect 4 channels here
        
        vgg_model = vgg16(weights='DEFAULT')
        self.loss_fn = VGGPerceptualLoss(vgg_model)
        self.mse_loss_fn = nn.MSELoss()
        
        self.hparams.learning_rate = 0.0001
        self.high_loss_images = []  

    def encode(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        mean, logvar = torch.chunk(h, 2, dim=1)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mean + std * eps
        return z

    def decode(self, z):
        z = self.post_quant_conv(z)
        x_recon = self.decoder(z)
        return torch.tanh(x_recon)  # Output is [B, 3, H, W]


    def forward(self, x):
        return self.decode(self.encode(x))

    def configure_optimizers(self):
        return torch.optim.Adam(
            filter(lambda p: p.requires_grad, self.parameters()), 
            lr=self.hparams.learning_rate
        )
    
    def training_step(self, batch, batch_idx):
        inputs, targets = batch  # Both are [B, 3, H, W]
        outputs = self(inputs)   # Outputs: [B, 3, H, W]
        
        # Compute MSE loss in LAB space
        mse_loss = self.mse_loss_fn(outputs, targets)
        
        # Convert LAB to RGB for perceptual loss
        inputs_rgb, outputs_rgb, targets_rgb = reconstruct_lab_and_rgb(inputs, outputs, targets)
        perceptual_loss = self.loss_fn(outputs_rgb, targets_rgb)
        
        total_loss = perceptual_loss + mse_loss
        self.log('train_loss', total_loss)
        self.log('perceptual_loss', perceptual_loss)
        self.log('mse_loss', mse_loss)
        
        # Store high-loss images
        if total_loss.item() > 0.7:
            self.high_loss_images.append((total_loss.item(), inputs, targets, outputs))
        
        # Every 100 batches, log the highest-loss image
        if (batch_idx + 1) % 100 == 0 and self.high_loss_images:
            high_loss_image = max(self.high_loss_images, key=lambda x: x[0])
            _, input_img, target_img, output_img = high_loss_image
    
            visualize_model_output(
                input_img, output_img, target_img, self.logger, self.global_step, tag='High_Loss_Image'
            )
            
            # Clear high-loss images list for the next 100 batches
            self.high_loss_images.clear()
        
        # Log images every N batches
        if batch_idx % 1000 == 0:
            num_images = 4
            inputs_cpu = inputs[:num_images].detach().cpu()
            targets_cpu = targets[:num_images].detach().cpu()
            outputs_cpu = outputs[:num_images].detach().cpu()
    
            visualize_model_output(
                inputs_cpu, outputs_cpu, targets_cpu, self.logger, self.global_step, tag='Input_Target_Output'
            )
        
        return total_loss


In [7]:
# chkpt_file = '~/workspace/checkpoints/version_14.ckpt'
# model = Colorizer.load_from_checkpoint(chkpt_file)

In [8]:
model = Colorizer()

In [9]:
# pretrained_model = torch.load('anything-vae.pth', map_location='cpu')
# model = Colorizer()
# pretrained_state_dict = pretrained_model.state_dict()
# missing_keys, unexpected_keys = model.load_state_dict(pretrained_state_dict, strict=False)
# filtered_missing_keys = [key for key in missing_keys if not key.startswith('loss_fn')]
# assert len(filtered_missing_keys) == 0
# assert len(unexpected_keys) == 0

In [None]:
data_folder = 'data/training'
data_csv = 'data.csv'
training_dataset = ColorizationDataset(data_folder, data_csv)
dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True, num_workers=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
logger = loggers.TensorBoardLogger("tb_logs")
trainer = Trainer(accelerator="gpu", devices=1, max_epochs=5, logger=logger, log_every_n_steps=2)

In [12]:
data_folder = 'data/training'
data_csv = 'data.csv'

dataset = ColorizationDataset(data_folder='path_to_data', data_csv='data.csv')

# Fetch a sample
sketch_lab_normalized, colored_lab_normalized = dataset[0]

print('Sketch LAB shape:', sketch_lab_normalized.shape)
print('Colored LAB shape:', colored_lab_normalized.shape)


Sketch LAB shape: torch.Size([3, 512, 512])
Colored LAB shape: torch.Size([3, 512, 512])


In [13]:
trainer.fit(model, dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type              | Params
------------------------------------------------------
0 | encoder         | Encoder           | 34.2 M
1 | decoder         | Decoder           | 49.5 M
2 | quant_conv      | Conv2d            | 72    
3 | post_quant_conv | Conv2d            | 20    
4 | loss_fn         | VGGPerceptualLoss | 138 M 
5 | mse_loss_fn     | MSELoss           | 0     
------------------------------------------------------
220 M     Trainable params
1.7 M     Non-trainable params
222 M     Total params
888.046   Total estimated model params size (MB)
/home/ubuntu/miniconda3/envs/dl-env/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 0:   3%|▎         | 3911/129629 [36:49<19:43:43,  1.77it/s, v_num=40]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 0:  11%|█         | 14233/129629 [2:14:06<18:07:15,  1.77it/s, v_num=40]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 0:  16%|█▌        | 20858/129629 [3:16:34<17:05:05,  1.77it/s, v_num=40]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 0:  34%|███▍      | 44247/129629 [6:57:01<13:24:42,  1.77it/s, v_num=40]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 0:  52%|█████▏    | 67459/129629 [10:35:48<9:45:57,  1.77it/s, v_num=40]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 0:  71%|███████   | 91547/129629 [14:22:56<5:58:58,  1.77it/s, v_num=40]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 0:  93%|█████████▎| 120270/129629 [18:53:49<1:28:13,  1.77it/s, v_num=40]

/home/ubuntu/miniconda3/envs/dl-env/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [14]:
# trainer = Trainer(model, training_dataset, device)
# trainer.train()

In [15]:
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
import torch

def viewTensor(lab_tensor):
    lab_tensor = lab_tensor.unsqueeze(0) if lab_tensor.dim() == 3 else lab_tensor  # Add batch dim if needed
    rgb_tensor = kornia.color.lab_to_rgb(lab_tensor)
    rgb_tensor = torch.clamp(rgb_tensor.squeeze(0), 0.0, 1.0)  # Remove batch dim and clamp to [0, 1]

    # Display image
    image = to_pil_image(rgb_tensor)
    plt.imshow(image)
    plt.axis('off')
    plt.show()


In [16]:
model.eval()
data_folder = 'data/test'
data_csv = 'data.csv'
test_dataset = ColorizationDataset(data_folder, data_csv)
model.cpu()

Colorizer(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
  

In [17]:
x, y = test_dataset[10]

In [18]:
y.shape

torch.Size([3, 512, 512])

In [19]:
idx = 10
x, y = test_dataset[idx]
output = model(x.unsqueeze(0))

KeyboardInterrupt: 

In [None]:
x.shape

In [None]:
viewTensor(x)

In [None]:
viewTensor(output[0])

In [None]:
viewTensor(y)