In [None]:
import torch
from torch_geometric.loader import DataLoader
from _04_mnist_digits.graph_dataset import GraphDataset # Needed for loading pickled dataset
from torchvision import datasets, transforms

In [2]:
# Define a transform to convert the images to tensors
transform = transforms.ToTensor()

# Download and load the MNIST dataset
train_dataset = datasets.MNIST(root='/data', train=True, transform=transform)
test_dataset = datasets.MNIST(root='/data', train=False, transform=transform)


In [3]:
images_train = train_dataset.data
labels_train = train_dataset.targets
images_test = test_dataset.data 
labels_test = test_dataset.targets

print(images_train.shape)
print(images_test.shape)
print(labels_train.shape)
print(labels_test.shape)

torch.Size([60000, 28, 28])
torch.Size([10000, 28, 28])
torch.Size([60000])
torch.Size([10000])


In [4]:
images = torch.cat((images_train, images_test), dim=0)
labels = torch.cat((labels_train, labels_test), dim=0)
print(images.shape)
print(labels.shape)

torch.Size([70000, 28, 28])
torch.Size([70000])


In [6]:
def find_centroid(segmented_image):
    """Find the centroid of each segment in the segmented image."""
    labels = np.unique(segmented_image)
    centroids = {}
    for label in labels:
        coords = np.column_stack(np.where(segmented_image == label))
        centroids[label] = coords.mean(axis=0)
    return centroids

In [8]:
from shared.img_to_graph import img_to_graph
import skimage.segmentation as segm
import skimage.graph as g
import numpy as np

img = images[0]

if type(img) is not torch.Tensor:
    img = torch.tensor(img)

if img.ndim == 2:
    img = img.unsqueeze(-1).repeat((1,1,3))
elif img.ndim == 3 and img.shape[2] == 1:
    img = img.repeat(1, 1, 3)

quickshift_params = {"kernel_size": 3, "max_dist": 4, "ratio": 0.4}
rag_params = {"mode": "similarity"}

segmented_image = segm.quickshift(img.numpy(), **quickshift_params)
graph: g.RAG = g.rag_mean_color(img.numpy(), segmented_image, **rag_params)

edge_index = torch.tensor(list(graph.edges)).t().contiguous() # (2, num_edges)
X = torch.zeros(len(graph.nodes)) # (mean color of each node)
centroids = torch.zeros((len(graph.nodes), 2)) # (x,y coordinates)

centroids_dict = find_centroid(segmented_image)

for node_idx, node in enumerate(graph.nodes):
    node_attr = graph.nodes[node]["mean color"]
    X[node_idx] = torch.mean(torch.tensor(node_attr))
    centroids[node_idx, :] = torch.tensor(centroids_dict[node])

edge_weights = torch.zeros(edge_index.shape[1])  # (num_edges,)
edges = list(graph.edges)
for i in range(len(edges)):
    n1, n2 = edges[i]
    edge_weights[i] = torch.sqrt(
        torch.sum((centroids[n1] - centroids[n2]) ** 2))

# Normalizations
edge_weights = edge_weights / (torch.sqrt(torch.tensor(2.))*img.shape[0])  # Normalize weights to [0, 1]
X = X / 255  # Normalize colors to [0, 1]
edge_weights = -torch.log(edge_weights)
edge_weights /= edge_weights.max()  # Normalize edge weights to [0, 1]


In [12]:
from pprint import pprint

for node in graph.nodes:
    pprint(graph.nodes[node])

{'labels': [np.int64(20)],
 'mean color': array([1.11695906, 1.11695906, 1.11695906]),
 'pixel count': 342,
 'total color': array([382., 382., 382.])}
{'labels': [np.int64(18)],
 'mean color': array([0.5443787, 0.5443787, 0.5443787]),
 'pixel count': 169,
 'total color': array([92., 92., 92.])}
{'labels': [np.int64(0)],
 'mean color': array([131., 131., 131.]),
 'pixel count': 2,
 'total color': array([262., 262., 262.])}
{'labels': [np.int64(3)],
 'mean color': array([171., 171., 171.]),
 'pixel count': 3,
 'total color': array([513., 513., 513.])}
{'labels': [np.int64(7)],
 'mean color': array([249.60606061, 249.60606061, 249.60606061]),
 'pixel count': 33,
 'total color': array([8237., 8237., 8237.])}
{'labels': [np.int64(1)],
 'mean color': array([127., 127., 127.]),
 'pixel count': 1,
 'total color': array([127., 127., 127.])}
{'labels': [np.int64(9)],
 'mean color': array([100.5, 100.5, 100.5]),
 'pixel count': 2,
 'total color': array([201., 201., 201.])}
{'labels': [np.int64(2)