In [2]:
%load_ext autoreload
%autoreload 2


import sys
sys.path.append('../')

In [3]:
from diffusion_adapters import StableDiffusionPipelineAdapterEmbeddings, StableDiffusionXLPipelineAdapterEmbeddings
from prior_models import TransformerEmbeddingDiffusionModelv2
from eval_utils import compute_embeddings
from utils import map_embeddings_to_ratings, split_recommender_data, store_eval_images, store_eval_images_per_user

from diffusers import DDPMScheduler
import pandas as pd
import torch
import os
from diffusers.utils import load_image
from tqdm import tqdm

from diffusion_adapters import StableDiffusionPipelineAdapterEmbeddings
from sampling import sample_from_diffusion

In [4]:
noise_scheduler = DDPMScheduler(num_train_timesteps=6000)
device = "cuda"
users = 94
images_per_user = 30

In [30]:
diffusion_prior_model = TransformerEmbeddingDiffusionModelv2(
    img_embed_dim=1024,
    num_users=94,    # So user embedding covers your entire user set
    n_heads=16,
    num_tokens=1,
    num_user_tokens=4,
    num_layers=8,
    dim_feedforward=2048,
    whether_use_user_embeddings=True
).to(device)

savepath = "../data/flickr/evaluation/diffusion_priors/models/weights/sd15_nl8_heads16_dim_feedforward2048_lr0.0001_it1_ut4_adamw_reduce_on_plateau_bs64_nslinear_spu80_timesteps6000_objnoise-pred_useueTrue.pth"
llm_savepath = "../data/flickr/evaluation/baselines/llm_profiling/generated_images/per_user/"
T2_savepath = "../data/flickr/evaluation/baselines/t2_prompt/"
gt_savepath = "../data/flickr/evaluation/ground_truth/usrthrs_100/liked_per_user"

diffusion_prior_model.load_state_dict(torch.load(savepath, weights_only=True))
diffusion_prior_model.eval()


model_id = "runwayml/stable-diffusion-v1-5"                                                                                                                                                                                                                
pipe = StableDiffusionPipelineAdapterEmbeddings.from_pretrained(model_id).to("cuda")
pipe.safety_checker = None
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") 

t2_embeddings = torch.load(f"{T2_savepath}/embeddings/sd15_embeddings.pth")



C:\Users\Gabriel\.cache\huggingface\hub\models--runwayml--stable-diffusion-v1-5\snapshots\451f4fe16113bff5a5d2269ed5ad43b0592e9a14


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [31]:
import numpy as np
def compute_precision_recall_at_k(
    generated,  # (n, latent_dim)
    encoded_real,  # (N, latent_dim)
    R,  # (N, num_users)
    top_k=1
):
    from sklearn.metrics.pairwise import cosine_similarity
    sims = cosine_similarity(generated.cpu().numpy(), encoded_real.cpu().numpy())  # (n, N)
    topk_indices = sims.argsort(axis=1)[:, -top_k:]  # (N, n)
    recommended_items = np.unique(topk_indices.flatten())
    liked_items = np.where(R == 1)[0]
    relevant_retrieved = np.intersect1d(recommended_items, liked_items)
    precision = len(relevant_retrieved) / len(recommended_items)
    recall = len(relevant_retrieved) / len(liked_items)
    return precision, recall

In [36]:
torch.manual_seed(0)

data = dict.fromkeys(range(users))
for user_idx in tqdm(range(users)):
    user_dict = dict.fromkeys(["rebeca_embeddings", "llm_embeddings", "T2"])
    like = 1
    score_tensor = torch.tensor(like).expand(images_per_user).long().to(device)
    user_tensor = torch.tensor(user_idx).expand(images_per_user).to(device)
    user_ids_uncond_tensor = torch.full_like(user_tensor, fill_value = 94).to(device)
    score_uncond_tensor = torch.full_like(score_tensor, fill_value = 2).to(device)

    sampled_img_embs = sample_from_diffusion(
                model=diffusion_prior_model,
                user_ids_cond=user_tensor,
                scores_cond=score_tensor,
                user_ids_uncond=user_ids_uncond_tensor,
                scores_uncond=score_uncond_tensor,
                img_embedding_size=1024,
                scheduler=noise_scheduler,
                guidance_scale=10,
                prediction_type="epsilon",
                device="cuda",
            ).detach()
    
    llm_path_images = f"{llm_savepath}/{user_idx}"
    llm_embs = []
    with torch.no_grad():
        for unq_img_path in tqdm(os.listdir(llm_path_images)):
            path = os.path.join(llm_path_images, unq_img_path)
            pil_image = load_image(path)
            image_emb = pipe.encode_image(pil_image, device="cuda", num_images_per_prompt=1)[0].squeeze()
            llm_embs.append(image_emb.cpu())
    llm_embs_tensor = torch.stack(llm_embs)[torch.randperm(50)[:images_per_user]]

    user_dict["rebeca_embeddings"] = sampled_img_embs.detach().cpu()
    user_dict["llm_embeddings"] = llm_embs_tensor
    user_dict["T2"] = t2_embeddings[torch.randperm(t2_embeddings.size(0))[:images_per_user]]
    data[user_idx]  = user_dict

100%|██████████| 50/50 [00:02<00:00, 21.85it/s]
100%|██████████| 50/50 [00:02<00:00, 20.83it/s]]
100%|██████████| 50/50 [00:02<00:00, 20.67it/s]]
100%|██████████| 50/50 [00:02<00:00, 21.43it/s]]
100%|██████████| 50/50 [00:02<00:00, 21.22it/s]]
100%|██████████| 50/50 [00:02<00:00, 20.98it/s]]
100%|██████████| 50/50 [00:02<00:00, 20.48it/s]]
100%|██████████| 50/50 [00:02<00:00, 20.92it/s]]
100%|██████████| 50/50 [00:02<00:00, 21.43it/s]]
100%|██████████| 50/50 [00:02<00:00, 21.23it/s]]
100%|██████████| 50/50 [00:02<00:00, 20.71it/s]t]
100%|██████████| 50/50 [00:02<00:00, 21.65it/s]t]
100%|██████████| 50/50 [00:02<00:00, 20.98it/s]t]
100%|██████████| 50/50 [00:02<00:00, 21.22it/s]t]
100%|██████████| 50/50 [00:02<00:00, 21.34it/s]t]
100%|██████████| 50/50 [00:02<00:00, 20.26it/s]t]
100%|██████████| 50/50 [00:02<00:00, 22.05it/s]t]
100%|██████████| 50/50 [00:02<00:00, 21.80it/s]t]
100%|██████████| 50/50 [00:02<00:00, 21.27it/s]t]
100%|██████████| 50/50 [00:02<00:00, 21.07it/s]t]
100%|██████

In [35]:
data[0]["llm_embeddings"].shape

torch.Size([50, 1024])

In [None]:


for user in range(94):

    user_gt_embs = torch.load(f"../data/flickr/evaluation/ground_truth/usrthrs_100/users/{user}/embeddings.pt")
    user_gt_ratings = torch.load(f"../data/flickr/evaluation/ground_truth/usrthrs_100/users/{user}/ratings.pt")

    rebeca_prec, rebeca_rec = compute_precision_recall_at_k(
        generated=data[user]["rebeca_embeddings"],
        encoded_real=user_gt_embs,
        R=user_gt_ratings
    )

    llm_prec, llm_rec = compute_precision_recall_at_k(
        generated=data[user]["llm_embeddings"],
        encoded_real=user_gt_embs,
        R=user_gt_ratings
    )

    t2_prec, t2_rec = compute_precision_recall_at_k(
        generated=data[user]["T2"],
        encoded_real=user_gt_embs,
        R=user_gt_ratings
    )

    print("Precisions: ",rebeca_prec, llm_prec, t2_prec)
    print("Recalls: ", rebeca_rec, llm_rec, t2_rec)

data[user]["rebeca_embeddings"]

Precisions:  0.125 0.0 0.2
Recalls:  1.0 0.0 1.0
Precisions:  0.7142857142857143 0.6 0.6
Recalls:  0.7142857142857143 0.42857142857142855 0.42857142857142855
Precisions:  0.25 0.42857142857142855 0.3333333333333333
Recalls:  0.6666666666666666 1.0 0.6666666666666666
Precisions:  0.4444444444444444 1.0 0.75
Recalls:  1.0 0.5 0.75
Precisions:  0.125 0.0 0.0
Recalls:  1.0 0.0 0.0
Precisions:  0.25 0.0 0.3333333333333333
Recalls:  0.6666666666666666 0.0 0.6666666666666666
Precisions:  0.42857142857142855 1.0 0.6666666666666666
Recalls:  0.6 0.8 0.4
Precisions:  0.4444444444444444 0.25 0.3333333333333333
Recalls:  1.0 0.25 0.5
Precisions:  0.5 0.3333333333333333 0.2857142857142857
Recalls:  1.0 0.5 0.5
Precisions:  0.2222222222222222 0.4 0.3333333333333333
Recalls:  1.0 1.0 0.5
Precisions:  0.25 0.3333333333333333 0.25
Recalls:  1.0 1.0 1.0
Precisions:  0.14285714285714285 0.14285714285714285 0.2
Recalls:  0.5 0.5 0.5
Precisions:  0.7142857142857143 0.4 1.0
Recalls:  1.0 0.4 0.8
Precisions:

ZeroDivisionError: division by zero

In [39]:
import torch

rebeca_precisions = []
rebeca_recalls = []

llm_precisions = []
llm_recalls = []

t2_precisions = []
t2_recalls = []

for user in range(94):
    user_gt_embs = torch.load(f"../data/flickr/evaluation/ground_truth/usrthrs_100/users/{user}/embeddings.pt")
    user_gt_ratings = torch.load(f"../data/flickr/evaluation/ground_truth/usrthrs_100/users/{user}/ratings.pt")

    if user_gt_ratings.sum() == 0:
        print(f"Skipping user {user} due to no positive ratings.")
        continue

    rebeca_prec, rebeca_rec = compute_precision_recall_at_k(
        generated=data[user]["rebeca_embeddings"],
        encoded_real=user_gt_embs,
        R=user_gt_ratings
    )
    llm_prec, llm_rec = compute_precision_recall_at_k(
        generated=data[user]["llm_embeddings"],
        encoded_real=user_gt_embs,
        R=user_gt_ratings
    )
    t2_prec, t2_rec = compute_precision_recall_at_k(
        generated=data[user]["T2"],
        encoded_real=user_gt_embs,
        R=user_gt_ratings
    )

    rebeca_precisions.append(rebeca_prec)
    rebeca_recalls.append(rebeca_rec)

    llm_precisions.append(llm_prec)
    llm_recalls.append(llm_rec)

    t2_precisions.append(t2_prec)
    t2_recalls.append(t2_rec)

# Compute means
rebeca_prec_mean = torch.tensor(rebeca_precisions).mean()
rebeca_rec_mean = torch.tensor(rebeca_recalls).mean()

llm_prec_mean = torch.tensor(llm_precisions).mean()
llm_rec_mean = torch.tensor(llm_recalls).mean()

t2_prec_mean = torch.tensor(t2_precisions).mean()
t2_rec_mean = torch.tensor(t2_recalls).mean()

# Display
print("== AVERAGE PRECISIONS ==")
print("REBECA:", rebeca_prec_mean.item())
print("LLM   :", llm_prec_mean.item())
print("T2    :", t2_prec_mean.item())

print("== AVERAGE RECALLS ==")
print("REBECA:", rebeca_rec_mean.item())
print("LLM   :", llm_rec_mean.item())
print("T2    :", t2_rec_mean.item())


Skipping user 30 due to no positive ratings.
Skipping user 35 due to no positive ratings.
Skipping user 75 due to no positive ratings.
== AVERAGE PRECISIONS ==
REBECA: 0.4227280616760254
LLM   : 0.5410780310630798
T2    : 0.5279958248138428
== AVERAGE RECALLS ==
REBECA: 0.8911128044128418
LLM   : 0.5758023858070374
T2    : 0.7311965823173523


(0.125, 1.0)

In [15]:
user_dict

{'rebeca_embeddings': tensor([[-0.1105,  0.0246,  0.2082,  ...,  0.2820,  0.7341,  1.0000],
         [ 0.6866, -0.0304, -0.1738,  ...,  0.4561, -0.4463,  0.3548],
         [-0.2670,  0.4138, -0.4932,  ..., -0.2317,  0.3979, -0.3279],
         ...,
         [-0.3196,  0.3060,  0.3436,  ..., -0.7037, -0.2401, -0.0504],
         [ 0.3215,  0.6854,  0.8581,  ...,  0.2277, -0.4329, -0.1015],
         [-0.3521,  0.4845, -0.3631,  ...,  0.5872, -0.9760, -0.0237]]),
 'llm_embeddings': tensor([[-1.0374,  1.3192,  0.0919,  ..., -0.1950, -0.9320,  0.9253],
         [-0.0185, -0.0692, -0.5141,  ..., -0.7484, -0.0904, -0.8282],
         [-0.0450,  0.4084, -0.0857,  ..., -1.2065,  0.0538, -0.7014],
         ...,
         [-0.2528,  0.8283, -0.5241,  ..., -0.1881, -0.2135, -0.7562],
         [-0.3547,  0.2377, -0.2352,  ..., -0.7045, -0.0153, -0.5422],
         [-0.5085,  0.2119,  0.0417,  ..., -0.7597, -0.9715, -0.3863]]),
 'T2': tensor([[-0.2513,  0.1295, -0.3730,  ..., -0.9222, -0.4242, -0.1011],


## Ground truth

In [4]:
image_features = torch.load("../data/flickr/processed/ip-adapters/SD15/sd15_image_embeddings.pt", weights_only=True)
filtered_ratings_df = pd.read_csv("../data/flickr/processed/filtered_ratings_df_usrthrs_100.csv")
#expanded_features = map_embeddings_to_ratings(image_features, filtered_ratings_df)
device = "cuda"

In [5]:
train_df, val_df, test_df = split_recommender_data(
    ratings_df=filtered_ratings_df,
    val_spu=10,
    test_spu=10,
    seed=42
)

Train set size: 177278
Validation set size: 928
Evaluation set size: 933


In [6]:
liked_df = test_df[test_df["score"]>=4]
test_df.to_csv("../data/flickr/evaluation/ground_truth/usrthrs_100/test_df.csv")
liked_df.to_csv("../data/flickr/evaluation/ground_truth/usrthrs_100/liked_df.csv")

In [29]:
liked_df

Unnamed: 0.1,original_index,Unnamed: 0,worker,imagePair,score,old_worker_id,image_id,worker_id
2,863,910,A14W0IW2KGR80K,farm8_7412_9119468631_6db230d056.jpg,5,0,909,0
10,4783,5803,A219DFCY05R0WJ,farm9_8232_8402674921_616367cd0f.jpg,5,6,5466,1
11,4911,5939,A219DFCY05R0WJ,farm8_7048_6935541584_449a69270b.jpg,4,6,5601,1
12,5502,6560,A219DFCY05R0WJ,farm1_109_316029216_7f07904970.jpg,4,6,6108,1
13,5575,6636,A219DFCY05R0WJ,farm9_8495_8269813182_c294c36f45.jpg,5,6,6183,1
...,...,...,...,...,...,...,...,...
923,178695,202524,A37CZZH18KQ2V2,farm8_7353_26951283693_38b1fb2a8c_b.jpg,5,208,18952,93
927,178941,202780,A37CZZH18KQ2V2,farm9_8287_7647087844_5a1c16694a.jpg,5,208,37452,93
929,178979,202820,A37CZZH18KQ2V2,farm2_1087_709243977_b9afa0f5d3.jpg,5,208,29047,93
931,179055,202900,A37CZZH18KQ2V2,farm1_766_23195362051_fd5ce4d44a_b.jpg,5,208,38296,93


In [8]:
store_eval_images(
    paths_iter=liked_df["imagePair"],
    src_dir="../data/flickr/raw/40K",
    dst_dir="../data/flickr/evaluation/ground_truth/usrthrs_100/liked/images/"
)

In [9]:
store_eval_images_per_user(
    liked_df=liked_df,
    src_dir="../data/flickr/raw/40K",
    dst_base_dir="../data/flickr/evaluation/ground_truth/usrthrs_100/liked_per_user/"
)

In [10]:
len(liked_df["worker_id"].unique())

91

In [26]:
model_id = "runwayml/stable-diffusion-v1-5"                                                                                                                                                                                                                
pipe = StableDiffusionPipelineAdapterEmbeddings.from_pretrained(model_id).to("cuda")
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")     
pipe.safety_checker = None
device = "cuda"

C:\Users\Gabriel\.cache\huggingface\hub\models--runwayml--stable-diffusion-v1-5\snapshots\451f4fe16113bff5a5d2269ed5ad43b0592e9a14


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [12]:
sd15_liked_embeddings = compute_embeddings(
    diffusion_pipe=pipe,
    image_paths="../data/flickr/evaluation/ground_truth/usrthrs_100/liked/images/"
    )

100%|██████████| 351/351 [00:16<00:00, 20.66it/s]


In [14]:
torch.save(sd15_liked_embeddings, "../data/flickr/evaluation/ground_truth/usrthrs_100/liked/embeddings/sd15_embeddings.pth")

In [28]:
USER_DIRS = "../data/flickr/evaluation/ground_truth/usrthrs_100/liked_per_user"

for user in os.listdir(USER_DIRS):
    user_id = user.split("_")[1]
    user_dir = os.path.join(USER_DIRS, f"user_{user_id}")
    images_dir = os.path.join(user_dir, "images")
    sdxl_ipadapter_embs = []
    with torch.no_grad():
        for path in os.listdir(images_dir):
                #path = "../data/raw/FLICKR-AES-001/40K/" + unq_img_path
            impath = os.path.join(images_dir, path)
            pil_image = load_image(impath)
            image_emb = pipe.encode_image(pil_image, device="cuda", num_images_per_prompt=1)[0].squeeze()
            sdxl_ipadapter_embs.append(image_emb.cpu())
        sdxl_ipadapter_embs_tensor = torch.stack(sdxl_ipadapter_embs)
        os.makedirs(os.path.join(user_dir, "embeddings"), exist_ok=True)
        torch.save(sdxl_ipadapter_embs_tensor, f"{images_dir}/../embeddings/sd15_embeddings.pth")

In [16]:
pipe = StableDiffusionXLPipelineAdapterEmbeddings.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")

C:\Users\Gabriel\.cache\huggingface\hub\models--stabilityai--stable-diffusion-xl-base-1.0\snapshots\462165984030d82259a11f4367a4eed129e94a7b


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [17]:
sdxl_liked_embeddings = compute_embeddings(
    diffusion_pipe=pipe,
    image_paths="../data/flickr/evaluation/ground_truth/usrthrs_100/liked/images/"
    )

100%|██████████| 351/351 [00:16<00:00, 21.01it/s]


In [18]:
torch.save(sdxl_liked_embeddings, "../data/flickr/evaluation/ground_truth/usrthrs_100/liked/embeddings/sdxl_embeddings.pth")

In [23]:


USER_DIRS = "../data/flickr/evaluation/ground_truth/usrthrs_100/liked_per_user"

for user in os.listdir(USER_DIRS):
    user_id = user.split("_")[1]
    user_dir = os.path.join(USER_DIRS, f"user_{user_id}")
    images_dir = os.path.join(user_dir, "images")
    sdxl_ipadapter_embs = []
    with torch.no_grad():
        for path in os.listdir(images_dir):
                #path = "../data/raw/FLICKR-AES-001/40K/" + unq_img_path
            impath = os.path.join(images_dir, path)
            pil_image = load_image(impath)
            image_emb = pipe.encode_image(pil_image, device="cuda", num_images_per_prompt=1)[0].squeeze()
            sdxl_ipadapter_embs.append(image_emb.cpu())
        sdxl_ipadapter_embs_tensor = torch.stack(sdxl_ipadapter_embs)
        os.makedirs(os.path.join(user_dir, "embeddings"))
        torch.save(sdxl_ipadapter_embs_tensor, f"{images_dir}/../embeddings/sdxl_embeddings.pth")

## Baselines

In [14]:
sd15_baseline_t0_embeddings = compute_embeddings(
    diffusion_pipe=pipe,
    image_paths="../data/flickr/evaluation/baselines/t0_prompt/images/"
    )

100%|██████████| 5010/5010 [04:42<00:00, 17.73it/s]


In [15]:
torch.save(sd15_baseline_t0_embeddings, "../data/flickr/evaluation/baselines/t0_prompt/embeddings/embeddings.pth")

In [16]:
sd15_baseline_t1_embeddings = compute_embeddings(
    diffusion_pipe=pipe,
    image_paths="../data/flickr/evaluation/baselines/t1_prompt/images/"
    )

100%|██████████| 5009/5009 [04:21<00:00, 19.18it/s]


In [17]:
torch.save(sd15_baseline_t1_embeddings, "../data/flickr/evaluation/baselines/t1_prompt/embeddings/embeddings.pth")

In [6]:
diffusion_prior_model = TransformerEmbeddingDiffusionModelv2(
    img_embed_dim=1024,
    num_users=210,    # So user embedding covers your entire user set
    n_heads=16,
    num_tokens=8,
    num_user_tokens=1,
    num_layers=8,
    dim_feedforward=1024
).to(device)


noise_scheduler = DDPMScheduler(num_train_timesteps=6000)

savepath = f"../weights/grid_search_3/sd15_nl8_heads32_dim_feedforward1024_lr0.0001_it8_ut1_adamw_reduce_on_plateau_bs64_nslinear_spu80_timesteps6000_objz0-pred.pth"

diffusion_prior_model.load_state_dict(torch.load(savepath))
diffusion_prior_model.eval()

  diffusion_prior_model.load_state_dict(torch.load(savepath))


TransformerEmbeddingDiffusionModelv2(
  (user_embedding): Embedding(210, 128)
  (score_embedding): Embedding(2, 128)
  (time_embedding): SinusoidalEmbedding()
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (output_fc): Linear(in_features=128, out_features=128, bias=True)
)

In [34]:
from evaluation import ModelEvaluator
from Datasets import FlatImageDataset
from torch.utils.data import DataLoader

In [28]:
from torcheval.metrics import FrechetInceptionDistance
import torch

# Initialize FID metric
fid_metric = FrechetInceptionDistance()

# Create synthetic identical datasets
real_images = torch.rand(20, 3, 299, 299)  # Random data
fake_images = real_images.clone()           # Exact copy of real data

# Update metric
fid_metric.update(real_images, is_real=True)
fid_metric.update(fake_images, is_real=False)

# Compute FID
score = fid_metric.compute()
print(f"FID Score (synthetic test): {score}")


FID Score (synthetic test): -0.11057186126708984


In [35]:
me = ModelEvaluator(
    diffusion_prior_model=diffusion_prior_model,
    scheduler=noise_scheduler,
    model_dir="../data/flickr/evaluation/diffusion_priors/experiment_2/model_2/",
    original_img_dir="../data/flickr/evaluation/diffusion_priors/experiment_2/original_eval_data/images/"
)

In [36]:
me.eval_fid(
    path_original="../data/flickr/evaluation/diffusion_priors/experiment_2/original_eval_data/images/",
    path_generated="../data/flickr/evaluation/diffusion_priors/experiment_2/model_2/images/",
)

100%|██████████| 41/41 [01:56<00:00,  2.84s/it]
100%|██████████| 33/33 [01:30<00:00,  2.74s/it]


tensor(76.5065)

In [45]:
me.eval_cmmd(
    path_original="../data/flickr/evaluation/diffusion_priors/experiment_2/original_eval_data/images/",
    path_generated="../data/flickr/evaluation/diffusion_priors/experiment_2/model_1/images/", 
    #path_original_emb="../data/flickr/evaluation/diffusion_priors/experiment_2/original_eval_data/embeddings/cmmd_embeddings.pth",
    #path_generated_emb="../data/flickr/evaluation/diffusion_priors/experiment_2/model_2/embeddings/cmmd_embeddings.pth"
)

Calculating embeddings for 2597 images from ../data/flickr/evaluation/diffusion_priors/experiment_2/original_eval_data/images/.


100%|██████████| 82/82 [01:32<00:00,  1.13s/it]


Calculating embeddings for 2100 images from ../data/flickr/evaluation/diffusion_priors/experiment_2/model_1/images/.


100%|██████████| 66/66 [01:23<00:00,  1.27s/it]


tensor(0.7269)