In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


from imports import*
from utils import*
from models import*


# In[2]:


#Now let's train the network on the graphs

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# If you use GPU, the device should be cuda
print('Device: {}'.format(device))


# In[3]:


args = {
    'device': device,
    'heads': 2,
    'num_features' : 1,
    'hidden' : 100,
    'num_classes' : 6,
    'dropout': 0.001,
    'alpha' : 0.1,
    'lr': 0.001,
    'epochs': 200,
}


# In[4]:


side = 512
new_side = 64
stride = 28


# In[ ]:


features_1_exp = raw_to_tensor("CV1_NLM8.raw", side)

features_2_exp = raw_to_tensor("CV2_NLM8.raw", side)

features_3_exp = raw_to_tensor("CV3_NLM8.raw", side)

features_4_exp = raw_to_tensor("CV4_NLM8.raw", side)


# In[ ]:


x_1_exp = torch.tensor(view_as_windows(features_1_exp.numpy(), (64,64,64), step=28).reshape(-1,64**3,1))
x_2_exp = torch.tensor(view_as_windows(features_2_exp.numpy(), (64,64,64), step=28).reshape(-1,64**3,1))
x_3_exp = torch.tensor(view_as_windows(features_3_exp.numpy(), (64,64,64), step=28).reshape(-1,64**3,1))
x_4_exp = torch.tensor(view_as_windows(features_4_exp.numpy(), (64,64,64), step=28).reshape(-1,64**3,1))


# In[ ]:


#labels_1_exp = tif_to_tensor("CV1LabelsSlice339.tif", side)

#labels_2_exp = tif_to_tensor("CV2LabelsSlice139.tif", side)

#labels_3_exp = tif_to_tensor("CV3LabelsSlice219.tif", side)

#labels_4_exp = tif_to_tensor("CV4LabelsSlice059.tif", side)


# In[ ]:


labels_1_exp = tif_to_tensor("CV1 Labels - Slice 339.tif", side)

labels_2_exp = tif_to_tensor("CV2 Labels - Slice 139.tif", side)

labels_3_exp = tif_to_tensor("CV3 Labels - Slice 219.tif", side)

labels_4_exp = tif_to_tensor("CV4 Labels - Slice 059.tif", side)


# In[ ]:


new_side = 64


# In[ ]:


cloud = torch.cartesian_prod(torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)))


# In[ ]:


k_neigh = 6


# In[ ]:


edges = create_edges(k_neigh, cloud)


# In[ ]:


batch = 1


# In[ ]:


class Eval_Test:

  def __init__(self, model, device):
       
    self.model = model
    self.device = device

  def eval_function(self, data_loader):

    # Sets model to train mode
    self.model.eval()

    data_loader = data_loader

    preds = torch.zeros((6,1)).cpu()

    for step, batch in enumerate(tqdm(data_loader, desc="Iteration")): #remind that tqdm draws progress bars

      model = self.model.to(device)
      
      batch = batch.to(device)

      batch_index = batch.batch

      edge_index = batch.edge_index.type(torch.LongTensor).to(device)
      
      with torch.no_grad():
        preds = torch.cat((preds,torch.nn.functional.softmax(model((batch.x).float(), edge_index, batch_index).transpose(1,0).cpu(), dim=0)),1)

      torch.cuda.empty_cache()
      del(batch)
      del(batch_index)
      del(edge_index)
      gc.collect()
      
    return preds


# In[ ]:


# Define Volume Size of the whole Volume (cube Size):
Vol_Size = 512
#Define Volume Size of the input (minibatch) Volume (minibatch cube Size):
MinBatch_Vol_Size = 64
#Define Stride:
Stride = 28

#Therefore the number of Steps are:
Steps = int((Vol_Size - MinBatch_Vol_Size) / Stride + 1)


# In[ ]:


def extract_overlap_pred(eval_obj, test_loader):
    
  preds = eval_obj.eval_function(test_loader)


  all_preds = preds[:,1:].reshape(6, 17, 17, 17, 64, 64, 64)


  summed_preds = torch.zeros(6, 512,512,512)


  for l in range(6):
      for i in range(Steps):
          for j in range(Steps):
              for k in range(Steps):
                  summed_preds[l,(i)*Stride:(i)*Stride+MinBatch_Vol_Size, 
                  (j)*Stride:(j)*Stride+MinBatch_Vol_Size, 
                  (k)*Stride:(k)*Stride+MinBatch_Vol_Size] = summed_preds[l,(i)*Stride:(i)*Stride+MinBatch_Vol_Size, 
                  (j)*Stride:(j)*Stride+MinBatch_Vol_Size, 
                  (k)*Stride:(k)*Stride+MinBatch_Vol_Size] + all_preds[l, i, j, k, :, :, :]


  preds_argmax = torch.argmax(summed_preds, dim=0)

  return preds_argmax


# In[ ]:


plt.imshow(labels_1_exp, vmin=0, vmax=5)
plt.savefig("Labels_1")
plt.close()

plt.imshow(labels_2_exp, vmin=0, vmax=5)
plt.savefig("Labels_2")
plt.close()

plt.imshow(labels_3_exp, vmin=0, vmax=5)
plt.savefig("Labels_3")
plt.close()

plt.imshow(labels_4_exp, vmin=0, vmax=5)
plt.savefig("Labels_4")
plt.close()


# In[ ]:


data_list_test_1 = []

for i in range(x_1_exp.size()[0]):
    data_list_test_1.append(Data(x=x_1_exp[i], edge_index=edges, y=labels_1_exp))

test_loader_1 = DataLoader(data_list_test_1, batch_size=1, shuffle=False)

data_list_test_2 = []

for i in range(x_2_exp.size()[0]):
    data_list_test_2.append(Data(x=x_2_exp[i], edge_index=edges, y=labels_2_exp))

test_loader_2 = DataLoader(data_list_test_2, batch_size=1, shuffle=False)

data_list_test_3 = []

for i in range(x_3_exp.size()[0]):
    data_list_test_3.append(Data(x=x_3_exp[i], edge_index=edges, y=labels_3_exp))

test_loader_3 = DataLoader(data_list_test_3, batch_size=1, shuffle=False)

data_list_test_4 = []

for i in range(x_4_exp.size()[0]):
    data_list_test_4.append(Data(x=x_4_exp[i], edge_index=edges, y=labels_4_exp))

test_loader_4 = DataLoader(data_list_test_4, batch_size=1, shuffle=False)


# In[ ]:


for i in range(10):
    
  model = GNN(args['heads'], args['num_features'], args['hidden'], args['num_classes'], args['dropout']).to(device)
  model.load_state_dict(torch.load('GNN_200_non_augm_' + str(i) + '.h5'))
  model.eval()

  eval_obj = Eval_Test(model, device)

  preds_1 = extract_overlap_pred(eval_obj, test_loader_1)

  plt.imshow(preds_1[339], vmin=0, vmax=5)
  plt.savefig("Preds_1" + str(i))
  plt.close()

  dice_1 = dice(preds_1[339], labels_1_exp, average='none', num_classes=6)
  data_1 = [i, 100 * dice_1[0].numpy(), 100 * dice_1[1].numpy(), 100 * dice_1[2].numpy(),  #added
          100 * dice_1[3].numpy(), 100 * dice_1[4].numpy(), 100 * dice_1[5].numpy()]

  preds_2 = extract_overlap_pred(eval_obj, test_loader_2)

  labels_2_exp[labels_2_exp == 6] = 0

  plt.imshow(preds_2[139], vmin=0, vmax=5)
  plt.savefig("Preds_2" + str(i))
  plt.close()

  dice_2 = dice(preds_2[139], labels_2_exp, average='none', num_classes=6)
  data_2 = [i, 100 * dice_2[0].numpy(), 100 * dice_2[1].numpy(), 100 * dice_2[2].numpy(),  #added
          100 * dice_2[3].numpy(), 100 * dice_2[4].numpy(), 100 * dice_2[5].numpy()]

  preds_3 = extract_overlap_pred(eval_obj, test_loader_3)

  labels_3_exp[labels_3_exp == 6] = 0

  plt.imshow(preds_3[219], vmin=0, vmax=5)
  plt.savefig("Preds_3" + str(i))
  plt.close()

  dice_3 = dice(preds_3[219], labels_3_exp, average='none', num_classes=6)
  data_3 = [i, 100 * dice_3[0].numpy(), 100 * dice_3[1].numpy(), 100 * dice_3[2].numpy(),  #added
          100 * dice_3[3].numpy(), 100 * dice_3[4].numpy(), 100 * dice_3[5].numpy()]

  preds_4 = extract_overlap_pred(eval_obj, test_loader_4)

  labels_4_exp[labels_4_exp == 6] = 0

  plt.imshow(preds_4[59], vmin=0, vmax=5)
  plt.savefig("Preds_4" + str(i))
  plt.close()

  dice_4 = dice(preds_4[59], labels_4_exp, average='none', num_classes=6)
  data_4 = [i, 100 * dice_4[0].numpy(), 100 * dice_4[1].numpy(), 100 * dice_4[2].numpy(), #added
          100 * dice_4[3].numpy(), 100 * dice_4[4].numpy(), 100 * dice_4[5].numpy()]

  with open('Dice_Experimental_GNN.csv', 'a') as f:   #modified
        writer = csv.writer(f)
        writer.writerow(data_1)
        writer.writerow(data_2)
        writer.writerow(data_3)
        writer.writerow(data_4)





Device: cuda




Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]

Iteration:   0%|          | 0/4913 [00:00<?, ?it/s]