In [1]:
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import sys

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append('/home/claserken/Developer/dataunlearning/metrics/song_likelihood')

In [3]:
model_id = "google/ddpm-celebahq-256"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load model and scheduler
ddpm = DDPMPipeline.from_pretrained(model_id).to(device)  # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference

diffusion_pytorch_model.safetensors not found
Loading pipeline components...: 100%|██████████| 2/2 [00:00<00:00,  4.76it/s]


In [4]:
def load_image(image_path):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    image = Image.open(image_path)
    image = transform(image)
    return image.to(device)

In [25]:
train_img = load_image('../data/examples/celeba_hq_256/10000.jpg')

In [6]:
scheduler = ddpm.scheduler

In [7]:
import sde_lib
from likelihood import get_likelihood_fn

In [8]:
scheduler # matches default VPSDE!

DDPMScheduler {
  "_class_name": "DDPMScheduler",
  "_diffusers_version": "0.27.2",
  "beta_end": 0.02,
  "beta_schedule": "linear",
  "beta_start": 0.0001,
  "clip_sample": true,
  "clip_sample_range": 1.0,
  "dynamic_thresholding_ratio": 0.995,
  "num_train_timesteps": 1000,
  "prediction_type": "epsilon",
  "rescale_betas_zero_snr": false,
  "sample_max_value": 1.0,
  "steps_offset": 0,
  "thresholding": false,
  "timestep_spacing": "leading",
  "trained_betas": null,
  "variance_type": "fixed_small"
}

In [9]:
sde = sde_lib.VPSDE()

In [11]:
likelihood_fn = get_likelihood_fn(sde)

In [19]:
from diffusers import UNet2DModel
model = UNet2DModel.from_pretrained('/home/claserken/Developer/dataunlearning/checkpoints/celeb/deletion/2024-08-30_04-04-45/checkpoint-20/unet').to('cuda')

In [22]:
x = likelihood_fn(model, torch.randn(1, 3, 256, 256, device='cuda'))

In [23]:
x

(tensor([9.1719], device='cuda:0'),
 tensor([[[[-1.9878, -0.1042,  0.2928,  ..., -0.6378, -0.5451, -0.1143],
           [-0.0117, -0.2000,  0.2598,  ..., -1.2336, -0.0825, -1.0213],
           [ 0.4781,  0.7857,  0.8542,  ...,  2.1790,  2.5461,  0.3577],
           ...,
           [-0.9474, -1.2104,  0.9481,  ..., -0.1677, -1.3332,  0.4809],
           [-0.1778, -0.4073, -2.5542,  ..., -1.0388,  1.2164,  0.4333],
           [-1.0357,  1.7204, -2.4044,  ...,  0.3515, -0.1471, -1.4446]],
 
          [[-0.8550, -0.4344,  0.9014,  ..., -0.2972,  0.1901,  0.2502],
           [-2.1319,  0.7314,  0.4469,  ..., -0.1239, -0.1865, -0.1318],
           [-0.6346,  1.5013,  0.3284,  ..., -0.6705, -0.9352,  0.7832],
           ...,
           [ 1.3237, -1.3118, -0.1264,  ...,  0.5369, -0.1172, -0.7119],
           [ 0.5856, -0.1026,  0.7952,  ...,  0.8756,  0.1312,  0.7500],
           [-2.0602, -0.4493,  1.8116,  ..., -0.3420, -0.8863,  0.1094]],
 
          [[ 1.5364, -0.9941,  0.7757,  ..., -0.50

In [26]:
y = likelihood_fn(model, train_img.unsqueeze(0))

In [27]:
y

(tensor([1.2540], device='cuda:0'),
 tensor([[[[-1.2867, -0.8036, -0.8242,  ..., -0.5800,  0.4676, -0.0913],
           [ 0.3851,  1.2567,  1.0880,  ...,  1.2021,  1.1049,  1.5980],
           [-0.1945,  0.3565,  0.0395,  ...,  0.5469,  0.7246,  1.6694],
           ...,
           [ 0.5186,  0.8699,  0.4212,  ...,  0.1673,  0.9407, -0.5756],
           [-0.0361,  0.1206,  0.0865,  ...,  0.0918,  1.4830,  0.9634],
           [-1.5456, -0.5031, -0.0382,  ..., -0.4738, -0.2906,  0.5479]],
 
          [[-0.3052,  0.6698,  1.3276,  ...,  0.0816, -0.0766, -1.0778],
           [ 0.6772,  1.4056,  1.4685,  ...,  0.7519,  1.0482, -0.5629],
           [ 0.3534,  1.1467,  0.9835,  ...,  0.4097,  0.8647, -0.1984],
           ...,
           [ 0.3395,  0.8517,  0.5164,  ..., -0.7547, -2.2606, -0.4302],
           [-0.5895, -0.0075, -0.0501,  ...,  0.4433, -0.3146,  0.5051],
           [-0.7600,  0.7295,  0.9603,  ...,  0.7245,  1.0796, -0.0232]],
 
          [[ 0.1133, -0.4922,  0.5320,  ..., -0.25