In [None]:
import os
from random import randrange
import torch
import torch.nn as nn
import torchvision as tv
import torchio as tio
import _01_dataloader as prl_dl
import _02_autoencoder as prl_ae
import _04_predictor as prl_pred

import importlib
importlib.reload(prl_dl)
importlib.reload(prl_ae)
importlib.reload(prl_pred)

<module '_04_predictor' from '/home/fengling/Documents/prl/prl_pytorch/_04_predictor.py'>

In [None]:
def isolate_lesion(patches_batch):
    lesion_mask_tensor = patches_batch["lesion_mask"]["data"]
    mask_tensor = torch.zeros_like(lesion_mask_tensor)
    for i in range(lesion_mask_tensor.size()[0]):
        tmp_lesion_mask = lesion_mask_tensor[i, :, :, :, :].clone()
        lesion_ids = tmp_lesion_mask.unique()

        if (len(lesion_ids) > 1): # lesion_id always has unique value of 0. Only need to blackout lesion if there are two lesions.
            id_to_keep = lesion_ids[randrange(len(lesion_ids) - 1) + 1]
            mask_tensor[i, :, :, :, :] = (tmp_lesion_mask == 0) + (tmp_lesion_mask == 10)
        else:
            mask_tensor[i, :, :, :, :] = torch.ones_like(tmp_lesion_mask)
    return(mask_tensor.repeat(1, 4, 1, 1, 1))

In [None]:
def get_lesion_type(patches_batch, isolation_mask_tensor):
    lesion_type_tensor = patches_batch["lesion_type"]["data"][:, 0, :, :, :]
    isolation_mask = isolation_mask_tensor[:, 0, :, :, :] # Originally is a [batch, 4, 24, 24, 24] tensor
    isolated_lesion_type = lesion_type_tensor * isolation_mask

    batch_size = isolated_lesion_type.size()[0]
    target_tensor = torch.zeros(batch_size, 3)
    weight_tensor = torch.zeros(batch_size, 3)

    for i in range(batch_size):
        tmp_unique = isolated_lesion_type[i, :, :, :].unique()
        tmp_unique = int(tmp_unique[len(tmp_unique) - 1].item())
        target_tensor[i, :] = process_lesion_type(tmp_unique)[0]
        weight_tensor[i, :] = process_lesion_type(tmp_unique)[1]

    return([target_tensor, weight_tensor])

In [None]:
def process_lesion_type(lesion_id): # Return tensor of [is_lesion, is_PRL, is_CVS]
    target = torch.zeros(3)
    weight = torch.ones(3)
    if (lesion_id == 0):
        return([target, weight]) # Just return the tensor of 0s

    digits = [int(x) for x in str(lesion_id)] 
    # First digit is always 1 for computational convenience
    
    if digits[1] == 1: # non-PRL lesion
        target[0] = 1

    if digits[1] == 2: # PRL lesion
        target[0] = 1
        target[1] = 1

    if digits[2] == 1: # non-CVS lesion
        target[0] = 1

    if digits[2] == 2: # possible CVS lesion (TODO try more sophisticated processing)
        target[0] = 1

    if digits[2] == 3:
        target[0] = 1
        target[2] = 1
        
    if digits[2] == 2 or digits[2] == 9: # possible CVS or no CVS data available for this subject
        weight = torch.tensor([1, 1, 0]) # Don't count errors in CVS column against the network

    return([target, weight])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

prl_autoencoder = prl_ae.Autoencoder3D()
prl_autoencoder = prl_autoencoder.to(device)

model_path = "prl_autoencoder_0703.pt"

# Load the saved model
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

# Load the state_dict from the checkpoint
prl_autoencoder.load_state_dict(checkpoint)

# Set the model to evaluation mode
prl_autoencoder.eval()

cpu


Autoencoder3D(
  (encoder): Sequential(
    (0): Conv3d(4, 16, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv3d(16, 32, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (7): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): Conv3d(64, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (11): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU()
    (14): Conv3d(256, 1024, ker

In [None]:
prl_predictor = prl_pred.Predictor3D()
prl_predictor = prl_predictor.to(device)

In [None]:
joint_optimizer = torch.optim.Adam(list(prl_autoencoder.encoder.parameters()) + 
                                   list(prl_predictor.parameters()), 
                                   lr=0.001)

In [None]:
num_epochs = 1
keys = ["t1", "flair", "epi", "phase"]
for epoch_index in range(num_epochs):
    epoch_loss = 0
        
    for patches_batch in prl_dl.patches_loader:
        joint_optimizer.zero_grad()
        isolation_mask_tensor = isolate_lesion(patches_batch)
        target_tensor, weight_tensor = get_lesion_type(patches_batch, isolation_mask_tensor)
        input_tensor = torch.cat([patches_batch.get(key)["data"] for key in keys], dim=1) * isolation_mask_tensor
        input_tensor = input_tensor.to(device)
        encoded_tensor = prl_autoencoder.get_latent(input_tensor)
        output_tensor = prl_predictor(encoded_tensor)
        loss = nn.BCELoss()(output_tensor * weight_tensor, target_tensor * weight_tensor)
        loss.backward()
        joint_optimizer.step()
        epoch_loss += float(loss)
    print("Batch " + str(epoch_index) + ": Loss = " + str(loss)) 

In [None]:
torch.save(prl_autoencoder.state_dict(), "prl_autoencoder_0625.pt")