In [1]:
import torch
from colorizers import siggraph17 as Colorizer
import coremltools as ct

Torch version 2.7.0 has not been tested with coremltools. You may run into unexpected errors. Torch 2.5.0 is the most recent version that has been tested.


In [3]:
torch_model = Colorizer(pretrained=True).eval()  # loads weights automatically

example_input = torch.rand(1, 1, 256, 256)
traced_model   = torch.jit.trace(torch_model, example_input)

coreml_model = ct.convert(
    traced_model,
    minimum_deployment_target=ct.target.macOS15,
    inputs=[ct.TensorType(name="input", shape=example_input.shape)]
)
coreml_model.save("Colorizer.mlpackage")

Converting PyTorch Frontend ==> MIL Ops: 100%|█████████████████████████████████████████████████████▊| 364/365 [00:00<00:00, 9460.45 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 352.43 passes/s]
Running MIL default pipeline: 100%|██████████████████████████████████████████████████████████████████| 89/89 [00:01<00:00, 77.86 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████████████████████████████████████████████████████| 12/12 [00:00<00:00, 408.15 passes/s]


# Model Verification

In [12]:
import numpy as np
from PIL import Image
from skimage import color        # pip install scikit-image
import coremltools as ct

# Load Core ML package (adjust path if needed)
coreml_model = ct.models.MLModel("Colorizer.mlpackage")

# 1.  Prepare input
in_img = Image.open("image.png").convert("RGB")
in_rgb = np.asarray(in_img)
in_lab = color.rgb2lab(in_rgb, channel_axis=2)             # (H,W,3)

# Split L-a-b and convert each to NCHW float32
lab_components = np.split(in_lab, 3, axis=-1)              # list of 3 (H,W,1)

in_l, _, _ = [
    np.expand_dims(c.transpose((2, 0, 1)).astype(np.float32), 0)  # (1,1,H,W)
    for c in lab_components
]

# 2.  Inference
out_ab = coreml_model.predict({"input": in_l})["var_518"]   # (1,2,H,W)

# 3.  Re-assemble LAB and back to RGB
out_lab = np.squeeze(
    np.concatenate([in_l, out_ab], axis=1),                # (1,3,H,W)
    axis=0
).transpose((1, 2, 0))                                     # (H,W,3)

out_rgb = (color.lab2rgb(out_lab, channel_axis=2) * 255).astype(np.uint8)
out_img = Image.fromarray(out_rgb)
out_img.save("image_colorized.png")
out_img.show()