In [None]:
import torch
import matplotlib.pyplot as plt
import os
from esd_diffusers import FineTunedModel, StableDiffuser
from IPython.utils import io
import matplotlib.image as mpimg
import pandas as pd
import numpy as np
import io
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

## Generates and saves images(1 per prompt)

In [None]:
def generate_images(model_path, prompts_path, save_path, tune_method):
    state_dict = torch.load(model_path)
    diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
    finetuner = FineTunedModel(diffuser, train_method=tune_method)

    finetuner.load_state_dict(state_dict)
    df = pd.read_csv(prompts_path)

    #generation loop
    all_images = []
    for _, row in df.iterrows():
        prompt = row.Prompt
        seed = torch.manual_seed(row.evaluation_seed)
        case_number = row.case_number
        with finetuner:
            images = diffuser(
                                prompt,
                                n_steps=10,
                                generator=seed
                        )
        all_images.extend(images)
    
    #saving images
    all_images_np = np.array(all_images)
    all_images_np = all_images_np.squeeze()

    os.makedirs(save_path, exist_ok=True)
    for i, image_np in enumerate(all_images_np):
        image = Image.fromarray(image_np)  # Convert to uint8 before saving
        image.save(os.path.join(save_path, f'{case_number}_{i}.png'))

## Generating images for Van Gogh erased model

In [None]:
prompts_path = '/home/lu.kev/Kevins Dataset.csv'
generate_images(model_path='/home/lu.kev/models/van_gogh_xattn_200.pt', prompts_path=prompts_path, save_path='ESD_images/Van_Gogh/', tune_method="xattn")

In [None]:
def show_image_grid(image_dir, num_rows=20, num_cols=10, fig_size = (20,40)):
    # Get list of image file names
    image_files = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')]


    # Create a 20x10 grid to display the images
    fig, axes = plt.subplots(num_rows, num_cols, figsize=fig_size)

    # Plot each image in the grid
    for i, ax in enumerate(axes.flatten()):
        img = mpimg.imread(image_files[i])
        ax.imshow(img)
        ax.axis('off')  # Turn off axis labels

    plt.tight_layout()
    plt.show()

## Generating image grids from saved images on Discovery cloud

In [None]:
show_image_grid('ESD_images/Dog/', num_rows=20, num_cols=10, fig_size=(20,40))

In [None]:
show_image_grid('ESD_images/Van_Gogh/', num_rows=20, num_cols=10, fig_size=(20,40))