In [None]:
import os
import torch
import PIL
import warnings
from diffusers import DPMSolverMultistepScheduler
from tqdm.auto import tqdm
from tree_ring import InversableStableDiffusionPipeline
from utils import to_tensor

warnings.filterwarnings("ignore")
# load diffusion model
device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "stabilityai/stable-diffusion-2-1-base"
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = InversableStableDiffusionPipeline.from_pretrained(
    model_id,
    scheduler=scheduler,
    torch_dtype=torch.float16,
    revision="fp16",
)
pipe = pipe.to(device)
num_inference_steps = 50
tester_prompt = ""  # assume at the detection time, the original prompt is unknown
text_embeddings = pipe.get_text_embedding(tester_prompt)

In [None]:
image_dir = "/fs/nexus-projects/HuangWM/datasets/attacked/diffusiondb/distortion_single_rotation-9-tree_ring"
image = PIL.Image.open(os.path.join(image_dir, "0.png"))

image_transformed = to_tensor([image]).to(text_embeddings.dtype).to(device)
image_latents = pipe.get_image_latents(image_transformed, sample=False)

reversed_latents = pipe.forward_diffusion(
    latents=image_latents,
    text_embeddings=text_embeddings,
    guidance_scale=1,
    num_inference_steps=num_inference_steps,
)