In [1]:
from imports import*
from utils import*
from models import*
from seed_everything import*
from Train_Eval_Test import*

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Hyperparameters of the model

args = {
    'device': device,
    'heads': 2,
    'num_features' : 1,
    'hidden' : 100, # 100 for simple GNN, 50 for ViG model
    'num_classes' : 6,
    'dropout': 0.001,
    'alpha' : 0.1,
    'lr': 0.001,
    'epochs': 200,
    'patience': 50,
}

Setup of the dataset

In [4]:
side = 512 # The side of the original volume
new_side = 64 # The side of the sub-volume on which we construct the graph
stride = 56 # The stride we use in extracting the overlapping sub-volumes

Import the synthetic non-augmented dataset (train/val)

In [None]:
features_1 = raw_to_tensor("FINAL1.raw", side)
features_2 = raw_to_tensor("FINAL2.raw", side)
features_3 = raw_to_tensor("FINAL3.raw", side)
features_4 = raw_to_tensor("FINAL4.raw", side)
features_5 = raw_to_tensor("FINAL5.raw", side)
features_6 = raw_to_tensor("FINAL6.raw", side)
features_7 = raw_to_tensor("FINAL7.raw", side)


labels_1 = raw_to_tensor("LABELS1.raw", side)
labels_2 = raw_to_tensor("LABELS2.raw", side)
labels_3 = raw_to_tensor("LABELS3.raw", side)
labels_4 = raw_to_tensor("LABELS4.raw", side)
labels_5 = raw_to_tensor("LABELS5.raw", side)
labels_6 = raw_to_tensor("LABELS6.raw", side)
labels_7 = raw_to_tensor("LABELS7.raw", side)

We compute the average occurrence of the 6 classes among the 8 volumes (train/val and test) once and we save it to a separate file

In [None]:
occ_1 = average_labels_vol(labels_1,side)
occ_2 = average_labels_vol(labels_2,side)
occ_3 = average_labels_vol(labels_3,side)
occ_4 = average_labels_vol(labels_4,side)
occ_5 = average_labels_vol(labels_5,side)
occ_6 = average_labels_vol(labels_6,side)
occ_7 = average_labels_vol(labels_7,side)
occ_test = average_labels_vol(labels_test,side)

In [None]:
av = (np.array(occ_1) + np.array(occ_2) + np.array(occ_3) + np.array(occ_4) + np.array(occ_5) + np.array(occ_6) + np.array(occ_7) + np.array(occ_test))/7

norm_av = (av[:])/(np.sum(av[:]))
inverse_norm_av = (1/norm_av).reshape(-1)
np.save('weights_loss.npy', inverse_norm_av)

We extract 64x64x64 overlapping sub-volumes from the original volumes with a stride of 56

In [None]:
x1 = torch.tensor(view_as_windows(features_1.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

x2 = torch.tensor(view_as_windows(features_2.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

x3 = torch.tensor(view_as_windows(features_3.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

x4 = torch.tensor(view_as_windows(features_4.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

x5 = torch.tensor(view_as_windows(features_5.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

x6 = torch.tensor(view_as_windows(features_6.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

x7 = torch.tensor(view_as_windows(features_7.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))


y1 = torch.tensor(view_as_windows(labels_1.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

y2 = torch.tensor(view_as_windows(labels_2.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

y3 = torch.tensor(view_as_windows(labels_3.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

y4 = torch.tensor(view_as_windows(labels_4.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

y5 = torch.tensor(view_as_windows(labels_5.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

y6 = torch.tensor(view_as_windows(labels_6.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

y7 = torch.tensor(view_as_windows(labels_7.numpy(), (new_side,new_side,new_side), step=stride).reshape(-1,new_side**3,1))

Import the synthetic augmented dataset (train/val), where the volumes have already size 64x64x64

In [None]:
Train_df = pd.read_csv('Training_Set.csv')

Val_df = pd.read_csv('Validation_Set.csv')

In [None]:
X_Train = torch.tensor(np.array(list(map(fromfile, Train_df.iloc[:,0].to_numpy().reshape(-1)))))
Y_Train = torch.tensor(np.array(list(map(fromfile, Train_df.iloc[:,1].to_numpy().reshape(-1)))))

X_Val = torch.tensor(np.array(list(map(fromfile, Val_df.iloc[:,0].to_numpy().reshape(-1)))))
Y_Val = torch.tensor(np.array(list(map(fromfile, Val_df.iloc[:,1].to_numpy().reshape(-1)))))

X_Train = X_Train.reshape(-1, new_side**3, 1)
X_Val = X_Val.reshape(-1, new_side**3, 1)

Import experimental subvolume for fine-tuning and the stride is reduced to 28

In [5]:
subvolume = np.load('subvolume_128.npy').reshape(128,128,128)
labels = np.load('labels_128.npy').reshape(128,128,128)

In [6]:
x_train = torch.tensor(view_as_windows(subvolume, (new_side,new_side,new_side), step=28).reshape(-1,new_side**3,1))

y_train = torch.tensor(view_as_windows(labels, (new_side,new_side,new_side), step=28).reshape(-1,new_side**3))

Load weights loss

In [7]:
inverse_norm_av = torch.load("weights_loss.npy")
inverse_norm_av[0] = 0 # do not want to optimize wrt voids

In [8]:
inverse_norm_av

tensor([ 0.0000, 23.1316, 25.6455,  9.6994, 22.6797,  1.3000],
       dtype=torch.float64)

We then create a graph from each sub-volume by connecting each voxel to its nearest 6 neighbors

In [9]:
k_neigh = 6

cloud = torch.cartesian_prod(torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)), torch.tensor(range(0, new_side)))
edges = create_edges(k_neigh, cloud)
adj = SparseTensor(row=edges[0], col=edges[1], sparse_sizes=(new_side**3,new_side**3))

We create the dataset for training and validation after shuffling the subvolumes

In [None]:
# non-augmented dataset for simple GNN

data_list = []

def data_list_creation(X, Y, edges):
  for i in range(X.size()[0]):
    data_list.append(Data(x=X[i], edge_index=edges, y=Y[i]))

data_list_creation(x1, y1, edges)
data_list_creation(x2, y2, edges)
data_list_creation(x3, y3, edges)
data_list_creation(x4, y4, edges)
data_list_creation(x5, y5, edges)
data_list_creation(x6, y6, edges)
data_list_creation(x7, y7, edges)

random.Random(4).shuffle(data_list) # We fix it to make it reproducible

In [None]:
# augmented dataset for ViG

data_list = []

for i in range(X_Train.size()[0]):
    data_list.append(Data(x=X_Train[i], y=Y_Train[i], adj_t=adj))

for i in range(X_Val.size()[0]):
    data_list.append(Data(x=X_Val[i], y=Y_Val[i], adj_t=adj))

random.Random(4).shuffle(data_list) # We fix it to make it reproducible

In [10]:
# experimental subvolume for fine-tuning

data_list = []

for i in range(x_train.size()[0]):
    data_list.append(Data(x=x_train[i], adj_t=adj, y=y_train[i]))

random.Random(4).shuffle(data_list) # We fix it to make it reproducible

Since the adjacency matrix of the graph occupies a lot of memory, we set the batch size to 4 and backpropagate every 16 steps to simulate a batch size of 64

In [11]:
batch = 4

In [None]:
train_loader = DataLoader(data_list[:int(0.7*len(data_list))], batch_size=batch, shuffle=True, drop_last=True)
val_loader = DataLoader(data_list[int(0.7*len(data_list)):], batch_size=batch, shuffle=True, drop_last=True)

We use the stop variable to eliminate the loss vector from memory while training the model until the last epoch

In [13]:
stop = len(data_list[:int(0.7*len(data_list))])//batch + 1

Training for 10 runs

In [None]:
save_dir = os.path.join(os.getcwd(), 'ViG') # directory in which we save the trained models

In [None]:
for i in range(10):

    seed_everything(i) # we fix it to make it reproducible

    model = ViG(batch, new_side, args['num_features'], args['hidden'], args['num_classes'], args['dropout']).to(device)
    model.reset_parameters()

    # when fine-tuning we load the model and do not reset the parameters
    model.load_state_dict(torch.load('ViG_' + str(i) + '.h5'))

    losses_train = []
    losses_val = []
    accuracies_train = []
    accuracies_val = []

    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=0.001)
    loss_fn = torch.nn.CrossEntropyLoss(weight=inverse_norm_av.type(torch.FloatTensor).to(device)) 

    decayRate = 0.96
    my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)

    #Initialize Variables for EarlyStopping
    best_loss = float('inf')
    best_model_weights = None
    patience = args['patience']

    for epoch in range(1, 1 + args['epochs']): #
        
        print('Training...')
        model = model.float()
        train_obj = Train(model, device, train_loader, optimizer, loss_fn, stop)
        loss = train_obj.train_function()
        losses_train.append(loss)
        my_lr_scheduler.step()

        eval_obj = Eval(model, device, loss_fn, args['classes'])
        _, train_acc_per_class = eval_obj.eval_function(train_loader, Dice=True)
        val_loss, val_acc_per_class = eval_obj.eval_function(val_loader, Dice=True)
        losses_val.append(val_loss)

        acc_train = np.mean(np.array(train_acc_per_class), axis=0)
        accuracies_train.append(acc_train)

        acc_val = np.mean(np.array(val_acc_per_class), axis=0)
        accuracies_val.append(acc_val)

        print([epoch, 100 * train_acc_per_class[1], 100 * train_acc_per_class[2], 100 * train_acc_per_class[3], 100 * train_acc_per_class[4], 100 * train_acc_per_class[5],
                      100 * val_acc_per_class[1], 100 * val_acc_per_class[2], 100 * val_acc_per_class[3], 100 * val_acc_per_class[4], 100 * val_acc_per_class[5]])

        # Early stopping
        if val_loss < best_loss:
            best_loss = val_loss
            best_model_weights = copy.deepcopy(model.state_dict()) # Deep copy here      
            patience = args['patience']  # Reset patience counter
        else:
            patience -= 1
            if patience == 0:

                eval_obj = Eval(model, device, loss_fn, args['classes'])
                _, train_acc_per_class = eval_obj.eval_function(train_loader, Dice=True) 
                _, val_acc_per_class = eval_obj.eval_function(val_loader, Dice=True) 
                
                data = [epoch, 
                        100 * train_acc_per_class[1], 100 * train_acc_per_class[2], 100 * train_acc_per_class[3], 100 * train_acc_per_class[4], 100 * train_acc_per_class[5],
                        100 * val_acc_per_class[1], 100 * val_acc_per_class[2], 100 * val_acc_per_class[3], 100 * val_acc_per_class[4], 100 * val_acc_per_class[5]]
                
                with open('Train_Val_Synthetic_ViG.csv', 'a') as f:
                    writer = csv.writer(f)
                    writer.writerow(data)

                break

    # Load the best model weights
    model.load_state_dict(best_model_weights)
        
    model_name = 'ViG_' + str(i) + '.h5'

    # Save model and weights
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    model_path = os.path.join(save_dir, model_name)
    torch.save(model.state_dict(), model_path)

    plt.plot(range(len(accuracies_train)), accuracies_train)
    plt.title("Accuracy on the training set")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.savefig("Accuracy_ViG_" + str(i))

    plt.plot(range(len(accuracies_val)), accuracies_val)
    plt.title("Accuracy on the validation set")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.savefig("Accuracy_ViG_" + str(i))

    plt.plot(range(len(losses_train)), losses_train)
    plt.title("Loss on the training set")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.savefig("Loss_ViG_" + str(i))

    plt.plot(range(len(losses_val)), losses_val)
    plt.title("Loss on the validation set")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.savefig("Loss_ViG_" + str(i))