# [Vision GNN: An Image is Worth Graph of Nodes](https://ar5iv.labs.arxiv.org/html/2206.00272)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Literal, Optional

In [2]:
class DenseDilated(nn.Module):
    """
    Find dilated neighbor from neighbor list
    edge_index: (2, batch_size, num_points, k)
    """
    def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
        super(DenseDilated, self).__init__()
        self.dilation = dilation
        self.stochastic = stochastic
        self.epsilon = epsilon
        self.k = k

    def forward(self, edge_index):
        if self.stochastic:
            if torch.rand(1) < self.epsilon and self.training:
                num = self.k * self.dilation
                randnum = torch.randperm(num)[:self.k]
                edge_index = edge_index[:, :, :, randnum]
            else:
                edge_index = edge_index[:, :, :, ::self.dilation]
        else:
            edge_index = edge_index[:, :, :, ::self.dilation]
        return edge_index

In [3]:
class DenseDilatedKnnGraph(nn.Module):
    """
    Find the neighbors' indices based on dilated knn
    """
    def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
        super(DenseDilatedKnnGraph, self).__init__()
        self.dilation = dilation
        self.stochastic = stochastic
        self.epsilon = epsilon
        self.k = k
        self._dilated = DenseDilated(k, dilation, stochastic, epsilon)
        self.knn = dense_knn_matrix

    def forward(self, x):
        edge_index = self.knn(x, self.k * self.dilation)
        return self._dilated(edge_index)
    
class DilatedKnnGraph(nn.Module):
    """
    Find the neighbors' indices based on dilated knn
    """
    def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
        super(DilatedKnnGraph, self).__init__()
        self.dilation = dilation
        self.stochastic = stochastic
        self.epsilon = epsilon
        self.k = k
        self._dilated = DenseDilated(k, dilation, stochastic, epsilon)
        self.knn = knn_graph

    def forward(self, x):
        x = x.squeeze(-1)
        B, C, N = x.shape
        edge_index = []
        for i in range(B):
            edgeindex = self.knn(x[i].contiguous().transpose(1, 0).contiguous(), self.k * self.dilation)
            edgeindex = edgeindex.view(2, N, self.k * self.dilation)
            edge_index.append(edgeindex)
        edge_index = torch.stack(edge_index, dim=1)
        return self._dilated(edge_index)

In [4]:
class BasicConv(nn.Sequential):
    def __init__(
        self, channels: tuple[int, ...],
        act: Literal["relu", "leakyrelu", "prelu"] = "relu",
        norm: Optional[Literal["batch", "instance"]] = None,
        bias: bool = True, drop: float = 0.0
    ):
        m = []
        for i in range(1, len(channels)):
            m.append(nn.Conv2d(channels[i - 1], channels[i], 1, bias=bias))
            if act is not None and act.lower() != 'none':
                m.append(act_layer(act))
            if norm is not None and norm.lower() != 'none':
                m.append(norm_layer(norm, channels[-1]))
            if drop > 0:
                m.append(nn.Dropout2d(drop))

        super(BasicConv, self).__init__(*m)

        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [5]:
class MRConv2d(nn.Module):
    """
    Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
    """
    def __init__(
        self, in_channels: int, out_channels: int,
        act: Literal["relu", "leakyrelu", "prelu"] = "relu",
        norm: Optional[Literal["batch", "instance"]] = None,
        bias: bool = True
    ):
        super(MRConv2d, self).__init__()
        self.nn = BasicConv([in_channels*2, out_channels], act, norm, bias)

    def forward(self, x, edge_index):
        x_i = batched_index_select(x, edge_index[1])
        x_j = batched_index_select(x, edge_index[0])
        x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
        return self.nn(torch.cat([x, x_j], dim=1))

In [6]:
class EdgeConv2d(nn.Module):
    """
    Edge convolution layer (with activation, batch normalization) for dense data type
    """
    def __init__(
        self, in_channels: int, out_channels: int,
        act: Literal["relu", "leakyrelu", "prelu"] = "relu",
        norm: Optional[Literal["batch", "instance"]] = None,
        bias: bool = True
    ):
        super(EdgeConv2d, self).__init__()
        self.nn = BasicConv([in_channels * 2, out_channels], act, norm, bias)

    def forward(self, x, edge_index):
        x_i = batched_index_select(x, edge_index[1])
        x_j = batched_index_select(x, edge_index[0])
        max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
        return max_value

In [7]:
class GraphConv2d(nn.Module):
    """
    Static Graph Convolution Layer
    Ref: https://github.com/lightaime/deep_gcns_torch/blob/master/gcn_lib/dense/torch_vertex.py
    """
    def __init__(
        self, in_channels: int, out_channels: int,
        conv: Literal["edge", "mr"] = "edge",
        act: Literal["relu", "leakyrelu", "prelu"] = "relu",
        norm: Optional[Literal["batch", "instance"]] = None,
        bias: bool = True
    ):
        super(GraphConv2d, self).__init__()
        
        self.gconv = {
            "edge": EdgeConv2d(in_channels, out_channels, act, norm, bias),
            "mr": MRConv2d(in_channels, out_channels, act, norm, bias)
        }.get(conv, NotImplementedError(""))
        
    def forward(self, x, edge_index):
        return self.gconv(x, edge_index)

In [8]:
class DynamicConv2d(GraphConv2d):
    """
    Dynamic Graph Covolution Layer
    """
    def __init__(
        self, in_channels, out_channels, 
        kernel_size: int = 9, dilation =1, 
        conv='edge', act='relu',
        norm=None, bias=True,
        stochastic=False, epsilon=0.0, knn='matrix'
    ):
        super(DynamicConv2d, self).__init__(in_channels, out_channels, conv, act, norm, bias)
        self.k = kernel_size
        self.d = dilation
        if knn == 'matrix':
            self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)
        else:
            self.dilated_knn_graph = DilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)

    def forward(self, x, edge_index=None):
        if edge_index is None:
            edge_index = self.dilated_knn_graph(x)
        return super(DynConv2d, self).forward(x, edge_index)
    

In [10]:
class GrapherModule(nn.Module):
    def __init__(
        self, in_channels, hidden_channels,
        k=9, dilation=1, drop_path=0.0
    ):
        super(GrapherModule, self).__init__()
        self.fc_1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels)
        )
        
        self.graph_conv = nn.Sequential(
            DynamicConv2d(in_channels, hidden_channels, k, dilation, act=None),
            nn.BatchNorm2d(hidden_channels),
            nn.GELU()
        )
        
        self.fc_2 = nn.Sequential(
            nn.Conv2d(hidden_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels)
        )
        
        self.drop_path = nn.Identity()
    
    def forward(self, x):
        
