In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle as pkl
import torch
from pathlib import Path
from typing import List
from PIL import Image, ImageDraw, ImageFont
import os, cv2
import json

In [None]:
class Draw:
    maxalpha = 255
    def __init__(self, width=22, height=22, **kw):
        self.width, self.height = width, height
        
    def draw(self, canvas, box_x=0, box_y=0, opacity=1.0): 
        localcanvas = canvas.crop((box_x, box_y, box_x + self.get_width(), box_y + self.get_height()))
        self._draw(localcanvas, opacity=opacity)
        canvas.paste(localcanvas, (box_x, box_y))
        
    def _draw(self, canvas, opacity=1.0): pass
    def drawself(self):
        canvas = Image.new('RGBA', (self.get_width(), self.get_height()), (255, 255, 255, 0))
        self.draw(canvas, 0, 0, 1)
        return canvas
    
    def get_width(self): return self.width
    def get_height(self): return self.height
    
        
class OverlayDraw(Draw):
    def __init__(self, base:Draw, overlay:Draw, *args, overlay_opacity=1.0, **kw):
        self.base, self.overlay = base, overlay
        assert self.base.get_width() == self.overlay.get_width()
        assert self.base.get_height() == self.overlay.get_height()
        self.overlay_opacity = overlay_opacity
        
    def get_width(self): return self.base.get_width()
    def get_height(self): return self.base.get_height()
    
    def _draw(self, canvas, opacity=1.0):
        self.base._draw(canvas, opacity=opacity)
        self.overlay._draw(canvas, opacity=opacity * self.overlay_opacity)

        
class DrawImage(Draw):
    def __init__(self, image, imgsize=256, **kw):
        super().__init__(**kw)
        self.image = image.convert("RGBA")
        self.imgsize = (imgsize, imgsize) if isinstance(imgsize, int) else imgsize
        self.width, self.height = self.imgsize
        
    def _draw(self, canvas, opacity=1.0):
        img = self.image.resize(self.imgsize)
        img.putalpha(int(opacity * self.maxalpha))
        canvas.alpha_composite(img)
    
    
class DrawText(Draw):
    def __init__(self, text, fontsize=28, width=3*256+2*3, height=100, offset=(0,0), bold=False, italic=False, condensed=False, color=(0, 0, 0), **kw):
        super().__init__(**kw)
        self.text, self.fontsize, self.width, self.height = text, fontsize, width, height
        
        font_name = "DejaVuSans" if not condensed else "DejaVuSansCondensed"
        font_name_append = ""
        if bold:
            font_name_append += "Bold"
        if italic: 
            font_name_append += "Oblique"
        
        font_path = os.path.join(cv2.__path__[0],'qt','fonts', 
                font_name + ("-" + font_name_append if font_name_append != "" else "") + '.ttf')
        print(font_path)
        
        self.font = font = ImageFont.truetype(font_path, self.fontsize)
        self.textcolor = color
        self.offset = offset
        
    def _draw(self, canvas, opacity=1.0):
        txtimg = Image.new("RGBA", (self.get_width(), self.get_height()), (255,255,255,0))
        draw = ImageDraw.Draw(txtimg)
        x = self.width // 2 + self.offset[0]
        y = self.height // 2 + self.offset[1]
        draw.text((x, y), self.text, fill=self.textcolor + (int(opacity * self.maxalpha),), font=self.font, anchor="mm", align="center")
        canvas.alpha_composite(txtimg)
        
        
class DrawRow(Draw):
    def __init__(self, *items, margin=3, **kw):
        super().__init__(**kw)
        self.kw = kw
        self.items = items
        self.margin = margin
        
    def get_width(self):
        ret = sum([item.get_width() for item in self.items])
        ret += self.margin * (len(self.items) - 1)
        return ret
        
    def get_height(self):
        return max([item.get_height() for item in self.items])
        
    def __add__(self, other):
        ret = type(self)(*(self.items + other.items), margin=self.margin, **self.kw)
        return ret
    
    def draw(self, canvas, x=0, y=0, opacity=1.0):
        for item in self.items:
            item.draw(canvas, x, y, opacity=opacity)
            x += item.get_width() + self.margin
            
            
class DrawCol(DrawRow):
    def get_width(self):
        return max([item.get_width() for item in self.items])
        
    def get_height(self):
        ret = sum([item.get_height() for item in self.items])
        ret += self.margin * (len(self.items) - 1)
        return ret
    
    def draw(self, canvas, x=0, y=0, opacity=1.0):
        for item in self.items:
            item.draw(canvas, x, y, opacity=opacity)
            y += item.get_height() + self.margin

In [None]:
def get_image_draws(exppath, label=None, which=(0, (0,)), imgsize=256, use_seg=False):
    with open(Path(exppath) / "outbatches.pkl", "rb") as f:
        outbatches = pkl.load(f)

    imagerow = []
    for whichspec in which:
        batchid, whichseeds = whichspec
        outbatch = outbatches[batchid]
        for whichseed in whichseeds:
            example = outbatch[whichseed]
            imgdata = example.image_data if not use_seg else example.seg_data.convert("L")
            imagerow.append(DrawImage(imgdata, imgsize=imgsize))
        imagerow.append(Draw(20, 0))
    del imagerow[-1]
    return DrawRow(*imagerow)    


## About

In this script, we're creating an image tables for qualitative comparisons.

### Main comparison

In [None]:
selectedseeds = (2, 1)
selectedseeds2 = (3, 4,)
whichbatch3 = 12
selectedseeds3 = (2,)
whichbatch4 = 3
selectedseeds4 = (1, 4)
textwidth = 128
imagespec = DrawCol(
    DrawRow(DrawText("Prompt:", 
                     fontsize=24,
                     height=128,
                     width=textwidth), 
            DrawText('"a photo of a {blue crystal ball}, \n {a red tennis ball} and {a gold coin} \n on a {wooden table}"', 
                     fontsize=22, height=128, italic=True, condensed=True,
                     width=len(selectedseeds)*256+(len(selectedseeds) - 1) * 3), 
            Draw(20, 0),
            DrawText('"a digital painting of \n {a rabbit mage} casting \n {a fire ball} standing on \n {clouds}."', 
                     fontsize=22, height=128, italic=True, condensed=True,
                     width=len(selectedseeds2)*256+(len(selectedseeds2) - 1) * 3), 
            Draw(20, 0),
            DrawText('"a photo of {a doll house} \n standing in {water} \n in a {lush green forest} \n and {a fire ball}."', 
                     fontsize=22, height=128, italic=True, condensed=True,
                     width=len(selectedseeds3)*256+(len(selectedseeds3) - 1) * 3),
            Draw(20, 0),
            DrawText('"a photo of {an apricot}, {a pumpkin} \n and {an orange} on {a wooden table}"', 
                     fontsize=22, height=128, italic=True, condensed=True,
                     width=len(selectedseeds4)*256+(len(selectedseeds4) - 1) * 3),),
    
    
    DrawRow(DrawText("Layout:", 
                     fontsize=24,
                     width=textwidth, height=256),) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_global_v5_exp_1/generated_extradev.pkl_1", 
               "ControlNet*", 
               which=((0, selectedseeds), (2, selectedseeds2)),
               use_seg=True
    )
    + DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_global_v5_exp_1/generated_openair1.pkl_1",
                   "", use_seg=True,
                   which=((whichbatch3, selectedseeds3),))
    + DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_global_v5_exp_1/generated_threeorange1.pkl_1",
                   "", use_seg=True,
                   which=((whichbatch4, selectedseeds4),)),
    Draw(0, 15),
    
    
    
    DrawRow(DrawText("CtrlNet*", 
                     fontsize=24,
                     width=textwidth, height=256),) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_global_v5_exp_1/generated_extradev.pkl_1", 
               "ControlNet*", 
               which=((0, selectedseeds), (2, selectedseeds2))
    )+ DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_global_v5_exp_1/generated_openair1.pkl_1",
                   "",
                   which=((whichbatch3, selectedseeds3),))
    + DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_global_v5_exp_1/generated_threeorange1.pkl_1",
                   "",
                   which=((whichbatch4, selectedseeds4),)),
    #Draw(0, 15),
    
    
    
    DrawRow(DrawText("CAC", 
                     fontsize=24,
                     width=textwidth, height=256),) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_cac_v5_exp_1/generated_extradev.pkl_1", 
               "CAC", 
               which=((0, selectedseeds), (2, selectedseeds2))
    )+ DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_cac_v5_exp_1/generated_openair1.pkl_1",
                   "",
                   which=((whichbatch3, selectedseeds3),))
    + DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_cac_v5_exp_1/generated_threeorange1.pkl_1",
                   "",
                   which=((whichbatch4, selectedseeds4),)),
    
    
    
    #Draw(0, 15),
    DrawRow(DrawText("DD\nW'=0.5", 
                     fontsize=24,
                     width=textwidth, height=256),) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_dd_v5_exp_1/generated_extradev.pkl_1", 
               "DD:0.5", 
               which=((0, selectedseeds), (2, selectedseeds2))
    )+ DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_dd_v5_exp_1/generated_openair1.pkl_1",
                   "",
                   which=((whichbatch3, selectedseeds3),))
    + DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_dd_v5_exp_1/generated_threeorange1.pkl_1",
                   "",
                   which=((whichbatch4, selectedseeds4),)),
    
    
    
    #Draw(0, 15),
    DrawRow(DrawText("eDiff-I\nW'=0.5", 
                     fontsize=24, condensed=True,
                     width=textwidth, height=256),) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_legacy-NewEdiffipp_v5_exp_1/generated_extradev.pkl_1", 
               "Ediffi:0.5", 
               which=((0, selectedseeds), (2, selectedseeds2))
    )+ DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_legacy-NewEdiffipp_v5_exp_1/generated_openair1.pkl_1",
                   "",
                   which=((whichbatch3, selectedseeds3),))
    + DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_legacy-NewEdiffipp_v5_exp_1/generated_threeorange1.pkl_1",
                   "",
                   which=((whichbatch4, selectedseeds4),)),
    Draw(0, 15),
    
# #     Draw(0, 0,
    
#     DrawRow(DrawText("CA-DnB\nW'=3.0", 
#                      fontsize=24, condensed=True,
#                      width=textwidth, height=256),) +
#     get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn2_v5_exp_2/generated_extradev.pkl_1", 
#                "", 
#                which=((0, selectedseeds), (2, selectedseeds2))
#     )+ DrawRow(Draw(20, 0)) +
#     get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn2_v5_exp_2/generated_openair1.pkl_1",
#                    "",
#                    which=((whichbatch3, selectedseeds3),))+
#     get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn2_v5_exp_2/generated_threeorange1.pkl_1",
#                    "",
#                    which=((whichbatch4, selectedseeds4),)),
# #     Draw(0, 15)
# #     )
    
    
    
    DrawRow(DrawText("CA-Redist\nWₐ=0\nWₘ=6.0", 
                     fontsize=24, condensed=True,
                     width=textwidth, height=256),) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn5a_v5_exp_2/generated_extradev.pkl_1", 
               "", 
               which=((0, selectedseeds), (2, selectedseeds2))
    )+ DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn5a_v5_exp_2/generated_openair1.pkl_2",
                   "",
                   which=((whichbatch3, selectedseeds3),))
    + DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn5a_v5_exp_2/generated_threeorange1.pkl_1",
                   "",
                   which=((whichbatch4, selectedseeds4),)),
    
    
    DrawRow(DrawText("CA-Redist\nWₐ=0.25\nWₘ=0", 
                     fontsize=24, condensed=True,
                     width=textwidth, height=256),) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn5a_v5_exp_4/generated_extradev.pkl_1", 
               "", 
               which=((0, selectedseeds), (2, selectedseeds2))
    )+ DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn5a_v5_exp_4/generated_openair1.pkl_2",
                   "",
                   which=((whichbatch3, selectedseeds3),))
    + DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn5a_v5_exp_4/generated_threeorange1.pkl_1",
                   "",
                   which=((whichbatch4, selectedseeds4),)),
    
    
    DrawRow(DrawText("CA-Redist\nWₐ=0.25\nWₘ=4.0", 
                     fontsize=24, condensed=True,
                     width=textwidth, height=256),) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn5a_v5_exp_5/generated_extradev.pkl_1", 
               "", 
               which=((0, selectedseeds), (2, selectedseeds2))
    )+ DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn5a_v5_exp_5/generated_openair1.pkl_2",
                   "",
                   which=((whichbatch3, selectedseeds3),))
    + DrawRow(Draw(20, 0)) +
    get_image_draws("/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_posattn5a_v5_exp_5/generated_threeorange1.pkl_1",
                   "",
                   which=((whichbatch4, selectedseeds4),)),
)

In [None]:
print(f"Grid width: {imagespec.get_width()}, grid height: {imagespec.get_height()}")
display(imagespec.drawself())
    

In [None]:
print(f"Grid width: {imagespec.get_width()}, grid height: {imagespec.get_height()}")
display(imagespec.drawself())
    

In [None]:
# table about degradation of Ediffi

In [None]:
def get_image_draw(exppath, whichbatch=0, whichseed=0, imgsize=256):
    with open(Path(exppath) / "outbatches.pkl", "rb") as f:
        outbatches = pkl.load(f)
        
    with open(Path(exppath) / "args.json") as f:
        args = json.load(f)

    imagerow = []
    outbatch = outbatches[whichbatch]
    example = outbatch[whichseed]
    imgdata = example.image_data
    ret = DrawImage(imgdata, imgsize=imgsize)
    return ret, args


In [None]:
whichbatch = 3
whichseed = 2

imgs = []
labeltexts = []
for i in (3, 1, 2, 4, 5, 6):
    img, args = get_image_draw(f"/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_legacy-NewEdiffipp_v5_exp_{i}/generated_threeorange1.pkl_1", whichbatch, whichseed)
    paramtext = DrawText(f"W'={args['strength']}", fontsize=26, width=img.get_width(), height=30, 
                         color=(0,0,0), bold=False, italic=False)
#     img = OverlayDraw(img, paramtext, overlay_opacity=0.8)
    imgs.append(img)
    labeltexts.append(paramtext)
    
todraw_ediffi = DrawRow(*imgs)
todraw_ediffi_labels = DrawRow(*labeltexts)
todraw_ediffi = DrawCol(todraw_ediffi_labels, todraw_ediffi)


In [None]:
print(todraw_ediffi.get_width(), todraw_ediffi.get_height())
img = todraw_ediffi.drawself()
print(img.size)
display(img)

In [None]:
# table about degradation of DenseDiffusion

In [None]:
whichbatch = 3
whichseed = 2

imgs = []
labeltexts = []
for i in (1,3,2):
    img, args = get_image_draw(f"/USERSPACE/lukovdg1/controlnet11/checkpoints/v5/checkpoints_coco_dd_v5_exp_{i}/generated_threeorange1.pkl_1", whichbatch, whichseed)
    paramtext = DrawText(f"W'={args['strength']}", fontsize=26, width=img.get_width(), height=30, 
                         color=(0,0,0), bold=False, italic=False)
#     img = OverlayDraw(img, paramtext, overlay_opacity=0.7)
    imgs.append(img)
    labeltexts.append(paramtext)
    
todraw_dd = DrawRow(*imgs)
todraw_dd_labels = DrawRow(*labeltexts)
todraw_dd = DrawCol(todraw_dd_labels, todraw_dd)
print(len(imgs))


In [None]:
display(todraw_dd.drawself())

In [None]:
todraw = DrawRow(todraw_ediffi, DrawRow(Draw(20, 0)), todraw_dd)
# display(todraw.drawself())

todraw_text = DrawRow(DrawText("eDiff-I", width=todraw_ediffi.get_width(), height=50, fontsize=36),
                      DrawRow(Draw(20, 0)),
                      DrawText("DenseDiffusion", width=todraw_dd.get_width(), height=50, fontsize=36),
                     )
todraw = DrawCol(todraw_text, todraw)
display(todraw.drawself())
