In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('..')

import os
import cv2
import numpy as np
import torch
from PIL import Image
from pytorch_lightning import seed_everything
from tqdm import tqdm

from cldm.model import create_model, load_state_dict
from dataset import PhotoSketchDataset
from inference import run_sampler
from share import *

logging improved.


# Get dataset

In [5]:
dataset = PhotoSketchDataset(split="test", data_dir="./data/sketch")

In [24]:
gold_images = [torch.Tensor(cv2.imread(dataset.images[i])) for i in range(len(dataset))]
NUM_IMGS = len(gold_images)

# gen_dir should contain all the generated test images
def get_generations(gen_dir, num_gens):
    generations = [torch.Tensor(cv2.imread(os.path.join(gen_dir, f"image_{i:03d}.jpg"))) for i in range(num_gens)]
    return generations

In [9]:
from ast import literal_eval

prompts = []
for i in range(NUM_IMGS):
    prompt_file = dataset.prompts[i]
    with open(prompt_file) as f:
        prompt = literal_eval(f.read().strip())[0]
        prompts.append(prompt)

In [15]:
from torchmetrics.image.fid import FrechetInceptionDistance

def eval_fid(gold_images, generated_images):
    # fid = FrechetInceptionDistance(normalize=True, device=torch.device("cuda:6"))
    fid = FrechetInceptionDistance(normalize=True)
    fid.update(torch.stack(gold_images).permute(0, 3, 1, 2), real=True)
    fid.update(torch.stack(generated_images).permute(0, 3, 1, 2), real=False)

    print(f"FID: {float(fid.compute())}")

In [32]:
from torchmetrics.functional.multimodal import clip_score
from functools import partial

clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")

def calculate_clip_score(images, prompts):
    # TODO: multiply by 255?
#     images_int = (images * 255).astype("uint8")
    images_int = (images * 255)
    clip_score = clip_score_fn(images, prompts).detach()
    return round(float(clip_score), 4)

def eval_clip(generated_images, prompts):
    clip_scores = [calculate_clip_score(generated_images, prompts)]
    print(f"CLIP score: {np.mean(clip_scores)}")
# CLIP score: 35.7038

In [None]:
import fnmatch
import os

experiments = [
    'lr=1e-05_bs=2',
    'lr=1e-05_bs=4',
    'lr=5e-05_bs=2',
    'lr=5e-05_bs=4',
    'lr=5e-06_bs=2',
    'lr=5e-06_bs=4',
]

gen_dir='/raid/lingo/alexisro/ControlNet/project/generations'
    
for exp_idx, exp in enumerate(experiments):
    print("\n====================================================================================")
    print(f"Running inference for experiment {exp_idx}/{len(experiments)}:", exp)
    gen_path = f'/raid/lingo/alexisro/ControlNet/project/generations/{exp}'
    print("Reading images from:", gen_path)
    num_imgs = 20
    generated_images = get_generations(gen_path, num_imgs)
    eval_clip(generated_images, prompts[:num_imgs])


Running inference for experiment 0/6: lr=1e-05_bs=2
Reading images from: /raid/lingo/alexisro/ControlNet/project/generations/lr=1e-05_bs=2
