In [3]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import os
import tqdm

device = "mps"

def load_images(image_dir, tile_size):
    transform = transforms.Compose([
        transforms.Resize(tile_size),  # Resize to uniform size
        transforms.ToTensor()           # Convert to tensor
    ])
    dataset = ImageFolder(root=image_dir, transform=transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    all_images, _ = next(iter(loader))
    return all_images.to(device)

def tensor_to_image(tensor):
    return transforms.ToPILImage()(tensor)

def assemble_mosaic(target_image_path, small_images, tile_size=(64, 64), mosaic_size=(1024, 1024)):
    target_image = Image.open(target_image_path)
    target_image = target_image.resize(mosaic_size)
    target_tensor = transforms.ToTensor()(target_image).to(device)

    # Calculate number of tiles
    num_tiles_x = mosaic_size[0] // tile_size[0]
    num_tiles_y = mosaic_size[1] // tile_size[1]

    print(f'Number of tiles: {num_tiles_x} x {num_tiles_y}')

    # Initialize mosaic tensor
    mosaic = torch.zeros(3, mosaic_size[1], mosaic_size[0], device=device)

    for i in tqdm.tnrange(num_tiles_x):
        for j in range(num_tiles_y):
            x = i * tile_size[0]
            y = j * tile_size[1]
            region = target_tensor[:, y:y+tile_size[1], x:x+tile_size[0]]
            avg_color = region.reshape(3, -1).mean(dim=1)

            # Find the closest tile
            distances = torch.norm(small_images - avg_color[:, None, None], dim=1, p=2).mean([1, 2])
            closest_img_idx = torch.argmin(distances)
            closest_img = small_images[closest_img_idx]

            # Place tile into mosaic
            mosaic[:, y:y+tile_size[1], x:x+tile_size[0]] = closest_img

    return tensor_to_image(mosaic)

# Usage
tile_size = (32, 32)
mosaic_size = (4096, 4096)
image_dir = 'data/celeba_hq/val'
target_image_path = 'data/celeba_hq/val/male/000080.jpg'

small_images = load_images(image_dir, tile_size)
mosaic = assemble_mosaic(target_image_path, small_images, tile_size, mosaic_size)
mosaic.save('mosaic.jpg')


Number of tiles: 128 x 128


  for i in tqdm.tnrange(num_tiles_x):


  0%|          | 0/128 [00:00<?, ?it/s]