In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [1]:
path_folder = '/path/to/path/folder'
rad_folder = '/path/to/rad/folder/'

In [None]:
# Modality ['t1', 't2', 't1ce', 'flair']
modality = 'flair'

# Fold 
fold = 1

In [None]:
class GuidanceModule_dataset(Dataset):
    def __init__(self, path_features, rad_features, modality, fold, datasplit):
        self.path_array = np.load(os.path.join(path_features, datasplit+'_'+str(fold)+'.npy'))
        self.rad_array = np.load(os.path.join(rad_features, modality, str(fold), datasplit+'_'+str(fold)+'_'+modality+'.npy'))
        
    def __len__(self):
        assert len(self.path_array) == len(self.rad_array)
        return len(self.path_array)
    
    def __getitem__(self, idx):
        p = self.path_array[idx]
        r = self.rad_array[idx]
        data = {'path':p, 'rad':r}
        return data         

In [None]:
train_set = GuidanceModule_dataset(path_folder, rad_folder, modality, fold, 'train')
val_set = GuidanceModule_dataset(path_folder, rad_folder, modality, fold, 'val')
test_set = GuidanceModule_dataset(path_folder, rad_folder, modality, fold, 'test')

In [None]:
train_loader = DataLoader(train_set, batch_size=50, shuffle=True)
val_loader = DataLoader(val_set, batch_size=30, shuffle=False)
test_loader = DataLoader(test_set, batch_size=30, shuffle=False)

In [None]:
class GuidanceModule_architecture(nn.Module):
    def __init__(self):
        super(GuidanceModule_architecture, self).__init__()
        
        self.e1 = nn.Linear(in_features=1024, out_features=512)
        self.e2 = nn.Linear(in_features=512, out_features=256)
        
        self.d1 = nn.Linear(in_features=256, out_features=256)
        self.d2 = nn.Linear(in_features=256, out_features=512)
        
        self.dropout = nn.Dropout(0.25)
        
    def forward(self, x):
        x = F.relu(self.e1(x))
        x = self.dropout(x)
        x = F.relu(self.e2(x))
        x = self.dropout(x)
        
        x = F.relu(self.d1(x))
        x = self.dropout(x)
        x = F.relu(self.d2(x))
        return x
    
# Model initialize
model = GuidanceModule_architecture()
model.to(device)
# print(model)

In [None]:
# Loss, optimizer, etc

criterion = nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.5, momentum=0.9, weight_decay=1e-6, nesterov=True)

In [None]:
# train loop

train_loss = []
val_loss = []
epochs = 150

for epoch in range(epochs):
    running_train_loss = 0.0
    model.train()
    for data in train_loader:
        p, r = data['path'], data['rad']
        p, r = p.to(device), r.to(device)
        optimizer.zero_grad()
        pred_p = model(r)
        loss_train = criterion(pred_p, p)
        loss_train.backward()
        optimizer.step()
        running_train_loss += loss_train.item()
        
    loss_train_avg = running_train_loss/len(train_loader)
    train_loss.append(loss_train_avg)
    print('Epoch {} of {}, Train Loss: {:.3f}'.format(epoch+1, epochs, loss_train_avg))
    
    running_val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for data in val_loader:
            p, r = data['path'], data['rad']
            p, r = p.to(device), r.to(device)
            pred_p = model(r)
            loss_val = criterion(pred_p, p)
            running_val_loss += loss_val.item()
            
        loss_val_avg = running_val_loss/len(val_loader)
        val_loss.append(loss_val_avg)

In [None]:
# Testing the network reconstruction

test_loss = []
running_test_loss = 0.0
model.eval()
with torch.no_grad():
    for data in test_loader:
        p, r = data['path'], data['rad']
        p, r = p.to(device), r.to(device)
        pred_p = model(r)
        loss_test = criterion(pred_p, p)
        running_test_loss += loss_test.item()

    loss_test_avg = running_test_loss/len(test_loader)
    test_loss.append(loss_test_avg)

In [None]:
plt.plot(train_loss, label='train_loss')
plt.plot(val_loss, label='val_loss')
plt.plot(epochs, test_loss, 'go', ms=5, label='test_loss')
plt.axvline(x=epochs, color='g', linewidth=1)
plt.legend()
plt.grid(linewidth=.5, linestyle='-')

# Saving the figure
fig_path = '/path/to/results/'
plt.savefig(os.path.join(fig_path, modality, modality+'_'+str(fold)+'.png'))

In [None]:
# Saving the model

save_dir = '/path/to/results/'
torch.save(model.state_dict(), os.path.join(save_dir, modality, modality+'_'+str(fold)+'.pt'))