In [1]:
import numpy as np 
import pandas as pd
import cv2
import os
import random
import matplotlib.pyplot as plt

In [2]:
# load image from npy file
images = np.load('tranformed_cavallo.npy')
print(images.shape)

(2623, 120, 120, 3)


In [3]:
def divide_image(image, dim):
    parts = []
    height, width, _ = image.shape
    subpart_width = width // dim
    subpart_height = height // dim
    
    for i in range(dim):
        for j in range(dim):
            part = image[i*subpart_height:(i+1)*subpart_height, j*subpart_width:(j+1)*subpart_width]
            parts.append(part)
    
    return parts

In [4]:
def generate_combinations(parts, num_combinations):
    combinations = []
    original_positions = []
    indices = list(range(len(parts)))
    
    for _ in range(num_combinations):
        temp = []
        random.shuffle(indices)
        combination = [parts[i] for i in indices]
        #further divide each part into 4x4 subpart
        for i in range(len(combination)):
            temp += divide_image(combination[i], 4)

        combinations.append(temp)
        original_positions.append(indices.copy())
    
    return combinations, original_positions

In [5]:
# parts = divide_image(images[1], 3)
# combinations, original_positions = generate_combinations(parts, 100)

In [88]:
def divide_image(image):
    subparts = []
    height, width, _ = image.shape
    subpart_width = width // 3
    subpart_height = height // 3
    
    for i in range(3):
        for j in range(3):
            part = image[i*subpart_height:(i+1)*subpart_height, j*subpart_width:(j+1)*subpart_width]
            for k in range(4):
                for l in range(4):
                    subpart = part[k*subpart_height//4:(k+1)*subpart_height//4, l*subpart_width//4:(l+1)*subpart_width//4]
                    subparts.append(subpart)
    
    return subparts

In [7]:
import torch
from torch_geometric.data import Data

In [8]:
def create_graph(subparts):
    # Define the number of nodes and edges
    num_nodes = 144
    internal_edges = []  # List to store internal edges within each 4x4 grid
    part_adjacent_edges = []  # List to store edges between adjacent parts
    
    # Define internal edges between consecutive subparts in each 4x4 grid (tile)
    for part_idx in range(9):  # There are 9 parts in the 3x3 grid
        # Calculate the starting index for the current part
        start_idx = part_idx * 16

        # Connect each node in the 4x4 grid
        for row in range(4):
            for col in range(4):
                current_index = start_idx + row * 4 + col

                # Connect to the right neighbor (if exists)
                if col < 3:
                    right_neighbor_index = current_index + 1
                    internal_edges.append((current_index, right_neighbor_index))
                    internal_edges.append((right_neighbor_index, current_index))

                # Connect to the bottom neighbor (if exists)
                if row < 3:
                    bottom_neighbor_index = current_index + 4
                    internal_edges.append((current_index, bottom_neighbor_index))
                    internal_edges.append((bottom_neighbor_index, current_index))

    # Define edges between adjacent parts (horizontally and vertically)
    # Horizontal connections
    for part_row in range(3):  # 3 rows of parts
        for part_col in range(2):  # 2 columns of adjacent parts
            # Calculate the start index of each adjacent part
            part1_start_idx = (part_row * 3 + part_col) * 16
            part2_start_idx = (part_row * 3 + part_col + 1) * 16
            
            # Connect each node in the rightmost column of part1 with the leftmost column of part2
            for row in range(4):
                part1_node_index = part1_start_idx + row * 4 + 3  # Rightmost column in part1
                part2_node_index = part2_start_idx + row * 4  # Leftmost column in part2
                
                part_adjacent_edges.append((part1_node_index, part2_node_index))
                part_adjacent_edges.append((part2_node_index, part1_node_index))

    # Vertical connections
    for part_col in range(3):  # 3 columns of parts
        for part_row in range(2):  # 2 rows of adjacent parts
            # Calculate the start index of each adjacent part
            part1_start_idx = (part_row * 3 + part_col) * 16
            part2_start_idx = ((part_row + 1) * 3 + part_col) * 16
            
            # Connect each node in the bottommost row of part1 with the topmost row of part2
            for col in range(4):
                part1_node_index = part1_start_idx + (3 * 4) + col  # Bottommost row in part1
                part2_node_index = part2_start_idx + col  # Topmost row in part2
                
                part_adjacent_edges.append((part1_node_index, part2_node_index))
                part_adjacent_edges.append((part2_node_index, part1_node_index))

    # Combine internal edges and part adjacent edges into one edge list
    # Create an attribute list to label each edge type
    edges = internal_edges + part_adjacent_edges
    edge_type = [0] * len(internal_edges) + [1] * len(part_adjacent_edges)  # 0 for internal, 1 for adjacent
    
    # Convert edges list to a tensor and transpose it
    edge_index = torch.tensor(edges).T
    edge_attr = torch.tensor(edge_type)  # Attribute tensor for edge types

    # Convert subparts list to a tensor of shape `(num_nodes, C, H, W)`
    # Assuming each subpart is a tensor with shape (C, H, W)
    subparts = [torch.tensor(subpart).view(-1) for subpart in subparts]
    node_features = torch.stack(subparts)  # Stack list of tensors into one tensor

    # Create the Data object for PyTorch Geometric
    graph_data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)

    return graph_data

In [9]:
import networkx as nx

In [10]:
def create_data(imgs):
    for i in imgs:
        parts = divide_image(i, 3)
        combinations, _ = generate_combinations(parts, 1)
        for comb in combinations:
            graph_data = create_graph(comb)
            yield graph_data

In [11]:
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

In [12]:
class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(
            dim=-1
        )  # product of a pair of nodes on each edge

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

In [13]:
@torch.no_grad()
def eval_link_predictor(model, data):

    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()

    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

In [14]:
def filter_adjacent_edges(graph_data):
    # Access edge_index and edge_attr tensors
    edge_index = graph_data.edge_index
    edge_attr = graph_data.edge_attr

    # Lists to store the indices of edges with attribute `1`
    adjacent_edge_indices = []

    # Iterate through each edge
    for i in range(edge_index.shape[1]):
        # Check the edge attribute
        if edge_attr[i].item() == 1:
            adjacent_edge_indices.append(i)

    # Filter edge_index and edge_attr tensors to include only edges with attribute `1`
    filtered_edge_index = edge_index[:, adjacent_edge_indices]
    filtered_edge_attr = edge_attr[adjacent_edge_indices]

    # all idx other tham filtered_edge_index
    all_idx = list(range(edge_index.shape[1]))
    for idx in adjacent_edge_indices:
        all_idx.remove(idx)

    graph_data_edge_index = edge_index[:, all_idx]
    graph_data_edge_attr = edge_attr[all_idx]

    graph_data = Data(
        x=graph_data.x,
        edge_index=graph_data_edge_index,
        edge_attr= graph_data_edge_attr,
    )

    graph_data_adjacent = Data(
        x=graph_data.x,
        edge_index=filtered_edge_index,
        edge_attr=filtered_edge_attr
    )
    
    return graph_data, graph_data_adjacent

In [54]:
# def combine_edges(graph_data, split_data):
#     """
#     Combines internal edges from `graph_data` with the edges from `split_data`.
#     """
#     # Get internal edges from the original graph data
#     internal_edge_index = graph_data.edge_index
#     internal_edge_attr = graph_data.edge_attr

#     # Get edges and attributes from the split data (train/val/test)
#     split_edge_index = split_data.edge_index
#     split_edge_attr = split_data.edge_attr

#     # Combine edge indices and attributes
#     combined_edge_index = torch.cat((internal_edge_index, split_edge_index), dim=1)
#     combined_edge_attr = torch.cat((internal_edge_attr, split_edge_attr))

#     # Create a new Data object with the combined edges
#     combined_data = Data(
#         x=graph_data.x,  # Use the same node features
#         num_nodes = graph_data.x.shape[0],
#         edge_index=combined_edge_index,
#         edge_attr=combined_edge_attr
#     )

#     return combined_data

def combine_edges(graph_data, split_data):
    """
    Combines internal edges from `graph_data` with the edges from `split_data`, including labels.
    """
    # Combine edge indices and attributes
    combined_edge_index = torch.cat((graph_data.edge_index, split_data.edge_index), dim=1)
    combined_edge_attr = torch.cat((graph_data.edge_attr, split_data.edge_attr))

    # Check if edge_label and edge_label_index are present and combine them if they are
    if hasattr(graph_data, 'edge_label') and hasattr(split_data, 'edge_label'):
        combined_edge_label = torch.cat((graph_data.edge_label, split_data.edge_label))
        combined_edge_label_index = torch.cat((graph_data.edge_label_index, split_data.edge_label_index),dim=1)

    # Create a new Data object
    combined_data = Data(
        x=graph_data.x,  # Use the same node features
        num_nodes = graph_data.x.shape[0],
        edge_index=combined_edge_index,
        edge_attr=combined_edge_attr,
        edge_label=combined_edge_label if 'combined_edge_label' in locals() else None,
        edge_label_index=combined_edge_label_index if 'combined_edge_label_index' in locals() else None
    )

    return combined_data


In [55]:
# def train_link_predictor(
#     model, train_loader, val_loader, optimizer, criterion, n_epochs=100
# ):

#     for epoch in range(1, n_epochs + 1):
#         model.train()
#         total_loss = 0

#         # Iterate through the DataLoader
#         for batch in train_loader:
#             # Zero out the gradients
#             optimizer.zero_grad()

#             # batch.edge_index = batch.edge_index.T

#             # Encode the batch's node features and edge indices
#             z = model.encode(batch.x.float(), batch.edge_index)

#             # Sample negative edges
#             neg_edge_index = negative_sampling(
#                 edge_index=batch.edge_index, num_nodes=batch.num_nodes,
#                 num_neg_samples=batch.edge_label_index.size(1),
#                 method='sparse'
#             )

#             # Combine positive and negative edge indices and labels
#             edge_label_index = torch.cat(
#                 [batch.edge_label_index, neg_edge_index],
#                 dim=-1,
#             )
#             edge_label = torch.cat([
#                 batch.edge_label,
#                 batch.edge_label.new_zeros(neg_edge_index.size(1))
#             ], dim=0)

#             # Decode and compute the loss
#             out = model.decode(z, edge_label_index).view(-1)
#             loss = criterion(out, edge_label)
#             loss.backward()
#             optimizer.step()

#             total_loss += loss.item()

#             print("hi")

#         # Compute the average training loss for the epoch
#         avg_train_loss = total_loss / len(train_loader)

#         # Evaluate the model on validation data
#         val_auc = eval_link_predictor(model, val_loader)

#         if epoch % 10 == 0:
#             print(f"Epoch: {epoch:03d}, Avg Train Loss: {avg_train_loss:.3f}, Val AUC: {val_auc:.3f}")

#     return model


In [96]:
import torch_geometric.transforms as T

train_data = []
val_data = []
test_data = []

for i in range(images.shape[0]):

    parts = divide_image(images[i])

    graph_data = create_graph(parts)


    graph_data, graph_data_adjacent = filter_adjacent_edges(graph_data)

    split = T.RandomLinkSplit(
        num_val=0.33,
        num_test=0.33,
        is_undirected=True,
        add_negative_train_samples=True,
        neg_sampling_ratio=1.0,
    )
    split_2 = T.RandomLinkSplit(
        num_val=0,
        num_test=0,
        is_undirected=True,
        add_negative_train_samples=True,
        neg_sampling_ratio=1.0,
    )

    train_data_adjacent, val_data_adjacent, test_data_adjacent = split(graph_data_adjacent)
    # print(train_data_adjacent)
    train_graph_data, val_graph_data, test_graph_data = split_2(graph_data)
    # print(train_graph_data)

    # Combine internal edges with training data adjacent edges
    train_data_combined = combine_edges(train_graph_data, train_data_adjacent)
    train_data.append(train_data_combined)



    # Combine internal edges with validation data adjacent edges
    val_data_combined = combine_edges(val_graph_data, val_data_adjacent)
    val_data.append(val_data_combined)
    # Combine internal edges with testing data adjacent edges
    test_data_combined = combine_edges(test_graph_data, test_data_adjacent)
    test_data.append(test_data_combined)

Data(x=[144, 300], edge_index=[2, 36], edge_attr=[36], edge_label=[30], edge_label_index=[2, 30])
Data(x=[144, 300], edge_index=[2, 432], edge_attr=[432], edge_label=[0], edge_label_index=[2, 0])


In [90]:
print(val_data_combined)

Data(x=[144, 300], edge_index=[2, 468], edge_attr=[468], num_nodes=144, edge_label=[30], edge_label_index=[2, 30])


In [79]:
train_data[0].x.float


<function Tensor.float>

In [91]:
from torch_geometric.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=1)



In [92]:
for batch in train_dataloader:
    print(batch)

    break

DataBatch(x=[144, 300], edge_index=[2, 468], edge_attr=[468], num_nodes=144, edge_label=[468], edge_label_index=[2, 468], batch=[144], ptr=[2])


In [93]:
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import negative_sampling


def train_link_predictor(
    model, train_data, val_data, optimizer, criterion, n_epochs=100
):

    for epoch in range(1, n_epochs + 1):
        for idx in range(len(train_data)):
            model.train()
            optimizer.zero_grad()
            data = train_data[idx]
            z = model.encode(data.x.float(), data.edge_index)

            # sampling training negatives for every training epoch
            neg_edge_index = negative_sampling(
                edge_index=data.edge_index, num_nodes=data.num_nodes,
                num_neg_samples=data.edge_label_index.size(1), method='sparse')

            edge_label_index = torch.cat(
                [data.edge_label_index, neg_edge_index],
                dim=-1,
            )
            edge_label = torch.cat([
                data.edge_label,
                data.edge_label.new_zeros(neg_edge_index.size(1))
            ], dim=0)

            out = model.decode(z, edge_label_index).view(-1)
            loss = criterion(out, edge_label)
            loss.backward()
            optimizer.step()

            val_auc = eval_link_predictor(model, val_data[idx])
            
        print(f"Epoch: {epoch:03d}, Train Loss: {loss:.3f}, Val AUC: {val_auc:.3f}")

    return model


@torch.no_grad()
def eval_link_predictor(model, data):

    model.eval()
    z = model.encode(data.x.float(), data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()

    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())

In [94]:
model = Net(300, 128, 64)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()
model = train_link_predictor(model, train_data, val_data, optimizer, criterion)
test_auc = eval_link_predictor(model, test_data)
print(f"Test: {test_auc:.3f}")

Epoch: 001, Train Loss: 0.696, Val AUC: 0.520
Epoch: 002, Train Loss: 0.695, Val AUC: 0.500
Epoch: 003, Train Loss: 0.694, Val AUC: 0.500
Epoch: 004, Train Loss: 0.693, Val AUC: 0.500
Epoch: 005, Train Loss: 0.693, Val AUC: 0.500
Epoch: 006, Train Loss: 0.693, Val AUC: 0.500
Epoch: 007, Train Loss: 0.693, Val AUC: 0.500
Epoch: 008, Train Loss: 0.693, Val AUC: 0.500
Epoch: 009, Train Loss: 0.693, Val AUC: 0.500
Epoch: 010, Train Loss: 0.694, Val AUC: 0.500
Epoch: 011, Train Loss: 0.693, Val AUC: 0.500
Epoch: 012, Train Loss: 0.704, Val AUC: 0.500
Epoch: 013, Train Loss: 0.693, Val AUC: 0.500


KeyboardInterrupt: 