[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/diffusion-e2e-ft-jupyter/blob/main/geowizard_e2e_ft_depth_normals_jupyter.ipynb)

In [None]:
!pip -q install diffusers xformers==0.0.28.post1

%cd /content
!git clone https://github.com/camenduru/geowizard-e2e-ft-hf
%cd /content/geowizard-e2e-ft-hf

In [None]:
import cv2
import numpy as np
from PIL import Image
import torch

from GeoWizard.geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
from GeoWizard.geowizard.models.unet_2d_condition import UNet2DConditionModel
from diffusers import DDIMScheduler, AutoencoderKL
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

checkpoint_path = "GonzaloMG/geowizard-e2e-ft"
vae = AutoencoderKL.from_pretrained(checkpoint_path, subfolder='vae')
scheduler = DDIMScheduler.from_pretrained(checkpoint_path, timestep_spacing="trailing", subfolder='scheduler')
image_encoder = CLIPVisionModelWithProjection.from_pretrained(checkpoint_path, subfolder="image_encoder")
feature_extractor = CLIPImageProcessor.from_pretrained(checkpoint_path, subfolder="feature_extractor")
unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")
pipe = DepthNormalEstimationPipeline(vae=vae, image_encoder=image_encoder, feature_extractor=feature_extractor, unet=unet, scheduler=scheduler)
pipe = pipe.to('cuda')
pipe.unet.eval()

def predict(image, processing_res_choice):
    with torch.no_grad():
        pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", processing_res=processing_res_choice, match_input_res=True)
    depth_pred = pipe_out.depth_np
    depth_colored = pipe_out.depth_colored
    normal_pred = pipe_out.normal_np
    normal_colored = pipe_out.normal_colored
    return depth_pred, depth_colored, normal_pred, normal_colored
  
processing_res_choice = 768
image = Image.open("/content/geowizard-e2e-ft/assets/examples/bottles.jpg").convert('RGB')
image_array = np.array(image).astype('uint8')
pil_image = Image.fromarray(image_array)
depth_pred, depth_colored, normal_pred, normal_colored = predict(pil_image, processing_res_choice)

In [None]:
depth_colored

In [None]:
normal_colored