In [1]:
import os
import torch
import numpy as np
from datasets import ADNI
from torch_geometric.data import DataLoader
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import ChebConv

data = ADNI(root='./data/imaging/')
loader = DataLoader(data, batch_size=1, shuffle=True)

Processing...
Done!


In [2]:
batch_size = 5
num_epochs = 50
filters    = [5, 10]
khops      = [4, 2]
fc_size    = 112

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = ChebConv(data.num_features, filters[0], khops[0])
        self.conv2 = ChebConv(filters[0], filters[1], khops[1])
        self.fc = Linear(86*filters[1], fc_size)
        self.logits = Linear(fc_size, 2)

    def forward(self, data):
        x, e, ea = data.x, data.edge_index, data.edge_attr
        #print('input:',x.shape)
        x = F.relu(self.conv1(x,e))
        #print('conv1:', x.shape)
        x = F.relu(self.conv2(x,e))
        #print('conv2:', x.shape)
        x = x.view([data.num_graphs, 86*filters[1]])
        x = F.relu(self.fc(x))
        x = F.dropout(x, training=self.training)
        x = self.logits(x)
        return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(num_epochs):
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
    if epoch % round(num_epochs/10) == 0:
        print('{:.3} (epoch {}/{})'.format(loss,epoch,num_epochs))
print('Done')

AttributeError: 'int' object has no attribute 'dim'

In [None]:
model.eval()
predictions, target = [], []
for data in loader:
    pred = model(data)
    predictions.extend(torch.Tensor.numpy(
                       torch.argmax(pred, dim=1)))
    target.extend(torch.Tensor.numpy(data.y))
#     print('p',torch.argmax(pred, dim=1))
#     print('y',data.y)
#     c = torch.argmax(pred, dim=1).eq(data.y).sum().item()
    
    
#     print('correct:', correct,'\n')
#     correct += c
#     total += data.y.size(0)
#     print(correct, total,'\n')
# print('Accuracy: {:.2f}'.format((correct/total)))