In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Sequential, Linear, ReLU
import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch_geometric.transforms as T
import torch_cluster
from torch_geometric.nn import NNConv, GCNConv, GraphConv
from torch_geometric.nn import PointConv, SplineConv

from torch_geometric.nn.inits import reset

from torch_geometric.data import DataLoader
from preprocessing1 import preprocess_dataset

import sys

In [8]:
datafile='./data/train_.pt'

device = torch.device('cuda:0')
showers = preprocess_dataset(datafile)

In [14]:
def str_to_class(classname: str):
    """
    Function to get class object by its name signature
    :param classname: str
        name of the class
    :return: class object with the same name signature as classname
    """
    return getattr(sys.modules[__name__], classname)

In [24]:
shower = showers[0].to(device)

In [135]:
class EdgeConv(MessagePassing):   
    def __init__(self, nn, aggr='max', **kwargs):
        super(EdgeConv, self).__init__(aggr=aggr, **kwargs)
        self.nn = nn
        self.reset_parameters()


    def reset_parameters(self):
        reset(self.nn)


    def forward(self, x, edge_index, dist):
        print(x.shape, edge_index.shape, dist.shape)
        """"""
        x = x.unsqueeze(-1) if x.dim() == 1 else x
      
        return self.propagate(edge_index, x=x, dist=dist)


    def message(self, x_i, x_j, dist):
        print(x_i.shape, len(dist))
        return self.nn(torch.cat([x_i, x_j - x_i, dist], dim=1))

    def __repr__(self):
        return '{}(nn={})'.format(self.__class__.__name__, self.nn)      

In [191]:
class EmulsionConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.mp = torch.nn.Linear(in_channels * 2, out_channels)

    def forward(self, x, edge_index, orders, dist):
        for order in orders:
            print(dist.shape)
            print(dist.transpose(0, 1)[:, order].shape)
            
            x = self.propagate(torch.index_select(edge_index[:, order],
                                                  0,
                                                  torch.LongTensor([1, 0]).to(x.device)), x=x, 
                               dist=torch.index_select(dist.transpose(0, 1)[:, order],
                                                  0,
                                                  torch.LongTensor([1, 0]).to(x.device)))
        
        return x

    def message(self, x_j, x_i, dist):
        print(x_i.shape, len(dist))
        return self.mp(torch.cat([x_i, x_j - x_i, dist], dim=1))

    def update(self, aggr_out, x):
        return aggr_out + x

In [192]:
####------------------With EMULSIONCONV layer------------------------####

class GraphNN_KNN_v1_v0(nn.Module):
    def __init__(self, k, dim_out=10):
        super().__init__()
        self.k = k       
        self.emconv1 = EmulsionConv(self.k, self.k)
        self.output = nn.Linear(10, dim_out)

    def forward(self, data):
        x, edge_index, orders, dist = data.x, data.edge_index, data.mask, data.edge_attr
        x = self.emconv1(x=x, edge_index=edge_index, orders=orders, dist = dist)
        return self.output(x)

In [193]:
####------------------With EDGECONV layer------------------------####

class GraphNN_KNN_v0_v1(nn.Module):
    def __init__(self, k, dim_out=10):
        super().__init__()
        self.k = k       
        self.wconv1 = EdgeConv(Sequential(nn.Linear(21, 10)), 'max')       
        self.output = nn.Linear(10, dim_out)

    def forward(self, data):
        x, edge_index, orders, dist = data.x, data.edge_index, data.mask, data.edge_attr
        x1 = self.wconv1(x=x, edge_index=edge_index, dist = dist)
        return self.output(x1)

In [194]:
k = showers[0].x.shape[1]
print(k)
graph_embedder = str_to_class('GraphNN_KNN_v1_v0')(dim_out=10, k=k).to(device)

10


In [195]:
graph_embedder(shower)

torch.Size([1011744, 1])


AttributeError: 'Tensor' object has no attribute 'T'