In [None]:
from imports import*
from utils import*
from models import*
from seed_everything import*

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Hyperparameters of the model

args = {
    'device': device,
    'heads': 2,
    'num_features' : 1,
    'hidden' : 100, # 50 for ViG model
    'num_classes' : 6,
    'dropout': 0.001,
    'alpha' : 0.1,
    'lr': 0.001,
    'epochs': 200,
}

We define the training and testing functions

In [3]:
stop = 766

class Train:

  def __init__(self, model, device, data_loader, optimizer, loss_f):
        
    self.model = model
    self.device = device
    self.data_loader = data_loader
    self.optimizer = optimizer
    self.loss_f = loss_f

  def train_function(self):

    # Sets model to train mode
    self.model.train()

    loss_ = 0

    count = 1

    model = self.model.to(device)

    for step, batch in enumerate(tqdm(self.data_loader, desc="Iteration")): #remind that tqdm draws progress bars


      batch = batch.to(device)
      batch_index = batch.batch
      edge_index = (batch.edge_index).type(torch.LongTensor).to(device)

      out = model((batch.x).float(), edge_index, batch_index) #Feed the data into the model
      loss_ = self.loss_f(out, batch.y.to(torch.int64))

      # backpropagate

      if count % 8 == 0: # we backpropagate every 8 steps to simulate a batch size of 64
 
        loss_.backward()
        self.optimizer.step()
        self.optimizer.zero_grad() #Zero grad the optimizer

      else:

        loss_.backward()

      count += 1  
      
      torch.cuda.empty_cache()
      del(batch)
      del(edge_index)
      del(batch_index)
      gc.collect()

      if step != stop-2: # we eliminate the loss vector to free memory for each step apart from the last one
        del(loss_)
        gc.collect()

    return loss_.item()
  
class Eval:

  def __init__(self, model, device):
        
    self.model = model
    self.device = device

  def eval_function(self, data_loader):

    # Sets model to eval mode
    self.model.eval()
    data_loader = data_loader

    dice_per_class = []

    for step, batch in enumerate(tqdm(data_loader, desc="Iteration")): #remind that tqdm draws progress bars

      model = self.model.to(device)

      batch = batch.to(device)

      batch_index = batch.batch

      edge_index = batch.edge_index.type(torch.LongTensor).to(device)

      with torch.no_grad():
        pred = torch.argmax(model((batch.x).float(), edge_index, batch_index),
                                axis=1)
        y_true = batch.y.view(pred.shape)
        y_pred = pred
            
      
      torch.cuda.empty_cache()
      del(batch)
      del(batch_index)
      del(edge_index)
      gc.collect()

      dice_per_class.append(Tensor.numpy(dice(y_pred, y_true, average='none', num_classes=args['num_classes']).detach().cpu()))
      
    mean_dice_per_class = np.mean(np.array(dice_per_class), axis=0)

    return mean_dice_per_class

In [None]:
side = 512 # The side of the original volume
new_side = 64 # The side of the sub-volume on which we construct the graph
stride = 56 # The stride we use in extracting the overlapping sub-volumes

We import the features and the labels of the 8 synthetic volumes (train/val and test)

In [4]:
features_1 = raw_to_tensor("FINAL1.raw", side)

features_2 = raw_to_tensor("FINAL2.raw", side)

features_3 = raw_to_tensor("FINAL3.raw", side)

features_4 = raw_to_tensor("FINAL4.raw", side)

features_5 = raw_to_tensor("FINAL5.raw", side)

features_6 = raw_to_tensor("FINAL6.raw", side)

features_7 = raw_to_tensor("FINAL7.raw", side)

features_test = raw_to_tensor("CVSynth.raw", side)



labels_1 = raw_to_tensor("LABELS1.raw", side)

labels_2 = raw_to_tensor("LABELS2.raw", side)

labels_3 = raw_to_tensor("LABELS3.raw", side)

labels_4 = raw_to_tensor("LABELS4.raw", side)

labels_5 = raw_to_tensor("LABELS5.raw", side)

labels_6 = raw_to_tensor("LABELS6.raw", side)

labels_7 = raw_to_tensor("LABELS7.raw", side)

labels_test = raw_to_tensor("CVSynth_Labels.raw", side)

We compute the average occurrence of the 6 classes among the 8 volumes (train/val and test)

In [5]:
occ_1 = average_labels_vol(labels_1,side)
occ_2 = average_labels_vol(labels_2,side)
occ_3 = average_labels_vol(labels_3,side)
occ_4 = average_labels_vol(labels_4,side)
occ_5 = average_labels_vol(labels_5,side)
occ_6 = average_labels_vol(labels_6,side)
occ_7 = average_labels_vol(labels_7,side)
occ_test = average_labels_vol(labels_test,side)

In [6]:
av = (np.array(occ_1) + np.array(occ_2) + np.array(occ_3) + np.array(occ_4) + np.array(occ_5) + np.array(occ_6) + np.array(occ_7) + np.array(occ_test))/7

norm_av = (av[:])/(np.sum(av[:]))
inverse_norm_av = (1/norm_av).reshape(-1)
inverse_norm_av = torch.tensor(inverse_norm_av)

We extract 64x64x64 overlapping sub-volumes from the original volumes with a stride of 56. 

In [7]:
x1 = torch.tensor(view_as_windows(features_1.numpy(), (64,64,64), step=stride).reshape(-1,64**3,1))

x2 = torch.tensor(view_as_windows(features_2.numpy(), (64,64,64), step=stride).reshape(-1,64**3,1))

x3 = torch.tensor(view_as_windows(features_3.numpy(), (64,64,64), step=stride).reshape(-1,64**3,1))

x4 = torch.tensor(view_as_windows(features_4.numpy(), (64,64,64), step=stride).reshape(-1,64**3,1))

x5 = torch.tensor(view_as_windows(features_5.numpy(), (64,64,64), step=stride).reshape(-1,64**3,1))

x6 = torch.tensor(view_as_windows(features_6.numpy(), (64,64,64), step=stride).reshape(-1,64**3,1))

x7 = torch.tensor(view_as_windows(features_7.numpy(), (64,64,64), step=stride).reshape(-1,64**3,1))


y1 = torch.tensor(view_as_windows(labels_1.numpy(), (64,64,64), step=stride).reshape(-1,64**3))

y2 = torch.tensor(view_as_windows(labels_2.numpy(), (64,64,64), step=stride).reshape(-1,64**3))

y3 = torch.tensor(view_as_windows(labels_3.numpy(), (64,64,64), step=stride).reshape(-1,64**3))

y4 = torch.tensor(view_as_windows(labels_4.numpy(), (64,64,64), step=stride).reshape(-1,64**3))

y5 = torch.tensor(view_as_windows(labels_5.numpy(), (64,64,64), step=stride).reshape(-1,64**3))

y6 = torch.tensor(view_as_windows(labels_6.numpy(), (64,64,64), step=stride).reshape(-1,64**3))

y7 = torch.tensor(view_as_windows(labels_7.numpy(), (64,64,64), step=stride).reshape(-1,64**3))

We then create a graph from each sub-volume by connecting each voxel to its nearest 6 neighbors.

In [None]:
k_neigh = 6
cloud = torch.cartesian_prod(torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)))
edges = create_edges(k_neigh, cloud)

We create the dataset for training and validation after shuffling the subvolumes from the first 7 volumes.

In [8]:
data_list = []

def data_list_creation(X, Y, edges):
  for i in range(X.size()[0]):
    data_list.append(Data(x=X[i], edge_index=edges, y=Y[i]))

data_list_creation(x1, y1, edges)
data_list_creation(x2, y2, edges)
data_list_creation(x3, y3, edges)
data_list_creation(x4, y4, edges)
data_list_creation(x5, y5, edges)
data_list_creation(x6, y6, edges)
data_list_creation(x7, y7, edges)

random.Random(4).shuffle(data_list) # We fix it to make it reproducible

Since the adjacency matrix of the graph occupies a lot of memory, we set the batch size to 4 and backpropagate every 8 steps to simulate a batch size of 64

In [None]:
batch = 4

In [None]:
train_loader = DataLoader(data_list[:int(0.6*len(data_list))], batch_size=batch, shuffle=True, drop_last=True)
val_loader = DataLoader(data_list[int(0.6*len(data_list)):], batch_size=batch, shuffle=True, drop_last=True)

We use the stop variable to eliminate the loss vector from memory while training the model until the last epoch

In [None]:
if len(data_list[:int(0.6*len(data_list))]) % batch == 0:
  stop = len(data_list[:int(0.6*len(data_list))])/batch
else:
  stop = len(data_list[:int(0.6*len(data_list))])//batch + 1

In [10]:
save_dir = os.path.join(os.getcwd(), 'GNN_non_augm') # directory in which we save the trained models

Training for 10 runs

In [None]:
for i in range(10):

    seed_everything(i) # we fix it to make it reproducible

    model = GNN(args['heads'], args['num_features'], args['hidden'], args['num_classes'], args['dropout']).to(device)
    model.reset_parameters()
    losses = []

    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=0.001)
    loss_fn = torch.nn.CrossEntropyLoss(weight=inverse_norm_av.type(torch.FloatTensor).to(device)) 

    decayRate = 0.96
    my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)


    for epoch in range(1, 1 + args['epochs']): #
        
        print('Training...')
        model = model.float()
        train_obj = Train(model, device, train_loader, optimizer, loss_fn)
        loss = train_obj.train_function()
        losses.append(loss)
        my_lr_scheduler.step()

        if epoch == args['epochs']:
          
          eval_obj = Eval(model, device)
          train_acc_per_class = eval_obj.eval_function(train_loader) 
          val_acc_per_class = eval_obj.eval_function(val_loader) 
        
          data = [i, 100 * train_acc_per_class[0], 100 * train_acc_per_class[1], 100 * train_acc_per_class[2],
                  100 * train_acc_per_class[3], 100 * train_acc_per_class[4], 100 * train_acc_per_class[5],
                  100 * val_acc_per_class[0], 100 * val_acc_per_class[1], 100 * val_acc_per_class[2],
                  100 * val_acc_per_class[3], 100 * val_acc_per_class[4], 100 * val_acc_per_class[5]]
          
          with open('Train_Val_Synthetic_GNN.csv', 'a') as f:
              writer = csv.writer(f)
              writer.writerow(data)
        
    model_name = 'GNN_200_non_augm_' + str(i) + '.h5'

    # Save model and weights
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    model_path = os.path.join(save_dir, model_name)
    torch.save(model.state_dict(), model_path)

    plt.plot(range(len(losses)), losses)
    plt.title("Loss on the training set")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.savefig("Loss_GNN_200_non_augm_" + str(i))