In [1]:
from imports import*
from utils import*
from models import*

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# If you use GPU, the device should be cuda
print('Device: {}'.format(device))

Arguments to feed the model

In [3]:
args = {
    'device': device,
    'heads': 2,
    'num_features' : 1,
    'hidden' : 100,
    'num_classes' : 6,
    'dropout': 0.001,
    'alpha' : 0.1,
    'lr': 0.001,
    'epochs': 100,
}

In [4]:
side = 512 # The side of the original volume
new_side = 64 # The side of the sub-volume on which we construct the graph
stride = 28 # The stride we use in extracting the overlapping sub-volumes

We import the features and the labels of the synthetic volume 

In [None]:
features_test = raw_to_tensor("CVSynth.raw", side)

labels_test = raw_to_tensor("CVSynth_Labels.raw", side)

We plot the first slice of labels to compare it with the model's segmentation

In [None]:
plt.imshow(labels_test[0], vmin=0, vmax=5)
plt.savefig("Labels")

plt.close()

And we extract the sub-volumes

In [None]:
x_test = torch.tensor(view_as_windows(features_test.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

y_test = torch.tensor(view_as_windows(labels_test.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3))

We create a list of edges connecting each node to its first 6 neighbors

In [None]:
cloud = torch.cartesian_prod(torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)))

k_neigh = 6

edges = create_edges(k_neigh, cloud)

We construct a dataset of sub-graphs from the testing volume

In [None]:
data_list_test = []

for i in range(x_test.size()[0]):
    data_list_test.append(Data(x=x_test[i], edge_index=edges, y=y_test[i]))

test_loader = DataLoader(data_list_test, batch_size=1, shuffle=False)

We define the test function

In [None]:
batch = 1

In [None]:
class Eval_Test:

  def __init__(self, model, device):
       
    self.model = model
    self.device = device

  def eval_function(self, data_loader):

    # Sets model to train mode
    self.model.eval()

    data_loader = data_loader

    preds = torch.zeros((args['num_classes'],1)).cpu()

    for step, batch in enumerate(tqdm(data_loader, desc="Iteration")): #remind that tqdm draws progress bars

      model = self.model.to(device)
      
      batch = batch.to(device)

      batch_index = batch.batch

      edge_index = batch.edge_index.type(torch.LongTensor).to(device)
      
      with torch.no_grad():
        preds = torch.cat((preds,torch.nn.functional.softmax(model((batch.x).float(), edge_index, batch_index).transpose(1,0).cpu(), dim=0)),1)

      torch.cuda.empty_cache()
      del(batch)
      del(batch_index)
      del(edge_index)
      gc.collect()
      
    return preds

We test each trained GNN model (10 in total) on the sub-volumes extracted from the synthetic volume.
We reconstruct 6 probability volumes (one for each class) and we assign to each voxel the highest probability class. We compute the dice score for each trained model and we write it on an external file.

In [None]:
# The number of Steps to reconstruct the original volume from the evaluation on the sub-volumes:
steps = int((side - new_side) / stride + 1)

In [None]:
for i in range(10):

  model = GNN(args['heads'], args['num_features'], args['hidden'], args['num_classes'], args['dropout']).to(device)
  model.load_state_dict(torch.load('GNN_200_non_augm_' + str(i) + '.h5'))
  model.eval()

  eval_obj = Eval_Test(model, device)


  preds = eval_obj.eval_function(test_loader)


  all_preds = preds[:,1:].reshape(args['num_classes'], steps, steps, steps, new_side, new_side, new_side)


  summed_preds = torch.zeros(args['num_classes'], side,side,side)


  for l in range(args['num_classes']):
      for i in range(steps):
          for j in range(steps):
              for k in range(steps):
                  summed_preds[l,(i)*stride:(i)*stride+new_side, 
                  (j)*stride:(j)*stride+new_side, 
                  (k)*stride:(k)*stride+new_side] = summed_preds[l,(i)*stride:(i)*stride+new_side, 
                  (j)*stride:(j)*stride+new_side, 
                  (k)*stride:(k)*stride+new_side] + all_preds[l, i, j, k, :, :, :]


  preds_argmax = torch.argmax(summed_preds, dim=0)

  # We plot the predicted segmentation of the first slice

  plt.imshow(preds_argmax[0], vmin=0, vmax=5)
  plt.savefig("Preds_Gnn_" + str(i))
  plt.close()

  DICE = dice(preds_argmax, labels_test, average='none', num_classes=args['num_classes'])

  data = [i, 100 * DICE[0], 100 * DICE[1], 100 * DICE[2],
          100 * DICE[3], 100 * DICE[4], 100 * DICE[5]]

  with open('Dice_Test_Synthetic_gnn_200_non_augm.csv', 'a') as f:
                writer = csv.writer(f)
                writer.writerow(data)