In [None]:
import os
import torch

from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot(image, n_rows, n_cols, i_plot, binary = False, show = False, title = '', no_other_class_indexes = [1, 2, 3, 4, 5]):
    plt.subplot(n_rows, n_cols, i_plot)
    plt.xticks([])
    plt.yticks([])
    plt.title(title)
    
    if image.ndim < 3:
        if not binary:
            pred_shape = image.shape + (3,)
            prediction = (0xff, 0xff, 0xff) * np.ones(pred_shape)
            prediction[image == no_other_class_indexes[0]]  = (0x00, 0xff, 0xff) # Calcita
            prediction[image == no_other_class_indexes[1]]  = (0x00, 0x70, 0xc0) # Dolomita
            prediction[image == no_other_class_indexes[2]]  = (0xda, 0xa5, 0x20) # Mg-Argilominerais
            prediction[image == no_other_class_indexes[3]]  = (0x63, 0x63, 0x63) # Poros
            prediction[image == no_other_class_indexes[4]]  = (0xff, 0xff, 0x00) # Quartzo
            #prediction[np.where(image not in no_other_class_indexes)] = (0xff, 0xff, 0xff) # Outros
            plt.imshow(prediction/255.0)
        else:
            plt.imshow(image, cmap = 'gray')
    else:
        image_as_numpy = np.moveaxis(image.numpy(), 0, -1)
        #n_images = image_as_numpy.shape[-1]//3
        #plt.subplot(1, n_images, 1)
        plt.imshow(image_as_numpy[:, :, :3])
        #if n_images == 2:
        #    plt.subplot(1, n_images, 2)
        #    plt.imshow(image_as_numpy[:, :, 3:])
        '''
        plt.axis('off')
        plt.xticks([])
        plt.yticks([])
        plt.imshow(image_as_numpy[:, :, :3])
        plt.savefig('poros.png', bbox_inches = 'tight', pad_inches = 0, dpi = 34.7)
        #plt.show()
        #plt.axis('off')
        #plt.xticks([])
        #plt.yticks([])
        #plt.imshow(image_as_numpy[:, :, 3:])
        #plt.savefig('poros.png', bbox_inches = 'tight', pad_inches = 0)
        '''
    
    if show:
        plt.show()

In [None]:
cluster_model = None # 0: cluster 0; 1: cluster 1; None: full dataset

In [None]:
model_dir = os.path.join(os.sep, 'petrobr', 'parceirosbr', 'smartseg', 'thinsection', 'qemscan', 'models')

if cluster_model is None:
    saved_model = {
        'old': 'smartseg_thinsection_completo.pth',
        'new': 'aa30a928c34e4fbd99f46c7b.pth'
    }
elif cluster_model == 0:
    saved_model = 'smartseg_thinsection_texturafina.pth'
elif cluster_model == 1:
    saved_model = 'smartseg_thinsection_texturadensa.pth'
    
in_channels = 6
out_channels = 7

Utilizando aqui os recursos do MONAI para carregar uma imagem aleatória de exemplo. Este bloco pode ser substituído por qualquer rotina que gere uma imagem de entrada do tipo Tensor PyTorch na forma (Canais, Altura, Largura).

In [None]:
import glob
from monai.utils import first
from monai.data import DataLoader, Dataset
from monai.transforms import Compose, LoadImaged, AsChannelFirstd, ScaleIntensityRanged, EnsureTyped

transforms = Compose(
    [
        LoadImaged(keys = ['image', 'label']),
        AsChannelFirstd(keys = ['image']),
        ScaleIntensityRanged(keys = ['image'], a_min = 0, a_max = 255, b_min = 0.0, b_max = 1.0, clip = True),
        EnsureTyped(keys = ['image', 'label'])
    ]
)

test_images = ['FL/5400.25', 'SC/6398.50t', 'SC/6390.00t', 'SC/6355.00t', 'SC/6346.60t', 'SC/6340.50t', 'SA/6292.50', 'AR/6381.90',
              'AR/6380.45', 'AR/6378.05', 'AR/6376.65', 'AR/6374.95', 'M/5219.38', 'M/5168.05', 'M/5236.95', 'FL/5400.55',
              'LB/5477.70', 'YB/4822.00', 'SL/5174.50', 'AS/5675.70', 'AS/5672.00']
data_paths_regex  = os.path.join(os.sep, 'petrobr', 'parceirosbr', 'smartseg', 'datasets', 'qemscan', 'generated', \
                                    '*', '*', '10000x10000' + '_nii.gz', 'data',   '*.nii.gz')
label_paths_regex = data_paths_regex.replace('data' + os.sep, 'labels' + os.sep)
data_paths  = glob.glob(data_paths_regex)
label_paths = glob.glob(label_paths_regex)
data_paths = [{'image': dpath, 'label': lpath} for dpath, lpath in zip(data_paths, label_paths) if any(image in dpath for image in test_images)]
dataset = Dataset(data = data_paths, transform = transforms)
data_loader = DataLoader(dataset = dataset, shuffle = True)
#input = first(data_loader)

In [None]:
model = UNet(spatial_dims = 2, in_channels = in_channels, out_channels = out_channels, channels = (16, 32, 64, 128, 256), \
             strides = (2, 2, 2, 2), num_res_units = 2, norm = Norm.BATCH)

output = {}

for input in data_loader:
    for age in ['old', 'new']:
        model.load_state_dict(
            torch.load(os.path.join(model_dir, saved_model[age]))
        )
        _ = model.eval()

        output[age] = model(input['image'])#sliding_window_inference(inputs = input['image'], roi_size = (512, 512), sw_batch_size = 4, predictor = model)
        output[age] = torch.argmax(output[age][:, 1:], dim = 1)[0].float() + 1
        
    info = input['image_meta_dict']['filename_or_obj'][0].split(os.sep)
    im_name = info[7] + '_' + info[8]
    plt.suptitle(im_name)
    break
    plot(input['image'][0], 2, 2, 1, title = 'Image')
    plot(input['label'][0], 2, 2, 2, title = 'Label', no_other_class_indexes = [4, 6, 25, 14, 15])
    plot(output['old'], 2, 2, 3, title = 'Old prediction')
    plot(output['new'], 2, 2, 4, title = 'New prediction', show = True)

In [None]:
model = UNet(spatial_dims = 2, in_channels = in_channels, out_channels = out_channels, channels = (16, 32, 64, 128, 256), \
             strides = (2, 2, 2, 2), num_res_units = 2, norm = Norm.BATCH)

output = {}

for input in data_loader:
    for age in ['old', 'new']:
        model.load_state_dict(
            torch.load(os.path.join(model_dir, saved_model[age]))
        )
        _ = model.eval()

        output[age] = model(input['image'])#sliding_window_inference(inputs = input['image'], roi_size = (512, 512), sw_batch_size = 4, predictor = model)
        output[age] = torch.argmax(output[age][:, 1:], dim = 1)[0].float() + 1
        
    info = input['image_meta_dict']['filename_or_obj'][0].split(os.sep)
    im_name = info[7] + '_' + info[8]
    plt.suptitle(im_name)
    break
    plot(input['image'][0], 2, 2, 1, title = 'Image')
    plot(input['label'][0], 2, 2, 2, title = 'Label', no_other_class_indexes = [4, 6, 25, 14, 15])
    plot(output['old'], 2, 2, 3, title = 'Old prediction')
    plot(output['new'], 2, 2, 4, title = 'New prediction', show = True)