In [None]:
from imports import*
from utils import*
from models import*
from Train_Eval_Test import*

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# If you use GPU, the device should be cuda
print('Device: {}'.format(device))

# Hyperparameters of the model

args = {
    'device': device,
    'heads': 2,
    'num_features' : 1,
    'hidden' : 100, # 100 for simple GNN, 50 for ViG model
    'num_classes' : 6,
    'dropout': 0.001,
    'alpha' : 0.1,
    'lr': 0.001,
    'epochs': 200,
}

Setup of the dataset

In [None]:
side = 512 # The side of the original volume
new_side = 64 # The side of the sub-volume on which we construct the graph
stride = 28 # The stride we use in extracting the overlapping sub-volumes

Import synthetic test volume and extract subvolumes

In [None]:
features_test = raw_to_tensor("CVSynth.raw", side)

labels_test = raw_to_tensor("CVSynth_Labels.raw", side)

In [None]:
x_test = torch.tensor(view_as_windows(features_test.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

y_test = torch.tensor(view_as_windows(labels_test.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3))

Import the features of the four experimental volumes

In [None]:
# experimental volumes conditioned with a Non-Local Means filter

features_1_exp = raw_to_tensor("CV1_NLM8.raw", side)
features_2_exp = raw_to_tensor("CV2_NLM8.raw", side)
features_3_exp = raw_to_tensor("CV3_NLM8.raw", side)
features_4_exp = raw_to_tensor("CV4_NLM8.raw", side)

In [None]:
# experimental volumes conditioned with BAM SynthCOND

features_1_exp = raw_to_tensor("CV1_AI.raw", side)
features_2_exp = raw_to_tensor("CV2_AI.raw", side)
features_3_exp = raw_to_tensor("CV3_AI.raw", side)
features_4_exp = raw_to_tensor("CV4_AI.raw", side)

And we extract the subvolumes

In [None]:
x_1_exp = torch.tensor(view_as_windows(features_1_exp.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))
x_2_exp = torch.tensor(view_as_windows(features_2_exp.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))
x_3_exp = torch.tensor(view_as_windows(features_3_exp.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))
x_4_exp = torch.tensor(view_as_windows(features_4_exp.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

Then we import the four manually-labelled slices, one for each volume

In [None]:
labels_1_exp = tif_to_tensor("CV1 Labels - Slice 339.tif", side)
labels_2_exp = tif_to_tensor("CV2 Labels - Slice 139.tif", side)
labels_3_exp = tif_to_tensor("CV3 Labels - Slice 219.tif", side)
labels_4_exp = tif_to_tensor("CV4 Labels - Slice 059.tif", side)

We then create a graph from each subvolume by connecting each voxel to its nearest 6 neighbors

In [None]:
k_neigh = 6

cloud = torch.cartesian_prod(torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)))
edges = create_edges(k_neigh, cloud)
adj = SparseTensor(row=edges[0], col=edges[1], sparse_sizes=(new_side**3,new_side**3))

We create a dataset of subvolumes from the synthetic volume

In [None]:
data_list_test = []

for i in range(x_test.size()[0]):
    data_list_test.append(Data(x=x_test[i], edge_index=edges, y=y_test[i]))

test_loader = DataLoader(data_list_test, batch_size=1, shuffle=False)

We create a dataset of subvolumes from every experimental volume

In [None]:
data_list_test_1 = []

for i in range(x_1_exp.size()[0]):
    data_list_test_1.append(Data(x=x_1_exp[i], edge_index=edges, y=labels_1_exp))

test_loader_1 = DataLoader(data_list_test_1, batch_size=1, shuffle=False)

data_list_test_2 = []

for i in range(x_2_exp.size()[0]):
    data_list_test_2.append(Data(x=x_2_exp[i], edge_index=edges, y=labels_2_exp))

test_loader_2 = DataLoader(data_list_test_2, batch_size=1, shuffle=False)

data_list_test_3 = []

for i in range(x_3_exp.size()[0]):
    data_list_test_3.append(Data(x=x_3_exp[i], edge_index=edges, y=labels_3_exp))

test_loader_3 = DataLoader(data_list_test_3, batch_size=1, shuffle=False)

data_list_test_4 = []

for i in range(x_4_exp.size()[0]):
    data_list_test_4.append(Data(x=x_4_exp[i], edge_index=edges, y=labels_4_exp))

test_loader_4 = DataLoader(data_list_test_4, batch_size=1, shuffle=False)

The batch size is set to 1 when testing the model

In [None]:

batch = 1

We plot the four manually-labelled slices to compare them with the model's segmentation

In [None]:
plt.imshow(labels_1_exp, vmin=0, vmax=5)
plt.savefig("Labels_1")
plt.close()

plt.imshow(labels_2_exp, vmin=0, vmax=5)
plt.savefig("Labels_2")
plt.close()

plt.imshow(labels_3_exp, vmin=0, vmax=5)
plt.savefig("Labels_3")
plt.close()

plt.imshow(labels_4_exp, vmin=0, vmax=5)
plt.savefig("Labels_4")
plt.close()

We test each trained GNN model (10 in total) on the subvolumes extracted from the four experimental volumes

For each experimental volume, we reconstruct 6 probability volumes (one for each class) and we assign to each voxel the highest probability class

We compute the Dice score only for the manually-labelled slices and we write it on an external file

In [None]:
#The number of Steps to reconstruct the original volume from the evaluation on the subvolumes:

steps = int((side - new_side) / stride + 1)

In [None]:
for i in range(10):
    
  model = ViG(batch, new_side, args['num_features'], args['hidden'], args['num_classes'], args['dropout']).to(device)
  model.load_state_dict(torch.load('ViG_' + str(i) + '.h5'))
  model.eval()

  eval_obj = Test(model, device, args['classes'])

  # synthetic

  preds = extract_overlap_pred(eval_obj, test_loader, args['num_classes'], steps, new_side, side, stride)

  dice = dice(preds, labels_test, average='none', num_classes=6)
  dice_overall = dice(preds, labels_test, average='macro')
  data = [i, 100 * dice[1].numpy(), 100 * dice[2].numpy(), 100 * dice[3].numpy(), 100 * dice[4].numpy(), 100 * dice[5].numpy(),
            100 * dice_overall.numpy()]

  preds_1 = extract_overlap_pred(eval_obj, test_loader_1, args['num_classes'], steps, new_side, side, stride)

  plt.imshow(preds_1[338], vmin=0, vmax=5) # the manually-labelled slice for the first volume is number 339
  plt.savefig("Preds_1_" + str(i))
  plt.close()

  dice_1 = dice(preds_1[338], labels_1_exp, average='none', num_classes=6)
  dice_1_overall = dice(preds_1[338], labels_1_exp, average='macro')
  data_1 = [i, 100 * dice_1[1].numpy(), 100 * dice_1[2].numpy(), 100 * dice_1[3].numpy(), 100 * dice_1[4].numpy(), 100 * dice_1[5].numpy(),
            100 * dice_1_overall.numpy()]

  preds_2 = extract_overlap_pred(eval_obj, test_loader_2)

  labels_2_exp[labels_2_exp == 6] = 0 # label remaining voids as Alluminium matrix

  plt.imshow(preds_2[138], vmin=0, vmax=5) # the manually-labelled slice for the second volume is number 139
  plt.savefig("Preds_2_" + str(i))
  plt.close()

  dice_2 = dice(preds_2[138], labels_2_exp, average='none', num_classes=6)
  dice_2_overall = dice(preds_2[138], labels_2_exp, average='macro')
  data_2 = [i, 100 * dice_2[1].numpy(), 100 * dice_2[2].numpy(), 100 * dice_2[3].numpy(), 100 * dice_2[4].numpy(), 100 * dice_2[5].numpy(),
            100 * dice_2_overall.numpy()]

  preds_3 = extract_overlap_pred(eval_obj, test_loader_3)

  labels_3_exp[labels_3_exp == 6] = 0 # label remaining voids as Alluminium matrix

  plt.imshow(preds_3[218], vmin=0, vmax=5) # the manually-labelled slice for the third volume is number 219
  plt.savefig("Preds_3" + str(i))
  plt.close()

  dice_3 = dice(preds_3[218], labels_3_exp, average='none', num_classes=6)
  dice_3_overall = dice(preds_3[218], labels_3_exp, average='macro')
  data_3 = [i, 100 * dice_3[1].numpy(), 100 * dice_3[2].numpy(), 100 * dice_3[3].numpy(), 100 * dice_3[4].numpy(), 100 * dice_3[5].numpy(),
            100 * dice_3_overall.numpy()]

  preds_4 = extract_overlap_pred(eval_obj, test_loader_4)

  labels_4_exp[labels_4_exp == 6] = 0 # label remaining voids as Alluminium matrix

  plt.imshow(preds_4[58], vmin=0, vmax=5) # the manually-labelled slice for the fourth volume is number 59
  plt.savefig("Preds_4" + str(i))
  plt.close()

  dice_4 = dice(preds_4[58], labels_4_exp, average='none', num_classes=6)
  dice_4_overall = dice(preds_4[58], labels_4_exp, average='macro')
  data_4 = [i, 100 * dice_4[1].numpy(), 100 * dice_4[2].numpy(), 100 * dice_4[3].numpy(), 100 * dice_4[4].numpy(), 100 * dice_4[5].numpy(),
            100 * dice_4_overall.numpy()]

  with open('Test_ViG.csv', 'a') as f:
        writer = csv.writer(f)
        writer.writerow(data)
        writer.writerow(data_1)
        writer.writerow(data_2)
        writer.writerow(data_3)
        writer.writerow(data_4)