In [1]:
%%capture
!pip install widgetsnbextension
!pip install ipywidgets
!pip install voila
!jupyter nbextension enable --py widgetsnbextension --sys-prefix
!jupyter serverextension enable voila --sys-prefix
!pip install rudalle
!pip install scikit-learn==0.13

Below are some of the Victoria & Albert Collections. Each column is specific to the collection, and you can explore images of V&A objects with the explore button below. 

To make the AI Art, you can mix and match across collections, styles, artists and subjects to come up with your own creation! 

Feel free to leave some fields blank if you like, but we recommend selecting at least 3!

In [2]:
import warnings 
warnings.filterwarnings('ignore')

In [3]:
time_period = ['Any', '16th century', '17th century', '18th century', '19th century', '20th century']


In [4]:
ceramic_subjects = ['Dish', 'Bowl', 'Cup', 'Jug', 'Teapot', 'Figure', 'Tile', 'Vase']

ceramic_materials = ['Paint', 'Earthenware', 'Porcelain','Enamels', 'Steel', 'Glaze', 'Stoneware', 'Gold metal', 'Glass', 'Fritware', 'Marble']

ceramic_styles = ['Ottoman', 'Armorial', 'Kakiemon', 'Chinese export', 'Modernist', 'Iznik', 'Baroque', 'Arts and Crafts']

ceramic_artists = ['Delft', 'Rie Lucie', 'Johann Kandler', 'Hunt Martin', 'IZNIK', 'Wedgwood', 'Keith Murray', 'Susie Cooper', 'William de Morgan']

In [5]:

fashion_subjects = ['Textiles', 'Belt', 'Bow', 'Costume', 'Suit', 'Coat', 'Hat', 'Man', 'Jacket','Flannel', 'Waistcoat',
 'Trousers', 'Designed', 'Furnishing', 'Evening', 'Shirt','Ensemble'
'Necklace', 'Cocktail', 'Dress', 'Belt', 'Woman']

fashion_styles = ['Arts and Crafts', 'High Fashion', 'East Asian', 'Theatrical', 'Modernist']

fashion_materials = ['Cotton','Organza', 'Beads', 'Crystals', 'Diamente', 'Silk', 'Glass-Beads', 'Pearls', 'Sequins', 'Thread', 'Rhinestone', 'Taffeta', 'Wool', 'Leather',
'Straw', 'Suede', 'Linen thread', 'Rayon', 'Metal', 'Gazar', 'Plastic', 'Chiffon', 'Synthetic Fibre', 'Snakeskin']

fashion_artists = ['Barbara Brown', 'William Morris','Balenciaga', 'Versace', 'Vivienne Westwood',  'Yves Saint Laurent', 'Yoruba Women' ]

In [6]:

furniture_subjects = ['Table', 'Chair', 'Metalwork', 'Furniture', 'Candlestick']

furniture_styles = ['Rococo', 'Neoclassical', 'Baroque', 'Art Nouveau', 'Art Deco']

furniture_materials = ['Wood', 'Steel', 'Mahogany', 'Metal', 'Glass', 'Silver']

furniture_artists = ['Paul Storr', 'Ashbee Robert', 'Garrard Robert', 'Hester Bateman', 'Joseph Wilmore']

In [7]:
import ipywidgets as widgets
from IPython.display import display, clear_output, Image
import os

In [8]:
layout = widgets.Layout(width='400px', height='50px')

In [9]:
import numpy as np
checkpoint = widgets.Dropdown(
            options=[200,500,900], description = 'Model Checkpoint',
        )

In [11]:
img_amount = widgets.IntSlider(value=2,min=1,max=9,step=1,description='Image amount',style= {'description_width': 'initial'}, layout = layout)

In [12]:
filtered = widgets.IntSlider(value=2,min=1,max=9,step=1,description='Filtered Images',style= {'description_width': 'initial'}, layout = layout)

In [13]:
img_amount.style.handle_color = 'lightblue'
filtered.style.handle_color = 'lightblue'

In [14]:
display(img_amount)
display(filtered)

IntSlider(value=2, description='Image amount', layout=Layout(height='50px', width='400px'), max=9, min=1, styl…

IntSlider(value=2, description='Filtered Images', layout=Layout(height='50px', width='400px'), max=9, min=1, s…

In [15]:
century = widgets.Dropdown(options = time_period, description = 'Select a Time Period', value = 'Any', style= {'description_width': 'initial'})

In [16]:
display(century)

Dropdown(description='Select a Time Period', options=('Any', '16th century', '17th century', '18th century', '…

In [17]:
collection = widgets.Dropdown(options = ['Ceramics', 'Fashion', 'Furniture & Metalwork'], description = 'Select a Collection', style= {'description_width': 'initial'})

In [18]:
from IPython.display import display
from ipywidgets import widgets


In [19]:
display(collection)

Dropdown(description='Select a Collection', options=('Ceramics', 'Fashion & Textiles', 'Furniture & Metalwork'…

In [20]:
ceramic_artist = widgets.Dropdown(options = ceramic_artists, description = 'Artist / Maker',style= {'description_width': 'initial'},)
fashion_artist = widgets.Dropdown(options = fashion_artists)
furniture_artist = widgets.Dropdown(options = furniture_artists)

In [21]:
display(widgets.HBox((ceramic_artist, fashion_artist, furniture_artist)))


HBox(children=(Dropdown(description='Artist / Maker', options=('Any', 'Delft', 'Rie Lucie', 'Johann Kandler', …

In [22]:
ceramic_subject = widgets.SelectMultiple(options = ceramic_subjects, description = 'Subject',style= {'description_width': 'initial'},)
fashion_subject = widgets.SelectMultiple(options = fashion_subjects)
furniture_subject = widgets.SelectMultiple(options = furniture_subjects)

In [23]:
display(widgets.HBox((ceramic_subject, fashion_subject, furniture_subject)))


HBox(children=(Dropdown(description='Subject', options=('Any', 'Dish', 'Bowl', 'Pot', 'Cup', 'Plate', 'Jug', '…

In [24]:
ceramic_material = widgets.SelectMultiple(options = ceramic_materials, description = 'Material',style= {'description_width': 'initial'},)
fashion_material = widgets.SelectMultiple(options = fashion_materials)
furniture_material = widgets.SelectMultiple(options = furniture_materials)

In [25]:
display(widgets.HBox((ceramic_material, fashion_material, furniture_material)))


HBox(children=(SelectMultiple(description='Material', options=('Any', 'Paint', 'China', 'Earthenware', 'Porcel…

In [26]:
ceramic_style = widgets.SelectMultiple(options = ceramic_styles, description = 'Style',style= {'description_width': 'initial'},)
fashion_style = widgets.SelectMultiple(options = fashion_styles)
furniture_style = widgets.SelectMultiple(options = furniture_styles)

In [27]:
display(widgets.HBox((ceramic_style, fashion_style, furniture_style)))


HBox(children=(SelectMultiple(description='Style', options=('Any', 'Ottoman', 'Armorial', 'Kakiemon', 'Chinese…

In [29]:
display(checkpoint)


Dropdown(description='Model Checkpoint', options=(100, 200, 300, 400, 500, 600, 700, 800, 900), value=100)

In [30]:
display(confidence)

ToggleButtons(description='Confidence', options=('Ultra-Low', 'Low', 'Medium', 'High', 'Ultra-High'), style=To…

In [31]:
display(variability)

ToggleButtons(description='Variability', options=('Ultra-Low', 'Low', 'Medium', 'High', 'Ultra-High'), style=T…

In [32]:
button_run = widgets.Button(description = 'Run model', tooltip='Send',
                style={'description_width': 'initial'}, layout = layout)

In [33]:
button_run.style.button_color = 'lightblue'

In [34]:
output = widgets.Output()

In [35]:
%%capture

from rudalle import get_rudalle_model, get_vae
import torch
from model.functions import generate, get_closest_training_images_by_clip

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)

In [36]:
%%capture
vae = get_vae().to("cuda")

In [38]:
def construct_prompt():
    prompt = ''
    if collection.value is not None:
        prompt += f'{collection.value}, '
    if century.value is not None:
        prompt += f'{century.value}, '
    styles = ceramic_style.value + fashion_style.value + furniture_style.value
    for st in styles:
        prompt += f'of style {st}, '
    prompt += 'by '

    if ceramic_artist.value != 'Any': 
        prompt += f'{ceramic_artist.value}, '
    if fashion_artist.value != 'Any': 
        prompt += f'{fashion_artist.value}, '
    if furniture_artist.value != 'Any': 
        prompt += f'{furniture_artist.value}, '

    if ceramic_subject.value != 'Any': 
        prompt += f'{ceramic_subject.value}, '
    if fashion_subject.value != 'Any': 
        prompt += f'{fashion_subject.value}, '
    if furniture_subject.value != 'Any': 
        prompt += f'{furniture_subject.value}, '

    materials = ceramic_material.value + fashion_material.value + furniture_material.value
    for m in materials:
        prompt += f'{m}, '

    print(prompt)
    prompt += '.jpg'
    return prompt

Ceramics, Any, by 


In [40]:
filenames = None

In [67]:
button_prompt = widgets.Button(description = 'Construct prompt', tooltip='Send',
                style={'description_width': 'initial'}, layout = layout)

button_prompt.style.button_color = 'lightgreen'

In [76]:
def on_button_prompt_clicked(event):
    with output:
        prompt = construct_prompt() 
        print(f'Constructed prompt: {prompt}')
        return prompt

In [77]:
prompt = button_prompt.on_click(on_button_prompt_clicked)

In [81]:
display(button_prompt, output)

Button(description='Construct prompt', layout=Layout(height='50px', width='400px'), style=ButtonStyle(button_c…

Output(outputs=({'name': 'stdout', 'text': 'Ceramics, Any, by \nConstructed prompt: Ceramics, Any, by .jpg\nCe…

In [41]:
import os

def on_button_clicked(event):
    with output:
        clear_output()
        model_path = os.path.join(f'../VA-design-generator/checkpoints/lookingglass_dalle_{checkpoint.value}00.pt')
        if not os.path.exists('output/'):
            os.mkdir('output/')
        filepath = f'output/{prompt}:{checkpoint.value}'
        if not os.path.exists(filepath):
            os.mkdir(filepath)
        model.load_state_dict(torch.load(model_path))
        filenames = generate(vae, model, prompt, confidence = 'Low', variability = 'Ultra-High', rurealesrgan_multiplier="x1", output_filepath=filepath, num_filtered = filtered.value, image_amount = img_amount.value)
        print(f'Images saved in {filepath}')
        for image in filenames:
            img = Image(image)
            display(img)
        return filenames


In [42]:
filenames = button_run.on_click(on_button_clicked)

In [43]:
if filenames:
    for image in filenames:
        img = Image(image)
        display(img)
        

In [44]:
display(button_run, output)

Button(description='Run model', layout=Layout(height='50px', width='400px'), style=ButtonStyle(button_color='l…

Output()

In [45]:
vbox_result = widgets.VBox([button_run, output])

In [46]:
button_run_2 = widgets.Button(description = 'Explore training images', tooltip='Send',
                style={'description_width': 'initial'}, layout = layout)

In [47]:
button_run_2.style.button_color = 'lightblue'

In [48]:
output = widgets.Output()

In [49]:
def on_button_clicked_2(event):
    from PIL import Image
    with output:
        prompt = construct_prompt()
        if collection.value == 'Ceramics':
            img_filename = get_closest_training_images_by_clip(ceramic_artist.value, prompt, '../VA-design-generator/images-labelled/ceramics')
            img = Image.open(f'../VA-design-generator/images-labelled/ceramics/{img_filename}')
        elif collection.value == 'Fashion & Textiles':
            img_filename = get_closest_training_images_by_clip(fashion_artist.value, prompt, '../VA-design-generator/images-labelled/fashion')
            img = Image.open(f'../VA-design-generator/images-labelled/fashion/{img_filename}')
        elif collection.value == 'Furniture & Metalwork':
            img_filename = get_closest_training_images_by_clip(furniture_artist.value, prompt, '../VA-design-generator/images-labelled/furniture')
            img = Image.open(f'../VA-design-generator/images-labelled/furniture/{img_filename}')
        display(img.resize((int(img.width*0.3), int(img.height*0.3))))

        

In [50]:
button_run_2.on_click(on_button_clicked_2)

In [51]:


def get_closest_training_images_by_clip(artist, prompt, directory):
    from transformers import CLIPProcessor, CLIPModel
    from PIL import Image
    from tqdm import tqdm
    model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
    processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
    filenames = os.listdir(directory)
    score = 0
    to_search = []
    for f in filenames:
        if artist in f:
            to_search.append(f)
    for i, f in enumerate(tqdm(to_search)):
        image = Image.open(f'{directory}/{f}')
        inputs = processor(text=[prompt], images = image, return_tensors = 'pt', padding=True)
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image 
        s = logits_per_image.item()
        if s > score:
            score = s
            index = i
    return filenames[index]

    

In [52]:
display(button_run_2, output)

Button(description='Explore training images', layout=Layout(height='50px', width='400px'), style=ButtonStyle(b…

Output()

In [53]:
vbox_result_2 = widgets.VBox([button_run_2, output])