**Nota**: A imagem D_6212 foi deixada para trás durante o treinamento do modelo principal, embora conste nos arquivos de dataset (trainval.csv - treino e validação; trainfull_only.csv - treino do modelo principal), pois seus chunks não haviam sido gerados. Como os dados do segundo lote praticamente não possuem argilominerais, a imagem D_6212 (que possui) foi incluída para validação na comparação entre os seguintes modelos:  
* *Completo*: primeiro lote validado nas imagens de validação do segundo lote (as testadas e tabeladas pelo Matheus);
* *Novo*: primeiro lote + imagens de treino do segundo lote (todas menos as de validação);
* *Fine-tunning*: modelo *Completo* com treinamento extra com as imagens de treino do segundo lote.

In [None]:
# cost-sensitive
### Resumo:
### * (over/under)sampling - está na lista para teste (em termos de balanceamento óptico apenas)
### * pesos de classe:já utilizo, porém testar a configuração 
###### * melhor é o inverso da distribuição (já utilizo)
###### * também podem ser escolhidos arbitrariamente ou buscados (grid search)
###### * scikit-learn usa a configuração n_samples/(n_classes*n_samples_in_class), diferente da que uso (1/f = n_samples/n_samples_in_class)
### * Emsembling - entender

# [OK] WCE: corte 256 em vez de 512: algumas diferenças nos resultados, mas numericamente tão ruins quanto
# WCE: modelo antigo 1000:8x256 batch 2 (e, se der tempo, 1000:16x256 batch 1)
# WCE: modelo 256
# balanceamento pelo "Problemas_imagem" da tabela
# balanceamento de propriedades ópticas por oversampling em vez de probabilístico
# aumentar as imagens simulando tingimentos
# separar subclasses de minerais (por exemplo, calcita e calcita tingida)
# tentar os insights do penúltimo slide de baixo para cima (exceto frequência mediana, já comprovada que é similar à frequência inversa)

# Exemplo de segmentação de poros com MONAI + Aim/MLFlow

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/3d_segmentation/spleen_segmentation_3d.ipynb)

## Importando módulos

In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
    Transform,
    RandomizableTransform,
    Compose,
    EnsureType,
    AsDiscrete,
    Identityd,
    LoadImaged,
    AsChannelFirstd,
    AddChanneld,
    MapLabelValued,
    ScaleIntensityRanged,
    GaussianSmoothd,
    CenterSpatialCropd,
    RandSpatialCropSamplesd,
    RandScaleIntensityd,
    RandAdjustContrastd,
    RandAxisFlipd,
    RandRotate90d,
    FillHolesd,
    EnsureTyped,
    ToTensord,
    ToDeviced,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import Metric, DiceMetric
from monai.losses import DiceLoss, DiceCELoss, DiceFocalLoss, FocalLoss, TverskyLoss, GeneralizedWassersteinDiceLoss, GeneralizedDiceLoss
from monai.inferers import sliding_window_inference
from monai.data import SmartCacheDataset, CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import aim
from aim.pytorch import track_gradients_dists, track_params_dists
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import numpy as np
import pandas as pd
from matplotlib.colors import Normalize
from PIL import Image
import cv2
import nibabel as nib
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import time
from torch.utils.data import WeightedRandomSampler
from tqdm import tqdm

## Configuração do ambiente

In [None]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


print_config()

## Separando os arquivos de imagem para treino e validação

In [None]:
def get_rgb(color_hex):
    return (int(color_hex[1:3], 16), int(color_hex[3:5], 16), int(color_hex[5: ], 16))

def generate_rgb_image(color_hex, shape):
    image = np.zeros((*shape, 3))
    image[:, :, :] = get_rgb(color_hex)
    
    return image/255.0

def generate_rgb_map2(single_channel_map, is_channel_first = False, as_tensor = False):
    rgb_map = np.zeros((*single_channel_map.shape[int(is_channel_first):], 3))
    for i in range(element_data.index.max()):
        element = element_data[element_data.index == i + 1]
        single_channel_img = single_channel_map if not is_channel_first else single_channel_map[0]
        rgb_map = np.where(
            np.stack([single_channel_img, single_channel_img, single_channel_img], axis = 2) == i + 1,
            generate_rgb_image(element['color_hex'].values[0], single_channel_img.shape),
            rgb_map)
    
    if as_tensor:
        rgb_map = torch.Tensor(np.rollaxis(rgb_map, rgb_map.ndim - 1, 0))
    return rgb_map

def generate_rgb_map(single_channel_map, is_channel_first = False, as_tensor = False):
    rgb_map = np.zeros((*single_channel_map.shape[int(is_channel_first):], 3)).astype(np.uint8)
    for i in range(element_data.index.max()):
        element = element_data[element_data.index == i + 1]
        single_channel_img = single_channel_map if not is_channel_first else single_channel_map[0]
        elem_indexes = np.where(single_channel_img == i + 1)
        rgb_map[elem_indexes[0], elem_indexes[1], :] = get_rgb(element['color_hex'].values[0]) #generate_rgb_image(element['color_hex'].values[0], single_channel_img.shape)
    
    if as_tensor:
        rgb_map = torch.Tensor(np.rollaxis(rgb_map, rgb_map.ndim - 1, 0))
    return rgb_map/255.0

In [None]:
def plot_colormap():
    row = 1
    nrows = element_data.shape[0]
    for i, element in element_data.iterrows():
        if i > 0:
            plt.subplot(nrows, 1, row)
            plt.xticks([])
            plt.yticks([])
            plt.ylabel(str(i) + '. ' + element['Element'], rotation = 'horizontal', horizontalalignment = 'right')
            plt.imshow(generate_rgb_image(element['color_hex'], (30, 30)))
            row += 1
    plt.show()

In [None]:
def manage_shrink(file_dicts, extension, data_is_shrank, label_is_shrank):
    for fd in file_dicts:
        if data_is_shrank:
            fd['image'] = fd['image'].replace('_' + extension, '_' + extension + '_shrank')
        if label_is_shrank:
            fd['label'] = fd['label'].replace('_' + extension, '_' + extension + '_shrank')
    return file_dicts

In [None]:
def log_transform(transform):
    transforms_info = {}
    for t in transform.transforms:
        transform_function = str(t).split('.')[-1].split()[0]
        transforms_info[transform_function] = {'keys': t.__dict__['keys'], 'values': None}
        for key in t.__dict__:
            if hasattr(t.__dict__[key], '__dict__'):
                transforms_info[transform_function]['values'] = t.__dict__[key].__dict__
    return transforms_info 

In [None]:
def calculate_sample_weights(data_proportions, n_files, n_groups):
    sample_weights = []
    if n_files > 0:
        for entry in data_proportions['amount_by_section_id']:
            sample_weights += entry['amount'] * [n_files/data_proportions[entry['group']]]
    num_samples = n_groups * sample_weights.count(min(sample_weights))
    return sample_weights, num_samples

def get_group(section_id):
    if weight_samples['by_section']:
        return section_id[0]
    if weight_samples['by_face']:
        if not (data_register['Código'] == section_id).any():
            print('ATENÇÃO:', section_id, 'não está na lista.')
            return 'undefined'
        face = data_register.loc[data_register['Código'] == section_id, 'Fácies'].values[0]
        if 'shrub' in face.lower():
            return 'shrub'
        if face == 'FLTgst':
            return 'GST'
        return face[:3]
    return ''

def balance(data_files, by = 'section'):
    from imblearn.over_sampling import RandomOverSampler
    assert(by in ['section', 'face'])
    
    X = np.array(data_files).reshape(-1, 1)
    if by == 'section':
        y = [path['image'].split(os.sep)[-5] for path in data_files]
    else:
        y = [path['group'] for path in data_files]
    
    return RandomOverSampler(sampling_strategy = 'not majority').fit_resample(X, y)[0].flatten()

In [None]:
py_mode = False

if py_mode:
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', dest = 'dataset_file', type = str)
    parser.add_argument('--shrank-train-data', dest = 'shrank_train_data', action = 'store_true')
    parser.add_argument('--shrank-train-label', dest = 'shrank_train_label', action = 'store_true')
    parser.add_argument('--shrank-val-data', dest = 'shrank_val_data', action = 'store_true')
    parser.add_argument('--shrank-val-label', dest = 'shrank_val_label', action = 'store_true')
    parser.add_argument('--isolate-bg', dest = 'isolate_bg', action = 'store_true')
    parser.add_argument('--include-bg-loss', dest = 'include_bg_loss', action = 'store_true')
    parser.add_argument('--include-bg-metric', dest = 'include_bg_metric', action = 'store_true')
    parser.add_argument('--siliciclastics', dest = 'siliciclastics_model', action = 'store_true')
    parser.add_argument('--experiment', dest = 'experiment', type = str, default = 'test')
    parser.add_argument('--epochs', dest = 'max_epochs', type = int, default = 50)
    parser.add_argument('--val-interval', dest = 'val_interval', type = int, default = 50)
    parser.add_argument('--intensity-aug', dest = 'intensity_aug', action = 'store_true')
    parser.add_argument('--weight-samples', dest = 'weight_samples', action = 'store_true')
    args = parser.parse_args()
    
    train_val_data_file = args.dataset_file
    shrank = {
        'train_data': args.shrank_train_data,  
        'train_label': args.shrank_train_label,
        'val_data': args.shrank_val_data,
        'val_label': args.shrank_val_label, 
    }
    isolate_background = args.isolate_bg
    include_background = {
        'loss': args.include_bg_loss,
        'metric': args.include_bg_metric
    }
    siliciclastics_model = args.siliciclastics_model
    experiment = args.experiment
    max_epochs = args.max_epochs
    val_interval = args.val_interval
    intensity_aug = args.intensity_aug
    weight_samples = args.weight_samples
else:
    project = 'qemscan' # qemscan / poreseg
    train_val_data_file = os.path.join(project, 'dataset_files', 'trainval_new.csv')
    isolate_background = False
    shrank = {
        'train_data': False,  
        'train_label': False,
        'val_data': False,
        'val_label': False, 
    }
    include_background = { # normally: qemscan - False/False; poreseg - True/True
        'loss': False,
        'metric': False
    }
    siliciclastics_model = False
    experiment = 'new_data.val_matheus+d6212_argilo' # remember to change
    max_epochs = 75
    val_interval = 15 # None for no validation
    intensity_aug = False
    weight_samples = {
        'by_section': False,
        'by_face': False
    }
    balance_by_oversampling = False
assert(weight_samples['by_section'] == False or weight_samples['by_face'] == False)
assert(project in ['qemscan', 'poreseg'])
    
device = torch.device(0) if torch.cuda.is_available() else 'cpu'
    
# use None to train a brand new model
model_to_load = None
show_output_comparison = False
multigpu = False

In [None]:
pp_only = False # normally: qemscan - False; poreseg - True
binary = False
nifti = True
exclude_pore = False
img_size = 1000 # normally: qemscan - 1000; poreseg - 128
crop_spatial_size = (512, 512)
num_crops_per_img = 1
test_background_rate_max = 5e-3
ext = 'nii.gz' if nifti else 'tif'
smart_cache = True
dataset_cache_rate = 0.5
dataset_replace_rate = 1.0

if not siliciclastics_model:
    interest_classes = ['Calcita (0% a 1%MgO)', 'Dolomita', 'Mg-Argilominerais', 'Poros', 'Quartzo', 'Outros']
else:
    interest_classes = ['Quartzo', 'Feldspato', 'Argilas', 'Poros', 'Outros']
others_class = 'Outros'

binary             = binary or (project == 'poreseg')
isolate_background = isolate_background or not binary
exclude_pore       = exclude_pore and not binary

train_val_data = pd.read_csv(train_val_data_file, dtype = {'depth': str})
data_dir = os.path.join(os.sep, 'petrobr', 'parceirosbr', 'smartseg', 'datasets', project, 'generated')
if project == 'qemscan':
    data_register = pd.read_csv(os.path.join('qemscan', 'register.csv'))

useful_train_val_entries = train_val_data[train_val_data['train'].astype(bool) | train_val_data['val'].astype(bool)]

train_files = []
val_files   = []
data_amounts = {'amount_by_section_id': []}
dataset_log = {'train': [], 'val': []}
for i, train_val_entry in useful_train_val_entries.iterrows():
    section  = train_val_entry['section']
    depth    = train_val_entry['depth']
    to_train = bool(train_val_entry['train'])
    to_val   = bool(train_val_entry['val'])
    section_id = section + '_' + depth
    
    group = get_group(section_id)
    if group == 'undefined':
        continue
    
    if len(group) > 0 and group not in data_amounts:
        data_amounts[group] = 0
    
    #if to_train:
    #    dataset_log['train'].append(section_id)
    #if to_val:
    #    dataset_log['val'].append(section_id)
    
    train_val_images = sorted(
        glob.glob(os.path.join(data_dir, section, depth,
                               str(img_size) + 'x' + str(img_size) + '_' + ext, 'data',   "*." + ext)))

    train_val_labels = sorted(
        glob.glob(os.path.join(data_dir, section, depth,
                               str(img_size) + 'x' + str(img_size) + '_' + ext, 'labels', "*." + ext)))
    #-#labels_proportions = pd.read_csv(os.path.join(data_dir, section, depth,
    #-#                          str(img_size) + 'x' + str(img_size) + '_' + ext, 'proportions.csv'))
    data_dicts = [
        {"image": image_name, "label": label_name, 'group': group}
        for image_name, label_name in zip(train_val_images, train_val_labels)
    ]
    
    if to_train and to_val:
        val_size = int(0.2 * len(data_dicts))
        train_files += data_dicts[:-val_size]
        val_files   += data_dicts[-val_size:]
        if len(group) > 0:
            data_amounts['amount_by_section_id'].append({'amount': len(data_dicts[:-val_size]), 'group': group})
            data_amounts[group] += data_amounts['amount_by_section_id'][-1]['amount']
    else:
        if to_val:
            val_files += data_dicts
            #-#val_proportions += list(labels_proportions.values[:, 1:])
        else:
            train_files += data_dicts
            if len(group) > 0:
                data_amounts['amount_by_section_id'].append({'amount': len(data_dicts), 'group': group})
                data_amounts[group] += data_amounts['amount_by_section_id'][-1]['amount']

    if to_train:
        dataset_log['train'].append((section_id, len(data_dicts)))
    if to_val:
        dataset_log['val'].append((section_id, len(data_dicts)))
                
train_files = manage_shrink(train_files, ext, shrank['train_data'], shrank['train_label'])
val_files   = manage_shrink(val_files,   ext, shrank['val_data'],   shrank['val_label'])

sample_weights = None
proportions_by_group = {}
if weight_samples['by_section'] or weight_samples['by_face']:
    group_type = 'section' if weight_samples['by_section'] else 'face'
    n_groups = len(data_amounts.keys()) - 1
    for group in data_amounts.keys():
        if group.startswith('amount'):
            continue
        proportions_by_group[group] = data_amounts[group]/len(train_files)
        
    if balance_by_oversampling:
        print('Balanceamento por superamostragem de grupo (' + group_type + ')...')
        train_files = balance(train_files, by = group_type)
        print('\t*', len(train_files), 'de treino no total:', len(train_files)//n_groups, 'por grupo.')
    else:
        print('Balanceando por pesos de grupo (' + group_type + ')...')
        sample_weights, num_samples = calculate_sample_weights(data_amounts, len(train_files), n_groups)
        print('\t* As', len(train_files), 'serão amostradas com probabilidade inversa à frequência do grupo, ' + \
            'por', num_samples, 'vezes (quantidade equivalente a de imagens por superamostragem).')

len_train_log = len(dataset_log['train'])
len_val_log   = len(dataset_log['val'])
print('\nConjuntos de dados\n\tTreino:', len(train_files), f'({len_train_log})\n\t\t', dataset_log['train'], '\n\t\t* Proporções por grupo:', proportions_by_group,\
      '\n\tValidação/Teste:', len(val_files), f'({len_val_log})\n\t\t', dataset_log['val'])

In [None]:
def class_remapping(element_data, remap_condition, target_class, target_color):
    n_occur = element_data[remap_condition].shape[0]
    element_data.loc[remap_condition, 'Element'] = target_class
    element_data.loc[remap_condition, 'color_hex'] = target_color
    
    return element_data

element_data = pd.read_csv(os.path.join(data_dir, '..', 'unified_labels.csv'), index_col = 0)

orig_labels = np.arange(element_data.shape[0] + 1)
target_labels = orig_labels.copy()
pore_label = element_data[element_data['Element']=='Poros'].index[0]
if binary:
    new_pore_label = 2 if isolate_background else 1
    new_non_pore_label = new_pore_label - 1

    target_labels[np.where((orig_labels != pore_label) & (orig_labels != 0))] = new_non_pore_label
    target_labels[np.where(orig_labels == pore_label)] = new_pore_label
else:
    print('Original labels...')
    plot_colormap()

    if siliciclastics_model:
        element_data = class_remapping(element_data, element_data['Element'] == 'Albita', 'Feldspato', '#ff0000')
        element_data = class_remapping(element_data,
                                       (element_data['Element'] == 'Mg-Argilominerais') | \
                                       (element_data['Element'] == 'Caulinita') | \
                                       element_data['Element'].str.contains('Esmectita'), \
                                       'Argilas', '#00ff00')
    element_data = class_remapping(element_data, ~element_data['Element'].isin(interest_classes), others_class, \
                                   element_data.loc[element_data['Element'] == others_class, 'color_hex'].values[0])

    target_labels = np.array([0] + [interest_classes.index(element) + 1 for element in element_data['Element']])
    element_data = element_data.set_index(target_labels[1:])

    print('Remapped labels...')
    plot_colormap()

    element_data = element_data.set_index([target_labels[1:]]).drop_duplicates().sort_index()

    print('Compacted labels...')
    plot_colormap()

elements = (['Desconhecido'] if include_background['metric'] else []) + element_data['Element'].tolist()
n_classes = np.unique(target_labels).size
print('Label remapping:', orig_labels, '->', target_labels)
print('Pore label:', pore_label)
print('Classes:', n_classes)

In [None]:
def check_binary_color_dist(image_files):
    hist = {}
    for i, image_file in enumerate(image_files):
        image = nib.load(image_file['image']).get_fdata()[:, :, :3]
        label = nib.load(image_file['label']).get_fdata()
        
        for class_ in interest_classes:
            if class_ not in hist:
                hist[class_] = {}
            
            im_class = image[label == element_data[element_data['Element']==class_].index[0]]
        
            for channel in range(3):
                if i == 0:
                    hist[class_][channel] = np.histogram(im_class[:, channel], bins = 256, range = (0, 255))[0]
                else:
                    hist[class_][channel] = np.mean(
                        np.append([hist[class_][channel]], [np.histogram(im_class[:, channel], bins = 256, range = (0, 255))[0]], axis = 0), axis = 0
                    )
        
        if (i + 1) % (len(image_files)//10) == 0:
            print(i + 1, '/', len(image_files))
    
    plt.figure(figsize = (16, 10))
    
    for i, class_ in enumerate(interest_classes):
        plt_index = i + 4
        plt.subplot(3, len(interest_classes), i + 1),                           plt.bar(range(256), hist[class_][0], color = 'red'), plt.title(class_)
        plt.subplot(3, len(interest_classes), i + 1 + len(interest_classes)),   plt.bar(range(256), hist[class_][1], color = 'green')
        plt.subplot(3, len(interest_classes), i + 1 + 2*len(interest_classes)), plt.bar(range(256), hist[class_][2], color = 'blue')
    
    plt.show()

#check_binary_color_dist(train_files) # (160, 220), (120, 180), (80, 140)

## Experimento determinístico para reprodutibilidade

In [None]:
set_determinism(seed=0)

## Transformações

In [None]:
class SelectChannelsd(Transform):
    def __init__(self, keys, n_channels):
        self.keys = keys
        self.n_channels = n_channels
    
    def __call__(self, img_dict):
        img_dict['image'] = img_dict['image'][:self.n_channels]
        return img_dict
    
class RemoveLabelExcessd(Transform):
    def __init__(self, keys, channel_ranges, non_pore_label, pore_label):
        self.keys = keys
        self.channel_ranges = channel_ranges
        self.non_pore_label = non_pore_label
        self.pore_label = pore_label
    
    def __call__(self, img_dict):
        final_label = np.zeros(img_dict['image'].shape[-2:]) + self.non_pore_label
        ch0t = (img_dict['image'][0] >= self.channel_ranges[0][0]) & \
            (img_dict['image'][0] <= self.channel_ranges[0][1])
        ch1t = (img_dict['image'][1] >= self.channel_ranges[1][0]) & \
            (img_dict['image'][1] <= self.channel_ranges[1][1])
        ch2t = (img_dict['image'][2] >= self.channel_ranges[2][0]) & \
            (img_dict['image'][2] <= self.channel_ranges[2][1])
        
        prev_label = img_dict['label'][0].copy()
        
        final_label[img_dict['label'][0] == 0] = 0
        final_label[ch0t & ch1t & ch2t & (img_dict['label'][0] == self.pore_label)] = self.pore_label
        
        plt.figure(figsize = (36, 36))
        imshow = np.rollaxis(np.array(img_dict['image']), 0, 3)[:, :, :3]/255
        plt.subplot(1, 2, 1), plt.imshow(imshow)
        imshow[img_dict['label'][0] != self.pore_label] = 0
        plt.subplot(1, 2, 2), plt.imshow(imshow)
        #plt.subplot(1, 2, 2), plt.imshow(img_dict['label'][0]), plt.title('QEMSCAN\nPORO - AMARELO')
        #plt.subplot(2, 2, 3), plt.imshow(final_label)
        
        img_dict['label'][0] = final_label
        
        #print(prev_label.all() == final_label.all())
        
        #plt.subplot(2, 2, 4), plt.imshow(img_dict['label'][0])
        
        return img_dict

class Erosiond(Transform):
    def __init__(self, keys, kernel_size, n_iter):
        self.keys = keys
        self.offset = kernel_size//2
        self.n_iter = n_iter
        self.i = 0
    
    def __call__(self, img_dict):
        for i in range(self.n_iter):
            image = img_dict['label'].copy()
            for row in range(0, image.shape[1], 2*self.offset):
                for col in range(0, image.shape[2], 2*self.offset):
                    receptive_field = image[:, max(0, row-self.offset):min(image.shape[1], row+self.offset+1), max(0, col-self.offset):min(image.shape[2], col+self.offset+1)]
                    if np.unique(receptive_field).size > 1:
                        img_dict['label'][:, max(0, row-self.offset):min(image.shape[1], row+self.offset+1), max(0, col-self.offset):min(image.shape[2], col+self.offset+1)] = 0
        return img_dict

class Smoothd(Transform):
    def __init__(self, keys, kernel_size, n_iter):
        self.keys = keys
        self.offset = kernel_size//2
        self.n_iter = n_iter
        self.i = 0
    
    def __call__(self, img_dict):
        for i in range(self.n_iter):
            image = img_dict['label'].clone()
            for row in range(0, image.shape[1]):
                for col in range(0, image.shape[2]):
                    receptive_field = image[:, max(0, row-self.offset):min(image.shape[1], row+self.offset+1), max(0, col-self.offset):min(image.shape[2], col+self.offset+1)]
                    if torch.unique(receptive_field).size()[0] > 1:
                        img_dict['label'][:, row, col] = self.mode(receptive_field)
        return img_dict
    
    def mode(self, array):
        return torch.mode(array.flatten()).values.item()

class RandAugmentIntensityByLabeld(RandomizableTransform):
    def __init__(self, keys, labels, offset_ranges_per_label, min = 0, max = 255, prob = 1.0):
        self.keys = keys
        self.labels = labels
        self.offset_ranges_per_label = offset_ranges_per_label
        self.min = min
        self.max = max
    
    def __call__(self, img_dict):
        for i, label in enumerate(self.labels):
            offset_ranges = self.offset_ranges_per_label[i]
            for channel in range(len(offset_ranges)):
                offset = self.R.uniform(low = offset_range[channel][0], high = offset_range[channel][1])
                img_dict['image'][np.where(img_dict['image'] == label)][channel] += offset
        img_dict['image'] = np.clip(img_dict['image'], self.min, self.max)
        return img_dict

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),                           # carregamento das imagens;
        AsChannelFirstd(keys=["image"]),                               # alterar formato de imagem para ter os canais como primeira dimensão;
        SelectChannelsd(keys=["image"], n_channels = 3 if pp_only else 6),
        AddChanneld(keys=["label"]),                      # adicionar dimensão de canal na máscara originalmente HxW;
        MapLabelValued(keys=["label"], orig_labels=orig_labels, target_labels=target_labels) \
            if not np.array_equal(orig_labels, target_labels) else Identityd(keys=["label"]),
        #EnsureChannelFirstd(keys=["image", "label"]),                 # garantia que a primeira dimensão da imagem são os canais;
        #Orientationd(keys=["image", "label"], axcodes="RAS"),         # orientação do volume;
        #Spacingd(keys=["image", "label"], pixdim=(                    # fator no qual cada dimensão do volume é reduzida;
        #    1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        #RemoveLabelExcessd(keys=["image", "label"], channel_ranges = [(160, 220), (120, 180), (80, 140)],
        #                 non_pore_label = new_non_pore_label, pore_label = new_pore_label) \
        #    if binary else Identityd(keys=["image", "label"]),
        #Erosiond(keys=["label"], kernel_size = 9, n_iter = 1),
        ##RandAugmentIntensityByLabeld(keys=["image"], labels = range(1, n_classes),
        ##                 offset_ranges_per_label = [
        ##                     [(0, 0), (0, 0), (0, 0)],
        ##                     [(0, 0), (0, 0), (0, 0)],
        ##                     [(0, 0), (0, 0), (0, 0)],
        ##                     [(0, 0), (0, 0), (0, 0)],
        ##                     [(0, 0), (0, 0), (0, 0)],
        ##                     [(0, 0), (0, 0), (0, 0)]
        ##                 ],
        ##prob = 0.5),
        #FillHolesd(keys=['label'], applied_labels = new_pore_label) \
        #    if binary else FillHolesd(keys=['label'], applied_labels = None, connectivity = 1),
        ScaleIntensityRanged(                                          # normalização;
            keys=["image"], a_min=0, a_max=255,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        #-$RandCropByPosNegLabeld(                                        # cortes aleatórios na imagem para aumento de dados;
        #-$    keys=["image", "label"],
        #-$    label_key="label",
        #-$    spatial_size=crop_spatial_size,
        #-$    pos=1,
        #-$    neg=1,
        #-$    num_samples=num_crops_per_img,
        #-$    image_key="image",
        #-$    image_threshold=0
        #-$),
        #CropForegroundd(keys=["image", "label"], source_key="image"),  # descarte de background nos entornos;
        #HistogramNormalized(keys=["image"], num_bins = 5, max = 1),
        
        RandSpatialCropSamplesd(
            keys=["image", "label"],
            roi_size=crop_spatial_size,
            num_samples=num_crops_per_img,
            random_size=False
        ) if img_size > 256 else Identityd(keys=["image", "label"]),
        
        #EnsureTyped(keys=["image", "label"]),#ToTensord(keys=['image','label']),
        #ToDeviced(keys=['image','label'], device = 'cuda:0'),
        #Smoothd(keys=["label"], kernel_size = 7, n_iter = 3),
        #EnsureTyped(keys=["image", "label"]),#ToTensord(keys=['image','label']),
        #ToDevice(keys=['image','label'], device = 'cpu'),
        
        #RandCropByLabelClassesd(
        #    keys=["image", "label"],
        #    spatial_size=crop_spatial_size,
        #    image_key='image',
        #    label_key='label',
        #    num_classes=n_classes,
        #    ratios=[1, 1, 1, 10, 1, 1, 10]
        #),
        #RandSmoothDeformd(keys=['label'], spatial_size = crop_spatial_size, rand_size = 100),
        RandScaleIntensityd(keys=["image"], factors = 0.25, prob = 0.3) if intensity_aug else Identityd(keys=["image"]),
        RandAdjustContrastd(keys=["image"], prob = 0.3) if intensity_aug else Identityd(keys=["image"]),
        #RandGaussianSmoothd(keys=["image"], prob = 0.3),
        RandAxisFlipd(keys=["image", "label"], prob = 0.3),
        #RandZoomd(keys=["image", "label"]),
        RandRotate90d(keys=["image", "label"], prob = 0.3),
        # RandAffined(                                                 # transformações aleatórias customizáveis;
        #     keys=['image', 'label'],
        #     mode=('bilinear', 'nearest'),
        #     prob=1.0, spatial_size=(96, 96, 96),
        #     rotate_range=(0, 0, np.pi/15),
        #     scale_range=(0.1, 0.1, 0.1)),
        EnsureTyped(keys=["image", "label"]),                          # assegura que os dados estão em formato compatível.
    ]
)
        
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys=["image"]),
        SelectChannelsd(keys=["image"], n_channels = 3 if pp_only else 6),
        AddChanneld(keys=["label"]),
        MapLabelValued(keys=["label"], orig_labels=orig_labels, target_labels=target_labels) \
            if not np.array_equal(orig_labels, target_labels) else Identityd(keys=["label"]),
        #EnsureChannelFirstd(keys=["image", "label"]),
        #Orientationd(keys=["image", "label"], axcodes="RAS"),
        #Spacingd(keys=["image", "label"], pixdim=(
        #    1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        #RemoveLabelExcessd(keys=["image", "label"], channel_ranges = [(160, 220), (120, 180), (80, 140)],
        #                 non_pore_label = new_non_pore_label, pore_label = new_pore_label) \
        #    if binary else Identityd(keys=["image", "label"]), 
        #FillHolesd(keys=['label'], applied_labels = new_pore_label) \
        #    if binary else FillHolesd(keys=['label'], applied_labels = None, connectivity = 1),
        #Erosiond(keys=["label"], kernel_size = 9, n_iter = 1),

        ScaleIntensityRanged(
            keys=["image"], a_min=0, a_max=255,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        #0#CenterSpatialCropd(
        #0#    keys=["image", "label"],
        #0#    roi_size=crop_spatial_size,
        #0#) if img_size > 256 else Identityd(keys=["image", "label"]),
        
        #EnsureTyped(keys=["image", "label"]),#ToTensord(keys=['image','label']),
        #ToDeviced(keys=['image','label'], device = 'cuda:0'),
        #Smoothd(keys=["label"], kernel_size = 7, n_iter = 3),
        #EnsureTyped(keys=["image", "label"]),#ToTensord(keys=['image','label']),
        #ToDevice(keys=['image','label'], device = 'cpu'),
        
        #CropForegroundd(keys=["image", "label"], source_key="image"),
        #HistogramNormalized(keys=["image"], num_bins = 5, max = 1),
        EnsureTyped(keys=["image", "label"]),
        #adaptor(RemoveLabelExcess(), {'image': 'image', 'label': 'label'})
    ]
)

## Pré-visualização de imagem e rótulo

In [None]:
def plot_dense_proportions(elements, dense_proportions):
    for key in ['true', 'pred']:
        dense_proportions[key] = dense_proportions[key].cpu().numpy()
    
    for i, element in enumerate(elements):
        bar_color = element_data[element_data['Element'] == element]['color_hex'].values[0] if element != 'Desconhecido' else 'black'
        plt.bar(x = element, height = dense_proportions['true'][i], color = bar_color, align = 'edge', width = -0.4, edgecolor = 'blue')
        plt.bar(x = element, height = dense_proportions['pred'][i], color = bar_color, align = 'edge', width =  0.4, edgecolor = 'red')
    plt.xticks(rotation = 'vertical')
    plt.ylabel('Proportion')
    plt.legend(
        handles = [
            plt.plot([], label = 'Real', color = 'blue')[0],
            plt.plot([], label = 'Predicted', color = 'red')[0]
        ]
    )
    fig = plt.figure()
    plt.show()
    return fig

In [None]:
class Proportion(Metric):
    def __init__(self, n_classes, include_background = True, argmax = True):
        self.n_classes = n_classes
        self.include_background = include_background
        self.argmax = argmax
        
        self.reset()
    
    def __call__(self, y_pred, y, image_path):
        if self.argmax:
            y_1ch      = torch.argmax(y[0],      dim = 0).to(device)
            y_pred_1ch = torch.argmax(y_pred[0], dim = 0).to(device)
        else:
            y_1ch      = y.int().to(device)
            y_pred_1ch = y_pred.int().to(device)   
        
        image_path_components = image_path.split(os.sep)
        section_id = image_path_components[-5] + '_' + image_path_components[-4]
        
        self.true = self.__calculate(y_1ch,      self.true, section_id)
        self.pred = self.__calculate(y_pred_1ch, self.pred, section_id)
    
    def __calculate(self, input, output, section_id):
        if not self.include_background:
            input = input[input != 0]
        
        proportions = torch.bincount(input.flatten(), minlength = self.n_classes) / torch.numel(input)
        proportions = proportions.reshape(1, -1)
        if section_id not in output:
            output[section_id] = proportions
        else:
            output[section_id] = torch.cat((output[section_id], proportions), dim = 0)
        return output
    
    def aggregate(self):
        first_class = int(not self.include_background)

        metrics = {'total': {}}
        section_ids = self.true.keys()
        for section_id in section_ids:
            metrics[section_id] = {
                'true': self.true[section_id].mean(dim = 0)[first_class:],
                'pred': self.pred[section_id].mean(dim = 0)[first_class:],
            }

        metrics['total']['true'] = torch.mean(
            torch.cat(tuple([metrics[section_id]['true'].reshape(1, -1) for section_id in section_ids]), dim = 0), dim = 0).cpu()
        metrics['total']['pred'] = torch.mean(
            torch.cat(tuple([metrics[section_id]['pred'].reshape(1, -1) for section_id in section_ids]), dim = 0), dim = 0).cpu()
        
        X, y = metrics['total']['pred'].reshape(-1, 1), metrics['total']['true'].reshape(-1, 1)
        linear_regression = LinearRegression().fit(X, y)

        metrics['total']['R2'] = linear_regression.score(X, y)
        metrics['total']['RMSE'] = mean_squared_error(y, linear_regression.predict(X), squared = False)
        
        return metrics

    def reset(self):
        self.true = {}
        self.pred = {}

In [None]:
train_proportions = np.array([0.4124453,  0.18595047, 0.06884056, 0.04539315, 0.2623897,  0.02498083]) #np.array([0.41365188, 0.18531045, 0.06884208, 0.04491788, 0.26241738, 0.02486032]) #np.array([0.4019435,  0.20186418, 0.06687519, 0.04392304, 0.25787783, 0.02751627])

In [None]:
check_ds = Dataset(data=train_files, transform=train_transforms)
seed = np.random.randint(10000)
print(seed)
torch.manual_seed(seed)
check_loader = DataLoader(check_ds, batch_size=1, shuffle = True)
check_data = first(check_loader)
print('check_data[\"image\"].shape = ', check_data["image"].shape)
print('check_data[\"label\"].shape = ', check_data["label"].shape)
for i in range(check_data["image"].shape[0]):
    image, label = (check_data["image"][i], check_data["label"][i])
    image_pp = np.rollaxis(np.array(image), 0, 3)[:, :, :3] # colocando os canais como última dimensão
    image_px = np.rollaxis(np.array(image), 0, 3)[:, :, 3:]
    print(i, ')', f"image shape: {image.shape}, label shape: {label.shape}")
    plt.figure("check", (12, 6))
    plt.subplot(1, 3, 1)
    plt.title("image")
    plt.imshow(image_pp)
    plt.subplot(1, 3, 2)
    plt.title("label")
    #plt.imshow(np.where(np.stack([label[0], label[0], label[0]], axis = 2) != pore_label, image_pp, (image_pp + 1)/2))
    label_to_show = generate_rgb_map((label[0]))
    print(label_to_show.min(), label_to_show.mean(), label_to_show.max())
    plt.imshow((label_to_show + image_pp)/2.0)
    if not pp_only:
        plt.subplot(1, 3, 3)
        plt.title("image PX")
        plt.imshow(image_px)
    plt.show()
print(elements)
if train_proportions is None:
    print('Calculando proporções dos elementos nos dados de treino...')
    check_ds = Dataset(data=train_files, transform=val_transforms)
    check_loader = DataLoader(check_ds, batch_size=1, shuffle = False)
    train_proportions = Proportion(n_classes=n_classes, include_background=include_background['metric'], argmax=False)
    for i, data in enumerate(check_loader):
        if np.random.uniform() < 0.01:
            print(int(100*i/(len(check_ds)//check_loader.batch_size)), '%', end = ' == ')
        train_proportions(y_pred = data['label'], y = data['label'], image_path = data['image_meta_dict']['filename_or_obj'][0])
    train_proportions = train_proportions.aggregate()
    plot_dense_proportions(elements, train_proportions['total'])
    train_proportions = train_proportions['total']['true']
    print(train_proportions)
else:
    print('ATENÇÃO: as proporções já estão calculadas. Usando:')
    print()
    print(train_proportions)
    plot_dense_proportions(elements, {'true': torch.Tensor(train_proportions), 'pred': torch.Tensor(len(train_proportions) * [0])})
    print('Para recalculá-las, execute a célula que atribui None à variável train_proportions.')

hist_check_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        MapLabelValued(keys=["label"], orig_labels=orig_labels, target_labels=target_labels) \
            if not np.array_equal(orig_labels, target_labels) else Identityd(keys=["label"]),
        EnsureTyped(keys=["image", "label"]),
    ]
)

image_files = train_files
bins = 5
#def check_binary_color_dist(image_files, bins):
check_ds = Dataset(data = image_files, transform = hist_check_transforms)
check_loader = DataLoader(check_ds, batch_size = 1, shuffle = False)

hist = {}
for i, data in enumerate(check_loader):
    image = data['image'][0, :, :, :3]
    label = data['label'][0]

    for class_ in interest_classes:
        if class_ not in hist:
            hist[class_] = {}

        im_class = image[label == element_data[element_data['Element']==class_].index[0]]

        if i == 0:
            hist[class_] = np.histogramdd(im_class, bins = bins, range = [(0, 255), (0, 255), (0, 255)])[0]
        else:
            hist[class_] = np.mean(
                np.append([hist[class_]], [np.histogramdd(im_class, bins = bins, range = [(0, 255), (0, 255), (0, 255)])[0]], axis = 0), axis = 0
            )

    if (i + 1) % (len(image_files)//10) == 0:
        print(i + 1, '/', len(image_files))

for i, class_ in enumerate(interest_classes):
    ax = plt.figure(figsize = (6, 6)).add_subplot(projection = '3d')
    ax.set_title(class_)
    ax.set_xlabel('R')
    ax.set_ylabel('G')
    ax.set_zlabel('B')

    binsR, binsG, binsB = np.meshgrid(range(bins), range(bins), range(bins))
    ax.scatter3D(binsR, binsG, binsB, s = 1000*hist[class_][binsR, binsG, binsB]/hist[class_][binsR, binsG, binsB].max())

    plt.show()

#check_binary_color_dist(train_files, 5)

correct_files = [file for file in train_files if ('A/5242.95' in file['image'])]
check_ds = Dataset(data=correct_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1, shuffle = False)
for i, check_data in enumerate(check_loader):
    if i not in [3, 4, 14, 15, 27, 39, 51, 63, 75, 86, 98, 99, 110, 122, 123, 134, 135, 145, 146]:
        continue
    
    image, label = (check_data["image"][0], check_data["label"][0])
    image_pp = np.rollaxis(np.array(image), 0, 3)[:, :, :3] # colocando os canais como última dimensão
    image_px = np.rollaxis(np.array(image), 0, 3)[:, :, 3:]
    print(i, ')', f"image shape: {image.shape}, label shape: {label.shape}")
    print(check_data['image_meta_dict']['filename_or_obj'][0])
    plt.figure("check", (12, 6))
    plt.subplot(1, 3, 1)
    plt.title("image")
    plt.imshow(image_pp)
    plt.subplot(1, 3, 2)
    plt.title("label")
    #plt.imshow(np.where(np.stack([label[0], label[0], label[0]], axis = 2) != pore_label, image_pp, (image_pp + 1)/2))
    if binary:
        plt.imshow(label[0], norm = Normalize(0, target_labels.max()))
    else:
        label_to_show = generate_rgb_map(label[0])
        plt.imshow((label_to_show + image_pp)/2.0)
    if not pp_only:
        plt.subplot(1, 3, 3)
        plt.title("image PX")
        plt.imshow(image_px)
    plt.show()

## Definindo modelo e _loss_

In [None]:
class ConfusionMatrix(Metric):
    def __init__(self, n_classes, include_background = True, reduction = 'mean'):
        assert(reduction in ['mean', 'norm_mean', 'sum'])
        self.first_class = int(not include_background)
        
        self.n_classes = n_classes - self.first_class
        self.include_background = include_background
        self.reduction = reduction
        
        self.reset()
    
    def __call__(self, y_pred, y, image_path):
        self.cumul_matrices += 1
        
        y_1ch      = torch.argmax(y[0],      dim = 0).to(device)
        y_pred_1ch = torch.argmax(y_pred[0], dim = 0).to(device)
        
        image_path_components = image_path.split(os.sep)
        section_id = image_path_components[-5] + '_' + image_path_components[-4]
        
        if section_id not in self.matrix:
            self.matrix[section_id] = torch.zeros(self.n_classes, self.n_classes).to(device)
        
        for actual_class in range(self.first_class, self.n_classes + self.first_class):
            for pred_class in range(self.first_class, self.n_classes + self.first_class):
                self.matrix[section_id][actual_class - self.first_class][pred_class - self.first_class] += \
                    ((y_1ch == actual_class) & (y_pred_1ch == pred_class)).count_nonzero()
        
        return self.matrix
    
    def aggregate(self):
        sections = self.matrix.keys()
        self.matrix['total'] = torch.zeros(self.n_classes, self.n_classes).to(device)
        for section_id in sections:
            self.matrix['total'] += self.matrix[section_id]
        
        if 'mean' in self.reduction:
            for key in self.matrix:
                self.matrix[key] = self.matrix[key] / self.cumul_matrices
                if self.reduction == 'norm_mean':
                    self.matrix[key] = self.matrix[key] / self.matrix[key].sum(dim = 1, keepdims = True)
        return self.matrix
    
    def reset(self):
        self.matrix = {}
        self.cumul_matrices = 0 

## Definindo Dataset e DataLoader para treino e validação

In [None]:
n_gpus = 1# if not multigpu else len(model.device_ids)

if smart_cache:
    train_ds = SmartCacheDataset(data = train_files, transform = train_transforms, cache_rate = dataset_cache_rate,
                                 replace_rate = dataset_replace_rate)
else:
    train_ds = CacheDataset(data = train_files, transform = train_transforms, cache_rate = dataset_cache_rate)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=16*n_gpus, shuffle = sample_weights is None,
                         sampler = WeightedRandomSampler(sample_weights, num_samples, replacement = True) if sample_weights is not None else None)

val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_rate=1.0)

val_loader = DataLoader(val_ds, batch_size=n_gpus)#,
                         #sampler = torch.utils.data.distributed.DistributedSampler(val_ds))#, num_workers=4)

train_loader.sampler

from monai.networks.nets import FlexibleUNet

UNet_metadata = dict(
    in_channels = (3 if pp_only else 6),
    out_channels = n_classes,
    backbone = 'efficientnet-b0',
    pretrained = True,
    spatial_dims = 2,
    decoder_channels = tuple(reversed((16, 32, 64, 128, 256))),
    norm = Norm.BATCH,
    act = 'prelu',
    decoder_bias = True
)
model = FlexibleUNet(**UNet_metadata).to(device)

from monai.networks.nets import UNETR

in_channels = (3 if pp_only else 6)
UNet_metadata = dict(
    in_channels = in_channels,
    out_channels = n_classes,
    img_size = crop_spatial_size,
    spatial_dims = 2,
    norm_name = Norm.BATCH,
)
model = UNETR(**UNet_metadata).to(device)

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
##device = torch.device("cuda:1")

#import process
import torch.multiprocessing as mp

#def gpu_process(rank, world_size):
#    os.environ['MASTER_ADDR'] = 'localhost'
#    os.environ['MASTER_PORT'] = '8891'
#    torch.distributed.init_process_group('nccl', world_size = world_size, rank = rank)

UNet_metadata = dict(
    spatial_dims=2,
    in_channels=(3 if pp_only else 6),
    out_channels=n_classes, # probabilidade de classe em cada canal
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),#
    #channels=(8, 16, 32, 64, 128, 256, 512, 1024),
    #strides=7*(2,),#
    num_res_units=2,
    norm=Norm.BATCH,
    #kernel_size = 9,
    #up_kernel_size = 9
)

#model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True).to(device)
#model.backbone.conv1 = torch.nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).to(device)
model = UNet(**UNet_metadata).to(device)

if multigpu:
    world_size = 4
    #print('Bf spw')
    mp.spawn(process.gpu_process, args = (world_size,), nprocs = world_size, join = True)
    #print('Af spw')
    #torch.distributed.init_process_group('nccl', world_size = 4, rank = 0)
    model = torch.nn.parallel.DistributedDataParallel(model)
    #print('Af dist')
    #model = torch.nn.parallel.DataParallel(model)
loss_function = DiceLoss(include_background=include_background['loss'], to_onehot_y=True, softmax=True)
loss_type = "DiceLoss"

#loss_function = DiceCELoss(include_background=include_background['loss'], to_onehot_y=True, softmax=True,
#                          ce_weight = torch.cat((torch.Tensor([0]), torch.Tensor(1/train_proportions))).to(device),
#                          #ce_weight = torch.Tensor([0, 1, 1, 0, 1, 1, 0]).to(device),
#                          lambda_dice = 0)
#loss_type = "DiceCELoss"

##experiment = 'cbn-focal-instnorm
#loss_function = FocalLoss(include_background=include_background['loss'], to_onehot_y=True,
#                         weight = torch.Tensor(1/train_proportions)).to(device)
#loss_type = "FocalLoss"
#experiment = 'cbn-dicefocal-fulldata-lessfocal'
#loss_function = DiceFocalLoss(include_background=include_background['loss'], to_onehot_y=True,
#                         focal_weight = torch.Tensor(1/train_proportions), lambda_focal = 0.02).to(device)
#loss_type = "DiceFocalLoss"
###experiment = 'cbn-tversky'
##loss_function = TverskyLoss(include_background=include_background['loss'], to_onehot_y=True, softmax=True)
##loss_type = "TverskyLoss"
####max_epochs = 10
#experiment = 'cbn-gendice-fulldata-slower'
#loss_function = GeneralizedDiceLoss(include_background=include_background['loss'], to_onehot_y=True, softmax=True)
#loss_type = "GeneralizedDiceLoss"
#dist_matrix = torch.Tensor(
#    [
#        [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
#        [1.0, 0.0, 0.02, 0.3, 1.0, 0.02, 0.9],
#        [1.0, 0.02, 0.0, 0.3, 1.0, 0.02, 0.9],
#        [1.0, 0.3, 0.3, 0.0, 1.0, 0.3, 0.9],
#        [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0],
#        [1.0, 0.02, 0.02, 0.3, 1.0, 0.0, 0.9],
#        [1.0, 0.9, 0.9, 0.9, 1.0, 0.9, 0.0]
#    ]
#).to(device)
#experiment = 'cbn-gwd-fulldata'
#loss_function = GeneralizedWassersteinDiceLoss(dist_matrix)
#loss_type = "GeneralizedWassersteinDiceLoss"

optimizer = torch.optim.Adam(model.parameters(), 1e-4)
#optimizer = torch.optim.SGD(model.parameters(), lr = 1e-4, momentum = 0.9, nesterov = True)
dice_metric = DiceMetric(include_background=include_background['metric'], reduction="mean")
confusion_metric = ConfusionMatrix(n_classes=n_classes, include_background=include_background['metric'], reduction='norm_mean')
proportion_metric = Proportion(n_classes=n_classes, include_background=include_background['metric'], argmax=True) #torch.nn.KLDivLoss(reduction = 'mean')

print(loss_type)

Optimizer_metadata = {}
for ind, param_group in enumerate(optimizer.param_groups):
    optim_meta_keys = list(param_group.keys())
    Optimizer_metadata[f'param_group_{ind}'] = {key: value for (key, value) in param_group.items() if 'params' not in key}

## Treino e validação em PyTorch

In [None]:
def plot(image, title):
    plt.title(title)
    plt.axis('off')
    plt.imshow(image)

def tensor_to_image(tensor):
    image = tensor.cpu().numpy()
    if tensor.ndim == 3:
        image = np.rollaxis(image, 0, 3)
    return image

def get_comparison(input_tensor, label_tensor, pred_tensor, confidence_map, step, show = False):
    if binary:
        #input_label = torch.where(label_tensor.to(input_tensor.device) == new_pore_label, (input_tensor + 1)/2, input_tensor)
        #input_pred  = torch.where(pred_tensor.to(input_tensor.device)  == 1,              (input_tensor + 1)/2, input_tensor)
        label_tensor = label_tensor.to(input_tensor.device)
        pred_tensor  = pred_tensor.to (input_tensor.device)
        
        true_positives = torch.zeros(input_tensor.shape)
        true_positives[1] = 1
        false_positives = torch.zeros(input_tensor.shape)
        false_positives[0] = 1
        false_negatives = torch.zeros(input_tensor.shape)
        false_negatives[:2] = 1
        
        input_label = torch.where(label_tensor.to(input_tensor.device) == new_pore_label, (input_tensor + true_positives)/2, input_tensor)
        input_pred  = torch.where((pred_tensor == 1) & (pred_tensor.to(input_tensor.device) == label_tensor.to(input_tensor.device)), (input_tensor + true_positives)/2, input_tensor)
        input_pred  = torch.where((pred_tensor == 1) & (pred_tensor.to(input_tensor.device) != label_tensor.to(input_tensor.device)), (input_tensor + false_positives)/2, input_pred)
        input_pred  = torch.where((pred_tensor == 0) & (pred_tensor.to(input_tensor.device) != label_tensor.to(input_tensor.device)), (input_tensor + false_negatives)/2, input_pred)
    else:
        #input_label = label_tensor
        #input_pred  =  pred_tensor
        input_label = (input_tensor + label_tensor)/2
        input_pred  = (input_tensor +  pred_tensor)/2
    input_false = (label_tensor != pred_tensor).to(torch.float)
        
    input_image = tensor_to_image(input_tensor)
    input_label = tensor_to_image(input_label)
    input_pred  = tensor_to_image(input_pred)
    input_conf  = tensor_to_image(confidence_map)
    input_false = tensor_to_image(input_false)
    
    #frames.append(plot(input_image, 'original'))
    #frames.append(plot(input_label, 'label'))
    #frames.append(plot(input_image, 'original'))
    #frames.append(plot(input_pred,  'prediction'))
    
    #gif.save(frames, 'foo.gif', duration = 2000)
    #Image.open('foo.gif').show()
    
    fig = plt.figure(figsize = (16, 16))
    plt.subplot(2, 2, 1), plt.title('ORIGINAL'),   plt.axis('off'), plt.imshow(input_image)
    plt.subplot(2, 2, 2), plt.title('LABEL'),      plt.axis('off'), plt.imshow(input_label)    
    plt.subplot(2, 2, 3), plt.title('PREDICTION'),   plt.axis('off'), plt.imshow(input_pred)
    plt.legend(
        handles = [
            plt.plot([], label = 'True positives', color = 'green')[0],
            plt.plot([], label = 'False positives', color = 'red')[0],
            plt.plot([], label = 'False negatives', color = 'yellow')[0]
        ],
        loc = 'best',
        fontsize = 'x-large'
    )
    plt.subplot(2, 2, 4), plt.title('CONFIDENCE'), plt.axis('off'), plt.imshow(input_conf, cmap = 'gray')
    #plt.subplot(2, 2, 4), plt.title('FALSES'), plt.axis('off'), plt.imshow(input_false, cmap = 'gray')
    
    fig.canvas.draw()
    output = np.array(fig.canvas.renderer.buffer_rgba())
    if show:
        #os.makedirs('images', exist_ok = True)
        #plt.savefig(os.path.join('images', str(step) + '.png'))
        plt.show()
    else:
        plt.close()
    
    return output

In [None]:
def test(test_loader, best_metric = -1, best_metric_epoch = -1, validation = False):
    def is_step_last_or_multiple_of(step, total_steps, multiple_of):
        return (step % multiple_of == 0) or step == total_steps
    
    if validation:
        print('Validating...')
        prefix = 'val_'
    else:
        print('Testing...')
        prefix = 'test_'
    
    post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=n_classes)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=n_classes)])
    
    model.eval()
    test_loss = 0
    step = 0
    time_secs = {'Sliding Window Inference': [], 'Aim Image Logging': [], 'Loss calculation': [], \
                 'Dice metric': [], 'Proportion metric': [], 'Confusion metric': [], \
                 'Dice aggregation': 0, 'Proportion aggregation': 0, 'Confusion aggregation': 0, \
                 'Loading': [], 'Plot': 0, 'Model saving': 0, 'Non-zero counting': [], 'To GPU': [], \
                 'Output and conf. map': [], 'RGB mapping': [], 'Metrics logging': 0, 'Total': 0}
    with torch.no_grad():
        time_secs['Total'] = time.time()
        time_secs['Loading'].append(time.time())
        for index, test_data in enumerate(tqdm(test_loader)):
            time_secs['Loading'][-1] = time.time() - time_secs['Loading'][-1]
            
            test_inputs, test_labels = (
                test_data["image"],##.to(device),
                test_data["label"],##.to(device),
            )
            
            if test_inputs.shape[0] > 1:
                test_inputs = test_inputs[:1, :, :, :]
                test_labels = test_labels[:1, :, :, :]
            
            time_secs['Non-zero counting'].append(time.time())
            background_rate = 1 - torch.count_nonzero(test_data["label"])/torch.numel(test_data["label"])
            time_secs['Non-zero counting'][-1] = time.time() - time_secs['Non-zero counting'][-1]
            if (not binary or isolate_background) and background_rate > test_background_rate_max:
                time_secs['Loading'].append(time.time())
                continue
            #test_inputs, test_labels = test_inputs[0].reshape([1, -1, 512, 512]), test_labels[0].reshape([1, -1, 512, 512])
            time_secs['To GPU'].append(time.time())
            if not multigpu:
                test_inputs = test_inputs.to(device)
                test_labels = test_labels.to(device)
            time_secs['To GPU'][-1] = time.time() - time_secs['To GPU'][-1]
            
            step += 1
            image_path = test_data['image_meta_dict']['filename_or_obj'][0]
            
            roi_size = crop_spatial_size
            sw_batch_size = 4
            time_secs['Sliding Window Inference'].append(time.time())
            test_outputs = sliding_window_inference(
                test_inputs, roi_size, sw_batch_size, model)
            time_secs['Sliding Window Inference'][-1] = time.time() - time_secs['Sliding Window Inference'][-1]

            # tracking input, label and output images with Aim
            time_secs['Output and conf. map'].append(time.time())
            output_first_channel = int(isolate_background) # if isolate_background, bg's probs (output's channel 0) are not considered
            output = torch.argmax(test_outputs[:, output_first_channel:], dim=1)[0].float()
            confidence_map = torch.max(test_outputs[:, output_first_channel:], dim=1)[0].float()
            time_secs['Output and conf. map'][-1] = time.time() - time_secs['Output and conf. map'][-1]
            
            if is_step_last_or_multiple_of(step, total_steps = len(test_loader), multiple_of = 50):
                time_secs['RGB mapping'].append(time.time())
                if binary:
                    test_labels_img = test_labels[0]/(n_classes - 1)
                    output_img = output/(n_classes - 1 - output_first_channel)
                else:
                    test_labels_img = generate_rgb_map(test_labels[0].cpu(), is_channel_first = True, as_tensor = True)
                    output += output_first_channel
                    #output = torch.where(
                    #    test_labels[0] == 0,
                    #    torch.zeros(output.shape, dtype = output.dtype).to(output.device),
                    #    output)
                    output_img = generate_rgb_map(output.cpu(), is_channel_first = output.ndim == 3, as_tensor = True)
                time_secs['RGB mapping'][-1] = time.time() - time_secs['RGB mapping'][-1]
                
                if not show_output_comparison:
                    time_secs['Aim Image Logging'].append(time.time())
                    aim_run.track(aim.Image(test_labels_img, \
                                            caption=f'Label Image: {index}'), \
                                   name='validation', context={'type':'label'})
                    aim_run.track(aim.Image(output_img, caption=f'Predicted Label: {index}'), \
                                   name = 'validation', context={'type':'prediction'})
                    aim_run.track(aim.Image(test_inputs[0, :3], \
                                            caption=f'Input Image: {index}'), \
                                   name='validation', context={'type':'input'})
                    if UNet_metadata['in_channels'] == 6:
                        aim_run.track(aim.Image(test_inputs[0, 3:], \
                                            caption=f'Input Image PX: {index}'), \
                                   name='validation', context={'type':'input_PX'})
                    aim_run.track(aim.Image(confidence_map/confidence_map.max(), \
                                            caption=f'Input Image: {index}'), \
                                   name='validation', context={'type':'confidence'})

                    if binary:
                        #diff_output_labels = output.to(torch.int8) - test_labels[0].to(torch.int8)
                        input_high_red      = test_inputs[0, :3].clone()
                        input_high_red[0]   = (input_high_red[0] + 1)/2
                        input_high_green    = test_inputs[0, :3].clone()
                        input_high_green[1] = (input_high_green[1] + 1)/2
                        input_high_blue     = test_inputs[0, :3].clone()
                        input_high_blue[2]  = (input_high_blue[2] + 1)/2

                        FN_highlighted       = torch.where(
                            (output ==  0).to(torch.bool) & (test_labels[0] == new_pore_label).to(torch.bool),
                            input_high_blue,  test_inputs[0, :3])
                        FN_TP_highlighted    = torch.where(
                            (output ==  1).to(torch.bool) & (test_labels[0] == new_pore_label).to(torch.bool),
                            input_high_green, FN_highlighted)
                        FN_TP_FP_highlighted = torch.where(
                            (output ==  1).to(torch.bool) & (test_labels[0] == new_non_pore_label).to(torch.bool),
                            input_high_red,   FN_TP_highlighted)

                        aim_run.track(aim.Image(FN_TP_FP_highlighted, \
                                                caption=f'FN (blue); TP (green); FP (red): {index}'), \
                                       name='validation', context={'type':'fn+tp+fp'})

                    time_secs['Aim Image Logging'][-1] = time.time() - time_secs['Aim Image Logging'][-1]
                else:
                    get_comparison(test_inputs[0, :3].cpu(), test_labels_img, output_img, confidence_map, step, show = True)
                    print(step, '===', image_path, 10 * '=')

            # validation loss
            time_secs['Loss calculation'].append(time.time())
            loss = loss_function(test_outputs, test_labels)
            test_loss += loss.item()
            time_secs['Loss calculation'][-1] = time.time() - time_secs['Loss calculation'][-1]

            test_outputs = [post_pred(i) for i in decollate_batch(test_outputs)]
            test_labels = [post_label(i) for i in decollate_batch(test_labels)]
            
            # compute metric for current iteration
            time_secs['Dice metric'].append(time.time())
            dice_metric(y_pred=test_outputs, y=test_labels)
            time_secs['Dice metric'][-1] = time.time() - time_secs['Dice metric'][-1]
            time_secs['Proportion metric'].append(time.time())
            proportion_metric(y_pred=test_outputs, y=test_labels, image_path=image_path)
            time_secs['Proportion metric'][-1] = time.time() - time_secs['Proportion metric'][-1]
            time_secs['Confusion metric'].append(time.time())
            confusion_metric(y_pred=test_outputs, y=test_labels, image_path=image_path)
            time_secs['Confusion metric'][-1] = time.time() - time_secs['Confusion metric'][-1]

            time_secs['Loading'].append(time.time())
        time_secs['Loading'] = time_secs['Loading'][:-1]
        
        # validation loss
        test_loss /= step
        
        #-## sparse: média dos erros calculados por imagem;
        #-## dense:  erro calculado para o conjunto de validação completo, como uma única grande imagem.
        #-##         Baseado no fato de que as proporções de cada elemento em uma imagem completa é a média das proporções
        #-##         em cada subimagem (quando todas as subimagens têm tamanhos iguais).
        #-#dense_proportions = {
        #-#    'val': test_proportions.mean(dim = 0),
        #-#    'out': output_proportions.mean(dim = 0)
        #-#}
        
        # aggregate the final mean dice result
        time_secs['Dice aggregation'] = time.time()
        metric = dice_metric.aggregate().item()
        time_secs['Dice aggregation'] = time.time() - time_secs['Dice aggregation']
        time_secs['Proportion aggregation'] = time.time()
        proportion = proportion_metric.aggregate()
        time_secs['Proportion aggregation'] = time.time() - time_secs['Proportion aggregation']
        time_secs['Confusion aggregation'] = time.time()
        confusion = confusion_metric.aggregate()
        time_secs['Confusion aggregation'] = time.time() - time_secs['Confusion aggregation']
        
        if not show_output_comparison:
            # track val
            aim_run.track(test_loss,  name=(prefix + "loss"),   context={'type':loss_type})
            aim_run.track(metric,     name=(prefix + "metric"), context={'type':loss_type})
            aim_run.track(proportion['total']['R2'], name="proportion_R2", context={'type':'Proportion'})
            aim_run.track(proportion['total']['RMSE'],  name="proportion_linear_RMSE",  context={'type':'Proportion'})

        # reset the status for next validation round
        dice_metric.reset()
        proportion_metric.reset()
        confusion_metric.reset()

        track_custom_metrics = False
        if validation:
            if metric > best_metric:
                time_secs['Model saving'] = time.time()
                track_custom_metrics = True
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_proportion = proportion
                best_confusion = confusion
                os.makedirs(os.path.join(project, 'models'), exist_ok = True)
                torch.save(model.state_dict(), os.path.join(
                    project, 'models', aim_run.name.split()[-1] + '.pth'))
                time_secs['Model saving'] = time.time() - time_secs['Model saving']
                
                best_model_log_message = f"saved new best metric model at the {epoch+1}th epoch"
                aim_run.track(aim.Text(best_model_log_message), name='best_model_log_message', epoch=epoch+1)
                print(best_model_log_message)

            message1 = f"current epoch: {epoch + 1} | " + prefix + f" loss: {test_loss:.4f} | current mean dice: {metric:.4f}"
            message2 = f"\nbest mean dice: {best_metric:.4f} "
            message3 = f"at epoch: {best_metric_epoch}"

            aim_run.track(aim.Text(message1 +"\n" + message2 + message3), name='epoch_summary', epoch=epoch+1)
        else:
            best_proportion = proportion
            best_confusion = confusion
            message1 = prefix + f" loss: {test_loss:.4f}"
            message2 = f"\nmean dice: {metric:.4f} "
            message3 = ''
        
        time_secs['Metrics logging'] = time.time()
        for section_id in proportion.keys():
            metrics_log = {
                'elements': elements,
                prefix + 'proportions': proportion[section_id]['true'].cpu().tolist(),
                'out_proportions': proportion[section_id]['pred'].cpu().tolist(),
                'color_hex': (['#000000'] if 'Desconhecido' in elements else []) + element_data['color_hex'].tolist()
            }
            
            for i, element in enumerate(elements):
                metrics_log['Pred. ' + element] = confusion[section_id][:, i].cpu().tolist()
            model_name = (aim_run.name if (model_to_load is None) else model_to_load).replace('.pth', '')
            if model_name.startswith('Run:'):
                model_name = model_name[5:]
            log_path = os.path.join(project, 'models_log', model_name, section_id)
            os.makedirs(log_path, exist_ok = True)
            pd.DataFrame(metrics_log).to_csv(os.path.join(log_path, 'log.csv'), index = False)
        time_secs['Metrics logging'] = time.time() - time_secs['Metrics logging']
        
        if aim_run is not None:
            aim_run.track(aim.Distribution(proportion['total']['true']), name=(prefix + 'proportion'), context={'type':'Proportion'})
            aim_run.track(aim.Distribution(proportion['total']['pred']), name='out_proportion', context={'type':'Proportion'})
            aim_run.track(aim.Image(confusion['total'], caption='Confusion Matrix'), name='confusion_matrix', context={'type':'ConfusionMatrix'})
        
        print(message1, message2, message3)
        time_secs['Plot'] = time.time()
        fig = plot_dense_proportions(elements, proportion['total'])
        #aim_run.track(aim.Figure(fig), name='proportions', context={'type':'Proportion'})
        time_secs['Plot'] = time.time() - time_secs['Plot']
        time_secs['Total'] = time.time() - time_secs['Total']
        
        print('Time:')
        for key in time_secs:
            if type(time_secs[key]) == list:
                time_to_show = np.sum(time_secs[key]).astype(int)
            else:
                time_to_show = time_secs[key]
            print('===', key + ':', time.strftime('%Hh%Mm%Ss', time.gmtime(time_to_show)), end = ' ')
        print()
        
        return best_metric, best_metric_epoch

test_images = ['5716.70', '5607.35', '5639.30']
test_files = []
for vf in val_files:
    if any(ti in vf['image'] for ti in test_images):
        test_files.append(vf)
        
test_loader = DataLoader(
    CacheDataset(data = test_files, transform = val_transforms, cache_rate = 1.0),
    batch_size = 1
)
test(test_loader)

In [None]:
aim_run = None
if model_to_load is None or not show_output_comparison:
    # initialize a new Aim Run
    #experiment = 'sections={train=' + train_sections + ',val=' + val_sections + '}' \
    #                  + '_usePolarizedLight=' + str(not pp_only) \
    #                  + '_PoreVsNonPore=' + str(binary)
    aim_run = aim.Run(experiment = experiment, repo = os.path.join(project, 'logging'))

train_time_secs = 0
best_metric = -1
best_metric_epoch = -1
if model_to_load is None:
    epoch_loss_values = []
    
    # log model metadata
    aim_run['UNet_metadata'] = UNet_metadata
    # log optimizer metadata
    aim_run['Optimizer_metadata'] = Optimizer_metadata

    aim_run['dataset'] = dict(dataset_log)
    
    aim_run['n_epochs'] = max_epochs
    aim_run['img_size'] = img_size
    aim_run['crop_spatial_size'] = crop_spatial_size
    aim_run['num_crops_per_img'] = num_crops_per_img
    aim_run['binary'] = binary
    aim_run['use_polarized_light'] = not pp_only
    aim_run['shrank'] = dict(shrank)
    aim_run['isolate_background'] = isolate_background
    aim_run['include_background'] = dict(include_background)
    aim_run['weight_samples'] = dict(weight_samples)
    aim_run['balance_by_oversampling'] = balance_by_oversampling

    log_train_transforms = log_transform(train_transforms)
    log_val_transforms   = log_transform(val_transforms)
    aim_run['train_transforms'] = dict(functions = str(list(log_train_transforms.keys())), params = str(log_train_transforms))
    aim_run['val_transforms']   = dict(functions = str(list(log_val_transforms.keys())),   params = str(log_val_transforms))

    for epoch in range(max_epochs):
        start_time_secs = time.time()
        #if epoch == 5:
        #    val_interval = 100
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        n_batches = len(train_ds) // train_loader.batch_size
        for batch_data in train_loader:
            step += 1
            inputs, labels = (
                batch_data["image"],#.to(device),
                batch_data["label"],#.to(device),
            )
            
            #plt.imshow(generate_rgb_map(labels[0], is_channel_first = True))
            #plt.show()
            
            if not multigpu:
                inputs = inputs.to(device)
                labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels.to(outputs.device))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            if step % (n_batches//10) == 0:
                print(f"{step}/{n_batches}, "
                    f"train_loss: {loss.item():.4f}")
            # track batch loss metric
            aim_run.track(loss.item(), name="batch_loss", context={'type':loss_type})

        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)

        # track epoch loss metric
        aim_run.track(epoch_loss, name="epoch_loss", context={'type':loss_type})

        epoch_time_secs  = time.time() - start_time_secs
        train_time_secs += epoch_time_secs
        
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        print('=== Epoch time:', time.strftime('%Hh%Mm%Ss', time.gmtime(epoch_time_secs)), end = ' ')
        print('=== Total time:', time.strftime('%Hh%Mm%Ss', time.gmtime(train_time_secs)), end = ' ')
        print('=== ETA:', time.strftime('%Hh%Mm%Ss', time.gmtime((max_epochs - (epoch + 1)) * epoch_time_secs)))

        if val_interval is not None:
            if (epoch + 1) % val_interval == 0:
                if (epoch + 1) % (val_interval * 2) == 0:
                    print('Tracking parameters and gradients...')
                    # track model params and gradients
                    track_params_dists(model,aim_run)
                    # THIS SEGMENT TAKES RELATIVELY LONG (Advise Against it)
                    track_gradients_dists(model, aim_run)

                best_metric, best_metric_epoch = test(
                    test_loader = val_loader, best_metric = best_metric, best_metric_epoch = best_metric_epoch,
                    validation = True
                )
                
    if val_interval is None:
        if aim_run is not None:
            model_name = aim_run.name.split()[-1]
        else:
            from datetime import datetime
            model_name = experiment + '_' + str(datetime.now())
            
        torch.save(model.state_dict(), os.path.join(project, 'models', model_name + '.pth'))

else:
    model.load_state_dict(
        torch.load(os.path.join(project, 'models', model_to_load), map_location = device)
    )
    if not show_output_comparison:
        aim_run['model'] = model_to_load
    test(val_loader)

In [None]:
if aim_run is not None:
    # finalize Aim Run
    aim_run.close()

## Aim UI diretamente no notebook

%load_ext aim
%aim up