In [1]:
import shutil, os, sys

# Only remove & clone if needed
if os.path.exists('./pix2pixHD'):
    shutil.rmtree('./pix2pixHD')
!git clone https://github.com/NVIDIA/pix2pixHD.git

sys.path.append('./pix2pixHD')
from models.networks import define_G, define_D

print("Repo cloned and imports successful!")


import torch
from models.networks import define_G
from torchvision import transforms
from PIL import Image
import numpy as np

# ----------------- Setup and Model Loading -----------------
# Set your paths
checkpoint_path = "/kaggle/input/pix2pix-hd-lighting/pytorch/epoch30/7/pix2pixhd_checkpoint_epoch_90.pth"  # update with your best/last epoch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model parameters (as in your training code)
input_nc = 28
output_nc = 3
G = define_G(input_nc, output_nc, 64, 'global', 4, 9, 1, 3, 'instance', [])
G = torch.nn.DataParallel(G)
G.to(device)
G.eval()

# Load EMA generator weights
checkpoint = torch.load(checkpoint_path, map_location=device)
G.load_state_dict(checkpoint["ema_generator_state_dict"], strict=False)  # We use EMA weights for inference
print("Loaded EMA weights.")

# ----------------- Helper Functions -----------------
def preprocess_input(img_path, lighting_idx):
    # Image Transform (must match training)
    transform = transforms.Compose([
        transforms.Resize((256,256)),
        transforms.ToTensor()
    ])
    img = Image.open(img_path).convert("RGB")
    img = transform(img)
    # Create one-hot lighting vector
    n = 25
    lighting_vec = torch.zeros(n, 1, 1)
    lighting_vec[lighting_idx] = 1
    lighting_vec = lighting_vec.expand(n, 256, 256)
    # Concatenate on channel axis
    cat_input = torch.cat([img, lighting_vec], dim=0)  # shape: [28,256,256]
    return cat_input.unsqueeze(0)  # Add batch dimension

def save_output(tensor_img, save_path):
    arr = tensor_img.detach().cpu().clamp(0,1).numpy()[0]  # [3, H, W]
    arr = np.transpose(arr, (1,2,0))  # [H, W, 3]
    arr = (arr * 255).astype(np.uint8)
    Image.fromarray(arr).save(save_path)
    print(f"Saved output: {save_path}")

# ----------------- Inference -----------------
# Example usage:
input_img_path = "/kaggle/input/multi-illumination-jpg/14n_copyroom1/dir_0_mip2.jpg"
lighting_idx = 5  # for example, use target lighting index 5

for i in range(1,24):
    cat_input = preprocess_input(input_img_path, i).to(device)
    with torch.no_grad(), torch.amp.autocast(device.type):
        output = G(cat_input)
        output = output.clamp(0., 1.)
    save_output(output, f"/kaggle/working/inference_output_{i}.png")

Cloning into 'pix2pixHD'...
remote: Enumerating objects: 343, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 343 (delta 0), reused 0 (delta 0), pack-reused 340 (from 1)[K
Receiving objects: 100% (343/343), 55.68 MiB | 39.16 MiB/s, done.
Resolving deltas: 100% (156/156), done.
Repo cloned and imports successful!
GlobalGenerator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(28, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_run

  checkpoint = torch.load(checkpoint_path, map_location=device)


Loaded EMA weights.
Saved output: /kaggle/working/inference_output_1.png
Saved output: /kaggle/working/inference_output_2.png
Saved output: /kaggle/working/inference_output_3.png
Saved output: /kaggle/working/inference_output_4.png
Saved output: /kaggle/working/inference_output_5.png
Saved output: /kaggle/working/inference_output_6.png
Saved output: /kaggle/working/inference_output_7.png
Saved output: /kaggle/working/inference_output_8.png
Saved output: /kaggle/working/inference_output_9.png
Saved output: /kaggle/working/inference_output_10.png
Saved output: /kaggle/working/inference_output_11.png
Saved output: /kaggle/working/inference_output_12.png
Saved output: /kaggle/working/inference_output_13.png
Saved output: /kaggle/working/inference_output_14.png
Saved output: /kaggle/working/inference_output_15.png
Saved output: /kaggle/working/inference_output_16.png
Saved output: /kaggle/working/inference_output_17.png
Saved output: /kaggle/working/inference_output_18.png
Saved output: /ka

In [2]:
!zip file.zip /kaggle/working/*.png

  adding: kaggle/working/inference_output_10.png (deflated 0%)
  adding: kaggle/working/inference_output_11.png (deflated 0%)
  adding: kaggle/working/inference_output_12.png (deflated 0%)
  adding: kaggle/working/inference_output_13.png (deflated 0%)
  adding: kaggle/working/inference_output_14.png (deflated 0%)
  adding: kaggle/working/inference_output_15.png (deflated 0%)
  adding: kaggle/working/inference_output_16.png (deflated 0%)
  adding: kaggle/working/inference_output_17.png (deflated 0%)
  adding: kaggle/working/inference_output_18.png (deflated 0%)
  adding: kaggle/working/inference_output_19.png (deflated 0%)
  adding: kaggle/working/inference_output_1.png (deflated 0%)
  adding: kaggle/working/inference_output_20.png (deflated 0%)
  adding: kaggle/working/inference_output_21.png (deflated 0%)
  adding: kaggle/working/inference_output_22.png (deflated 0%)
  adding: kaggle/working/inference_output_23.png (deflated 0%)
  adding: kaggle/working/inference_output_2.png (deflate