Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ViG models [NeurIPS 2022] #1578

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,16 @@ output/
Untitled.ipynb
Testing notebook.ipynb


# MacOS
*.DS_Store

# Root dir exclusions
/*.csv
/*.yaml
/*.json
/*.jpg
/*.png
/*.zip
/*.tar.*
/*.tar.*

3 changes: 2 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*'
'eva_*', 'flexivit*', 'eva02*', 'pvig_*', 'samvit_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

Expand Down Expand Up @@ -405,6 +405,7 @@ def _create_fx_model(model, train=False):
'vit_large*',
'vit_base_patch8*',
'xcit_large*',
'pvig_*',
]


Expand Down
1 change: 1 addition & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .gnn_layers import DyGraphConv2d
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .inplace_abn import InplaceAbn
from .linear import Linear
Expand Down
215 changes: 215 additions & 0 deletions timm/layers/gnn_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Layers for GNN model
# Reference: https://github.com/lightaime/deep_gcns_torch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from .drop import DropPath


def pairwise_distance(x, y):
"""
Compute pairwise distance of a point cloud
"""
with torch.no_grad():
xy_inner = -2*torch.matmul(x, y.transpose(2, 1))
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True)
return x_square + xy_inner + y_square.transpose(2, 1)


def dense_knn_matrix(x, y, k=16, relative_pos=None):
"""Get KNN based on the pairwise distance
"""
with torch.no_grad():
x = x.transpose(2, 1).squeeze(-1)
y = y.transpose(2, 1).squeeze(-1)
batch_size, n_points, n_dims = x.shape
dist = pairwise_distance(x.detach(), y.detach())
if relative_pos is not None:
dist += relative_pos
_, nn_idx = torch.topk(-dist, k=k)
center_idx = torch.arange(0, n_points, device=x.device).repeat(batch_size, k, 1).transpose(2, 1)
return torch.stack((nn_idx, center_idx), dim=0)


class DenseDilated(nn.Module):
"""
Find dilated neighbor from neighbor list
"""
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(DenseDilated, self).__init__()
self.dilation = dilation
self.stochastic = stochastic
self.epsilon = epsilon
self.k = k

def forward(self, edge_index):
if self.stochastic:
if torch.rand(1) < self.epsilon and self.training:
num = self.k * self.dilation
randnum = torch.randperm(num)[:self.k]
edge_index = edge_index[:, :, :, randnum]
else:
edge_index = edge_index[:, :, :, ::self.dilation]
else:
edge_index = edge_index[:, :, :, ::self.dilation]
return edge_index


class DenseDilatedKnnGraph(nn.Module):
"""
Find the neighbors' indices based on dilated knn
"""
def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0):
super(DenseDilatedKnnGraph, self).__init__()
self.dilation = dilation
self.k = k
self._dilated = DenseDilated(k, dilation, stochastic, epsilon)

def forward(self, x, y=None, relative_pos=None):
x = F.normalize(x, p=2.0, dim=1)
if y is not None:
y = F.normalize(y, p=2.0, dim=1)
edge_index = dense_knn_matrix(x, y, self.k * self.dilation, relative_pos)
else:
edge_index = dense_knn_matrix(x, x, self.k * self.dilation, relative_pos)
return self._dilated(edge_index)


def batched_index_select(x, idx):
# fetches neighbors features from a given neighbor idx
batch_size, num_dims, num_vertices_reduced = x.shape[:3]
_, num_vertices, k = idx.shape
idx_base = torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced
idx = idx + idx_base
idx = idx.contiguous().view(-1)

x = x.transpose(2, 1)
feature = x.contiguous().view(batch_size * num_vertices_reduced, -1)[idx, :]
feature = feature.view(batch_size, num_vertices, k, num_dims).permute(0, 3, 1, 2).contiguous()
return feature


def norm_layer(norm, nc):
# normalization layer 2d
norm = norm.lower()
if norm == 'batch':
layer = nn.BatchNorm2d(nc, affine=True)
elif norm == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm)
return layer


class MRConv2d(nn.Module):
"""
Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751) for dense data type
"""
def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True):
super(MRConv2d, self).__init__()
# self.nn = BasicConv([in_channels*2, out_channels], act_layer, norm, bias)
self.nn = nn.Sequential(
nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4),
norm_layer(norm, out_channels),
act_layer(),
)

self.init_weights()

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, x, edge_index, y=None):
x_i = batched_index_select(x, edge_index[1])
if y is not None:
x_j = batched_index_select(y, edge_index[0])
else:
x_j = batched_index_select(x, edge_index[0])
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
b, c, n, _ = x.shape
x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)], dim=2).reshape(b, 2 * c, n, _)
return self.nn(x)


class EdgeConv2d(nn.Module):
"""
Edge convolution layer (with activation, batch normalization) for dense data type
"""
def __init__(self, in_channels, out_channels, act_layer=nn.GELU, norm=None, bias=True):
super(EdgeConv2d, self).__init__()
self.nn = nn.Sequential(
nn.Conv2d(in_channels*2, out_channels, 1, bias=bias, groups=4),
norm_layer(norm, out_channels),
act_layer(),
)

self.init_weights()

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, x, edge_index, y=None):
x_i = batched_index_select(x, edge_index[1])
if y is not None:
x_j = batched_index_select(y, edge_index[0])
else:
x_j = batched_index_select(x, edge_index[0])
max_value, _ = torch.max(self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
return max_value


class GraphConv2d(nn.Module):
"""
Static graph convolution layer
"""
def __init__(self, in_channels, out_channels, conv='mr', act_layer=nn.GELU, norm=None, bias=True):
super(GraphConv2d, self).__init__()
if conv == 'edge':
self.gconv = EdgeConv2d(in_channels, out_channels, act_layer, norm, bias)
elif conv == 'mr':
self.gconv = MRConv2d(in_channels, out_channels, act_layer, norm, bias)
else:
raise NotImplementedError('conv:{} is not supported'.format(conv))

def forward(self, x, edge_index, y=None):
return self.gconv(x, edge_index, y)


class DyGraphConv2d(GraphConv2d):
"""
Dynamic graph convolution layer
"""
def __init__(self, in_channels, out_channels, kernel_size=9, dilation=1, conv='mr', act_layer=nn.GELU,
norm=None, bias=True, stochastic=False, epsilon=0.0, r=1):
super(DyGraphConv2d, self).__init__(in_channels, out_channels, conv, act_layer, norm, bias)
self.k = kernel_size
self.d = dilation
self.r = r
self.dilated_knn_graph = DenseDilatedKnnGraph(kernel_size, dilation, stochastic, epsilon)

def forward(self, x, relative_pos=None):
B, C, H, W = x.shape
y = None
if self.r > 1:
y = F.avg_pool2d(x, self.r, self.r)
y = y.reshape(B, C, -1, 1).contiguous()
x = x.reshape(B, C, -1, 1).contiguous()
edge_index = self.dilated_knn_graph(x, y, relative_pos)
x = super(DyGraphConv2d, self).forward(x, edge_index, y)
return x.reshape(B, -1, H, W).contiguous()
1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .twins import *
from .vgg import *
from .visformer import *
from .vision_gnn import *
from .vision_transformer import *
from .vision_transformer_hybrid import *
from .vision_transformer_relpos import *
Expand Down
1 change: 1 addition & 0 deletions timm/models/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from timm.layers.filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from timm.layers.gather_excite import GatherExcite
from timm.layers.global_context import GlobalContext
from timm.layers.gnn_layers import DyGraphConv2d
from timm.layers.helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from timm.layers.inplace_abn import InplaceAbn
from timm.layers.linear import Linear
Expand Down