In [1]:
import os
import torch
import numpy as np
from scipy.sparse import coo_matrix
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_tar)

adj = np.loadtxt('../cogd/model/graph/mci_connectome.csv',delimiter=',')
a = coo_matrix(adj)

class ASet(InMemoryDataset):

    def __init__(self, root, train=True, transform=None):
        super(ASet, self).__init__(root, transform)
        path = self.processed_paths[0] if train else self.processed_paths[1]
        self.data, self.slices = torch.load(path)

    @property
    def raw_file_names(self):
        return ['train1.csv', 'test1.csv']

    @property
    def processed_file_names(self):
        return ['training.pt', 'test.pt']

    def download(self):
        pass
        #print('no download avail..')

    def process(self):
        #connectome
        adj = np.loadtxt('../cogd/model/graph/mci_connectome.csv',delimiter=',')
        a = coo_matrix(adj)

        #load
        train_path = "../cogd/model/data/imaging/kfold/train1.csv"
        test_path  = "../cogd/model/data/imaging/kfold/test1.csv"
        train = np.loadtxt(train_path,delimiter=',',skiprows=1)
        test = np.loadtxt(test_path,delimiter=',',skiprows=1)
        data_list = []
        
        for m in train:
            edge_index = torch.tensor([a.row,a.col], dtype=torch.long) # [2, num_edges] coo format
            edge_attr  = torch.tensor(np.transpose([a.data]), dtype=torch.float) # [num_edges, num_edge_features]
            node_feat = np.reshape(m[2:], (86,2), order='F')
            x = torch.tensor(node_feat, dtype=torch.float) # [num_nodes, num_node_features]
            y = torch.tensor([m[0]-1], dtype=torch.long)
            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        

In [17]:
data = ASet(root='./data/adni/')

from torch_geometric.data import DataLoader
loader = DataLoader(data, batch_size=3, shuffle=True)

Processing...
Done!


In [20]:
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import ChebConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = ChebConv(data.num_features, 5, 4)
        self.conv2 = ChebConv(5, 10, 2)
        self.dense = Linear(10, 112)
        self.dense2 = Linear(112, 2)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        print('input:',x.shape)
        x = self.conv1(x, edge_index)
        print('conv1 ->', x.shape)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        print('conv2 ->', x.shape)
        x = F.relu(self.dense(x))
        print('dense-relu ->', x.shape)
        #x = torch.sum(x, dim=0)
        print(x.shape)

        # goal is to return [batch_size, data.num_features]
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        print(out.shape, data.y.shape)
        #sum of the output will be divided by the number of elements in the output
        loss = F.binary_cross_entropy_with_logits(out, data.y, reduction='elementwise_mean')
#         loss.backward()
#         optimizer.step()

input: torch.Size([258, 2])
conv1 -> torch.Size([258, 5])
conv2 -> torch.Size([258, 10])
dense-relu -> torch.Size([258, 112])
torch.Size([258, 112])
torch.Size([258, 112]) torch.Size([3])


TypeError: binary_cross_entropy_with_logits() got an unexpected keyword argument 'reduction'

In [None]:
model.eval()
_, pred = model(data).max(dim=1)
correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))