# Inference times from scratch models

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import torch
import time
import numpy as np
from diffusers import DDPMPipeline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def measure_time(pipe_name, model_id, n_samples=10):
    print(f"\nEvaluating {pipe_name}")
    pipe = DDPMPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
    pipe.to(device)

    times = []
    for i in range(n_samples):
        start = time.time()
        img = pipe(batch_size=1).images
        end = time.time()
        times.append(end - start)
        if i == 0:
            print(f"Generated image shape: {img[0].size}")

    mean_time = np.mean(times)
    std_time = np.std(times)

    print(f"{pipe_name} | Mean: {mean_time:.3f} s | Std: {std_time:.3f} s")

# Run the evaluations
measure_time("DDPM 64", "benetraco/brain_ddpm_64")
measure_time("DDPM 128", "benetraco/brain_ddpm_128")
measure_time("DDPM 256", "benetraco/brain_ddpm_256")
measure_time("Latent DDPM", "benetraco/latent_scratch")


2025-06-09 12:03:04.867864: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749463384.891989   73719 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749463384.899196   73719 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1749463384.918677   73719 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749463384.918707   73719 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749463384.918709   73719 computation_placer.cc:177] computation placer alr


Evaluating DDPM 64


Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generated image shape: (64, 64)


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

DDPM 64 | Mean: 28.619 s | Std: 0.566 s

Evaluating DDPM 128


Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generated image shape: (128, 128)


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

DDPM 128 | Mean: 31.880 s | Std: 0.217 s

Evaluating DDPM 256


Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generated image shape: (256, 256)


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

DDPM 256 | Mean: 67.942 s | Std: 0.051 s

Evaluating Latent DDPM


Loading pipeline components...:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Generated image shape: (32, 32)


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Latent DDPM | Mean: 27.853 s | Std: 0.261 s


# Generate samples different guidance values (SD finetuned model)

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
import gc
from tqdm import tqdm
from diffusers import StableDiffusionPipeline
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path

class GuidedSampler:
    def __init__(self, model_id="benetraco/latent_finetuning", resolution=32,
                 num_inference_steps=999, device="cuda", seed=17844):
        self.device = device
        self.seed = seed
        self.resolution = resolution
        self.generator = torch.manual_seed(seed)
        self.num_inference_steps = num_inference_steps

        self.pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
        self.pipe.to(device)

        self.unet = self.pipe.unet
        self.vae = self.pipe.vae
        self.scheduler = self.pipe.scheduler
        self.tokenizer = self.pipe.tokenizer
        self.text_encoder = self.pipe.text_encoder

        self.scheduler.set_timesteps(self.num_inference_steps)

    def _get_embeddings(self, prompt):
        tokens = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
        tokens = {k: v.to(self.device) for k, v in tokens.items()}
        with torch.no_grad():
            return self.text_encoder(**tokens).last_hidden_state

    def visualize_guidance_effect(self, prompts, guidance_values, output_dir="guidance_effect"):
        os.makedirs(output_dir, exist_ok=True)
        uncond_emb = self._get_embeddings("")

        for prompt in prompts:
            text_emb = self._get_embeddings(prompt)
            tag = prompt.replace(" ", "_")

            latent = torch.randn(1, 4, self.resolution, self.resolution, generator=self.generator).to(self.device)
            latent *= self.scheduler.init_noise_sigma

            fig, axes = plt.subplots(1, len(guidance_values), figsize=(3 * len(guidance_values), 3))
            fig.suptitle(f"Prompt: {prompt}", fontsize=12)

            for i, g in enumerate(guidance_values):
                lat = latent.clone()
                for t in self.scheduler.timesteps:
                    lat = self.scheduler.scale_model_input(lat, t)
                    with torch.no_grad():
                        noise_uncond = self.unet(lat, t, encoder_hidden_states=uncond_emb).sample
                        noise_text = self.unet(lat, t, encoder_hidden_states=text_emb).sample
                        noise = noise_uncond + g * (noise_text - noise_uncond)
                    lat = self.scheduler.step(noise, t, lat).prev_sample

                lat /= self.vae.config.scaling_factor
                with torch.no_grad():
                    decoded = self.vae.decode(lat).sample.cpu()

                decoded = (decoded + 1.0) / 2.0
                decoded = decoded.clamp(0, 1)

                image_path = os.path.join(output_dir, f"{tag}_g{g}.png")
                save_image(decoded[0], image_path)

                img = Image.open(image_path)
                axes[i].imshow(img, cmap="gray")
                axes[i].axis("off")
                axes[i].set_title(f"g={g}")

                del lat, decoded, noise, noise_uncond, noise_text
                gc.collect()
                torch.cuda.empty_cache()

            plt.tight_layout()
            plt.subplots_adjust(top=0.85)
            fig_path = os.path.join(output_dir, f"{tag}_comparison.png")
            plt.savefig(fig_path)
            plt.close(fig)
            print(f"Saved: {fig_path}")

# Example usage:
if __name__ == "__main__":
    sampler = GuidedSampler(
        model_id="benetraco/latent_finetuning_scanners_healthy",
        resolution=32,
        num_inference_steps=999
    )

    sampler.visualize_guidance_effect(
        prompts=["Philips FLAIR MRI", "Siemens FLAIR MRI", "GE FLAIR MRI"],
        guidance_values=[0.0, 1.0, 2.0, 3.0, 5.0],
        output_dir="guidance_effect_visualization"
    )


2025-06-09 15:55:25.283917: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749477325.309338  156432 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749477325.317128  156432 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1749477325.337947  156432 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749477325.337986  156432 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1749477325.337989  156432 computation_placer.cc:177] computation placer alr

Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]