In [11]:
# @title imports

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from skimage.segmentation import slic, mark_boundaries

import torchvision
import torchvision.models as models

import math

import os

from torch.utils.data import random_split, DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

# from torchvision.datasets import MNIST, CIFAR100
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

import pickle

import matplotlib.pyplot as plt

from skimage.segmentation import slic
from skimage.measure import regionprops
from skimage import filters, graph
import torch
import numpy as np
from torch_geometric.utils.convert import from_networkx
from torchvision.transforms import Resize

import networkx as nx

# Code

In [12]:
import os
import pickle

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.transforms import AddLaplacianEigenvectorPE

from skimage.segmentation import slic
from skimage.measure import regionprops
from skimage import filters, graph
import torch
import numpy as np
from torch_geometric.utils.convert import from_networkx

import os
import pickle

import torch
from torchvision.transforms import Resize
from torch.nn.utils.rnn import pad_sequence
from torch_geometric.data import Data

def extract_patches(img, seg, reg):
    imgs, masks, coords = [], [], []
    for idx in range(len(reg)):
        x_min, y_min, x_max, y_max = reg[idx].bbox
        cropped_image = img[:, x_min:x_max, y_min:y_max]
        cropped_mask = seg[x_min:x_max, y_min:y_max]
        cropped_mask = cropped_mask == idx + 1
        imgs.append(cropped_image)
        masks.append(cropped_mask)
        coords.append(reg[idx].centroid)

    return imgs, masks, coords

def random_split(data, ratios, dataset_name):
    if len(ratios) == 2:
        train_ratio, test_ratio = ratios
        val_ratio = 0
    elif len(ratios) == 3:
        train_ratio, val_ratio, test_ratio = ratios
    else:
        raise ValueError("ratios must be of length 2 or 3")
    
    save_path = os.sep.join([".", "data", "split", dataset_name, f"{train_ratio}-{val_ratio}-{test_ratio}.pkl"])
    if os.path.isfile(save_path):
        idx_train, idx_val, idx_test = pickle.load(open(save_path, "rb"))
        print("Loaded existing data split")
        
    else:
        n = len(data)
        t = [int(train_ratio*n), int((train_ratio+val_ratio)*n)]
        
        idx = torch.randperm(n)
        idx_train, idx_val, idx_test = idx[:t[0]], idx[t[0]:t[1]], idx[t[1]:]
        
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        pickle.dump((idx_train, idx_val, idx_test), open(save_path, "wb"))
        print("Created new data split")
    
    train_data, val_data, test_data = [], [], []
    for i in idx_train:
        train_data.append(data[i])
    for i in idx_val:
        val_data.append(data[i])
    for i in idx_test:
        test_data.append(data[i])
    
    return train_data, val_data, test_data

def resize_stack_slic_graph_patches(data, size):
    r = Resize(size)

    for idx, g in enumerate(data):
        for i in range(len(g.imgs)):
            g.imgs[i] = (torch.Tensor(g.imgs[i] * g.masks[i]) / torch.sum(torch.Tensor(g.masks[i]), dim=(0, 1))).unsqueeze(0)
        # g.imgs = [torch.Tensor(g.imgs[i] * g.masks[i]).unsqueeze(0) for i in range(len(g.imgs))] 
        g.imgs = [r(img) for img in g.imgs]
        g.imgs = torch.cat(g.imgs, dim=0)

        if len(g.imgs.shape) == 3:
            g.imgs = g.imgs.unsqueeze(1)
        
        if idx % 1000 == 0:
            print("Processed graph", idx)

    return data

def collate_slic_graph_patches(batch):
    lengths = torch.tensor([len(g.imgs) for g in batch])
    max_len = torch.max(lengths)
    mask = torch.arange(max_len).expand(len(lengths), max_len) >= lengths.unsqueeze(1)

    imgs = pad_sequence([g.imgs for g in batch], batch_first=True)
    coords = pad_sequence([g.centroid for g in batch], batch_first=True)

    y = torch.tensor([g.y for g in batch], dtype=torch.long)

    return imgs, coords, mask, y

def generate_all_edges(h, w):
    directions = [
        (-1, 0),  # left
        (-1, -1), # top-left
        (0, -1),  # top
        (1, -1),  # top-right
        (1, 0),   # right
        (1, 1),   # bottom-right
        (0, 1),   # bottom
        (-1, 1)   # bottom-left
    ]
    
    # Generate all pixel coordinates (x, y)
    x_coords = torch.arange(w)
    y_coords = torch.arange(h)
    
    # Create a grid of all pixel coordinates
    grid_x, grid_y = torch.meshgrid(x_coords, y_coords, indexing='ij')
    
    # Flatten the grid to get all (x, y) pairs
    grid_x = grid_x.flatten()
    grid_y = grid_y.flatten()
    
    # Initialize lists to collect valid edges
    source_indices = []
    target_indices = []
    
    # Loop through all directions
    for dx, dy in directions:
        # Compute the target coordinates (x + dx, y + dy)
        target_x = grid_x + dx
        target_y = grid_y + dy
        
        # Determine valid edges (those within the image bounds)
        valid_mask = (target_x >= 0) & (target_x < w) & (target_y >= 0) & (target_y < h)
        
        # Extract valid coordinates
        valid_source_indices = grid_y[valid_mask] * w + grid_x[valid_mask]
        valid_target_indices = target_y[valid_mask] * w + target_x[valid_mask]
        
        # Append to the lists
        source_indices.append(valid_source_indices)
        target_indices.append(valid_target_indices)
    
    # Concatenate all valid edges
    all_source_indices = torch.cat(source_indices)
    all_target_indices = torch.cat(target_indices)
    
    # Stack the source and target indices to form the edges
    edges = torch.stack([all_source_indices, all_target_indices], dim=0)
    
    return edges

def image_to_pygraph(data):
    img, y = data
    c, h, w = img.shape
    
    edge_index = generate_all_edges(h, w)
    
    x = img.reshape(c, -1).T
    
    coords = torch.cartesian_prod(torch.arange(h), torch.arange(w))
    
    return Data(x=x, edge_index=edge_index, y=y, coords=coords)

def image_to_SLIC_graph(data, n_segments=14*14, compactness=0.5, save_img=False):
    img, y = data
    
    assert type(img) == torch.Tensor and len(img.shape) == 3 and (img.shape[0] == 1 or img.shape[0] == 3)

    num_channels = img.shape[0]
    img_np = np.array(img.permute(1, 2, 0)) if num_channels == 3 else np.array(img.squeeze(0))
    seg = slic(img_np, n_segments=n_segments, compactness=compactness, channel_axis=-1 if num_channels == 3 else None)
    reg = regionprops(seg)

    edge_boundary = filters.sobel(img_np if num_channels == 1 else np.mean(img_np, axis=2))
    nx_g = graph.rag_boundary(seg, edge_boundary)
    g = from_networkx(nx_g)
    if save_img:
        g.img = img_np
        g.seg = seg
        g.edge_boundary = edge_boundary

    imgs, masks, coords = extract_patches(img, seg, reg)
    g.centroid = torch.Tensor([coords[label[0] - 1] for label in g.labels])
    g.imgs = [imgs[label[0] - 1] for label in g.labels]
    g.masks = [masks[label[0] - 1] for label in g.labels]
    g.y = y

    return g

class ImageClassificationDataset(Dataset):
    def __init__(self, root, *args, **kwargs) -> None:
        super().__init__()
        self.root = root
        self.data = []
        self.train_data, self.val_data, self.test_data = [], [], []
        self.train_loader, self.val_loader, self.test_loader = None, None, None
    
    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]
    
    def to_pixel_graphs(self, laplacian_pe=False, **kwargs):
        g_path = os.sep.join([self.root, "graph", "pixel.pkl"])
        if os.path.isfile(g_path):
            self.data = pickle.load(open(g_path, "rb"))
            print("Loaded existing pixel graphs")
        
        else:
            data = []
            for idx, datum in enumerate(self.data):
                data.append(image_to_pygraph(datum))
                if idx % 1000 == 0:
                    print("Processed image", idx)
            os.makedirs(os.path.dirname(g_path), exist_ok=True)
            pickle.dump(data, open(g_path, "wb"))
            print("Converted image dataset to pixel graphs")
            
            self.data = data
        
        if laplacian_pe: # not recommended; use in the model unless overhead is too high
            transform = AddLaplacianEigenvectorPE(**kwargs)
            for idx, datum in enumerate(self.data):
                pe = transform(datum).laplacian_eigenvector_pe
                datum.x = torch.cat((datum.x, pe), dim=-1)
                if idx % 1000 == 0:
                    print("Processed image", idx)
            print("Precomputed Laplacian PE")
            
    
    def _convert_to_slic(self, n_segments, compactness):
        g_path = os.sep.join([self.root, "graph", f"SLIC_{n_segments}_{compactness}.pkl"])
        if os.path.isfile(g_path):
            data = pickle.load(open(g_path, "rb"))
            print("Loaded existing SLIC graphs")
        
        else:
            data = []
            for idx, datum in enumerate(self.data):
                data.append(image_to_SLIC_graph(datum, n_segments, compactness))
                if idx % 1000 == 0:
                    print("Processed image", idx)
            os.makedirs(os.path.dirname(g_path), exist_ok=True)
            pickle.dump(data, open(g_path, "wb"))    
            print("Converted image dataset to SLIC graphs")
        
        return data
    
    def to_slic_graphs(self, n_segments, compactness, resize_stack_patches=None):
        if resize_stack_patches is not None:
            g_path = os.sep.join([self.root, "graph", f"SLIC_{n_segments}_{compactness}_resized_stacked_{resize_stack_patches}.pkl"])
            if os.path.isfile(g_path):
                self.data = pickle.load(open(g_path, "rb"))
                print("Loaded existing resized and stacked graphs")
            
            else:
                self.data = self._convert_to_slic(n_segments, compactness)
                
                print("Resizing and stacking patches")
                self.data = resize_stack_slic_graph_patches(self.data, resize_stack_patches)
                os.makedirs(os.path.dirname(g_path), exist_ok=True)
                pickle.dump(self.data, open(g_path, "wb"))
        
        else:
            self.data = self._convert_to_slic(n_segments, compactness)
    
    def splits(self, ratios):
        assert sum(ratios) == 1

        if self.train_data == []:
            self.train_data, self.val_data, self.test_data = random_split(self.data, ratios, "CIFAR100")
        return self.train_data, self.val_data, self.test_data

    def loaders(self, batch_size, shuffles=(True, False, False), num_workers=0, graph_loader=False, collate_fn=None):
        LoaderClass = PyGDataLoader if graph_loader else DataLoader
        if self.train_loader is None:
            self.train_loader = LoaderClass(
                self.train_data,
                batch_size=batch_size,
                shuffle=shuffles[0],
                num_workers=num_workers,
                collate_fn=collate_fn,
            )
            self.val_loader = LoaderClass(
                self.val_data,
                batch_size=batch_size,
                shuffle=shuffles[1],
                num_workers=num_workers,
                collate_fn=collate_fn,
            )
            self.test_loader = LoaderClass(
                self.test_data,
                batch_size=batch_size,
                shuffle=shuffles[2],
                num_workers=num_workers,
                collate_fn=collate_fn,
            )
        return self.train_loader, self.val_loader, self.test_loader
    
class CIFAR100(ImageClassificationDataset):
    def __init__(self, root, transform=None, download=True) -> None:
        super().__init__(root=root)
        
        self.data = datasets.CIFAR100(root, transform=transform, download=download)


class MNIST(ImageClassificationDataset):
    def __init__(self, root, transform=None, download=True) -> None:
        super().__init__(root=root)
        
        self.data = datasets.MNIST(root, transform=transform, download=download)
    

In [13]:
transform = Compose([
    Resize((32, 32)),
    ToTensor(),
    Normalize(0, 1),
])
dataset = CIFAR100(root="../data/image/CIFAR100", download=True, transform=transform)

Files already downloaded and verified


In [16]:
resize_stack_patches = (7, 7)
n_segments = 16
compactness = 15
dataset.to_slic_graphs(resize_stack_patches=None, n_segments=n_segments, compactness=compactness)

Loaded existing SLIC graphs


In [81]:
img = np.random.randint(10, size=(2, 3, 3))
mask = np.random.randint(2, size=(3, 3))
mask.astype(bool), img

(array([[ True, False, False],
        [ True, False,  True],
        [False, False,  True]]),
 array([[[6, 7, 2],
         [2, 5, 7],
         [3, 6, 7]],
 
        [[2, 0, 3],
         [7, 8, 5],
         [8, 9, 2]]]))

In [90]:
mask.reshape(1, mask.shape[0], mask.shape[1]).repeat(img.shape[0], axis=0)

array([[[1, 0, 0],
        [1, 0, 1],
        [0, 0, 1]],

       [[1, 0, 0],
        [1, 0, 1],
        [0, 0, 1]]])

In [95]:
masked_img = np.ma.masked_array(img, ~mask.reshape(1, mask.shape[0], mask.shape[1]).repeat(img.shape[0], axis=0).astype(bool))
masked_img.mean(axis=(-1,-2))
masked_img.std(axis=(-1,-2))
masked_img.min(axis=(-1,-2))
masked_img.max(axis=(-1,-2))

masked_array(data=[7, 7],
             mask=[False, False],
       fill_value=999999)

In [110]:
def lrgb_statistics(img, mask):
    if len(img.shape) == 3:
        expanded_mask = mask.reshape(1, mask.shape[0], mask.shape[1]).repeat(img.shape[0], axis=0)
        masked_img = np.ma.masked_array(img, ~expanded_mask)
        axes = (-1, -2)
        res = torch.tensor(np.concatenate([
            masked_img.mean(axis=axes),
            masked_img.std(axis=axes),
            masked_img.min(axis=axes),
            masked_img.max(axis=axes),
        ]))
        return res
    
    elif len(img.shape) == 2:
        masked_img = np.ma.masked_array(img, ~mask)
        res = torch.tensor(np.array([
            masked_img.mean(),
            masked_img.std(),
            masked_img.min(),
            masked_img.max(),
        ]))
        return res
    
    else:
        raise NotImplementedError()

lrgb_statistics(b[0].imgs[0][0,:,:], b[0].masks[0])

tensor([0.8540, 0.2255, 0.0863, 1.0000], dtype=torch.float64)

In [112]:
def lrgb_statistics(img, mask):
    if len(img.shape) == 3:
        expanded_mask = mask.reshape(1, mask.shape[0], mask.shape[1]).repeat(img.shape[0], axis=0)
        masked_img = np.ma.masked_array(img, ~expanded_mask)
        axes = (-1, -2)
        res = torch.tensor(np.concatenate([
            masked_img.mean(axis=axes),
            masked_img.std(axis=axes),
            masked_img.min(axis=axes),
            masked_img.max(axis=axes),
        ]))
        return res
    
    elif len(img.shape) == 2:
        masked_img = np.ma.masked_array(img, ~mask)
        res = torch.tensor(np.array([
            masked_img.mean(),
            masked_img.std(),
            masked_img.min(),
            masked_img.max(),
        ]))
        return res
    
    else:
        raise NotImplementedError()

def slic_graph_patches_to_lrgb_stats(data):
    # each patch becomes a 4c-dimensional embedding with mean, std, min, max
    for idx, g in enumerate(data):
        g.x = []
        for i in range(len(g.imgs)):
            g.x.append(lrgb_statistics(g.imgs[i], g.masks[i]))
        g.x = torch.stack(g.x, dim=0)
        
        if idx % 1000 == 0:
            print("Processed graph", idx)
    
    return data
    
slic_graph_patches_to_lrgb_stats(b)

Processed graph 0


[Data(edge_index=[2, 46], labels=[12, 1], weight=[46], count=[46], num_nodes=12, centroid=[12, 2], imgs=[12], masks=[12], y=19, batch=[1], x=[12, 12]),
 Data(edge_index=[2, 42], labels=[11, 1], weight=[42], count=[42], num_nodes=11, centroid=[11, 2], imgs=[11], masks=[11], y=29, x=[11, 12]),
 Data(edge_index=[2, 52], labels=[12, 1], weight=[52], count=[52], num_nodes=12, centroid=[12, 2], imgs=[12], masks=[12], y=0, x=[12, 12]),
 Data(edge_index=[2, 36], labels=[10, 1], weight=[36], count=[36], num_nodes=10, centroid=[10, 2], imgs=[10], masks=[10], y=11, x=[10, 12]),
 Data(edge_index=[2, 54], labels=[13, 1], weight=[54], count=[54], num_nodes=13, centroid=[13, 2], imgs=[13], masks=[13], y=1, x=[13, 12]),
 Data(edge_index=[2, 32], labels=[9, 1], weight=[32], count=[32], num_nodes=9, centroid=[9, 2], imgs=[9], masks=[9], y=86, x=[9, 12]),
 Data(edge_index=[2, 42], labels=[11, 1], weight=[42], count=[42], num_nodes=11, centroid=[11, 2], imgs=[11], masks=[11], y=90, x=[11, 12]),
 Data(edge