## Tensor Reshaping

In [46]:
import torch

# Define the tensor (example with more than 2 rectangles)
tensor = torch.tensor([[[0, 1, 1, 2],
                        [1, 0, 2, 1]],

                       [[2, 3, 3, 4],
                        [3, 2, 4, 3]],

                       [[5, 6, 6, 7],
                        [6, 5, 7, 6]]])
print(tensor.shape)
tmp = tensor.permute(1, 0, 2)
print(tmp.shape)
tmp

torch.Size([3, 2, 4])
torch.Size([2, 3, 4])


tensor([[[0, 1, 1, 2],
         [2, 3, 3, 4],
         [5, 6, 6, 7]],

        [[1, 0, 2, 1],
         [3, 2, 4, 3],
         [6, 5, 7, 6]]])

In [48]:
result = tmp.reshape(tensor.shape[1], -1)
result

tensor([[0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 6, 7],
        [1, 0, 2, 1, 3, 2, 4, 3, 6, 5, 7, 6]])

In [None]:

# # Reshape the tensor to concatenate all rectangles along the second dimension
# result = tensor.permute(1, 0, 2).reshape(tensor.shape[1], -1)

# print(result)


## Investigate GATConv

In [1]:
import torch
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, Batch

In [124]:
layer = GATConv(2, 8, heads=4)
# layer = GCNConv(2, 32)

In [125]:
# Original Data
edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
edge_index=edge_index.t().contiguous()
x = torch.tensor([[-1, 2], [0, 3], [1, 4]], dtype=torch.float)

In [126]:
# Create mini-batch
x1 = torch.tensor([[-1, 3], [0, 4], [1, 5]], dtype=torch.float)
x_batch = torch.stack([x, x1])

edge_index1 = edge_index + x.shape[0]
edge_index_batch = torch.stack([edge_index, edge_index1])

print(x_batch.shape)
print(edge_index_batch.shape)


torch.Size([2, 3, 2])
torch.Size([2, 2, 4])


In [127]:
x_batch

tensor([[[-1.,  2.],
         [ 0.,  3.],
         [ 1.,  4.]],

        [[-1.,  3.],
         [ 0.,  4.],
         [ 1.,  5.]]])

In [128]:
edge_index_batch

tensor([[[0, 1, 1, 2],
         [1, 0, 2, 1]],

        [[3, 4, 4, 5],
         [4, 3, 5, 4]]])

In [129]:
x_b = x_batch.view(-1, x_batch.size(-1))
print(x_b.shape)
x_b

torch.Size([6, 2])


tensor([[-1.,  2.],
        [ 0.,  3.],
        [ 1.,  4.],
        [-1.,  3.],
        [ 0.,  4.],
        [ 1.,  5.]])

In [130]:
# Create collated minibatch manually
x_batch_pyg = torch.reshape(x_batch, (-1, x_batch.shape[-1]))
print(x_batch_pyg.shape)
x_batch_pyg

torch.Size([6, 2])


tensor([[-1.,  2.],
        [ 0.,  3.],
        [ 1.,  4.],
        [-1.,  3.],
        [ 0.,  4.],
        [ 1.,  5.]])

In [131]:
# (B, 2, E) -> (2, B, E) -> (2, B*E)
# edge_index_batch_pyg = torch.permute(edge_index_batch, (1, 0, 2)).reshape(2, -1)
# print(edge_index_batch_pyg.shape)
# edge_index_batch_pyg

N_nodes = x_batch.size(1)
batch_size = x_batch.size(0)
edge_index_batch_pyg = torch.cat([edge_index + i * N_nodes for i in range(batch_size)], dim=1)

print(edge_index_batch_pyg.shape)

torch.Size([2, 8])


In [132]:
# Create a minibatch using PyG Data objects
data_list = [Data(x=x, edge_index=edge_index) for x in x_batch]
batch = Batch.from_data_list(data_list)

batch

DataBatch(x=[6, 2], edge_index=[2, 8], batch=[6], ptr=[3])

In [114]:
batch.x

tensor([[-1.,  2.],
        [ 0.,  3.],
        [ 1.,  4.],
        [-1.,  3.],
        [ 0.,  4.],
        [ 1.,  5.]])

In [115]:
batch.edge_index

tensor([[0, 1, 1, 2, 0, 3, 4, 4, 5, 3],
        [1, 0, 2, 1, 2, 4, 3, 5, 4, 5]])

In [116]:
print(torch.isclose(x_batch_pyg, batch.x).all())
print(torch.isclose(edge_index_batch_pyg, batch.edge_index).all())

tensor(True)
tensor(True)


In [60]:
y_batch = layer(batch.x, batch.edge_index)

In [61]:
y_batch.shape

torch.Size([6, 32])

In [70]:
y_b = torch.reshape(y_batch, (x_batch.shape[0], x_batch.shape[1], -1))
print(y_b.shape)

torch.Size([2, 3, 32])


In [None]:
import torch
import os

# Define the dimensions
N_rec = 50
N_send = 50
d_h = 10

# Create the tensors
rec_rep = torch.randn(N_rec, d_h)
send_rep = torch.randn(N_send, d_h)

# Concatenate the tensors
node_reps = torch.cat((rec_rep, send_rep), dim=-2)

# Print the shape of the resulting tensor
print(node_reps.shape)

torch.Size([100, 10])


# Learn PyG

In [None]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        
        print("edge_index.shape")
        print(edge_index.shape)
        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        # [2, E] -> [2, E + N]
        print("edge_index.shape (self loops added)")
        print(edge_index.shape)
        
        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x) # [N, in_channels] -> [N, out_channels]
        print("x.shape")
        print(x.shape)

        # Step 3: Compute normalization.
        row, col = edge_index

        # Diagonal of degree matrix (all elements not on diagonal are 0)
        deg = degree(col, x.size(0), dtype=x.dtype) # [N]
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        print("degree shapes")
        print(deg.shape)
        print(deg_inv_sqrt.shape)
        
        # Elementwise multiplication
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # [E]
        print("norm.shape")
        print(norm.shape)

        # Step 4-5: Start propagating messages.
        # calculate message by passing x, norm to self.message() (internally)
        # for each node, use edge_index to determine incoming messages
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: Apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # calculates the message for node i (that is, x_i)
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        
        print("x_j.shape")
        print(x_j.shape)
        print("norm.view(-1, 1).shape")
        print(norm.view(-1, 1).shape)
        print("message shape")
        print((norm.view(-1, 1) * x_j).shape)
        return norm.view(-1, 1) * x_j
    
    def update(self, aggr_out, x):
        print("aggr_out.shape")
        print(aggr_out.shape)
        print("x.shape")
        print(x.shape)
        return aggr_out + x

In [None]:
# Create an instance of GCNConv
gcn = GCNConv(in_channels=64, out_channels=128)

# Create sample input data
x = torch.randn(10, 64)  # Node feature matrix
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8],
                           [1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 7]])  # Edge indices

# Pass the input data through GCNConv
output = gcn(x, edge_index)

# Print the output
print(output)
print(output.shape)

edge_index.shape
torch.Size([2, 11])
edge_index.shape (self loops added)
torch.Size([2, 21])
x.shape
torch.Size([10, 128])
degree shapes
torch.Size([10])
torch.Size([10])
norm.shape
torch.Size([21])
x_j.shape
torch.Size([21, 128])
norm.view(-1, 1).shape
torch.Size([21, 1])
message shape
torch.Size([21, 128])
aggr_out.shape
torch.Size([10, 128])
x.shape
torch.Size([10, 128])
tensor([[ 0.5136, -0.7564, -0.1948,  ..., -0.0750,  0.2263, -0.8994],
        [ 0.5533, -0.5819, -0.0547,  ...,  0.6010, -0.3274, -0.3957],
        [ 0.3818,  0.1921,  0.9273,  ..., -0.3475, -0.1355, -0.7313],
        ...,
        [-0.0356,  0.4850, -0.4018,  ..., -0.4825,  0.5431,  1.2539],
        [-1.2525,  0.2215,  0.4838,  ..., -0.0128,  0.4759,  0.0181],
        [ 0.4090, -0.0837, -0.1797,  ...,  0.2653, -0.4348, -0.2556]],
       grad_fn=<AddBackward0>)
torch.Size([10, 128])


In [None]:
import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)

x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous())

In [None]:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

data = dataset[0]

In [None]:
len(dataset)

600

In [None]:
data

Data(edge_index=[2, 168], x=[37, 21], y=[1])

In [None]:
batch = next(iter(loader))

In [None]:
batch

DataBatch(edge_index=[2, 4388], x=[1228, 21], y=[32], batch=[1228], ptr=[33])

In [None]:
batch.num_graphs

32

In [None]:
batch.num_nodes

1228

In [None]:
batch.batch.shape

torch.Size([1024])

In [None]:
batch.x.shape

torch.Size([1024, 21])

In [None]:
from torch_geometric.utils import scatter


x = scatter(batch.x, batch.batch, dim=0, reduce='mean')

In [None]:
x.shape

torch.Size([32, 21])