In [None]:
!wget https://hf.co/danjacobellis/walloc/resolve/main/RGB_Li_48c_J3_nf8_v1.0.2.pth

In [None]:
!wget https://hf.co/danjacobellis/walloc/resolve/main/RGB_Li_12c_J3_nf8_v1.0.2.pth

In [None]:
!wget https://huggingface.co/danjacobellis/LCCL/resolve/main/colorize_walloc_4x_256_p16.pth

In [None]:
!wget https://huggingface.co/danjacobellis/LCCL/resolve/main/colorize_walloc_16x_512_p32.pth

In [1]:
import torch
import torch.nn as nn
import numpy as np
import warnings
import IPython.display
import io
import time
from PIL import Image
from einops import rearrange
from datasets import load_dataset
from datasets import Image as HFImage
from torchvision.transforms import (
    RandomResizedCrop, Resize, Grayscale,
    PILToTensor, ToPILImage, 
    Compose, RandomHorizontalFlip )
from max_vit_with_register_tokens import MaxViT
from fastprogress.fastprogress import master_bar, progress_bar
from piq import LPIPS, DISTS, psnr, multi_scale_ssim
from walloc import walloc
class Config: pass

In [2]:
device = "cuda"

checkpoint = torch.load("RGB_Li_12c_J3_nf8_v1.0.2.pth",map_location="cpu",weights_only=False)
codec_config = checkpoint['config']
codec = walloc.Codec2D(
    channels = codec_config.channels,
    J = codec_config.J,
    Ne = codec_config.Ne,
    Nd = codec_config.Nd,
    latent_dim = codec_config.latent_dim,
    latent_bits = codec_config.latent_bits,
    lightweight_encode = codec_config.lightweight_encode
)
codec.load_state_dict(checkpoint['model_state_dict'])
codec = codec.to(device)
codec.eval();

checkpoint = torch.load("colorize_walloc_16x_512_p32.pth",map_location="cpu",weights_only=False)
config = checkpoint['config']
model = MaxViT(
    channels = codec_config.latent_dim,
    patch_size = config.patch_size//(2**codec_config.J),
    num_classes = config.num_classes,
    dim = config.embed_dim,
    depth = config.depth,
    downsample = config.downsample,
    # heads = config.heads, # calculated as dim//dim_head  
    # mlp_dim = config.mlp_dim, # calculated as 4*dim
    dim_head = config.dim_head,
    dim_conv_stem = config.dim_conv_stem,
    window_size = config.window_size,
    mbconv_expansion_rate = config.mbconv_expansion_rate,
    mbconv_shrinkage_rate = config.mbconv_shrinkage_rate,
    dropout = config.dropout,
    num_register_tokens = config.num_register_tokens,
    dense_prediction=True
).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval();

In [3]:
lpips_loss = LPIPS().to("cuda")
dists_loss = DISTS().to("cuda")
gpu_mem_baseline = torch.cuda.memory_reserved(0)/1e6



In [4]:
dataset_valid = load_dataset('danjacobellis/LSDIR_val',split='validation',trust_remote_code=True)
valid_transform = Compose([
    Resize(
        size=(config.image_size,config.image_size),
        interpolation=Image.Resampling.LANCZOS
    ),
    PILToTensor(),
])

In [5]:
def colorize_gpu(sample):
    with torch.no_grad():
        img = sample['image'].convert("RGB")
        y = PILToTensor()(img)
        x = Grayscale(num_output_channels=3)(valid_transform(img))
        x = x.to(torch.float)
        y = y.to(torch.float)
        x = (x / 255) - 0.5
        y = y / 255
        x = x.to(device)
        y = y.to(device)
        x = x.unsqueeze(0)
        y = y.unsqueeze(0)
        
        t0 = time.time()
        x = codec.wavelet_analysis(x,codec.J)
        x = codec.encoder[:2](x)
        encoding_time = time.time() - t0

        t0 = time.time()
        pred = model(x)
        model_time = time.time() - t0

        t0 = time.time()
        pred = codec.decoder(pred)
        pred = codec.wavelet_synthesis(pred, codec.J)
        pred = codec.clamp(pred) + 0.5
        decoding_time = time.time()-t0

        t0 = time.time()
        pred = Resize(size=(img.height, img.width))(pred)
        pred = pred.clamp(0,1)
        post_time = time.time() - t0        

        colorized = ToPILImage()(pred[0])
        buff = io.BytesIO()
        colorized.save(buff, format='WEBP', lossless=True)
        colorized_bytes = buff.getbuffer()
        
        PSNR = psnr(pred,y)
        MSSIM = multi_scale_ssim(pred,y)
        LPIPS_dB = -10*np.log10(lpips_loss(pred, y).item())
        DISTS_dB = -10*np.log10(dists_loss(pred, y).item())

        gpu_mem = gpu_mem = torch.cuda.memory_reserved(0)/1e6 - gpu_mem_baseline
        return {
            'colorized': colorized_bytes,
            'encoding_time': encoding_time,
            'time': model_time + post_time,
            'decoding_time': decoding_time,
            'gpu_mem': gpu_mem,
            'PSNR': PSNR,
            'MSSIM': MSSIM,
            'LPIPS_dB': LPIPS_dB,
            'DISTS_dB': DISTS_dB,
        }

In [6]:
gpu = dataset_valid.map(colorize_gpu)
gpu = gpu.cast_column('colorized',HFImage())

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

In [7]:
def colorize_cpu(sample):
    with torch.no_grad():
        img = sample['image'].convert("RGB")
        y = PILToTensor()(img)
        x = Grayscale(num_output_channels=3)(valid_transform(img))
        x = x.to(torch.float)
        y = y.to(torch.float)
        x = (x / 255) - 0.5
        y = y / 255
        x = x.to(device)
        y = y.to(device)
        x = x.unsqueeze(0)
        y = y.unsqueeze(0)
        
        t0 = time.time()
        x = codec.wavelet_analysis(x,codec.J)
        x = codec.encoder[:2](x)
        encoding_time = time.time() - t0

        t0 = time.time()
        pred = model(x)
        model_time = time.time() - t0

        t0 = time.time()
        pred = codec.decoder(pred)
        pred = codec.wavelet_synthesis(pred, codec.J)
        pred = codec.clamp(pred) + 0.5
        decoding_time = time.time()-t0

        t0 = time.time()
        pred = Resize(size=(img.height, img.width))(pred)
        pred = pred.clamp(0,1)
        post_time = time.time() - t0
        
        return {
            'encoding_time': encoding_time,
            'time': model_time + post_time,
            'decoding_time': decoding_time,
        }

In [8]:
device = "cpu"
model = model.to(device)
codec = codec.to(device)
cpu = dataset_valid.map(colorize_cpu)

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

In [9]:
combined = gpu.add_column('cpu_time',cpu['time'])
combined = combined.add_column('cpu_encoding_time',cpu['encoding_time'])
combined = combined.add_column('cpu_decoding_time',cpu['decoding_time'])

In [11]:
combined.push_to_hub("danjacobellis/LSDIR_colorize_walloc_x16_512_p32", split='validation')

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Map:   0%|          | 0/125 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/danjacobellis/LSDIR_colorize_walloc_x16_512_p32/commit/2c61ca71b40868d5364b2c72bdc89ea6100d1c8e', commit_message='Upload dataset', commit_description='', oid='2c61ca71b40868d5364b2c72bdc89ea6100d1c8e', pr_url=None, pr_revision=None, pr_num=None)