In [None]:
!pip install tritonclient[all]

## 1: Imports

In [None]:
import os
import time
from pathlib import Path
import numpy as np
import cv2
import tritonclient.http as httpclient

## 2: Setup Triton Client & Config

In [None]:
TRITON_URL = "10.67.32.50:8000"
MODEL_NAME = "ddcolor_trt"
BATCH_SIZE = 4
IMAGE_SIZE = (512, 512)

client = httpclient.InferenceServerClient(url=TRITON_URL)
if client.is_model_ready(MODEL_NAME):
    print(f"✅ Connected to Triton and model '{MODEL_NAME}' is ready.")
else:
    raise RuntimeError("❌ Model not ready.")

## 3: Define Preprocessing & Postprocessing Function

In [None]:
def preprocess_image(img_path, image_size=(512, 512)):
    img = cv2.imread(img_path)
    if img is None:
        raise ValueError(f"Cannot read {img_path}")

    orig_size = (img.shape[1], img.shape[0])
    rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
    lab = cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB)
    l = lab[:, :, 0:1]
    gray_lab = np.concatenate([l, np.zeros_like(l), np.zeros_like(l)], axis=-1)
    gray_rgb = cv2.cvtColor(gray_lab, cv2.COLOR_LAB2RGB)
    gray_rgb = cv2.resize(gray_rgb, image_size)
    inp = gray_rgb.transpose(2, 0, 1).astype(np.float32)
    return inp, l, orig_size



def postprocess_image(output_ab, orig_l, orig_size):
    W, H = orig_size
    ab = cv2.resize(output_ab.transpose(1, 2, 0), (W, H))
    lab_out = np.concatenate([orig_l, ab], axis=-1)
    rgb = cv2.cvtColor(lab_out, cv2.COLOR_LAB2RGB)
    rgb = np.clip(rgb, 0, 1)
    return (rgb * 255).astype(np.uint8)


## 4: Inference

In [None]:
def infer_batch(batch_array):
    inputs = [httpclient.InferInput("input", batch_array.shape, "FP32")]
    inputs[0].set_data_from_numpy(batch_array)
    outputs = [httpclient.InferRequestedOutput("output")]
    result = client.infer(model_name=MODEL_NAME, inputs=inputs, outputs=outputs)
    out_array = result.as_numpy("output")
    return out_array

def colorize_directory(input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    image_files = [p for p in Path(input_dir).glob("*") if p.suffix.lower() in [".jpg", ".png", ".jpeg", ".bmp"]]
    if not image_files:
        print("❌ No images found."); return

    print(f"📂 Found {len(image_files)} images. Processing in batches of {BATCH_SIZE}...")
    total_time = 0
    processed = 0

    for i in range(0, len(image_files), BATCH_SIZE):
        batch_paths = image_files[i:i + BATCH_SIZE]
        batch, Ls, sizes = [], [], []
        for p in batch_paths:
            inp, l, s = preprocess_image(str(p), IMAGE_SIZE)
            batch.append(inp); Ls.append(l); sizes.append(s)
        batch_array = np.stack(batch, axis=0)

        start = time.time()
        output = infer_batch(batch_array)
        total_time += time.time() - start

        for j, p in enumerate(batch_paths):
            out = postprocess_image(output[j], Ls[j], sizes[j])
            out_path = Path(output_dir) / f"colorized_{p.name}"
            cv2.imwrite(str(out_path), cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
            print(f"✅ Saved {out_path.name}")
            processed += 1

    print(f"\n🏁 Done. {processed} images processed in {total_time:.2f}s "
          f"({total_time/processed:.3f}s per image)")


In [None]:
input_dir = "./grayscale_images"
output_dir = "./colorized_images"

colorize_directory(input_dir, output_dir)