In [1]:
!pip install einops



In [None]:
# partial draft via chatgpt, untested

from PIL import Image
import torch
from scipy.spatial.distance import pdist, squareform
from scipy.optimize import linear_sum_assignment
from einops import rearrange

def multiplex(latents, K):
    """
    Rearrange the latents tensor into a KxK grid using the einops rearrange function.

    Parameters
    ----------
    latents: torch.Tensor
        The latent tensors to be rearranged.
    K: int
        The grid size for rearrangement.

    Returns
    -------
    torch.Tensor
        The rearranged latents in a KxK grid.
    """
    return rearrange(latents, '(b k1 k2) h w -> b (k1 h) (k2 w)', k1=K, k2=K)

def diffusion_generate(latents):
    """
    TODO: Implement this function.
    """
    pass

def demultiplex(images, K):
    """
    Rearrange the images tensor back from a KxK grid using the einops rearrange function.

    Parameters
    ----------
    images: torch.Tensor
        The image tensors to be rearranged.
    K: int
        The grid size for rearrangement.

    Returns
    -------
    torch.Tensor
        The rearranged images from a KxK grid back to their original shape.
    """
    return rearrange(images, 'b (k1 h) (k2 w) -> (b k1 k2) h w', k1=K, k2=K)

def tsp_sort(images, num_fixed=3):
    """
    Perform a TSP sort on the images with the first num_fixed images fixed.

    Parameters
    ----------
    images: torch.Tensor
        The image tensors to be sorted.
    num_fixed: int, optional
        The number of images that are already in the correct order. Default is 3.

    Returns
    -------
    torch.Tensor
        The TSP sorted images.
    """
    # Compute the cosine similarity between all pairs of images
    similarity = torch.nn.functional.cosine_similarity(images.unsqueeze(1), images.unsqueeze(0), dim=2)
    
    # Convert to a distance matrix
    distance_matrix = 1.0 - similarity

    # Perform the TSP sort
    row_indices, col_indices = linear_sum_assignment(distance_matrix[num_fixed:, num_fixed:].cpu().numpy())
    
    # Add the fixed indices back to get the final sorted order
    sorted_order = torch.cat([torch.arange(num_fixed), torch.as_tensor(col_indices + num_fixed)])

    return images[sorted_order]

def decode_latents(latents):
    """
    TODO: Implement this function.
    """
    pass

def process_iteration(latents, K, save_path, iteration):
    """
    Process a single iteration: multiplex, generate images, demultiplex, TSP sort, and slide & freeze.
    The latents are assumed to already include the last K latents from the previous step.
    """
    latents = multiplex(latents, K)
    images = diffusion_generate(latents)
    images = demultiplex(images, K)
    images = tsp_sort(images, num_fixed=K)

    # Slide and freeze
    frozen_images = images[:K**2-K]
    shifted_images = torch.roll(images[-K:], shifts=1, dims=0)
    images = torch.cat((frozen_images, shifted_images), dim=0)

    # Decode the frozen images and save them to disk
    decoded_images = decode_latents(frozen_images)
    decoded_images = ((decoded

###############################################################
                       
import torch
from PIL import Image

def encode_image(image):
    """
    TODO: Implement this function.
    Abstract function to encode an RGB image to a latent.
    """
    pass

def gaussian_blur(image, sigma):
    """
    TODO: Implement this function.
    Abstract function to apply a Gaussian blur to an image tensor.
    """
    pass

def initialize_latent_grid(init_image_path, K, noise_stddev, blur_sigma):
    """
    Initialize a KxK latent grid from an initial image.

    Parameters
    ----------
    init_image_path: str
        Path to the initial image file.
    K: int
        The size of the latent grid.
    noise_stddev: float
        Standard deviation of the Gaussian noise to add.
    blur_sigma: float
        Sigma parameter for the Gaussian blur.

    Returns
    -------
    torch.Tensor
        The initialized latent grid.
    """
    # Load the initial image
    image = Image.open(init_image_path)
    image = torch.from_numpy(np.array(image)).float() / 255.0

    # Encode the image to a latent
    latent = encode_image(image)

    # Tile the latent into a KxK grid
    latent_grid = latent.repeat(K, K, 1, 1)

    # Add Gaussian noise
    noise = torch.randn_like(latent_grid) * noise_stddev
    latent_grid += noise

    # Apply Gaussian blur
    latent_grid = gaussian_blur(latent_grid, blur_sigma)

    # Run a process iteration with no fixed frames and skip persistence
    process_iteration(latent_grid, K, save_path=None, iteration=0, num_fixed=0, persist=False)

    return latent_grid
