<a href="https://colab.research.google.com/github/jermwatt/yolo-diffusion/blob/shap-e-testing/shape-e-tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Single object shap-e demo

## 1.  Machine setup

### 1.1. pre-launch machine check

Activate the code block below to double check that you are using a GPU runtime for the experiments.  You don't have to use one if you can't, but the experiments will run significantly faster (in particular the diffusion steps) if you are using one.

In [None]:
# check for GPU runtime
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

# check for memory
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

### 1.2. installs - segmentation

Activate the cell below to install all necessary packages to your collab enviroment for these experiments.

In [None]:
# install reqiured libraries 
!pip install "ultralytics==8.0.111" "transformers==4.29.2" "timm==0.9.2" "diffusers==0.16.1" "safetensors==0.3.1" "accelerate==0.19.0"

Activate the cell below to pull all helper functions required to run this demo.

In [None]:
# pull some images from the web
import urllib.request

def download_file(url, output_path):
    urllib.request.urlretrieve(url, output_path)

# pull segmentation module
url = "https://raw.githubusercontent.com/jermwatt/morphi_lab/object_diffusion_collab_demo/segmenter.py"
local_path = "/content/segmenter.py"
download_file(url,local_path)

# pull diffusion module
url = "https://raw.githubusercontent.com/jermwatt/morphi_lab/object_diffusion_collab_demo/diffuser.py"
local_path = "/content/diffuser.py"
download_file(url,local_path)

# pull utilities module
url = "https://raw.githubusercontent.com/jermwatt/morphi_lab/object_diffusion_collab_demo/utilities.py"
local_path = "/content/utilities.py"
download_file(url,local_path)

# pull main module
url = "https://raw.githubusercontent.com/jermwatt/morphi_lab/object_diffusion_collab_demo/main.py"
local_path = "/content/main.py"
download_file(url,local_path)

Activate the cell below to pull in a a sample images to experiment with.

In [None]:
# man holding donut - we'll use this one for testing
url = "https://www.shutterstock.com/image-photo/surprised-young-man-holding-donut-260nw-586330142.jpg"
output_path = "/content/test_donut.png"
download_file(url, output_path)

### 1.2.2. shape-e installs

In [None]:
!git clone https://github.com/openai/shap-e.git /content/shap-e

In [None]:
# cannot seem to install in edit mode via : pip install . -e
!cd shap-e && pip install .

### 1.3.  Segmentation

In [None]:
from segmenter import segment_image, label_lookup_dict

In [None]:
# segment the donut out of the test image
img_path = "/content/test_donut.png"
labels = ['person']
img, mask, seg = segment_image(img_path,
                               labels=labels)

In [None]:
seg.show_result()

In [None]:
# set mask to 1 where img is non zero
mask[mask!=0] = 1

# multiply elementwise img and mask
img_masked = seg.orig_img*mask

In [None]:
import matplotlib.pyplot as plt
plt.imshow(img_masked.astype(int))
plt.show()

### 1.4.  Feed into shap-e

In [None]:
import shap_e
import torch
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget
from shap_e.util.image_util import load_image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# load in shap-e models
xm = load_model('transmitter', device=device)
model = load_model('image300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))

In [None]:
batch_size = 4
guidance_scale = 3.0

# To get the best result, you should remove the background and show only the object of interest to the model.
image = load_image("/content/test_donut.png")

latents = sample_latents(
    batch_size=batch_size,
    model=model,
    diffusion=diffusion,
    guidance_scale=guidance_scale,
    model_kwargs=dict(images=[img_masked] * batch_size),
    progress=True,
    clip_denoised=True,
    use_fp16=True,
    use_karras=True,
    karras_steps=64,
    sigma_min=1e-3,
    sigma_max=160,
    s_churn=0,
)

In [None]:
render_mode = 'nerf' # you can change this to 'stf' for mesh rendering
size = 256 # this is the size of the renders; higher values take longer to render.

cameras = create_pan_cameras(size, device)
for i, latent in enumerate(latents):
    images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)
    display(gif_widget(images))
    break