# Colorize with Unet

In [13]:
import os
import torch
from torch import nn
import torchvision
from torchvision import transforms, models as torchvision_models
from torch.utils.data import Dataset, DataLoader
import timm
import pandas as pd
from PIL import Image
from pytorch_lightning import LightningModule, Trainer, loggers, callbacks
from diffusers import StableDiffusionPipeline, AutoencoderKL, DiffusionPipeline

from diffusers import models

from torchvision.models import vgg16

In [14]:
class ColorizationDataset(Dataset):
    # data
    def __init__(self, data_folder, data_csv, transform=None):
        """
        Args:
            input_dir (string): Directory with all the input images.
            output_dir (string): Directory with all the target (color) images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_folder = data_folder
        self.data_path = os.path.join(data_folder, data_csv)
        self.images = pd.read_csv(self.data_path)
        self.transform = transforms.Compose([
            # transforms.Grayscale(num_output_channels=3),  # Convert grayscale to RGB by replicating channels
            transforms.ToTensor()  # Convert images to PyTorch tensors
        ])
        self.tranform_output = transforms.Compose([transforms.ToTensor()])
        
    def __len__(self):
        return 1000
        # return len(self.images)

    def __getitem__(self, idx):
        sketch, colored = self.images.iloc[idx]
        sketch_image = self.transform(self.__loadImage(sketch))
        colored_image = self.tranform_output(self.__loadImage(colored))
        return sketch_image, colored_image

    def viewImage(self, idx):
        sketch, colored = self.images.iloc[idx]
        return self.__loadImage(sketch), self.__loadImage(colored)

    def __loadImage(self, image_path):
        return Image.open(os.path.join(self.data_folder, image_path)).convert('RGBA')

class VGGPerceptualLoss(LightningModule):
    def __init__(self, vgg_model):
        super().__init__()
        self.vgg = vgg_model
        self.criterion = nn.MSELoss()
        self.features = list(self.vgg.features[:16])
        self.features = nn.Sequential(*self.features).eval()
        
        for params in self.features.parameters():
            params.requires_grad = False

    def forward(self, x, y):
        return self.criterion(self.features(x),self.features(y))

def color_histogram_loss(output, target, bins=256, min_value=0, max_value=1):
    hist_loss = 0.0
    for channel in range(3):
        output_hist = torch.histc(output[:, channel, :, :], bins=bins, min=min_value, max=max_value)
        target_hist = torch.histc(target[:, channel, :, :], bins=bins, min=min_value, max=max_value)
        output_hist /= output_hist.sum()
        target_hist /= target_hist.sum()
        hist_loss += torch.norm(output_hist - target_hist, p=2)
    return hist_loss / 3

In [15]:

class Colorizer(LightningModule):
    def __init__(self, unet):
        super().__init__()
        self.model = unet
        vgg_model = vgg16(weights=True)
        self.loss_fn = VGGPerceptualLoss(vgg_model)
        self.hparams.learning_rate = 0.00001

    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        return torch.optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.hparams.learning_rate)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs).sample
        perceptual_loss = self.loss_fn(outputs, targets)
        histogram_loss = color_histogram_loss(outputs, targets)
        total_loss = perceptual_loss + histogram_loss * 2  # You can also use weights here if needed
        self.log('train_loss', total_loss)
        self.log('perceptual_loss', perceptual_loss)
        self.log('histogram_loss', histogram_loss)
        return total_loss

In [16]:
pipeline = DiffusionPipeline.from_pretrained("AdamOswald1/Anything-Preservation")

safety_checker/model.safetensors not found


In [17]:
model = Colorizer(pipeline.unet)



In [18]:
unet = pipeline.unet

In [19]:
data_folder = 'data/training'
data_csv = 'data.csv'
training_dataset = ColorizationDataset(data_folder, data_csv)
dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True, num_workers=1)

In [20]:
x, y = training_dataset[0]

In [21]:
x.shape

torch.Size([4, 512, 512])

In [22]:
unet.eval()

UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0-1): 2 x Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_fe

In [23]:
import ipdb

In [None]:
result = unet(x.unsqueeze(0) , encoder_hidden_states=pipeline.vae.encoder, timestep=1)

--Return--
None
> [0;32m/tmp/ipykernel_1436/2410953089.py[0m(1)[0;36m<module>[0;34m()[0m
[0;32m----> 1 [0;31m[0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      2 [0;31m[0mresult[0m [0;34m=[0m [0munet[0m[0;34m([0m[0mx[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;36m0[0m[0;34m)[0m [0;34m,[0m [0mencoder_hidden_states[0m[0;34m=[0m[0mpipeline[0m[0;34m.[0m[0mvae[0m[0;34m.[0m[0mencoder[0m[0;34m,[0m [0mtimestep[0m[0;34m=[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  
ipdb>  d


*** Newest frame


ipdb>  n


[0;31m    [... skipped 1 hidden frame][0m

> [0;32m/home/ubuntu/miniconda3/envs/dl-env/lib/python3.8/site-packages/IPython/core/interactiveshell.py[0m(3511)[0;36mrun_code[0;34m()[0m
[0;32m   3510 [0;31m                [0;31m# Reset our crash handler in place[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 3511 [0;31m                [0msys[0m[0;34m.[0m[0mexcepthook[0m [0;34m=[0m [0mold_excepthook[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   3512 [0;31m        [0;32mexcept[0m [0mSystemExit[0m [0;32mas[0m [0me[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


> [0;32m/home/ubuntu/miniconda3/envs/dl-env/lib/python3.8/site-packages/IPython/core/interactiveshell.py[0m(3448)[0;36mrun_ast_nodes[0;34m()[0m
[0;32m   3447 [0;31m                    [0masy[0m [0;34m=[0m [0mcompare[0m[0;34m([0m[0mcode[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 3448 [0;31m                [0;32mif[0m [0;32mawait[0m [0mself[0m[0;34m.[0m[0mrun_code[0m[0;34m([0m[0mcode[0m[0;34m,[0m [0mresult[0m[0;34m,[0m [0masync_[0m[0;34m=[0m[0masy[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   3449 [0;31m                    [0;32mreturn[0m [0;32mTrue[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  u


*** all frames above hidden, use `skip_hidden False` to get get into those.


ipdb>  n


[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

[0;31m    [... skipped 1 hidden frame][0m

> [0;32m/home/ubuntu/miniconda3/envs/dl-env/lib/python3.8/site-packages/IPython/core/interactiveshell.py[0m(3436)[0;36mrun_ast_nodes[0;34m()[0m
[0;32m   3435 [0;31m[0;34m[0m[0m
[0m[0;32m-> 3436 [0;31m            [0;32mfor[0m [0mnode[0m[0;34m,[0m [0mmode[0m [0;32min[0m [0mto_run[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   3437 [0;31m                [0;32mif[0m [0mmode[0m [0;34m==[0m [0;34m"exec"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/ubuntu/miniconda3/envs/dl-env/lib/python3.8/site-packages/IPython/core/interactiveshell.py[0m(3437)[0;36mrun_ast_nodes[0;34m()[0m
[0;32m   3436 [0;31m            [0;32mfor[0m [0mnode[0m[0;34m,[0m [0mmode[0m [0;32min[0m [0mto_run[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 3437 [0;31m                [0;32mif[0m [0mmode[0m [0;34m==[0m [0;34m"exec"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   3438 [0;31m                    [0mmod[0m [0;34m=[0m [0mModule[0m[0;34m([0m[0;34m[[0m[0mnode[0m[0;34m][0m[0;34m,[0m [0;34m[[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/home/ubuntu/miniconda3/envs/dl-env/lib/python3.8/site-packages/IPython/core/interactiveshell.py[0m(3438)[0;36mrun_ast_nodes[0;34m()[0m
[0;32m   3437 [0;31m                [0;32mif[0m [0mmode[0m [0;34m==[0m [0;34m"exec"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 3438 [0;31m                    [0mmod[0m [0;34m=[0m [0mModule[0m[0;34m([0m[0;34m[[0m[0mnode[0m[0;34m][0m[0;34m,[0m [0;34m[[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   3439 [0;31m                [0;32melif[0m [0mmode[0m [0;34m==[0m [0;34m"single"[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


In [None]:
result

In [33]:
unet = models.unet_2d.UNet2DModel(in_channels= 4, out_channels=4)

In [36]:
unet.eval()

UNet2DModel(
  (conv_in): Conv2d(4, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=224, out_features=896, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=896, out_features=896, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 224, eps=1e-05, affine=True)
          (conv1): Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=896, out_features=224, bias=True)
          (norm2): GroupNorm(32, 224, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(22

In [41]:
x.unsqueeze(0).shape

torch.Size([1, 4, 512, 512])

In [39]:
result = unet(x.unsqueeze(0), timestep=1)

KeyboardInterrupt: 