In [1]:
import numpy as np
import torch
from PIL import Image
from timm.data.transforms_factory import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import matplotlib.pyplot as plt
import random
import torch.nn as nn
import math
import warnings
warnings.filterwarnings('ignore')

In [2]:

def preprocess_image(image_path, input_size=(300, 300), mean=(0.485, 0.456, 0.406, 0.0), std=(0.229, 0.224, 0.225, 1.0)):
    # Load the image
    image = Image.open(image_path).convert('RGBA')
    
    # Define the transformation
    transform = create_transform(
        input_size=input_size,
        is_training=False,
        use_prefetcher=False,
        no_aug=False,
        interpolation='bilinear',
        mean=mean,
        std=std,
        crop_pct=None,
        tf_preprocessing=False,
    )
    
    tensor_image = transform(image).unsqueeze(0) 

    tensor_image = tensor_image.cuda() if torch.cuda.is_available() else tensor_image
    
    return tensor_image

preprocessed_image = preprocess_image('test2.png')


In [3]:
preprocessed_image.shape

torch.Size([1, 4, 300, 300])

In [4]:
''' for 4 tiles 
def cut_into_tiles_tensor(image_tensor, num_tiles=9):

    if num_tiles != 4:
        raise ValueError("This function currently supports only 4 tiles.")
    
    _, C, H, W = image_tensor.shape  
    tile_height, tile_width = H // 3, W // 3
    
    tiles = {
        'tl': image_tensor[:, :, :tile_height, :tile_width],
        'tr': image_tensor[:, :, :tile_height, tile_width:],
        'bl': image_tensor[:, :, tile_height:, :tile_width],
        'br': image_tensor[:, :, tile_height:, tile_width:]
    }
    
    return tiles

tiles = cut_into_tiles_tensor(preprocessed_image, 4)

'''

' for 4 tiles \ndef cut_into_tiles_tensor(image_tensor, num_tiles=9):\n\n    if num_tiles != 4:\n        raise ValueError("This function currently supports only 4 tiles.")\n    \n    _, C, H, W = image_tensor.shape  \n    tile_height, tile_width = H // 3, W // 3\n    \n    tiles = {\n        \'tl\': image_tensor[:, :, :tile_height, :tile_width],\n        \'tr\': image_tensor[:, :, :tile_height, tile_width:],\n        \'bl\': image_tensor[:, :, tile_height:, :tile_width],\n        \'br\': image_tensor[:, :, tile_height:, tile_width:]\n    }\n    \n    return tiles\n\ntiles = cut_into_tiles_tensor(preprocessed_image, 4)\n\n'

In [5]:
def cut_into_tiles_tensor(image_tensor, num_tiles=9):
    if num_tiles != 9:
        raise ValueError("This function currently supports only 9 tiles.")
    
    _, C, H, W = image_tensor.shape
    # Ensure the image dimensions are divisible by 3
    if H % 3 != 0 or W % 3 != 0:
        raise ValueError("Image dimensions must be divisible by 3.")
    
    tile_height, tile_width = H // 3, W // 3
    
    tiles = {
        'tl': image_tensor[:, :, :tile_height, :tile_width],  # Top left
        'tc': image_tensor[:, :, :tile_height, tile_width:2*tile_width],  # Top center
        'tr': image_tensor[:, :, :tile_height, 2*tile_width:],  # Top right
        'ml': image_tensor[:, :, tile_height:2*tile_height, :tile_width],  # Middle left
        'mc': image_tensor[:, :, tile_height:2*tile_height, tile_width:2*tile_width],  # Middle center
        'mr': image_tensor[:, :, tile_height:2*tile_height, 2*tile_width:],  # Middle right
        'bl': image_tensor[:, :, 2*tile_height:, :tile_width],  # Bottom left
        'bc': image_tensor[:, :, 2*tile_height:, tile_width:2*tile_width],  # Bottom center
        'br': image_tensor[:, :, 2*tile_height:, 2*tile_width:]  # Bottom right
    }
    
    return tiles

tiles = cut_into_tiles_tensor(preprocessed_image)


In [6]:
def create_permuted_images_tensor(tiles, num_permutations):
    # Assuming tiles is a dictionary of tensors and we have a square number of them
    num_tiles = len(tiles)
    n = int(math.sqrt(num_tiles))  # n x n grid
    if n ** 2 != num_tiles:
        raise ValueError("Number of tiles must be a perfect square.")

    tile_positions = list(tiles.keys())
    permuted_images = []

    for _ in range(num_permutations):
        random.shuffle(tile_positions)
        rows = []

        for i in range(0, num_tiles, n):
            # Concatenate tiles along width to form one row
            row = torch.cat([tiles[pos] for pos in tile_positions[i:i+n]], dim=3)
            rows.append(row)

        # Concatenate all rows along height to form the new image
        new_img_tensor = torch.cat(rows, dim=2)
        permuted_images.append(new_img_tensor)

    return torch.stack(permuted_images)

permuted_images_tensors = create_permuted_images_tensor(tiles, 1)

In [7]:

from vig import Grapher

num_patches = 225

grapher_module = Grapher(
    in_channels=4,         # RGB image has 3 channels
    kernel_size=3,         # A common choice, could be different based on your architecture
    dilation=1,            # Standard dilation
    conv='edge',           # Replace 'edge' with the actual type used in your model
    act='relu',            # A common activation function
    norm=None,             # Depends on whether you want to use normalization
    bias=True,             # Typically, biases are used
    stochastic=False,      # If stochastic depth is not used
    epsilon=0.0,           # Hyperparameter for the edge convolution
    r=1,                   # Downsampling rate
    n=num_patches,         # Number of nodes
    drop_path=0.0,         # Drop path rate for stochastic depth
    relative_pos=True,
    # groups = 1     # Set to True if the model uses relative positions
)

grapher_module = grapher_module.to(permuted_images_tensors[0].device)
print(grapher_module)

graph_output = grapher_module(permuted_images_tensors[0])



using relative_pos
Grapher(
  (fc1): Sequential(
    (0): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (graph_conv): DyGraphConv2d(
    (gconv): EdgeConv2d(
      (nn): BasicConv(
        (0): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1), groups=4)
        (1): ReLU()
      )
    )
    (dilated_knn_graph): DenseDilatedKnnGraph(
      (_dilated): DenseDilated()
    )
  )
  (fc2): Sequential(
    (0): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (drop_path): Identity()
)
hi


: 

In [None]:

pool = nn.AdaptiveAvgPool2d((14, 14))

pooled_output = pool(graph_output)

print(pooled_output.shape) 

In [None]:
from gcn_lib import DenseDilatedKnnGraph

dense_dilated_knn_graph = DenseDilatedKnnGraph(k=20, dilation=1, stochastic=False, epsilon=0.0)

num_points = pooled_output.shape[2] * pooled_output.shape[3]  # For a 224x224 image, this would be 50176
pooled_output_reshaped = pooled_output.reshape(1, 4, num_points, 1)  # Reshape to [1, 4, 50176, 1]

# Use the module to find edges
edge_index = dense_dilated_knn_graph(pooled_output_reshaped)

# edge_index now contains the indices of edges based on dilated KNN
edge_index.shape

In [None]:
k=9
edge_index_flat = edge_index.view(2, -1)

edge_index_flat.shape

In [None]:
N = edge_index_flat.shape[1] // k  
print(N)
adjacency_matrix = torch.zeros((N, N))


In [None]:
for i in range(edge_index_flat.shape[1]):
    source_node = edge_index_flat[0, i]
    target_node = edge_index_flat[1, i]

    adjacency_matrix[source_node, target_node] = 1

edge_list = edge_index_flat.t().tolist()  # Convert to a list of tuples

In [None]:
num_nodes = 14*14
feature_dimension = 4
graph_features = pooled_output.reshape(1, feature_dimension, num_nodes)

graph_features = graph_features.squeeze(0)  # This is now of shape [4, 50176]

node_features = {node: [] for node in range(num_nodes)}  # Dictionary to store features for each node

for target_node, source_node in edge_list:

    node_features[target_node].append(graph_features[:, source_node])

In [None]:
edge_list

In [None]:


image_array = permuted_images_tensors[0][0][0]  # Placeholder for the actual image

grid_size = 16  # Grid size for the 14x14 grid

# Example edge cases
edge_nine = [
[5, 5],
[93, 5],
[191, 5],
[107, 5],
[123, 5],
[19, 5],
[121, 5],
[137, 5],
[158, 5],
[94, 5],
[177, 5],
[32, 5],
[163, 5],
[95, 5],
[124, 5],
[18, 5],
[30, 5],
[17, 5],
[72, 5],
[52, 5]]

# Display the image as a background
fig, ax = plt.subplots(figsize=(10, 10))
plt.imshow(image_array, extent=[0, grid_size-1, grid_size-1, 0])


ax.set_xlim([0, grid_size-1])
ax.set_ylim([0, grid_size-1])
ax.set_aspect('equal')
# Inverting y-axis to have (0,0) at top-left
ax.set_ylim(ax.get_ylim()[::-1])

ax.set_xticks(np.arange(0, grid_size, 1))
ax.set_yticks(np.arange(0, grid_size, 1))


# Overlay the connections on top of the image
for conn in edge_nine:

    x1, y1 = ((conn[0]) % 15, (conn[0]) // 15)
    x2, y2 = ((conn[1]) % 15, (conn[1]) // 15)

    x1_adjust = x1 + 0.5

    x2_adjust = x2 + 0.5
    
    plt.plot([x1_adjust, x2_adjust], [y1 + 0.5, y2 + 0.5], marker='o', markersize=5, linestyle='-', color='red', linewidth=1)

plt.grid(True)
plt.show()
