In [None]:
import os, sys, pandas
import os.path as osp
import torch 
from datetime import datetime
from diffusers import StableDiffusionPipeline
from tqdm import tqdm

In [None]:
category_name_map = {'Gary Marcus et al.': 'GaryMarcus'}
diff_models = {
    '14': 'CompVis/stable-diffusion-v1-4', 
    '15': 'runwayml/stable-diffusion-v1-5', 
    '21': 'stabilityai/stable-diffusion-2-1'
}
device = 'cuda:7'

In [None]:
benchmark = pandas.read_csv('imagen.csv')
num_prompts = len(benchmark)
print(f"The current benchmark includes {num_prompts} prompts. ")

In [None]:
now = datetime.now().strftime("%H%M%S")
print(f'Start to generate images at: {now}')
model_ids = ['14', '15', '21']
num_models = len(model_ids)
README_names, README_contents = dict(), dict()
for idx in model_ids:
    # Use custom pipeline to accept prompts with long length
    model_name = diff_models[idx]
    model_name_no_slash = model_name.replace('/', '-')
    pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16, custom_pipeline="lpw_stable_diffusion")
    pipe = pipe.to(device)
    for i in range(num_prompts):
        prompt = benchmark['Prompt'][i]
        category = benchmark['Category'][i]
        root_dir = osp.join(category, str(i + 1))
        os.makedirs(root_dir, exist_ok=True)
        if category in category_name_map:
            category = category_name_map[category]
        seed = 0
        torch.manual_seed(seed)
        image = pipe(prompt).images[0]
        image.save(f'{root_dir}/{model_name_no_slash}_seed_{seed}.jpg')
now = datetime.now().strftime("%H%M%S")
print(f'Finish to generate images at: {now}')

In [None]:
for i in range(num_prompts):
    prompt = benchmark['Prompt'][i]
    category = benchmark['Category'][i]
    root_dir = osp.join(category, str(i + 1))
    with open(osp.join(root_dir, 'prompt.txt'), 'w') as fout:
        fout.write(prompt)

In [None]:
fs = os.listdir()
fs = [x for x in fs if osp.isdir(x) and x != '.ipynb_checkpoints']

In [None]:
for f in fs:
    README_name = osp.join(f, 'README.md')
    content = [
        '# Results',
        f'Category "{f}" in the benchmark: '
        '\n',
        '<table class="center">'
    ]
    examples = os.listdir(f)
    examples = [x for x in examples if osp.isdir(osp.join(f, x)) and x != '.ipynb_checkpoints']
    examples.sort(key=int)
    for e in examples:
        dirname = osp.join(f, e)
        prompt = open(osp.join(dirname, 'prompt.txt')).readlines()[0].strip()
        images = [x for x in os.listdir(dirname) if '.jpg' in x]
        images.sort()
        num_images = len(images)
        ratio = int(100 / num_images)
        content.append(f'\t<tr><td style="text-align:center;", colspan="{num_images}"><b>{e}:{prompt}</b></td></tr>')
        content.append('\t<tr>')
        model_names = []
        for i, name in enumerate(images):
            img_rel_path = osp.join(e, name)
            content.append(f'\t\t<td><img src="{img_rel_path}"></td>')
        content.append('\t</tr><tr>')
        for i, name in enumerate(images):
            model_name = name.split('_seed')[0]
            content.append(f'\t\t<td width={ratio}% style="text-align:center;">{model_name}</td>')
        content.append('\t</tr>')
    content.append('</table>')
    with open(README_name, 'w') as fout:
        fout.write('\n'.join(content))
        
    