In [2]:
import gc
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["WANDB_API_KEY"] = "7b14a62f11dc360ce036cf59b53df0c12cd87f5a"
import torch
import random
from tqdm import tqdm
from lora import LoRANetwork
from diffusers.utils import export_to_gif
from shap_e.diffusion.sample import sample_latents
from shap_e.models.download import load_model, load_config
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images

In [10]:
n = 5
rank = 4
size = 160
prompt = ""
alpha = 1.0
sigma_max = 160
scales = [-1, 1]
render_mode = "nerf"
cond_drop_prob = 0.5
guidance_scale = 7.5
name = "armsslider_2024-03-23 13:46:14.541733"
lora_weight = f"/home/noamatia/repos/spic-e/outputs/{name}/model_final.pt"

output_dir = os.path.join('outputs', name, 'test')
os.makedirs(output_dir, exist_ok=True)

def flush(*args):
    for arg in args:
        del arg
    torch.cuda.empty_cache()
    gc.collect()

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
model.wrapped.cond_drop_prob = cond_drop_prob
model.freeze_all_parameters()
network = LoRANetwork(model.wrapped, rank, alpha).to(device)
network.load_state_dict(torch.load(lora_weight))
diffusion = diffusion_from_config(load_config('diffusion'))
test_model_kwargs = dict(texts=[prompt])
cameras = create_pan_cameras(size, device)

for i in tqdm(range(n), total=n):
    seed = random.randint(0, 5000)
    x_T = torch.randn((1, model.d_latent), device=device).expand(1, -1) * sigma_max
    for scale in scales:
        network.set_lora_slider(scale)
        with network:
            with torch.no_grad():
                test_latents = sample_latents(
                    device=device,
                    batch_size=1,
                    model=model,
                    diffusion=diffusion,
                    guidance_scale=guidance_scale,
                    model_kwargs=test_model_kwargs,
                    clip_denoised=True,
                    use_fp16=True,
                    use_karras=True,
                    karras_steps=64,
                    sigma_min=1e-3,
                    sigma_max=sigma_max,
                    s_churn=0,
                    progress=True,
                    x_T=x_T,
                )
        images = decode_latent_images(xm, test_latents[0], cameras, rendering_mode=render_mode)
        result_path = os.path.join(output_dir, f'{i}_{scale}.gif')
        export_to_gif(images, result_path)
        flush(test_latents)

create LoRA for SplitVectorDiffusion: 245 modules.


  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

 20%|██        | 1/5 [04:58<19:52, 298.23s/it]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

 40%|████      | 2/5 [09:52<14:48, 296.14s/it]

  0%|          | 0/64 [00:00<?, ?it/s]

 40%|████      | 2/5 [10:11<15:17, 305.95s/it]


KeyboardInterrupt: 