In [2]:
%%capture
!pip install widgetsnbextension
!pip install ipywidgets
!pip install voila
!jupyter nbextension enable --py widgetsnbextension --sys-prefix
!jupyter serverextension enable voila --sys-prefix
!pip install rudalle
!pip install scikit-learn==0.13

In [4]:
import warnings 
warnings.filterwarnings('ignore')

In [5]:
collections = ['Ceramics']

techniques = ['glazing', 'coating', 'incising', 'tin-glazed', 'transfer-printed', 'mounted', 'applied-work', 'silver staining', 'acid etching', 
'pot metal', 'cerre eglomise', 'painted and stained', 'painted', 'printing', 'dust-pressed']


artists = ['Rie Lucie', 'Johann Kandler', 'Hunt Martin', 'Clarke Casper', 'Stanley Lane', 'Guy Green', 'Keith Murray', 'Tapio Wirkkala']

materials = ['hard paste porcelain', 'earthenware', 'stoneware', 'metal', 'porcelain', 'enamel', 'bone china', 'glaze', 'glass', 'tenmoku', 
'stained glass', 'resin']

subjects = ['bowl', 'cup', 'Jesus Christ', 'dish', 'Figure', 'Saucer', 'Bottle', 'Ewer', 'Pot', 'Plate', 'Islamic', 'Vase', 
'Tile', 'Turkish', 'Architectural', 'Teapot', 'Teacup', 'Jug']

In [6]:
import ipywidgets as widgets
from IPython.display import display, clear_output, Image
import os

In [7]:
layout = widgets.Layout(width='400px', height='50px')

In [8]:
import numpy as np
checkpoint = widgets.Dropdown(
            options=[20,40,60,80,100,120,140,160,180,200], description = 'Model Checkpoint',
        )

In [9]:
confidence = widgets.ToggleButtons(
            options=['Ultra-Low', 'Low', 'Medium', 'High', 'Ultra-High'],description='Confidence',
            style= {'description_width': 'initial'})
variability = widgets.ToggleButtons(
            options=['Ultra-Low', 'Low', 'Medium', 'High', 'Ultra-High'], description = 'Variability',
            style= {'description_width': 'initial'})

In [10]:
img_amount = widgets.IntSlider(value=2,min=1,max=9,step=1,description='Image amount',style= {'description_width': 'initial'}, layout = layout)

In [11]:
filtered = widgets.IntSlider(value=2,min=1,max=9,step=1,description='Filtered Images',style= {'description_width': 'initial'}, layout = layout)

In [12]:
img_amount.style.handle_color = 'lightblue'
filtered.style.handle_color = 'lightblue'

In [13]:
display(img_amount)
display(filtered)

IntSlider(value=2, description='Image amount', layout=Layout(height='50px', width='400px'), max=9, min=1, styl…

IntSlider(value=2, description='Filtered Images', layout=Layout(height='50px', width='400px'), max=9, min=1, s…

In [14]:
collection = widgets.Dropdown(options = sorted(collections), description = 'Collection', value = None)


In [15]:
subject = widgets.Dropdown(options = sorted(subjects), description = 'Subject', value = None)
material = widgets.Dropdown(options = sorted(materials), description = 'Materials',  value = None)
technique = widgets.Dropdown(options = sorted(techniques), description = 'Techniques', value = None)
artist = widgets.Dropdown(options = sorted(artists), description = 'Artists', value = None)

In [16]:
display(checkpoint)


Dropdown(description='Model Checkpoint', options=(20, 40, 60, 80, 100, 120, 140, 160, 180, 200), value=20)

In [17]:
display(confidence)

ToggleButtons(description='Confidence', options=('Ultra-Low', 'Low', 'Medium', 'High', 'Ultra-High'), style=To…

In [18]:
display(variability)

ToggleButtons(description='Variability', options=('Ultra-Low', 'Low', 'Medium', 'High', 'Ultra-High'), style=T…

In [19]:
print("Below you can construct the inputs for the model. Feel free to leave some of them blank if you prefer!")

Below you can construct the inputs for the model. Feel free to leave some of them blank if you prefer!


In [27]:
display(collection)

Dropdown(description='Collection', options=('Ceramics',), value=None)

In [21]:

display(subject)
display(artist)
display(technique)
display(material)

Dropdown(description='Subject', options=('Architectural', 'Bottle', 'Ewer', 'Figure', 'Islamic', 'Jesus Christ…

Dropdown(description='Artists', options=('Clarke Casper', 'Guy Green', 'Hunt Martin', 'Johann Kandler', 'Keith…

Dropdown(description='Techniques', options=('acid etching', 'applied-work', 'cerre eglomise', 'coating', 'dust…

Dropdown(description='Materials', options=('bone china', 'earthenware', 'enamel', 'glass', 'glaze', 'hard past…

In [28]:
button_run = widgets.Button(description = 'Run model', tooltip='Send',
                style={'description_width': 'initial'}, layout = layout)

In [29]:
button_run.style.button_color = 'lightblue'

In [30]:
output = widgets.Output()

In [31]:
%%capture

from rudalle import get_rudalle_model, get_vae
import torch
from model.functions import generate, get_closest_training_images_by_clip

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)

In [32]:
%%capture
vae = get_vae().to("cuda")

In [33]:
def construct_prompt():
    prompt = ''
    if collection.value is not None:
        prompt += f'{collection.value}, '
    if artist.value is not None:
        prompt += f'by {artist.value}, '
    if subject.value is not None:
        prompt += f'{subject.value}, '
    if material.value is not None:
        prompt += f'made of {material.value}, '
    if technique.value is not None:
        prompt += f'with technique {technique.value} '
    
    print(prompt)
    prompt += '.jpg'
    return prompt

In [34]:
import os

def on_button_clicked(event):
    with output:
        prompt = construct_prompt()

        print(f'Running model with prompt: {prompt}')

        clear_output()
        model_path = os.path.join(f'../VA-design-generator/checkpoints/lookingglass_dalle_{checkpoint.value}00.pt')
        if not os.path.exists('output/'):
            os.mkdir('output/')
        filepath = f'output/{prompt}:{checkpoint.value}:{confidence.value}:{variability.value}'
        if not os.path.exists(filepath):
            os.mkdir(filepath)
        model.load_state_dict(torch.load(model_path))
        filenames = generate(vae, model, prompt, confidence = confidence.value, variability = variability.value, rurealesrgan_multiplier="x1", output_filepath=filepath, num_filtered = filtered.value, image_amount = img_amount.value)
        print(f'Images saved in {filepath}')

        for image in filenames:
            img = Image(image)
            display(img)
            

In [35]:
button_run.on_click(on_button_clicked)

In [36]:
display(button_run, output)

Button(description='Run model', layout=Layout(height='50px', width='400px'), style=ButtonStyle(button_color='l…

Output()

In [38]:
vbox_result = widgets.VBox([button_run, output])

In [39]:
button_run_2 = widgets.Button(description = 'Explore training images', tooltip='Send',
                style={'description_width': 'initial'}, layout = layout)

In [40]:
button_run_2.style.button_color = 'lightblue'

In [41]:
output = widgets.Output()

In [63]:
def on_button_clicked_2(event):
    from PIL import Image
    with output:
        prompt = construct_prompt()

        img_filename = get_closest_training_images_by_clip(artist.value, prompt, '../VA-design-generator/images-labelled/ceramics')
        img = Image.open(f'../VA-design-generator/images-labelled/ceramics/{img_filename}')
        display(img.resize((int(img.width*0.3), int(img.height*0.3))))

        

In [64]:
button_run_2.on_click(on_button_clicked_2)

In [67]:


def get_closest_training_images_by_clip(artist, prompt, directory):
    from transformers import CLIPProcessor, CLIPModel
    from PIL import Image
    from tqdm import tqdm
    model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
    processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
    filenames = os.listdir(directory)
    score = 0
    to_search = []
    for f in filenames:
        if artist in f:
            to_search.append(f)
    for i, f in enumerate(tqdm(to_search)):
        image = Image.open(f'{directory}/{f}')
        inputs = processor(text=[prompt], images = image, return_tensors = 'pt', padding=True)
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image 
        s = logits_per_image.item()
        if s > score:
            score = s
            index = i
    return filenames[index]

    

In [68]:
display(button_run_2, output)

Button(description='Explore training images', layout=Layout(height='50px', width='400px'), style=ButtonStyle(b…

Output(outputs=({'traceback': ['\x1b[0;31m--------------------------------------------------------------------…

In [45]:
vbox_result_2 = widgets.VBox([button_run_2, output])