In [None]:
from diffusers import DiffusionPipeline
import torch
from diffusers import DPMSolverMultistepScheduler
from PIL import Image
torch.set_grad_enabled(False)

import torchmetrics
from torchmetrics.functional import multimodal


# a bunch of below is unncessary but was in example of https://github.com/christophschuhmann/improved-aesthetic-predictor/blob/main/sac%2Blogos%2Bava1-l14-linearMSE.pth
from PIL import Image
import io
import matplotlib.pyplot as plt
%matplotlib inline
import os
import json

from warnings import filterwarnings


# os.environ["CUDA_VISIBLE_DEVICES"] = "0"    # choose GPU if you are on a multi GPU server
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import tqdm

from os.path import join
from torch.utils.data import Dataset, DataLoader
import json

import clip
import time


from PIL import Image, ImageFile


In [None]:
model_id = "runwayml/stable-diffusion-v1-5"
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda')

In [None]:
generator = torch.Generator("cuda").manual_seed(0)


In [None]:
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)


In [None]:
def get_inputs(batch_size=1):
    generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
    prompts = batch_size * [prompt]
    num_inference_steps = 20

    return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}

def image_grid(imgs, rows=2, cols=2):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-L/14", device=device)  #RN50x64   

In [None]:
class MLP(nn.Module):
    def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
        super().__init__()
        self.input_size = input_size
        self.xcol = xcol
        self.ycol = ycol
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, 1024),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 128),
            #nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            #nn.ReLU(),
            nn.Dropout(0.1),

            nn.Linear(64, 16),
            #nn.ReLU(),

            nn.Linear(16, 1)
        )

    def forward(self, x):
        return self.layers(x)


def normalized(a, axis=-1, order=2):
    l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
    l2[l2 == 0] = 1
    return a / np.expand_dims(l2, axis)


aes_model = MLP(768)  # CLIP embedding dim is 768 for CLIP ViT L 14
s = torch.load("sac+logos+ava1-l14-linearMSE.pth")   # load the model you trained previously or the model available in this repo
aes_model.load_state_dict(s)
aes_model.to("cuda")
aes_model.eval()


def aesthetic_score(image_tensor): # image tensors as below
    image_feats = clip_model.encode_image(image_tensor)
    image_emb_arr = image_feats / image_feats.norm(dim=1, keepdim=True)
    return aes_model(image_emb_arr)


In [None]:
prompt = "A cute sloth holding a small treasure chest. A bright golden glow is coming from the chest."
n_show = 3
aes_weight = 1

start_time = time.time()
images = pipeline(**get_inputs(batch_size=16)).images

In [None]:
feed_images = torch.cat([preprocess(im).unsqueeze(0).to(device) for im in images])
text = clip.tokenize([prompt]).to(device)
logits_per_image, logits_per_text = clip_model(feed_images, text)
with torch.cuda.amp.autocast():
    aes_scores = aesthetic_score(feed_images)   
x = logits_per_image.flatten().cpu().numpy()
y = aes_scores.flatten().cpu().numpy()

plt.figure()
plt.scatter(x, y)
for i, (x_pos, y_pos) in enumerate(zip(x,y)):
    plt.text( x_pos, y_pos, str(i))
plt.xlabel('CLIP Score')
plt.ylabel('Aes. Score')
plt.show()
plt.close()

idx_rank = np.argsort(x+y*aes_weight)[::-1]
print(idx_rank)

best_idx = idx_rank[:n_show]
worst_idx = idx_rank[-n_show:]

for idx_set in [best_idx, worst_idx]:
    for idx in idx_set:
        print(idx, x[idx], y[idx])
    display(image_grid([images[i] for i in idx_set], rows=1, cols=n_show))
    
print("Total time:", time.time() - start_time)