In [1]:
import os
import random
import glob
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
from mpl_toolkits.axes_grid1 import ImageGrid

In [2]:
random.seed(42)

In [3]:
def plot_img_grid(img_list,x_titles=[],y_titles=[],suptitle=None,supylabel=None,save_file=None):
    fig = plt.figure(figsize=(10, 10),dpi=150)

    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                    nrows_ncols=(5, 5),  # creates 2x2 grid of axes
                    axes_pad=0.02,  # pad between axes in inch.
                    share_all=True,
                    aspect=False)
    pad=5
    switch = 0
    for i, (ax,im) in enumerate(zip(grid, img_list)):
        # Iterating over the grid returns the Axes.
        ax.imshow(im)

        ax.set_yticklabels([])
        ax.set_yticks([])
        ax.set_xticks([])
        
        if x_titles:
            if i < 5:
                ax.set_title(x_titles[i])

        if y_titles:
            if i>4 and i % 5 == 0:
                switch += 1       
            ax.set_ylabel(y_titles[switch],fontsize=14)

    if suptitle:
        fig.suptitle(suptitle, fontsize=14,y=0.91)

    if supylabel:
        fig.supylabel(supylabel,fontsize=14,x=0.1)

    if save_file:
        plt.savefig(save_file,bbox_inches='tight')
        plt.close()
    else:
        plt.show()

### Originalbilder

In [4]:
y_titles = ['dreambooth','gan','unconditional','finetuning','lora']
x_titles = ['Healthy','Viral','Bacterial','COVID-19','Fungal']

In [5]:
src_folder = '/data/DS/Projekte/covid19/data/segmentation_test/train_per_class'
c_imgs = random.sample(glob.glob(f'{src_folder}/C/*'),5)
h_imgs = random.sample(glob.glob(f'{src_folder}/NB/*'),5)
v_imgs = random.sample(glob.glob(f'{src_folder}/V/*'),5)
f_imgs = random.sample(glob.glob(f'{src_folder}/P/*'),5)
b_imgs = random.sample(glob.glob(f'{src_folder}/B/*'),5)

file_list = [item for row in list(zip(h_imgs,v_imgs,b_imgs,c_imgs,f_imgs)) for item in row]
img_list = [Image.open(x).resize((256,256)) for x in file_list]

In [6]:
plot_img_grid(img_list,x_titles,save_file='export/image_collections/original.pdf')

### Alle Methoden, alle Klassen

In [7]:
src_folder = '/data/DS/Projekte/covid19/notebooks/Synth_Paper/images/assessment'
y_titles = ['DreamBooth','Fine-tuning','GAN','LoRA','Unconditional']
labels = ['H','V','B','C','F']
x_titles = ['Healthy','Viral','Bacterial','COVID-19','Fungal']
methods = ['sd_dreambooth','sd_finetuning','gan','sd_lora','unconditional']

In [8]:
file_list = []
for method in methods:
    for label in labels:
        rnd_img = random.choice(glob.glob(f'{src_folder}/{method}/{label}/*'))
        file_list.append(rnd_img)
img_list = [Image.open(x).resize((256,256)) for x in file_list]

In [9]:
plot_img_grid(img_list,x_titles,y_titles, save_file='export/image_collections/all_classes_all_methods.pdf')

### Eine Klasse, alle Methoden

In [12]:
src_folder = '/data/DS/Projekte/covid19/notebooks/Synth_Paper/images/assessment'
y_titles = ['DreamBooth','Fine-tuning','GAN','LoRA','Unconditional']
x_titles = ['H','V','B','C','F']
#x_titles = ['Healthy','Viral','Bacterial','COVID-19','Fungal']
methods = ['sd_dreambooth','sd_finetuning','gan','sd_lora','unconditional']
suptitles = ['Healthy/No Pneumonia','Viral Pneumonia','Bacterial Pneumonia','COVID-19 Pneumonia','Fungal Pneumonia']

In [13]:
for suptitle,label in zip(suptitles,x_titles):
    file_list = []
    for method in methods:
        rnd_imgs = random.sample(glob.glob(f'{src_folder}/{method}/{label}/*'),5)
        file_list.extend(rnd_imgs)

    img_list = [Image.open(x).resize((256,256)) for x in file_list]
    plot_img_grid(img_list,y_titles=y_titles,suptitle=suptitle,save_file=f'export/image_collections/single_class_all_methods_{label}.pdf')

### Eine Methode, alle Klassen

In [14]:
src_folder = '/data/DS/Projekte/covid19/notebooks/Synth_Paper/images/assessment'
y_titles = ['DreamBooth','Fine-tuning','GAN','LoRA','Unconditional']
labels = ['H','V','B','C','F']
x_titles = ['Healthy','Viral','Bacterial','COVID-19','Fungal']
methods = ['sd_dreambooth','sd_finetuning','gan','sd_lora','unconditional']
#suptitles = ['Healthy/No Pneumonia','Viral Pneumonia','Bacterial Pneumonia','COVID-19 Pneumonia','Fungal Pneumonia']

In [15]:
for method,method_name in zip(methods,y_titles):
    file_list = []
    for label in labels:
            rnd_imgs = random.sample(glob.glob(f'{src_folder}/{method}/{label}/*'),5)
            file_list.extend(rnd_imgs)

    sorted_list = file_list[0::5] + file_list[1::5] + file_list[2::5] + file_list[3::5] + file_list[4::5]  
    img_list = [Image.open(x).resize((256,256)) for x in sorted_list]
    plot_img_grid(img_list,x_titles=x_titles,supylabel=method_name,save_file=f'export/image_collections/single_method_all_classes_{method}.pdf')

### FID progression for lora

4 Bilder pro Checkpoint in 4 Checkpoints?

In [None]:
iterations = ['500','2000','7500','15000']
fids = [236.28,166.55,208.46,190.50]

In [None]:
file_list = ['images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-500/image-1.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-2000/image-0.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-7500/image-0.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-15000/image-0.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-500/image-8.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-2000/image-4.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-7500/image-4.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-15000/image-1.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-500/image-11.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-2000/image-7.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-7500/image-7.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-15000/image-2.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-500/image-44.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-2000/image-1.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-7500/image-10.png',
             'images/sd_lora/1e-5/sd_lora_scale1/F/checkpoint-15000/image-3.png']

img_list = [Image.open(x) for x in file_list]

In [None]:
fig = plt.figure(figsize=(10, 10),dpi=150)

grid = ImageGrid(fig, 111,  # similar to subplot(111)
                nrows_ncols=(4, 4),  # creates 2x2 grid of axes
                axes_pad=0.02,  # pad between axes in inch.
                share_all=True,
                aspect=False)

for i, (ax,im) in enumerate(zip(grid, img_list)):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

    ax.set_yticklabels([])
    ax.set_yticks([])
    ax.set_xticks([])
    
    if i < 4:
        ax.set_title(f'iteration {iterations[i]}\nFID {fids[i]}')

plt.savefig('export/image_collections/fid_progression.pdf',bbox_inches='tight')
plt.close()