In [None]:
!wget https://huggingface.co/danjacobellis/LCCL/resolve/main/colorize_pixels_128_p8.pth

In [1]:
import torch
import torch.nn as nn
import numpy as np
import warnings
import IPython.display
import io
import time
from PIL import Image
from einops import rearrange
from datasets import load_dataset
from datasets import Image as HFImage
from torchvision.transforms import (
    RandomResizedCrop, Resize, Grayscale,
    PILToTensor, ToPILImage, 
    Compose, RandomHorizontalFlip )
from max_vit_with_register_tokens import MaxViT
from fastprogress.fastprogress import master_bar, progress_bar
from piq import LPIPS, DISTS, psnr, multi_scale_ssim
class Config: pass

In [2]:
device = "cuda"
checkpoint = torch.load("colorize_pixels_128_p8.pth",map_location="cpu",weights_only=False)
config = checkpoint['config']
model = MaxViT(
    channels = config.channels,
    patch_size = config.patch_size,
    num_classes = config.num_classes,
    dim = config.embed_dim,
    depth = config.depth,
    downsample = config.downsample,
    # heads = config.heads, # calculated as dim//dim_head  
    # mlp_dim = config.mlp_dim, # calculated as 4*dim
    dim_head = config.dim_head,
    dim_conv_stem = config.dim_conv_stem,
    window_size = config.window_size,
    mbconv_expansion_rate = config.mbconv_expansion_rate,
    mbconv_shrinkage_rate = config.mbconv_shrinkage_rate,
    dropout = config.dropout,
    num_register_tokens = config.num_register_tokens,
    dense_prediction=True
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval();

In [3]:
lpips_loss = LPIPS().to("cuda")
dists_loss = DISTS().to("cuda")



In [4]:
dataset_valid = load_dataset('danjacobellis/LSDIR_val',split='validation',trust_remote_code=True)
valid_transform = Compose([
    Resize(
        size=(config.image_size,config.image_size),
        interpolation=Image.Resampling.LANCZOS
    ),
    PILToTensor(),
])

In [5]:
def colorize_eval(sample):
    with torch.no_grad():
        img = sample['image'].convert("RGB")
        y = PILToTensor()(img)
        x = Grayscale(num_output_channels=3)(valid_transform(img))
        x = x.to(torch.float)
        y = y.to(torch.float)
        x = x / 255
        y = y / 255
        x = x.to(device)
        y = y.to(device)
        x = x.unsqueeze(0)
        y = y.unsqueeze(0)
        
        t0 = time.time()
        pred = model(x)
        pred = Resize(size=(img.height, img.width))(pred)
        pred = pred.clamp(0,1)
        elapsed_time = time.time() - t0

        colorized = ToPILImage()(pred[0])
        buff = io.BytesIO()
        colorized.save(buff, format='WEBP', lossless=True)
        colorized_bytes = buff.getbuffer()
        
        PSNR = psnr(pred,y)
        MSSIM = multi_scale_ssim(pred,y)
        LPIPS_dB = -10*np.log10(lpips_loss(pred, y).item())
        DISTS_dB = -10*np.log10(dists_loss(pred, y).item())
        return {
            'colorized': colorized_bytes,
            'time': elapsed_time,
            'PSNR': PSNR,
            'MSSIM': MSSIM,
            'LPIPS_dB': LPIPS_dB,
            'DISTS_dB': DISTS_dB,
        }

In [6]:
colorized = dataset_valid.map(colorize_eval)
colorized = colorized.cast_column('colorized',HFImage())

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

In [7]:
colorized.push_to_hub("danjacobellis/LSDIR_colorize_pixels_128_p8", split='validation')

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/danjacobellis/LSDIR_colorize_pixels_128_p8/commit/c724befe705eb4f478f259e71a5989177be28d01', commit_message='Upload dataset', commit_description='', oid='c724befe705eb4f478f259e71a5989177be28d01', pr_url=None, pr_revision=None, pr_num=None)