In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import vgg16
try:
    from torchvision.models import VGG16_Weights
    vgg_weights = VGG16_Weights.IMAGENET1K_V1
except ImportError:
    vgg_weights = None

import numpy as np
from PIL import Image
import os
from skimage.color import rgb2lab, lab2rgb
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import lpips
from colormath.color_objects import LabColor, sRGBColor
from colormath.color_conversions import convert_color
from colormath.color_diff import delta_e_cie2000
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore', category=RuntimeWarning)

if not hasattr(np, 'asscalar'):
    np.asscalar = lambda x: x.item()

In [None]:
# Configuration
MODEL_ID = 'C_lambda1_0.5x'
USE_GAN = True
USE_L1 = True
USE_PERCEPTUAL = True
LAMBDA_L1 = 50.0  # 0.5× baseline
LAMBDA_PERCEPTUAL = 10.0

BATCH_SIZE = 64
NUM_EPOCHS = 50
LEARNING_RATE = 2e-4
IMAGE_SIZE = 256
NUM_WORKERS = 4

DATA_DIR = '../data/colorize_dataset/data'
MODEL_DIR = f'../models/{MODEL_ID}'
RESULTS_DIR = '../results'

os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f'Using device: {device}')
print(f'Lambda L1: {LAMBDA_L1} (0.5× baseline)')

In [None]:
class ColorizeDataset(Dataset):
    def __init__(self, color_dir, size=256, split='train'):
        self.color_dir = color_dir
        self.size = size
        self.split = split
        self.images = [f for f in os.listdir(color_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
        
        if split == 'train':
            self.transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                transforms.ToTensor()
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((size, size)),
                transforms.ToTensor()
            ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.color_dir, self.images[idx])
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        
        img_np = img.permute(1, 2, 0).numpy()
        lab = rgb2lab(img_np)
        
        L = lab[:, :, 0] / 50.0 - 1.0
        ab = lab[:, :, 1:] / 110.0
        
        L = torch.from_numpy(L).float().unsqueeze(0)
        ab = torch.from_numpy(ab).float().permute(2, 0, 1)
        
        return L, ab


In [None]:
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_dropout=False, use_batchnorm=True):
        super(UNetBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False) if down 
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity(),
            nn.Dropout(0.5) if use_dropout else nn.Identity(),
            nn.LeakyReLU(0.2) if down else nn.ReLU()
        )
    
    def forward(self, x):
        return self.conv(x)

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=2):
        super(UNetGenerator, self).__init__()
        
        self.down1 = UNetBlock(in_channels, 64, down=True, use_batchnorm=False)
        self.down2 = UNetBlock(64, 128, down=True)
        self.down3 = UNetBlock(128, 256, down=True)
        self.down4 = UNetBlock(256, 512, down=True)
        self.down5 = UNetBlock(512, 512, down=True)
        self.down6 = UNetBlock(512, 512, down=True)
        self.down7 = UNetBlock(512, 512, down=True)
        self.down8 = UNetBlock(512, 512, down=True, use_batchnorm=False)
        
        self.up1 = UNetBlock(512, 512, down=False, use_dropout=True)
        self.up2 = UNetBlock(1024, 512, down=False, use_dropout=True)
        self.up3 = UNetBlock(1024, 512, down=False, use_dropout=True)
        self.up4 = UNetBlock(1024, 512, down=False)
        self.up5 = UNetBlock(1024, 256, down=False)
        self.up6 = UNetBlock(512, 128, down=False)
        self.up7 = UNetBlock(256, 64, down=False)
        
        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
            nn.Tanh()
        )
    
    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        
        u1 = self.up1(d8)
        u2 = self.up2(torch.cat([u1, d7], 1))
        u3 = self.up3(torch.cat([u2, d6], 1))
        u4 = self.up4(torch.cat([u3, d5], 1))
        u5 = self.up5(torch.cat([u4, d4], 1))
        u6 = self.up6(torch.cat([u5, d3], 1))
        u7 = self.up7(torch.cat([u6, d2], 1))
        
        return self.final(torch.cat([u7, d1], 1))

In [None]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(PatchGANDiscriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )
    
    def forward(self, img_L, img_ab):
        img_input = torch.cat((img_L, img_ab), 1)
        return self.model(img_input)

In [None]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        if vgg_weights is not None:
            vgg = vgg16(weights=vgg_weights).features
        else:
            vgg = vgg16(pretrained=True).features
        self.features = nn.Sequential(*list(vgg)[:16]).eval()
        for param in self.features.parameters():
            param.requires_grad = False
        self.criterion = nn.L1Loss()
    
    def forward(self, pred, target):
        pred_features = self.features(pred)
        target_features = self.features(target)
        return self.criterion(pred_features, target_features)

class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True):
        super(GANLoss, self).__init__()
        self.loss = nn.MSELoss() if use_lsgan else nn.BCEWithLogitsLoss()
    
    def __call__(self, prediction, target_is_real):
        target = torch.ones_like(prediction) if target_is_real else torch.zeros_like(prediction)
        return self.loss(prediction, target)

In [None]:
generator = UNetGenerator().to(device)
discriminator = PatchGANDiscriminator().to(device)

criterion_L1 = nn.L1Loss().to(device)
perceptual_loss = PerceptualLoss().to(device)
gan_loss = GANLoss().to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

train_dataset = ColorizeDataset(os.path.join(DATA_DIR, 'train_color'), IMAGE_SIZE, split='train')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

val_dataset = ColorizeDataset(os.path.join(DATA_DIR, 'val_color'), IMAGE_SIZE, split='val')
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')
print(f'Generator parameters: {sum(p.numel() for p in generator.parameters()):,}')
print(f'Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}')


In [None]:
def validate_epoch(generator, val_loader, criterion_L1, perceptual_loss, device):
    generator.eval()
    total_loss = 0.0
    
    with torch.no_grad():
        for L, ab in val_loader:
            L, ab = L.to(device), ab.to(device)
            fake_ab = generator(L)
            
            loss_l1 = criterion_L1(fake_ab, ab)
            
            fake_rgb = torch.cat([L.repeat(1, 3, 1, 1), fake_ab.repeat(1, 1, 1, 1)[:, :1, :, :]], dim=1)
            real_rgb = torch.cat([L.repeat(1, 3, 1, 1), ab.repeat(1, 1, 1, 1)[:, :1, :, :]], dim=1)
            loss_perceptual = perceptual_loss(fake_rgb, real_rgb)
            
            loss = loss_l1 * LAMBDA_L1 + loss_perceptual * LAMBDA_PERCEPTUAL
            total_loss += loss.item()
    
    return total_loss / len(val_loader)

def visualize_validation_samples(generator, val_loader, device, num_samples=8):
    generator.eval()
    with torch.no_grad():
        L, ab = next(iter(val_loader))
        L, ab = L[:num_samples].to(device), ab[:num_samples].to(device)
        fake_ab = generator(L)
        
        def lab_to_rgb(L_tensor, ab_tensor):
            L_np = (L_tensor + 1.0) * 50.0
            ab_np = ab_tensor * 110.0
            Lab = torch.cat([L_np, ab_np], dim=1).permute(0, 2, 3, 1).cpu().numpy()
            rgb_imgs = []
            for lab_img in Lab:
                rgb_img = lab2rgb(lab_img)
                rgb_imgs.append(rgb_img)
            return np.array(rgb_imgs)
        
        pred_rgb = lab_to_rgb(L, fake_ab)
        target_rgb = lab_to_rgb(L, ab)
        grayscale = L.cpu().numpy()
    
    fig, axes = plt.subplots(3, num_samples, figsize=(20, 8))
    for i in range(num_samples):
        axes[0, i].imshow(grayscale[i, 0], cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Grayscale', fontsize=10)
        
        axes[1, i].imshow(np.clip(pred_rgb[i], 0, 1))
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Colorized', fontsize=10)
        
        axes[2, i].imshow(np.clip(target_rgb[i], 0, 1))
        axes[2, i].axis('off')
        if i == 0:
            axes[2, i].set_title('Ground Truth', fontsize=10)
    
    plt.tight_layout()
    plt.show()


In [None]:
history = {'epoch': [], 'train_loss': [], 'val_loss': [], 'gen_loss': [], 'disc_loss': []}
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    generator.train()
    discriminator.train()
    epoch_g_loss = 0.0
    epoch_l1_loss = 0.0
    epoch_perceptual_loss = 0.0
    epoch_gan_loss = 0.0
    epoch_d_loss = 0.0
    
    for i, (L, ab) in enumerate(train_loader):
        L, ab = L.to(device), ab.to(device)
        
        optimizer_D.zero_grad()
        
        fake_ab = generator(L)
        fake_pred = discriminator(L, fake_ab.detach())
        real_pred = discriminator(L, ab)
        
        d_loss_fake = gan_loss(fake_pred, False)
        d_loss_real = gan_loss(real_pred, True)
        d_loss = (d_loss_fake + d_loss_real) * 0.5
        
        d_loss.backward()
        optimizer_D.step()
        
        optimizer_G.zero_grad()
        
        fake_ab = generator(L)
        fake_pred = discriminator(L, fake_ab)
        
        loss_l1 = criterion_L1(fake_ab, ab)
        loss_gan_g = gan_loss(fake_pred, True)
        
        fake_rgb = torch.cat([L.repeat(1, 3, 1, 1), fake_ab.repeat(1, 1, 1, 1)[:, :1, :, :]], dim=1)
        real_rgb = torch.cat([L.repeat(1, 3, 1, 1), ab.repeat(1, 1, 1, 1)[:, :1, :, :]], dim=1)
        loss_perceptual = perceptual_loss(fake_rgb, real_rgb)
        
        g_loss = loss_gan_g + LAMBDA_L1 * loss_l1 + LAMBDA_PERCEPTUAL * loss_perceptual
        
        g_loss.backward()
        optimizer_G.step()
        
        epoch_g_loss += g_loss.item()
        epoch_l1_loss += loss_l1.item()
        epoch_perceptual_loss += loss_perceptual.item()
        epoch_gan_loss += loss_gan_g.item()
        epoch_d_loss += d_loss.item()
    
    avg_train_loss = epoch_g_loss / len(train_loader)
    avg_d_loss = epoch_d_loss / len(train_loader)
    
    val_loss = validate_epoch(generator, val_loader, criterion_L1, perceptual_loss, device)
    
    history['epoch'].append(epoch + 1)
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(val_loss)
    history['gen_loss'].append(avg_train_loss)
    history['disc_loss'].append(avg_d_loss)
    
    print(f'Epoch [{epoch+1}/{NUM_EPOCHS}] Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, D_Loss: {avg_d_loss:.4f}')
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            'loss': val_loss,
        }, os.path.join(MODEL_DIR, 'best_model.pt'))
        print(f'Saved best model at epoch {epoch+1} with val_loss: {val_loss:.4f}')
    
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            'loss': avg_train_loss,
        }, os.path.join(MODEL_DIR, f'checkpoint_epoch_{epoch+1}.pt'))
        visualize_validation_samples(generator, val_loader, device, epoch + 1, RESULTS_DIR)
    
    history_df = pd.DataFrame(history)
    history_df.to_csv(os.path.join(RESULTS_DIR, f'training_history_{MODEL_ID}.csv'), index=False)

print('Training completed!')


In [None]:
class ColorimetricEvaluator:
    def __init__(self, device):
        self.device = device
        self.lpips_fn = lpips.LPIPS(net='alex').to(device)
    
    def lab_to_rgb(self, L, ab):
        L = (L + 1.0) * 50.0
        ab = ab * 110.0
        Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
        rgb_imgs = []
        for lab_img in Lab:
            rgb_img = lab2rgb(lab_img)
            rgb_imgs.append(rgb_img)
        return np.array(rgb_imgs)
    
    def calculate_psnr(self, pred, target):
        pred = np.clip(pred, 0, 1)
        target = np.clip(target, 0, 1)
        psnr_values = [psnr(t, p, data_range=1.0) for p, t in zip(pred, target)]
        return np.mean(psnr_values)
    
    def calculate_ssim(self, pred, target):
        pred = np.clip(pred, 0, 1)
        target = np.clip(target, 0, 1)
        ssim_values = [ssim(t, p, data_range=1.0, channel_axis=2) for p, t in zip(pred, target)]
        return np.mean(ssim_values)
    
    def calculate_ciede2000(self, pred, target):
        pred = np.clip(pred * 255, 0, 255).astype(np.uint8)
        target = np.clip(target * 255, 0, 255).astype(np.uint8)
        
        delta_e_values = []
        for p, t in zip(pred, target):
            p_lab = rgb2lab(p / 255.0)
            t_lab = rgb2lab(t / 255.0)
            
            p_mean = np.mean(p_lab, axis=(0, 1))
            t_mean = np.mean(t_lab, axis=(0, 1))
            
            p_color = LabColor(p_mean[0], p_mean[1], p_mean[2])
            t_color = LabColor(t_mean[0], t_mean[1], t_mean[2])
            
            delta_e = delta_e_cie2000(p_color, t_color)
            delta_e_values.append(delta_e)
        
        return np.mean(delta_e_values)
    
    def calculate_lpips(self, pred, target):
        pred_tensor = torch.from_numpy(pred).permute(0, 3, 1, 2).float().to(self.device) * 2 - 1
        target_tensor = torch.from_numpy(target).permute(0, 3, 1, 2).float().to(self.device) * 2 - 1
        
        with torch.no_grad():
            lpips_values = self.lpips_fn(pred_tensor, target_tensor)
        
        return lpips_values.mean().item()
    
    def evaluate(self, generator, dataloader):
        generator.eval()
        all_pred_rgb = []
        all_target_rgb = []
        
        with torch.no_grad():
            for L, ab in dataloader:
                L, ab = L.to(self.device), ab.to(self.device)
                fake_ab = generator(L)
                
                pred_rgb = self.lab_to_rgb(L, fake_ab)
                target_rgb = self.lab_to_rgb(L, ab)
                
                all_pred_rgb.append(pred_rgb)
                all_target_rgb.append(target_rgb)
        
        all_pred_rgb = np.concatenate(all_pred_rgb, axis=0)
        all_target_rgb = np.concatenate(all_target_rgb, axis=0)
        
        metrics = {
            'PSNR': self.calculate_psnr(all_pred_rgb, all_target_rgb),
            'SSIM': self.calculate_ssim(all_pred_rgb, all_target_rgb),
            'CIEDE2000': self.calculate_ciede2000(all_pred_rgb, all_target_rgb),
            'LPIPS': self.calculate_lpips(all_pred_rgb, all_target_rgb)
        }
        
        return metrics

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(history['epoch'], history['train_loss'], label='Train Loss', marker='o')
ax.plot(history['epoch'], history['val_loss'], label='Val Loss', marker='s')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training and Validation Loss')
ax.legend()
ax.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, f'train_val_loss_{MODEL_ID}.png'), dpi=300, bbox_inches='tight')
plt.show()


In [None]:
test_dataset = ColorizeDataset(os.path.join(DATA_DIR, 'test_color'), IMAGE_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

evaluator = ColorimetricEvaluator(device)

checkpoint = torch.load(os.path.join(MODEL_DIR, 'best_model.pt'), map_location=device)
generator.load_state_dict(checkpoint['generator_state_dict'])

print('Evaluating model...')
metrics = evaluator.evaluate(generator, test_loader)

print('\nEvaluation Results:')
for metric, value in metrics.items():
    print(f'{metric}: {value:.4f}')

In [None]:
metrics_df = pd.DataFrame([metrics])
metrics_df['model'] = MODEL_ID
metrics_df.to_csv(os.path.join(RESULTS_DIR, f'metrics_{MODEL_ID}.csv'), index=False)
print(f'Metrics saved to {RESULTS_DIR}/metrics_{MODEL_ID}.csv')

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

axes[0, 0].plot(history['epoch'], history['g_loss'])
axes[0, 0].set_title('Generator Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].grid(True)

axes[0, 1].plot(history['epoch'], history['d_loss'])
axes[0, 1].set_title('Discriminator Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True)

axes[0, 2].plot(history['epoch'], history['loss_gan'], label='GAN Loss')
axes[0, 2].set_title('GAN Loss')
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('Loss')
axes[0, 2].legend()
axes[0, 2].grid(True)

axes[1, 0].plot(history['epoch'], history['loss_l1'], label='L1 Loss')
axes[1, 0].set_title('L1 Loss (λ₁ = 50.0)')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].legend()
axes[1, 0].grid(True)

axes[1, 1].plot(history['epoch'], history['loss_perceptual'], label='Perceptual Loss')
axes[1, 1].set_title('Perceptual Loss')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].legend()
axes[1, 1].grid(True)

axes[1, 2].plot(history['epoch'], history['loss_l1'], label='L1', alpha=0.7)
axes[1, 2].plot(history['epoch'], history['loss_perceptual'], label='Perceptual', alpha=0.7)
axes[1, 2].plot(history['epoch'], history['loss_gan'], label='GAN', alpha=0.7)
axes[1, 2].set_title('All Loss Components')
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('Loss')
axes[1, 2].legend()
axes[1, 2].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, f'training_history_{MODEL_ID}.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f'Training history plot saved to {RESULTS_DIR}/training_history_{MODEL_ID}.png')

In [None]:
generator.eval()
with torch.no_grad():
    sample_L, sample_ab = next(iter(test_loader))
    sample_L, sample_ab = sample_L[:8].to(device), sample_ab[:8].to(device)
    fake_ab = generator(sample_L)
    
    evaluator_viz = ColorimetricEvaluator(device)
    pred_rgb = evaluator_viz.lab_to_rgb(sample_L, fake_ab)
    target_rgb = evaluator_viz.lab_to_rgb(sample_L, sample_ab)
    grayscale = sample_L.cpu().numpy()

fig, axes = plt.subplots(3, 8, figsize=(20, 8))
for i in range(8):
    axes[0, i].imshow(grayscale[i, 0], cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Grayscale', fontsize=10)
    
    axes[1, i].imshow(np.clip(pred_rgb[i], 0, 1))
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Colorized', fontsize=10)
    
    axes[2, i].imshow(np.clip(target_rgb[i], 0, 1))
    axes[2, i].axis('off')
    if i == 0:
        axes[2, i].set_title('Ground Truth', fontsize=10)

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, f'colorization_results_{MODEL_ID}.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f'Colorization results saved to {RESULTS_DIR}/colorization_results_{MODEL_ID}.png')