In [None]:
import os
from PIL import Image, ImageChops
import matplotlib.pyplot as plt
import random

# Make sure you have already created a Kaggle API token and updated kaggle.json
!kaggle datasets download -d residentmario/wheres-waldo -p dataset/
!unzip -o dataset/wheres-waldo.zip -d dataset/
!curl $ASSET_PATH/flickr256/vis_model.pth --output ml-mdm/vis_model_256x256.pth # Apple's Matryoshka Diffusion Model

In [None]:
# Remove the white border from the images
def trim(im):
    bg = Image.new(im.mode, im.size, im.getpixel((0,0)))
    diff = ImageChops.difference(im, bg)
    diff = ImageChops.add(diff, diff, 2.0, -100) 
    bbox = diff.getbbox()
    if bbox:
        return im.crop(bbox)

for image_file in os.listdir('dataset/mondrian/png'):
    if image_file.endswith('.png'):
        im = Image.open("dataset/mondrian/png/" + image_file)
        im = trim(im)
        im.save("dataset/mondrian/trimmed/" + image_file)

In [None]:
def create_image_grid(mondrian_folder, waldo_folder, grid_size=(3, 6)):
    mondrian_files = [f for f in os.listdir(mondrian_folder) if f.endswith('.png')]
    waldo_files = [f for f in os.listdir(waldo_folder) if f.endswith('.jpg')]
    
    mondrian_count = grid_size[0] * (grid_size[1] // 2)
    waldo_count = grid_size[0] * (grid_size[1] - grid_size[1] // 2)
    
    selected_mondrian = random.sample(mondrian_files, min(mondrian_count, len(mondrian_files)))
    selected_other = random.sample(waldo_files, min(waldo_count, len(waldo_files)))
    
    fig, axes = plt.subplots(grid_size[0], grid_size[1], figsize=(20, 10))
    
    for row in range(grid_size[0]):
        for col in range(grid_size[1]):
            ax = axes[row, col]
            if col < grid_size[1] // 2:
                img_file = selected_mondrian.pop(0) if selected_mondrian else None
                folder = mondrian_folder
            else:
                img_file = selected_other.pop(0) if selected_other else None
                folder = waldo_folder
            
            if img_file:
                img = Image.open(os.path.join(folder, img_file))
                ax.imshow(img)
            ax.axis('off')
    
    plt.tight_layout()
    plt.savefig('image_grid.png')
    plt.show()

create_image_grid('dataset/mondrian/trimmed', 'dataset/waldo/256/waldo')