In [1]:
from anything_vae import (
    Encoder,
    Decoder,
)

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models import vgg16
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from torchvision import transforms, models as torchvision_models
from pytorch_lightning import LightningModule, Trainer, loggers, callbacks
# import pytorch_lightning as pl
from torchmetrics import MeanSquaredError
from PIL import Image

import torch
from torch.utils.data import DataLoader
import torch.optim as optim

from collections import deque
import heapq
from sklearn.cluster import KMeans
import re
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ColorizationDataset(Dataset):
    def __init__(self, data_folder, data_csv, transform=None, hint_offset=3):
        """
        Args:
            data_folder (string): Directory with all the images.
            data_csv (string): CSV file with image paths.
            transform (callable, optional): Optional transform to be applied on a sample.
            hint_offset (int): Number of images away to fetch the hint image.
        """
        self.data_folder = data_folder
        self.data_path = os.path.join(data_folder, data_csv)
        self.images = pd.read_csv(self.data_path)

        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  # Convert grayscale to RGB by replicating channels
            transforms.ToTensor()  # Convert images to PyTorch tensors
        ])

        self.tranform_output = transforms.Compose([transforms.ToTensor()])

        self.hint_offset = hint_offset

        # Extract show names from the file paths
        self.images['show'] = self.images['Sketch Path'].apply(
            lambda x: os.path.basename(os.path.dirname(x))
        )

        # Sort the DataFrame by show to group images from the same show
        self.images = self.images.sort_values(by=['show']).reset_index(drop=True)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        row = self.images.iloc[idx]
        sketch = row['Sketch Path']
        colored = row['Frame Path']
        show = row['show']

        # Get indices of all images from the same show
        show_indices = self.images.index[self.images['show'] == show].tolist()

        # Find the position of the current image within the show's indices
        pos_in_show = show_indices.index(idx)

        # Calculate hint index with offset and ensure it's within bounds
        hint_pos_in_show = pos_in_show + self.hint_offset
        hint_pos_in_show = max(0, min(hint_pos_in_show, len(show_indices) - 1))

        # Get the actual index of the hint image in self.images
        hint_idx = show_indices[hint_pos_in_show]

        hint_row = self.images.iloc[hint_idx]
        hint_sketch = hint_row['Sketch Path']
        hint_colored = hint_row['Frame Path']

        # Load images
        sketch_image = self.transform(self.__loadImage(sketch))
        colored_image = self.tranform_output(self.__loadImage(colored))
        hint_image = self.tranform_output(self.__loadImage(hint_colored))

        return sketch_image, colored_image, hint_image

    def viewImage(self, idx):
        sketch, colored = self.images.iloc[idx][['Sketch Path', 'Frame Path']]
        return self.__loadImage(sketch), self.__loadImage(colored)

    def __loadImage(self, image_path):
        return Image.open(os.path.join(self.data_folder, image_path))


In [3]:
class VGGPerceptualLoss(LightningModule):
    def __init__(self, vgg_model):
        super().__init__()
        self.vgg = vgg_model
        self.criterion = nn.MSELoss()
        self.features = list(self.vgg.features[:16])
        self.features = nn.Sequential(*self.features).eval()
        
        for params in self.features.parameters():
            params.requires_grad = False

    def forward(self, x, y):
        return self.criterion(self.features(x),self.features(y))

In [4]:
class ColorHintEmbedding(nn.Module):
    def __init__(self, n_colors=50, color_dim=3, embed_dim=256):
        super().__init__()
        self.color_embedding = nn.Linear(color_dim, embed_dim)
        self.position_embedding = nn.Embedding(n_colors, embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, colors, positions):
        # colors: [batch_size, n_colors, 3]
        # positions: [batch_size, n_colors, 2] (x,y coordinates)
        b, n, _ = colors.shape
        color_embed = self.color_embedding(colors)  # [b, n, embed_dim]
        pos_embed = self.position_embedding(positions)  # [b, n, embed_dim]
        return self.norm(color_embed + pos_embed)

class ReferenceImageEncoder(nn.Module):
    def __init__(self, transformer_dim=256):
        super().__init__()
        # Reduce initial channels and add more aggressive pooling
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),  # Stride=2 reduces spatial dim
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # Further reduction
            nn.ReLU(),
            nn.Conv2d(64, transformer_dim, 3, stride=2, padding=1),  # Final reduction
            nn.ReLU()
        )
        self.norm = nn.GroupNorm(8, transformer_dim)
        
        # Add adaptive pooling to ensure consistent output size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((16, 16))  # Fixed output size
        
    def forward(self, ref_image):
        features = self.conv_blocks(ref_image)
        features = self.norm(features)
        features = self.adaptive_pool(features)  # Ensure consistent spatial dims
        return features

class SelfAttention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        
        # Reduce internal dimension
        self.dim_head = 32  # Reduced from dim // heads
        self.hidden_dim = self.dim_head * heads
        
        # Use smaller internal dimensions for QKV projections
        self.to_q = nn.Conv2d(dim, self.hidden_dim, 1, bias=False)
        self.to_k = nn.Conv2d(dim, self.hidden_dim, 1, bias=False)
        self.to_v = nn.Conv2d(dim, self.hidden_dim, 1, bias=False)
        
        self.to_out = nn.Sequential(
            nn.Conv2d(self.hidden_dim, dim, 1),
            nn.Dropout(0.1)  # Add dropout for regularization
        )

    def forward(self, x, context=None):
        b, c, h, w = x.shape
        
        if context is not None:
            # Resize context if needed
            if context.shape[-2:] != (h, w):
                context = F.interpolate(context, size=(h, w), mode='bilinear', align_corners=False)
            
            q = self.to_q(x)
            k = self.to_k(context)
            v = self.to_v(context)
        else:
            q = self.to_q(x)
            k = self.to_k(x)
            v = self.to_v(x)
            
        # Reshape with reduced dimension
        q = q.reshape(b, self.heads, self.dim_head, -1)
        k = k.reshape(b, self.heads, self.dim_head, -1)
        v = v.reshape(b, self.heads, self.dim_head, -1)
        
        # Efficient attention computation
        dots = torch.matmul(q.transpose(-2, -1), k) * self.scale
        
        # Use memory efficient softmax
        dots = dots - dots.max(dim=-1, keepdim=True)[0]  # Numerical stability
        attn = dots.softmax(dim=-1)
        
        # Optional memory optimization: use chunked matrix multiplication
        out = torch.matmul(attn, v.transpose(-2, -1))
        out = out.reshape(b, -1, h, w)
        
        return self.to_out(out)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.attention = SelfAttention(dim, heads)
        self.color_cross_attn = SelfAttention(dim, heads)
        self.ref_cross_attn = SelfAttention(dim, heads)
        
        self.norm1 = nn.GroupNorm(8, dim)
        self.norm2 = nn.GroupNorm(8, dim)
        self.norm_color = nn.GroupNorm(8, dim)
        self.norm_ref = nn.GroupNorm(8, dim)
        
        self.ffn = nn.Sequential(
            nn.Conv2d(dim, dim * 4, 1),
            nn.GELU(),
            nn.Conv2d(dim * 4, dim, 1)
        )

    def forward(self, x, color_hints=None, ref_features=None):
        # Self-attention
        x = x + self.attention(self.norm1(x))
        
        # Cross-attention with color hints if provided
        if color_hints is not None:
            b, c, h, w = x.shape
            # Reshape color hints to spatial dimension
            color_hints = color_hints.view(b, c, -1).permute(0, 2, 1)
            color_hints = color_hints.view(b, c, h, w)
            x = x + self.color_cross_attn(self.norm_color(x), color_hints)
            
        # Cross-attention with reference image features if provided
        if ref_features is not None:
            x = x + self.ref_cross_attn(self.norm_ref(x), ref_features)
        
        # FFN
        x = x + self.ffn(self.norm2(x))
        return x


In [5]:
class Colorizer(LightningModule):
    def __init__(self, checkpoint_path=None, transformer_dim=256, transformer_heads=8):
        super(Colorizer, self).__init__()
        
        if checkpoint_path is not None:
            checkpoint = torch.load(checkpoint_path, map_location='cpu')
            
            self.encoder = Encoder()
            self.decoder = Decoder()
            self.quant_conv = nn.Conv2d(8, 8, kernel_size=1)
            self.post_quant_conv = nn.Conv2d(4, 4, kernel_size=1)
            
            self.encoder.load_state_dict(
                {k.replace('encoder.', ''): v for k, v in checkpoint['state_dict'].items() if k.startswith('encoder.')}
            )
            self.decoder.load_state_dict(
                {k.replace('decoder.', ''): v for k, v in checkpoint['state_dict'].items() if k.startswith('decoder.')}
            )
            self.quant_conv.load_state_dict(
                {k.replace('quant_conv.', ''): v for k, v in checkpoint['state_dict'].items() if k.startswith('quant_conv.')}
            )
            self.post_quant_conv.load_state_dict(
                {k.replace('post_quant_conv.', ''): v for k, v in checkpoint['state_dict'].items() if k.startswith('post_quant_conv.')}
            )
            
            vgg_model = vgg16(weights=True)
            self.loss_fn = VGGPerceptualLoss(vgg_model)
            self.mse_loss_fn = nn.MSELoss()
            
            for param in self.encoder.parameters():
                param.requires_grad = False
            for param in self.decoder.parameters():
                param.requires_grad = False
            for param in self.quant_conv.parameters():
                param.requires_grad = False
            for param in self.post_quant_conv.parameters():
                param.requires_grad = False
                
            print("Loaded pretrained weights from checkpoint")
        else:
            self.encoder = Encoder()
            self.decoder = Decoder()
            self.quant_conv = nn.Conv2d(8, 8, kernel_size=1)
            self.post_quant_conv = nn.Conv2d(4, 4, kernel_size=1)
            vgg_model = vgg16(weights=True)
            self.loss_fn = VGGPerceptualLoss(vgg_model)
            self.mse_loss_fn = nn.MSELoss()
            print("Initialized new model from scratch")
        
        # Initialize transformer and hint processing components
        self.to_transformer_dim = nn.Conv2d(4, transformer_dim, 1)
        self.transformer = TransformerBlock(transformer_dim, transformer_heads)
        self.from_transformer_dim = nn.Conv2d(transformer_dim, 4, 1)
        
        # Add color hint and reference image processing
        self.color_hint_processor = ColorHintEmbedding(
            n_colors=50,
            color_dim=3,
            embed_dim=transformer_dim
        )
        self.ref_image_encoder = ReferenceImageEncoder(transformer_dim)
        
        # Training monitoring
        self.num_high_loss_images = 50
        self.high_loss_heap = []
        self.current_min_high_loss = 0
        
        self.hparams.learning_rate = 0.0001

    
    def _freeze_autoencoder(self):
        """Freeze the autoencoder components."""
        components_to_freeze = [
            self.encoder,
            self.decoder,
            self.quant_conv,
            self.post_quant_conv
        ]
        
        for component in components_to_freeze:
            for param in component.parameters():
                param.requires_grad = False
    
    def configure_optimizers(self):
        # Only include trainable parameters
        trainable_params = [p for p in self.parameters() if p.requires_grad]
        return torch.optim.Adam(trainable_params, lr=self.hparams.learning_rate)
    
    @classmethod
    def load_from_checkpoint(
        cls,
        checkpoint_path,
        map_location=None,
        strict=True,
        **kwargs
    ):
        """
        Custom load_from_checkpoint to handle freezing after loading.
        """
        # Load the checkpoint using parent class method
        model = super().load_from_checkpoint(
            checkpoint_path,
            map_location=map_location,
            strict=strict,
            **kwargs
        )
        
        # Reapply freezing if specified
        if model.hparams.freeze_autoencoder:
            model._freeze_autoencoder()
        
        return model

    def encode(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        mean, logvar = torch.chunk(h, 2, dim=1)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mean + std * eps
        return z

    def decode(self, z, color_hints=None, ref_image=None):
        z = self.post_quant_conv(z)
        z = self.to_transformer_dim(z)
        
        # Process hints if provided
        color_features = None
        ref_features = None
        
        if color_hints is not None:
            colors, positions = color_hints
            color_features = self.color_hint_processor(colors, positions)
            
        if ref_image is not None:
            ref_features = self.ref_image_encoder(ref_image)

        z = self.transformer(z, color_features, ref_features)
        
        z = self.from_transformer_dim(z)
        x_recon = self.decoder(z)
        return x_recon

    def forward(self, x, color_hints=None, ref_image=None):
        z = self.encode(x)
        x_recon = self.decode(z, color_hints, ref_image)
        return x_recon

    def training_step(self, batch, batch_idx):
        # Unpack batch - now supports hints and reference images
        if len(batch) == 2:
            inputs, targets = batch
            color_hints = None
            ref_image = None
        elif len(batch) == 3:
            inputs, targets, ref_image = batch
            color_hints = None
        else:
            inputs, targets, color_hints, ref_image = batch
        
        outputs = self(inputs, color_hints, ref_image)
        
        perceptual_loss = self.loss_fn(outputs, targets)
        mse_loss = self.mse_loss_fn(outputs, targets)
        total_loss = perceptual_loss + mse_loss
        
        # Store high loss images with reference image
        self.store_high_loss_image(total_loss, inputs, targets, outputs, ref_image)
        
        # Logging
        self.log('train_loss', total_loss)
        self.log('perceptual_loss', perceptual_loss)
        self.log('mse_loss', mse_loss)
        
        # Visualization logic
        if (batch_idx + 1) % 100 == 0:
            self.visualize_high_loss_images(self.logger, self.global_step)
        
        if batch_idx % 500 == 0:
            num_images = min(4, inputs.shape[0])
            for i in range(num_images):
                grid = self.visualize_single_output(
                    inputs[i],
                    outputs[i],
                    targets[i],
                    ref_image[i] if ref_image is not None else None
                )
                self.logger.experiment.add_image(
                    f'Sample_Images/sample_{i+1}',
                    grid,
                    self.global_step
                )
        
        return total_loss

    def visualize_high_loss_images(self, logger, step):
        """Visualize stored high loss images with reference images if available"""
        if not self.high_loss_heap:
            return
            
        # Sort by loss in descending order
        sorted_entries = sorted(self.high_loss_heap, key=lambda x: x[0], reverse=True)
        
        # Log each high-loss image separately
        for idx, (loss_value, data) in enumerate(sorted_entries):
            grid = self.visualize_single_output(
                data['inputs'],
                data['outputs'],
                data['targets'],
                data['ref_image']
            )
            
            logger.experiment.add_image(
                f'High_Loss_Images/image',
                grid,
                step
            )


    def store_high_loss_image(self, loss, inputs, targets, outputs, ref_image=None):
        """Store high loss images with reference image if available"""
        # Convert to CPU and detach from computation graph
        cpu_data = {
            'loss': loss.item(),
            'inputs': inputs.detach().cpu(),
            'targets': targets.detach().cpu(),
            'outputs': outputs.detach().cpu(),
            'ref_image': ref_image.detach().cpu() if ref_image is not None else None
        }
        
        if len(self.high_loss_heap) < self.num_high_loss_images:
            heapq.heappush(self.high_loss_heap, (loss.item(), cpu_data))
            self.current_min_high_loss = min(loss.item(), self.current_min_high_loss if self.high_loss_heap else float('inf'))
        elif loss.item() > self.current_min_high_loss:
            heapq.heapreplace(self.high_loss_heap, (loss.item(), cpu_data))
            self.current_min_high_loss = self.high_loss_heap[0][0]


    def visualize_single_output(self, input_img, output_img, target_img, ref_image=None):
        """Helper function to create a grid with reference image if available"""
        # Ensure we're working with batched images
        if input_img.dim() == 3:
            input_img = input_img.unsqueeze(0)
            output_img = output_img.unsqueeze(0)
            target_img = target_img.unsqueeze(0)
            if ref_image is not None:
                ref_image = ref_image.unsqueeze(0)
        
        # Create row with input, output, target, and reference image if available
        images = [input_img, output_img, target_img]
        if ref_image is not None:
            images.append(ref_image)
            
        # Concatenate all images
        row = torch.cat(images, dim=0)
        
        # Handle grayscale images
        if row.shape[1] == 1:
            row = row.repeat(1, 3, 1, 1)
            
        # Create grid with all images side by side
        nrow = 4 if ref_image is not None else 3
        grid = torchvision.utils.make_grid(row, nrow=nrow, normalize=True, padding=2)
        return grid



In [6]:
chkpt_file = 'checkpoints/version_16.ckpt'
model = Colorizer(chkpt_file)



Loaded pretrained weights from checkpoint


In [7]:
# data_folder = 'data/toy'
data_folder = 'data/training'
data_csv = 'data.csv'
training_dataset = ColorizationDataset(data_folder, data_csv)
dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True, num_workers=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
logger = loggers.TensorBoardLogger("tb_logs", name='image-hint-frozen-vae-transformer')
trainer = Trainer(accelerator="gpu", devices=1, max_epochs=20, logger=logger, log_every_n_steps=2)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                 | Type                  | Params
----------------------------------------------------------------
0  | encoder              | Encoder               | 34.2 M
1  | decoder              | Decoder               | 49.5 M
2  | quant_conv           | Conv2d                | 72    
3  | post_quant_conv      | Conv2d                | 20    
4  | loss_fn              | VGGPerceptualLoss     | 138 M 
5  | mse_loss_fn          | MSELoss               | 0     
6  | to_transformer_dim   | Conv2d                | 1.3 K 
7  | transformer          | TransformerBlock      | 1.3 M 
8  | from_transformer_dim | Conv2d                | 1.0 K 
9  | color_hint_processor | ColorHintEmbedding    | 14.3 K
10 | ref_image_encoder    | ReferenceImageEncoder | 167 K 
----------------------------------------------------------------
138 M     Trainable params
85.4 M    Non-trainable params
223 M     Total params
894.042   Total estimated model params

Epoch 0:  64%|███████████████████████▌             | 82499/129629 [16:06:09<9:11:57,  1.42it/s, v_num=1]

In [None]:
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image

def viewTensor(output):
    image = to_pil_image(output.squeeze())

    # Display the image
    plt.imshow(image)
    plt.axis('off')  # Turn off axis numbers and ticks
    plt.show()

In [None]:
model.eval()
data_folder = 'data/test'
data_csv = 'data.csv'
test_dataset = ColorizationDataset(data_folder, data_csv)
model.cpu()

In [None]:
idx = 10
x, y = test_dataset[idx]
output = model(x.unsqueeze(0))

In [None]:
viewTensor(x)

In [None]:
viewTensor(output[0])

In [None]:
viewTensor(y)