In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, 5)
        self.conv2 = nn.Conv2d(4, 12, 5)
        self.fc1 = nn.Linear(12*4*4, 10)

        self.layers = nn.ModuleDict({
            'conv1': self.conv1,
            'conv2': self.conv2,
            'fc1': self.fc1
        })

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.avg_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.avg_pool2d(x, 2)
        x = x.view(-1, 12*4*4)
        x = self.fc1(x)
        return x
    

model = Model(1, 10)
summary(model, input_size=(1, 1, 28, 28))

In [None]:
# Convert model to onnx
dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True)



In [None]:
import onnx

# Load an ONNX model
model_path = 'model.onnx'
model = onnx.load(model_path)

# Print a summary of the model
print('IR version:', model.ir_version)
print('Producer name:', model.producer_name)
print('Graph name:', model.graph.name)
print('Number of nodes:', len(model.graph.node))

print()

# List all nodes
for node in model.graph.node:
    print('Node:', node.name)
    print('Inputs:', node.input)
    print('Outputs:', node.output)
    print('Type:', node.op_type)
    print()


In [None]:
from graphviz import Digraph
import onnx


def format_attributes(attributes: list[onnx.AttributeProto]) -> str:
    """Format node attributes for display."""
    attr_str = ""
    for attr in attributes:
        # Convert attribute to a readable format based on type
        if attr.type == onnx.AttributeProto.INT:
            attr_value = attr.i
        elif attr.type == onnx.AttributeProto.FLOAT:
            attr_value = attr.f
        elif attr.type == onnx.AttributeProto.STRING:
            attr_value = attr.s.decode()  # Assuming byte string
        elif attr.type == onnx.AttributeProto.INTS:
            attr_value = list(attr.ints)
        elif attr.type == onnx.AttributeProto.FLOATS:
            attr_value = list(attr.floats)
        elif attr.type == onnx.AttributeProto.STRINGS:
            attr_value = [x.decode() for x in attr.strings]
        else:
            attr_value = "[Unsupported attribute type]"
        attr_str += f"{attr.name}: {attr_value}\n"
    return attr_str.strip()


def visualize_model_flow_with_attributes(onnx_model: onnx.ModelProto) -> Digraph:
    dot = Digraph(comment="Model Visualization with Attributes", format="png")

    # Create a mapping from output tensors to node names
    tensor_to_node = {}

    # Add nodes for input and output tensors
    for input_tensor in onnx_model.graph.input:
        tensor_to_node[input_tensor.name] = input_tensor.name

    # Create nodes for each operation
    for node in onnx_model.graph.node:
        for output in node.output:
            tensor_to_node[output] = node.name

    for output_tensor in onnx_model.graph.output:
        tensor_to_node[output_tensor.name] = output_tensor.name

    print(tensor_to_node)

    # Add nodes and edges
    for node in onnx_model.graph.node:
        attributes = format_attributes(node.attribute)
        label = f"{node.op_type}\n{node.name}\n{attributes}"
        dot.node(node.name, label, shape="box")
        for input_tensor in node.input:
            if input_tensor in tensor_to_node:
                # Connect the output of the previous node to the current node
                dot.edge(tensor_to_node[input_tensor], node.name)

    return dot


# Load your model
model_path = "model.onnx"
model = onnx.load(model_path)

# Visualize the model flow with attributes
dot = visualize_model_flow_with_attributes(model)
# Save and render the visualization
dot.render("output/model_flow_with_attributes_visualization", view=True)

In [None]:
model.graph.output

In [3]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import os
import tqdm

device = "mps"

def load_images(image_dir, tile_size):
    transform = transforms.Compose([
        transforms.Resize(tile_size),  # Resize to uniform size
        transforms.ToTensor()           # Convert to tensor
    ])
    dataset = ImageFolder(root=image_dir, transform=transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    all_images, _ = next(iter(loader))
    return all_images.to(device)

def tensor_to_image(tensor):
    return transforms.ToPILImage()(tensor)

def assemble_mosaic(target_image_path, small_images, tile_size=(64, 64), mosaic_size=(1024, 1024)):
    target_image = Image.open(target_image_path)
    target_image = target_image.resize(mosaic_size)
    target_tensor = transforms.ToTensor()(target_image).to(device)

    # Calculate number of tiles
    num_tiles_x = mosaic_size[0] // tile_size[0]
    num_tiles_y = mosaic_size[1] // tile_size[1]

    print(f'Number of tiles: {num_tiles_x} x {num_tiles_y}')

    # Initialize mosaic tensor
    mosaic = torch.zeros(3, mosaic_size[1], mosaic_size[0], device=device)

    for i in tqdm.tnrange(num_tiles_x):
        for j in range(num_tiles_y):
            x = i * tile_size[0]
            y = j * tile_size[1]
            region = target_tensor[:, y:y+tile_size[1], x:x+tile_size[0]]
            avg_color = region.reshape(3, -1).mean(dim=1)

            # Find the closest tile
            distances = torch.norm(small_images - avg_color[:, None, None], dim=1, p=2).mean([1, 2])
            closest_img_idx = torch.argmin(distances)
            closest_img = small_images[closest_img_idx]

            # Place tile into mosaic
            mosaic[:, y:y+tile_size[1], x:x+tile_size[0]] = closest_img

    return tensor_to_image(mosaic)

# Usage
tile_size = (32, 32)
mosaic_size = (4096, 4096)
image_dir = 'data/celeba_hq/val'
target_image_path = 'data/celeba_hq/val/male/000080.jpg'

small_images = load_images(image_dir, tile_size)
mosaic = assemble_mosaic(target_image_path, small_images, tile_size, mosaic_size)
mosaic.save('mosaic.jpg')


Number of tiles: 128 x 128


  for i in tqdm.tnrange(num_tiles_x):


  0%|          | 0/128 [00:00<?, ?it/s]