# Using reference image

1. Use reference image from near by

In [1]:
import os
import torch
from torch import nn
import torchvision
from torchvision import transforms, models as torchvision_models
from torch.utils.data import Dataset, DataLoader
import timm
import pandas as pd
from PIL import Image
from pytorch_lightning import LightningModule, Trainer, loggers, callbacks
from diffusers import StableDiffusionPipeline, AutoencoderKL, DiffusionPipeline

from torchvision.models import vgg16

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image

def viewTensor(output):
    image = to_pil_image(output.squeeze())

    # Display the image
    plt.imshow(image)
    plt.axis('off')  # Turn off axis numbers and ticks
    plt.show()

In [2]:
class ColorizationDataset(Dataset):
    # data
    def __init__(self, data_folder, data_csv, transform=None):
        """
        Args:
            input_dir (string): Directory with all the input images.
            output_dir (string): Directory with all the target (color) images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_folder = data_folder
        self.data_path = os.path.join(data_folder, data_csv)
        self.images = pd.read_csv(self.data_path)
        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  # Convert grayscale to RGB by replicating channels
            transforms.ToTensor()  # Convert images to PyTorch tensors
        ])
        self.tranform_output = transforms.Compose([transforms.ToTensor()])
        
    def __len__(self):
        return 1000
        # return len(self.images)

    def __getitem__(self, idx):
        sketch, colored = self.images.iloc[idx]
        sketch_image = self.transform(self.__loadImage(sketch))
        colored_image = self.tranform_output(self.__loadImage(colored))
        return sketch_image, colored_image

    def viewImage(self, idx):
        sketch, colored = self.images.iloc[idx]
        return self.__loadImage(sketch), self.__loadImage(colored)

    def __loadImage(self, image_path):
        return Image.open(os.path.join(self.data_folder, image_path))

class VGGPerceptualLoss(LightningModule):
    def __init__(self, vgg_model):
        super().__init__()
        self.vgg = vgg_model
        self.criterion = nn.MSELoss()
        self.features = list(self.vgg.features[:16])
        self.features = nn.Sequential(*self.features).eval()
        
        for params in self.features.parameters():
            params.requires_grad = False

    def forward(self, x, y):
        return self.criterion(self.features(x),self.features(y))

def color_histogram_loss(output, target, bins=256, min_value=0, max_value=1):
    hist_loss = 0.0
    for channel in range(3):
        output_hist = torch.histc(output[:, channel, :, :], bins=bins, min=min_value, max=max_value)
        target_hist = torch.histc(target[:, channel, :, :], bins=bins, min=min_value, max=max_value)
        output_hist /= output_hist.sum()
        target_hist /= target_hist.sum()
        hist_loss += torch.norm(output_hist - target_hist, p=2)
    return hist_loss / 3

class Colorizer(LightningModule):
    def __init__(self, vae):
        super().__init__()
        self.model = vae
        vgg_model = vgg16(weights=True)
        self.loss_fn = VGGPerceptualLoss(vgg_model)
        self.hparams.learning_rate = 0.00001

    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.hparams.learning_rate)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs).sample
        perceptual_loss = self.loss_fn(outputs, targets)
        histogram_loss = color_histogram_loss(outputs, targets)
        total_loss = perceptual_loss + histogram_loss * 2  # You can also use weights here if needed
        self.log('train_loss', total_loss)
        self.log('perceptual_loss', perceptual_loss)
        self.log('histogram_loss', histogram_loss)
        return total_loss

In [3]:
vae = torch.load('anything-vae.pth', map_location='cpu')

In [42]:
encoder = vae.encoder
decoder = vae.decoder
quant_conv = vae.quant_conv
post_quant_conv = vae.post_quant_conv

In [7]:
data_folder = 'data/test'
data_csv = 'data.csv'
test_dataset = ColorizationDataset(data_folder, data_csv)

In [8]:
x, y = test_dataset[0]

In [43]:
encoded_sketch = post_quant_conv(quant_conv(encoder(x.unsqueeze(0))))

RuntimeError: Given groups=1, weight of size [4, 4, 1, 1], expected input[1, 8, 64, 64] to have 4 channels, but got 8 channels instead

In [None]:
encoded_color = post_quant_conv(quant_conv(encoder(y.unsqueeze(0))))

In [None]:
encoded_color.shape

In [None]:
diff =  encoded_color - encoded_sketch

In [39]:
diff.shape

torch.Size([1, 8, 64, 64])

In [22]:
encoded_color.shape

torch.Size([1, 8, 64, 64])

In [40]:
decoder(diff)

RuntimeError: Given groups=1, weight of size [512, 4, 3, 3], expected input[1, 8, 64, 64] to have 4 channels, but got 8 channels instead

In [45]:
import torch
import torch.nn as nn

# Assuming your VAE model has these attributes, otherwise, you might need to create them
class ModifiedVAE(nn.Module):
    def __init__(self, vae):
        super(ModifiedVAE, self).__init__()
        self.encoder = vae.encoder
        self.decoder = vae.decoder
        self.quant_conv = vae.quant_conv
        self.post_quant_conv = vae.post_quant_conv

    def forward(self, x, y):
        # Encode inputs
        encoded_sketch = self.encoder(x.unsqueeze(0))
        encoded_color = self.encoder(y.unsqueeze(0))
        # Perform quantization
        quant_sketch = self.quant_conv(encoded_sketch)
        quant_color = self.quant_conv(encoded_color)
        # Compute the difference
        diff = quant_color - quant_sketch
        # Pass through post-quantization layer
        adjusted_diff = self.post_quant_conv(diff)
        # Decode the adjusted difference
        return self.decoder(adjusted_diff)

# Load the pre-trained VAE model
vae = torch.load('anything-vae.pth', map_location='cpu')

# Initialize the modified VAE with the adjusted channel layer
modified_vae = ModifiedVAE(vae)

# Example dataset and input
data_folder = 'data/test'
data_csv = 'data.csv'
test_dataset = ColorizationDataset(data_folder, data_csv)
x, y = test_dataset[0]

# Pass through the modified VAE
output = modified_vae(x, y)

print(output.shape)  # Verify the output shape


RuntimeError: Given groups=1, weight of size [4, 4, 1, 1], expected input[1, 8, 64, 64] to have 4 channels, but got 8 channels instead