In [None]:
from typing import Callable, Optional, Union

import numpy as np

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor
from torch_geometric.nn.models import MLP
from torch_geometric.nn.conv import GATConv, GATv2Conv

from torch_cluster import knn, knn_graph

from torch_kmeans import KMeans

from pygcn_lib.torch_centroid import ArgsUpCentroidsEdgeConv, ArgsUpCentroidsGATConv, UpdateCentroids
from pygcn_lib.torch_vertex import EdgeConvStaticGConv, MRConvStaticGConv, GINConvStaticGConv, GraphSAGEStaticGConv, StaticGraphConv, ArgsStaticGraphConv


graphs = 10
points_per_graph = 100
clusters = 4
in_C = 32
out_C = 64
neighbors = 10
device = torch.device('cuda:0')



N = graphs*points_per_graph
subgraphs = graphs*clusters
x = torch.rand(N, in_C).to(device)
batch = torch.concat([torch.arange(subgraphs, dtype=torch.int64).repeat_interleave(N // subgraphs), torch.tensor([subgraphs - 1], dtype=torch.int64).repeat(N - ((N // subgraphs)*subgraphs))], dim=-1).to(device)
edge_index = knn_graph(x=x, k=neighbors, batch=batch, loop=True, flow='source_to_target').to(device)
x_center = torch.rand(subgraphs, in_C).to(device)
batch_center = torch.arange(graphs, dtype=torch.int64).repeat_interleave(clusters).to(device)



args = ArgsStaticGraphConv( in_channels=in_C,
                            out_channels=out_C,
                            args_centroid=ArgsUpCentroidsGATConv(   num_centroids=clusters,
                                                                    in_channels=in_C,
                                                                    out_channels_total=32,
                                                                    heads=4,
                                                                    concat=True,
                                                                    dropout=0.0,
                                                                    negative_slope=0.2,
                                                                    aggr='add',
                                                                    version='v1'),
                            conv_centroid='gat',
                            groups=1,
                            dropout=0.0,
                            act='gelu',
                            norm='batch_norm',
                            aggr='max')


conv = 'edge'


layer = StaticGraphConv(args,
                        conv).to(device)

out = layer(x=x,
            batch=batch,
            edge_index=edge_index,
            x_center=x_center,
            batch_center=batch_center)

print(out)
print(out.shape)

In [None]:
from pygcn_lib.torch_dynamic import DynamicGraphConv, ArgsClusterKMeans
from pygcn_lib.torch_vertex import ArgsStaticGraphConv
from pygcn_lib.torch_centroid import ArgsUpCentroidsEdgeConv, ArgsUpCentroidsGATConv
import torch

layer = DynamicGraphConv(args_gconv=ArgsStaticGraphConv(in_channels=2,
                                                        out_channels=6,
                                                        args_centroid=ArgsUpCentroidsEdgeConv(num_centroids=4,
                                                                                              in_channels=2,
                                                                                              out_channels=8,
                                                                                              groups=1,
                                                                                              dropout=0.0,
                                                                                              act='relu',
                                                                                              norm='batch_norm',
                                                                                              aggr='max',
                                                                                              ),
                                                        conv_centroid='edge',
                                                        groups=1,
                                                        dropout=0.0,
                                                        act='gelu',
                                                        norm='batch_norm',
                                                        aggr='max'),
                         conv_gconv='mr',
                         neighbors=5,
                         dilation=2,
                         reduction=None,
                         normalize_for_edges=True,
                         stochastic=True,
                         epsilon=0.1,
                         drop_rate_neighbors=0.1,
                         method_for_edges='dropout',
                         args_cluster=ArgsClusterKMeans()).to(torch.device('cuda:0'))
print(layer)

In [None]:
input = torch.rand(3, 2, 10, 2).to(torch.device('cuda:0'))
output = layer(input)
print(output.shape)

In [None]:
import torch
import torch.nn as nn

in_channels = 3
kernel_size = 4

pe = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=1,
            padding='same',
            bias=True,
            groups=in_channels)

input = torch.rand(1, 3, 4, 4)

output = pe(input)

print(output.shape)


In [None]:
import torch

x = torch.rand(2, 3, 2, 2)
y = torch.tensor([10, 20, 30]).unsqueeze(-1).unsqueeze(-1)

print(x)
print(y)
print(x*y)

In [None]:
from pygcn_lib.torch_dynamic import Grapher
import torch

layer = Grapher(in_channels=192,
                out_channels=192*2,
                factor=1,
                dropout=0.0,
                act='gelu',
                norm='batch_norm',
                drop_path=0.0,
                clusters=4,
                neighbors=18,
                dilation=2,
                stochastic=False,
                epsilon=0.2,
                drop_rate_neighbors=None,
                method_for_edges='dilated',
                init_method='rnd',
                num_init=4,
                max_iter=50,
                tol=5e-4,
                vertex_conv='mr',
                center_conv='gat',
                use_conditional_pos=True,
                use_relative_pos=None).to(torch.device('cuda:2'))

input = torch.rand(64, 192, 14, 14).to(torch.device('cuda:2'))
output = layer(input)

print(output)
print(output.shape)

In [2]:
from clustervig_test import IsoClusterViG_Ti_n196_c4
import torch

import os

torch.cuda.set_device(1)

device = torch.device('cuda:1')

model = IsoClusterViG_Ti_n196_c4().to(device)
input = torch.rand(64, 3, 224, 224).to(device)

output = model(input)

print(output[0].shape)

0
tensor(1, device='cuda:1')
1
2
tensor(1, device='cuda:1')
3
4
tensor(1, device='cuda:1')
5
6
tensor(1, device='cuda:1')
7
8
tensor(1, device='cuda:1')
9
10
tensor(1, device='cuda:1')
11
12
tensor(1, device='cuda:1')
13
14
tensor(1, device='cuda:1')
15
16
tensor(1, device='cuda:1')
17
18
tensor(1, device='cuda:1')
19
20
tensor(1, device='cuda:1')
21
22
tensor(1, device='cuda:1')
23
torch.Size([64, 1000])
