In [None]:
import os
import re
import cv2
import glob
import nrrd
import shutil
import datetime
import warnings
import numpy as np
import pandas as pd
import nibabel as nib
import tifffile as tif
import matplotlib.pyplot as plt

from tqdm import tqdm
from netCDF4 import Dataset
from skimage.transform import resize as sk_resize
from matplotlib.colors import Normalize
from numpy.lib.stride_tricks import sliding_window_view

In [None]:
def plot(title, image):
    plt.title(title)
    plt.imshow(image)
    plt.show()

In [None]:
def gen_nc_file_name(group, value):
    if group == 'Siliciclastics':
        return value + '.nc'
    return group + '_' + value + '.nc'

In [None]:
def remap(image, unified_labels, type):
    rgb = np.stack((image, image, image), axis = 2)
    for i in unified_labels[unified_labels['type'] == type].index:
        remap_indexes = np.where(image == i)
        color = unified_labels.loc[i, 'color_hex']
        rgb[remap_indexes] = np.array([[int(color[1:3], 16), int(color[3:5], 16), int(color[5: ], 16)]])/255.0
    
    return rgb

def imshow_mini(image, axis, title, unified_labels, remap_colors = False, cmap = None):
    axis.set_title(title)
    axis.set_xticks([])
    axis.set_yticks([])
    axis.imshow(cv2.resize(image if not remap_colors else remap(image, unified_labels, type = 'ec' if 'E.C.' in title else 'qemscan'), (0, 0), fx = 1/10, fy = 1/10), cmap = cmap)

In [None]:
def plot_proportions(proportions, unified_labels, title = ''):
    os.makedirs('proportions', exist_ok = True)
    unified_minerals = unified_labels[unified_labels['type'] == 'qemscan'] 
    
    for element in proportions.index:
        bar_color = unified_minerals[unified_minerals['Element'] == element]['color_hex'].values
        plt.bar(x = element, height = proportions[element],
                color = bar_color if element != 'Desconhecido' else 'black',
               edgecolor = None if bar_color != '#ffffff' else 'black')
    plt.title(title)
    plt.xticks(rotation = 'vertical')
    plt.ylabel('Proportion')
    plt.savefig(os.path.join('proportions', title + '.png'))
    plt.show()

In [None]:
def get_ec_color(ec_name):
    color_hex = '#'
    for x in np.random.randint(256, size = 3):
        vhex = hex(x)[2:]
        if len(vhex) == 1:
            vhex = '0' + vhex
        color_hex += vhex
    return color_hex

In [None]:
def get_ec_label(key):
    final_label = ''
    words = key.split()
    for i, word in enumerate(words):
        if word.endswith('s'):
            word = word[:-1]
        final_label += word
        if i < len(words) - 1:
            final_label += ' '
    
    words = final_label.split()
    if len(words) > 1 and words[1] == '-':
        final_label = words[-1]
    
    final_label = final_label.replace('atico', 'ático').replace('oide', 'óide').replace('icula', 'ícula') 
    final_label = final_label.capitalize()
    if re.search('_[0-9]*$', final_label):
        final_label = final_label[:final_label.rfind('_')]
    return final_label

def correct(element_name, is_ec = False):
    if element_name == '�xido de Tit�nio':
        return 'Óxido de Titânio'
    if element_name == 'Zirc�o':
        return 'Zircão'
    return get_ec_label(element_name) if is_ec else element_name

def unify_labels(include_qemscan, include_ec, data_dir, groups, from_nc, unified_labels_file, initial_labels = None):
    pore_new_color_hex = '#636363'
    
    elements = {'Element': [], 'type': [], 'color_hex': []}
    if from_nc:
        for group in groups:
            for depth in groups[group]:
                nc_path = os.path.join(data_dir, group, gen_nc_file_name(group, depth))
                print('Label unification: reading from', nc_path)
                data = Dataset(nc_path, 'r')
                
                for key in data.variables.keys():
                    if any(key.startswith(coord) for coord in ['c_', 'x_', 'y_', 'z_']):
                        continue
                    is_qemscan = ('QEMSCAN' in key) or (('Q' in key) and ('Transformed' in key))
                    is_ec = all(light not in key for light in ['PP', 'PX']) \
                        and (key not in ['c', 'x', 'y', 'z']) and (not key.startswith('SOI'))\
                        and not is_qemscan
 
                    if (include_qemscan and is_qemscan):
                        segmented_data = data.variables[key]
                        for color_info in segmented_data.labels[1:]:
                            color_info = color_info.split(',')
                            elements['Element'].append(correct(color_info[0]))
                            elements['type'].append('qemscan')
                            if color_info == pore_new_color_hex:
                                raise Exception('Element' + color_info[0] + 'is represented by the color hex' + color_info[2] + \
                                                ', which would be used to replace pore\'s zero color hex.')
                            if color_info[2] == '#000000':
                                color_info[2] = pore_new_color_hex
                            elements['color_hex'].append(color_info[2])
                            #elements['R'].append(int(color_info[2][1:3], 16))
                            #elements['G'].append(int(color_info[2][3:5], 16))
                            #elements['B'].append(int(color_info[2][5: ], 16))
                
                    elif (include_ec and is_ec):
                        segmented_data = data.variables[key]
                        
                        ec_name = get_ec_label(key)
                        if ec_name not in elements['Element']:
                            elements['Element'].append(ec_name)
                            elements['type'].append('ec')
                            elements['color_hex'].append(get_ec_color(ec_name))  

                data.close()
    else:
        elements['Element'].append('Poros')    
        elements['color_hex'].append(pore_new_color_hex)    
    
    unified_labels = pd.DataFrame(elements)
    unified_labels = unified_labels.drop_duplicates()
        
    unified_labels = unified_labels.sort_values(by = ['type', 'Element'])
    
    indexes = []
    for label_type in unified_labels['type'].unique():
        if initial_labels is None:
            initial_labels = {}
        
        if label_type not in initial_labels:
            initial_labels[label_type] = 1

        start_index = initial_labels[label_type]
        indexes += list(range(start_index, unified_labels[unified_labels['type'] == label_type].shape[0] + start_index))
            
    unified_labels.set_index([indexes], inplace = True)

    print(unified_labels)
    
    assert not unified_labels['Element'].duplicated().any(),\
        'Element ' + unified_labels['Element'][unified_labels['Element'].duplicated()] + ' appears more than once.'
    assert not unified_labels['color_hex'].duplicated().any(),\
        'Color ' + unified_labels['color_hex'][unified_labels['color_hex'].duplicated()] +  ' appears more than once.'
    
    unified_labels.to_csv(unified_labels_file)
    
    print('Labels unified successfully!')
        
def convert_labels_to_unified(segmented_data, unified_labels, element = 'all', is_ec = False, split_instances = False):
    label = np.ma.getdata(segmented_data[0])
    adapted_label = np.zeros(label.shape)
    
    if element == 'all':
        for color_info in segmented_data.labels[1:]:
            element, elem_label = color_info.split(',')[:2]

            adapted_label[np.where(label == int(elem_label))] = \
                unified_labels[unified_labels['Element'] == correct(element, is_ec)].index
    else:
        ### bounding-box
        #for color_info in segmented_data.labels[1:]:
        #    _, elem_label = color_info.split(',')[:2]
        #    
        #    elem_loc = np.where(label == int(elem_label))
        #    ymin, ymax = min(elem_loc[0]), max(elem_loc[0])
        #    xmin, xmax = min(elem_loc[1]), max(elem_loc[1])
        #    adapted_label[ymin:ymax+1, xmin:xmax+1] = \
        #        unified_labels[unified_labels['Element'] == correct(element, is_ec)].index
        
        ### default
        if not split_instances:
            adapted_label = np.squeeze(adapted_label)
            
            adapted_label[np.where(label != 0)] = \
                unified_labels[unified_labels['Element'] == correct(element, is_ec)].index
        else:
            return label
        
    return adapted_label


def calculate_minerals_proportions(qemscan, unified_labels):
    unified_minerals = unified_labels[unified_labels['type'] == 'qemscan']
    qemscan_proportions = {
        'Desconhecido': qemscan[qemscan == 0].size / qemscan.size
    }
    
    for i, datarow in unified_minerals.iterrows():
        qemscan_proportions[datarow['Element']] = qemscan[qemscan == i].size / qemscan.size
    
    return qemscan_proportions

def calculate_general_proportions(proportions_list, weights = None):
    proportions_table = pd.DataFrame(proportions_list)
    if weights is None:
        return proportions_table.mean()
    
    result = pd.Series(0, index = proportions_table.columns)
    for weight, proportion in zip(weights, proportions_list):
        result += weight * proportion
    return result / np.sum(weights)

In [None]:
# single_label means "each segment belongs to the same class"
# Examples:
# * QEMSCAN is a unique variable with 1 segment for each mineral phase. So, each segment belongs to a different class and single_label = False
# * An EC class is full of segments, but them all belongs to such class. So, single_label = True 
def process_labels(data, variable_name, unified_labels, single_label = False, is_ec = False, split_instances = False):
    segmented_data = data.variables[variable_name]
    label = convert_labels_to_unified(segmented_data, unified_labels, element = variable_name if single_label else 'all',
                                      is_ec = is_ec, split_instances = split_instances)
    
    return label

In [None]:
def load_and_preprocess(group, depth, data_dir, from_nc, bg_thresh, show, unified_labels, missing_nodes_allowed = None, crop_soi_area = True, split_ec_instances = False):
    if from_nc:
        nc_path = os.path.join(data_dir, group, gen_nc_file_name(group, depth))
        
        print('Opening data file', nc_path)
        data = Dataset(nc_path, 'r')

        image   = None
        qemscan = None
        ec      = None
        pp_variable_name  = None
        px_variable_name  = None
        qs_variable_name  = None
        soi_variable_name = None
        ec_variable_names = set()
        for key in data.variables.keys():
            if any(key.startswith(coord) for coord in ['c_', 'x_', 'y_', 'z_']):
                continue
            if 'PP' in key:
                pp_variable_name = key
            elif 'PX' in key:
                px_variable_name = key
            elif ('QEMSCAN' in key) or (('Q' in key) and ('Transformed' in key)):
                qs_variable_name = key
            elif key.startswith('SOI'):
                soi_variable_name = key
            elif key not in ['c', 'x', 'y', 'z']:
                ec_variable_names.add(key)
 
        has_pp      = pp_variable_name  is not None
        has_px      = px_variable_name  is not None
        has_qemscan = qs_variable_name  is not None
        has_soi     = soi_variable_name is not None
        has_ec      = len(ec_variable_names) > 0
        int_has_pp      = int(has_pp)
        int_has_px      = int(has_px)
        int_has_qemscan = int(has_qemscan)
        int_has_ec      = int(has_ec)
        
        for node, has_node in zip([has_pp, has_px, has_qemscan, has_ec], ['pp', 'px', 'qemscan', 'ec']):
            assert has_node or node in missing_nodes_allowed,\
                os.path.basename(nc_path) + ' has no ' + node.upper() + ' node. Repair it or include \'' + node + '\' in missing_nodes_allowed.'
    
        # As variáveis são masked_array. A função np.getdata() obtém apenas os dados no formato numpy, sem máscara
        if has_pp:
            image = np.ma.getdata(data.variables[pp_variable_name][0])
        
        if has_px:
            if image is not None:
                image = np.concatenate(
                    (
                        image,
                        np.ma.getdata(data.variables[px_variable_name][0])
                    ),
                    axis = 2
                )
            else:
                image = np.ma.getdata(data.variables[px_variable_name][0])
        
        if has_qemscan:
            qemscan = process_labels(data, qs_variable_name, unified_labels)
                
        if has_soi:
            soi = np.ma.getdata(data.variables[soi_variable_name][0])
        else:
            soi = np.ones(image.shape[:2] if image is not None else qemscan.shape)
    
        if has_ec:
            ec = None
            ec_variable_names = list(ec_variable_names)
            ec_variable_names.sort()
            n_ec_classes = unified_labels[unified_labels['type'] == 'ec'].shape[0]#count()
            for ec_label in ec_variable_names:
                eseg = process_labels(data, ec_label, unified_labels, single_label = True, is_ec = True, split_instances = split_ec_instances)
                if not split_ec_instances:
                    if ec is None:
                        ec = ec = np.zeros((*eseg.shape[:2],))
                    ec[eseg != 0] = eseg[eseg != 0]
                else:
                    ec_index = unified_labels[unified_labels['Element'] == correct(ec_label, is_ec = True)].index[0] - 1
                    if ec is None:
                        ec = np.zeros((*eseg.shape[:2], n_ec_classes))
                    ec[:, :, ec_index] = eseg
    
        data.close()
    else:
        arrays_paths = glob.glob(os.path.join(data_dir, group, depth, '*.nrrd'))
        print('Opening pore data at depth', depth)
        
        for array_path in arrays_paths:
            filename = os.path.basename(array_path)
            if filename.startswith('BUZ'):
                image = np.moveaxis(nrrd.read(array_path)[0], 0, 2)[:, :, :, 0]
            elif filename.startswith('SOI'):
                soi = nrrd.read(array_path)[0][:, :, 0]
            elif filename.startswith('Seg') and 'LabelMap' not in filename:
                label = (~(nrrd.read(array_path)[0] - 1).astype(bool)).astype(np.uint8)[:, :, 0]

    channels_per_light = 3
    has_image = has_pp or has_px
    
    mask = soi.astype(bool)

    if show:
        axes = plt.subplots(2, 1 + int_has_pp + int_has_px + int_has_qemscan + int_has_ec)[1]
        col = 0
        first_exib_channel = 0
        
        if has_pp:
            imshow_mini(image[:, :, first_exib_channel:first_exib_channel+channels_per_light], axes[0][col], 'PP', unified_labels)
            first_exib_channel += channels_per_light
            col += 1
        if has_px:
            imshow_mini(image[:, :, first_exib_channel:first_exib_channel+channels_per_light], axes[0][col], 'PX', unified_labels)
            col += 1
        if has_qemscan: 
            imshow_mini(qemscan, axes[0][col], 'QEMSCAN', unified_labels, remap_colors = from_nc)
            col += 1
        if has_ec: 
            imshow_mini(ec, axes[0][col], 'E.C.', unified_labels, remap_colors = from_nc)
            col += 1
        imshow_mini(soi, axes[0][col], 'SOI', unified_labels, cmap = 'gray')

    if crop_soi_area:
        nonzero_rows, nonzero_cols = np.where(mask != 0)
        row_min, row_max, col_min, col_max = \
            nonzero_rows.min(), nonzero_rows.max(), nonzero_cols.min(), nonzero_cols.max()

        if has_image:
            image[np.where(mask == 0)] = 0
            image = image[row_min:row_max+1, col_min:col_max+1]
        if has_qemscan:
            qemscan[np.where(mask == 0)] = 0
            qemscan = qemscan[row_min:row_max+1, col_min:col_max+1]
        soi = soi[row_min:row_max+1, col_min:col_max+1]
    
    if show:
        col = 0
        first_exib_channel = 0
        
        if has_pp:
            imshow_mini(image[:, :, first_exib_channel:first_exib_channel+channels_per_light], axes[1][col], 'useful area:\nPP', unified_labels)
            first_exib_channel += channels_per_light
            col += 1
        if has_px:
            imshow_mini(image[:, :, first_exib_channel:first_exib_channel+channels_per_light], axes[1][col], 'useful area:\nPX', unified_labels)
            col += 1
        if has_qemscan: 
            imshow_mini(qemscan, axes[1][col], 'useful area:\nQEMSCAN', unified_labels, remap_colors = from_nc)
            col += 1
        if has_ec: 
            imshow_mini(ec, axes[1][col], 'useful area:\nE.C.', unified_labels, remap_colors = from_nc)
            col += 1
        axes[1][col].axis('off')
    
    plt.show()
    
    output = []
    for array in [image, qemscan, ec, soi]:
        if array is not None:
            array = array.astype(np.uint8 if array.max() <= 255 else np.uint16)
        output.append(array)
    
    return tuple(output)

In [None]:
def write(image, output_dir, extension, index, channel_first, as_volume, final_size, is_segment, prefix = None, suffix = None,
         compact_rgb = False):
    orig_shape = image.shape
    n_channels = image.shape[-1] if image.ndim == 3 else 1
    
    if channel_first:
        if image.ndim == 2:
            image = image.reshape(1, image.shape[0], image.shape[1])
        else:
            image = np.rollaxis(image, image.ndim - 1)
        
    
    if final_size is not None:
        # if is_segment: interpolate by nearest neighbors (order 0); else: interpolate by smoothing (order 1)
        image = sk_resize(image, (final_size, final_size), preserve_range = True, order = int(not is_segment)).astype(np.uint16)
        
    if n_channels == 3:
        if compact_rgb:
            image = image.view(dtype = [('R', 'u1'), ('G', 'u1'), ('B', 'u1')])
        if as_volume:
            image = image.reshape(*image.shape, 1)
            
    if index == 0 and orig_shape != image.shape:
        print('** Shape transformed from', orig_shape, 'to', image.shape, '**')
    
    if prefix is None:
        prefix = ''
    if suffix is None:
        suffix = ''
    
    output_path = os.path.join(output_dir, prefix + '{:04d}'.format(index) + suffix + '.' + extension)
    if extension == 'nii.gz':
        nib.save(nib.Nifti1Image(image, affine = None), output_path)
    elif extension == 'tif':
        tif.imwrite(output_path, image)
    elif extension == 'png':
        try:
            cv2.imwrite(output_path, image if n_channels == 1 else cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
        except:
            output_path = os.path.join(output_dir, prefix + '{:04d}'.format(index) + suffix + '.' + 'seg.nrrd')
            header = None
            if n_channels == 3:
                channel_dim = 0 if channel_first else 2
                header = {'kinds': image.ndim * ['domain']}
                header['kinds'][channel_dim] = 'RGB-color'

            nrrd.write(output_path, image, header = header)
    else:
        header = None
        if n_channels == 3:
            channel_dim = 0 if channel_first else 2
            header = {'kinds': image.ndim * ['domain']}
            header['kinds'][channel_dim] = 'RGB-color'
            
        nrrd.write(output_path, image, header = header)

In [None]:
def mark_void_pixels(data, label):
    print('Computing void coords...')
    data_voids   = np.where((data == 0).all(axis = 2))
    label_voids  = np.where(label == 0)
    common_voids = np.where((data == 0).all(axis = 2) & (label == 0))
    
    print('Zipping void coords...')
    data_voids   = list(zip(list(data_voids[0]), list(data_voids[1])))
    label_voids  = list(zip(list(label_voids[0]), list(label_voids[1])))
    common_voids = list(zip(list(common_voids[0]), list(common_voids[1])))
    
    #print('Removing duplicates...')
    #data_voids  = [coord for coord in data_voids  if coord not in common_voids]
    #label_voids = [coord for coord in label_voids if coord not in common_voids]

    print('Creating images copies...')
    marked_data = data.copy()
    marked_label = label.copy()
    print('Marking...')
    for coords, color in zip([data_voids, label_voids, common_voids], [(255, 0, 0), (0, 255, 0), (255, 255, 0)]):
        print(coords, color)
        for i, coord in enumerate(coords):
            print('\t', 100*(i+1)/len(coords), '%')
            marked_data  = cv2.circle(marked_data,  coord, 1000, color, 3)
            marked_label = cv2.circle(marked_label, coord, 1000, color, 3)
    
    ax = plt.subplots(1, 2)[1]
    plt.suptitle('Red: data voids\nGreen: label voids\nYellow: common voids')
    imshow_mini(marked_data[:, :, :3],  ax[0], '')
    imshow_mini(marked_label,           ax[1], '')
    plt.show()

In [None]:
# Some instances appear as one main island and some residual islands around. Getting only the greatest because
# otherwise each island would be written in YOLO format as a different instance.
def get_greatest_contour(mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    if len(contours) == 1:
        return contours[0]
    max_size = 0
    contour = None
    for c in contours:
        if c.shape[0] > max_size:
            max_size = c.shape[0]
            contour = c
    return contour

def save_anot_yolo_format(chunk, output_dir, index, prefix = None, suffix = None, save_bbox = False, save_seg = False):
    if prefix is None:
        prefix = ''
    if suffix is None:
        suffix = ''
    
    out_filename = prefix + '{:04d}'.format(index) + suffix + '.' + 'txt'
    bbox_dir = os.path.join(os.path.dirname(output_dir), 'ec_yolo_bboxes')
    seg_dir  = os.path.join(os.path.dirname(output_dir), 'ec_yolo_seg')
    output_path = {
        'bbox': os.path.join(bbox_dir, out_filename),
        'seg' : os.path.join(seg_dir,  out_filename)
    }
    
    if save_bbox:
        os.makedirs(bbox_dir, exist_ok = True)
        bbox_file = open(output_path['bbox'], 'w')
    if save_seg:
        os.makedirs(seg_dir,  exist_ok = True)
        seg_file = open(output_path['seg'], 'w')

    for i in range(chunk.shape[-1]):
        # Find contours in the mask image
        class_masks = chunk[:,:,i]#.astype(np.uint8)
        instances = np.unique(class_masks)
        for instance in tqdm(instances):
            if instance == 0:
                continue

            mask = np.where(class_masks == instance, 1, 0).astype(np.uint8)
            
            # Some instances appear as one main island and some residual islands around. Getting only the greatest because
            # otherwise each island would be written in YOLO format as a different instance.
            contour = get_greatest_contour(mask)

            # Write bounding boxes to file in YOLO format
            if save_bbox:
                x, y, w, h = cv2.boundingRect(contour)
                bbox_file.write('{:d} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format(i, (x+w/2)/chunk.shape[1], (y+h/2)/chunk.shape[0], w/chunk.shape[1], h/chunk.shape[0]))
            if save_seg:
                segment_yolo = (np.squeeze(contour)/chunk.shape[:2][::-1]).ravel().astype('str')
                seg_file.write('{:d} '.format(i)+' '.join(segment_yolo)+'\n')
    
    if save_bbox:
        bbox_file.close()
    if save_seg:
        seg_file.close()

In [None]:
def gen_dataset(groups, data_dir, from_nc = True, do_qemscan_unification = False, do_ec_unification = False,
                unified_labels_file = 'unified_labels.csv', initial_labels = None, ds_image_size = 32,
                extension = 'nii.gz', channel_first = False, preserve_channels = False, as_volume = False, show = False,
                save_nodes = None, calc_props = True, bg_thresh = 10, shrank = False, max_zero_rate_thresh = 1,
                missing_nodes_allowed = None, final_size = None, single_output_dir = False, save_randomized = False,
               compact_rgb = False, crop_soi_area = True, split_ec_instances = False, yolo=False, yolo_seg=False, occlusion_percentage=None):
    accepted_extensions = ['nii.gz', 'tif', 'png', 'seg.nrrd', 'nrrd']
    accepted_nodes = ['pp', 'px', 'qemscan', 'ec']
    
    #ds_image_size in format (x, y) (i.e., index 0 is x)
    
    save = save_nodes is not None and len(save_nodes) > 0
    full_size = ds_image_size is None
    if save:
        if type(extension) == str:
            extension = (extension,) * len(save_nodes)
        assert all(e in accepted_extensions for e in extension),   'Supported formats: ' + str(accepted_extensions)
        assert all(node in accepted_nodes for node in save_nodes), 'Compatible nodes: '  + str(accepted_nodes)
    assert do_qemscan_unification or do_ec_unification or show or save or calc_props, 'Nothing to do.'
    assert full_size or type(ds_image_size) == int or len(ds_image_size) == 2
    assert not show or not split_ec_instances, 'For now, splitted ECs cannot be exhibited. Use split_ec_instances = False when using show = True.'
    assert occlusion_percentage is None or type(occlusion_percentage) in [int, float, list] or type(occlusion_percentage) 
    
    ds_path = os.path.join(data_dir, 'generated')
    os.makedirs(ds_path, exist_ok = True)
    
    if do_qemscan_unification or do_ec_unification:
        unify_labels(do_qemscan_unification, do_ec_unification, data_dir, groups, from_nc, unified_labels_file, initial_labels)
    unified_labels = pd.read_csv(unified_labels_file, index_col = 0)
    
    if occlusion_percentage is not None:
        n_ec_classes = unified_labels[unified_labels['type'] == 'ec'].shape[0]
        if type(occlusion_percentage) in [int, float]:
            occlusion_percentage = n_ec_classes * [occlusion_percentage]
        assert len(occlusion_percentage) == n_ec_classes, f'occlusion_percentage must be None, a single value or a list containing ' + \
            f'one value per EC class. Found {n_ec_classes} EC classes but {len(occlusion_percentage)} occlusion percentages.'
    
    dataset_proportions = []
    proportion_weights = []
    
    if type(ds_image_size) == int:
        ds_image_size = (ds_image_size, ds_image_size)
    
    size_infix = str(ds_image_size[0]) + 'x' + str(ds_image_size[1]) if not full_size else 'WxH'
    final_size_suffix = ('_as_' + str(final_size) + 'x' + str(final_size)) if final_size is not None else ''
    channel_suffix = '_cf' if channel_first else ''
    channel_suffix += '_ch' if preserve_channels else ''
    vol_suffix = '_vol' if as_volume else ''
    shrink_suffix = '_shrank' if shrank else ''
    datetime_suffix = datetime.datetime.now().strftime("_%d.%m.%Y-%H.%M.%S") if single_output_dir else ''
    
    delete_output_dir = save
    for group in groups:
        for depth in groups[group]:
            data, qemscan, ec, soi = load_and_preprocess(group, depth, data_dir, from_nc, bg_thresh, show, unified_labels,
                                                         missing_nodes_allowed, crop_soi_area, split_ec_instances)
            
            if not save and not calc_props:
                continue

            h, w = soi.shape
            if full_size:
                ds_image_size = (w, h)
                
            if not single_output_dir:
                general_output_dir = os.path.join(ds_path, group, depth, size_infix \
                                              + final_size_suffix + '_' + str(extension) + channel_suffix + vol_suffix + shrink_suffix)
            else:
                general_output_dir = os.path.join(ds_path, size_infix \
                                              + final_size_suffix + '_' + str(extension) + channel_suffix + vol_suffix + shrink_suffix \
                                              + datetime_suffix)
            
            if delete_output_dir:
                shutil.rmtree(general_output_dir, ignore_errors = True)
                delete_output_dir = not single_output_dir

            minerals_proportions = []
            file_prefixes = {}
            file_suffixes = {}
            view = {}
            random_prefixes = np.arange(100000) if save_randomized else None
            
            chunk_indexes = {
                'y': np.arange(0, h - ds_image_size[1], ds_image_size[1]) if not full_size else [0],
                'x': np.arange(0, w - ds_image_size[0], ds_image_size[0]) if not full_size else [0],
            }
            n_chunks_x = (w//ds_image_size[0])
            n_chunks_y = (h//ds_image_size[1])
            n_chunks = n_chunks_x * n_chunks_y
            crop_w = ds_image_size[0] * n_chunks_x
            crop_h = ds_image_size[1] * n_chunks_y
            
            soi = soi[(h-crop_h)//2 : (h+crop_h)//2, (w-crop_w)//2 : (w+crop_w)//2]
            
            sets = []
            i_ext = 0
            print('*** set_type: central shape [discarding boards not suitable in image.shape/ds_image_size] --> (N, H, W, C[, D]) ***')
            for set_type, image in [('data', data), ('qemscan', qemscan), ('ec', ec)]:
                if image is None or (set_type != 'data' and set_type not in save_nodes) or \
                    (set_type == 'data' and 'pp' not in save_nodes and 'px' not in save_nodes):
                    continue

                output_dir = os.path.join(general_output_dir, set_type)
                os.makedirs(output_dir, exist_ok = True)

                image = image[(h-crop_h)//2 : (h+crop_h)//2, (w-crop_w)//2 : (w+crop_w)//2]
                
                if show:
                    view[set_type] = image.copy()
                print(set_type + ':', image.shape, end = ' --> ')

                window_shape = (ds_image_size[1], ds_image_size[0])
                if image.ndim == 3:
                    window_shape += (image.shape[-1],)

                print((n_chunks,) + window_shape)
                
                sets.append((set_type, image, output_dir, extension[i_ext]))
                i_ext += 1

            i_chunk = 0
            discarded = 0
            for yi in chunk_indexes['y']:
                for xi in chunk_indexes['x']:
                    i_chunk += 1
                    if n_chunks >= 10 and (i_chunk % (n_chunks//10)) == 0:
                        print(i_chunk, '/', n_chunks, '(' + str(round(100*i_chunk/n_chunks)) + '%)')

                    chunk_soi = soi[yi : yi + ds_image_size[1], xi : xi + ds_image_size[0]]
                    zero_rate = 1 - np.count_nonzero(chunk_soi)/chunk_soi.size

                    if zero_rate <= max_zero_rate_thresh:
                        if save:
                            prefix = None
                            suffix  = None
                            if save_randomized:
                                prefix = np.random.choice(random_prefixes)
                                random_prefixes = np.delete(random_prefixes, np.where(random_prefixes == prefix))
                                prefix = str(prefix) + '.'
                            if single_output_dir:
                                suffix = '_' + group + '_' + str(depth)
                            
                            for set_type, image, output_dir, set_extension in sets:
                                chunk = image[yi : yi + ds_image_size[1], xi : xi + ds_image_size[0]]
                                
                                if occlusion_percentage and set_type == 'ec':
                                    for i in range(image.shape[-1]):                                        
                                        border_pixels = np.concatenate((chunk[0,                   0 : ds_image_size[0]-1, i], 
                                                                        chunk[ds_image_size[1]-1,     0 : ds_image_size[0]-1, i],
                                                                        chunk[0 : ds_image_size[1]-1, 0,                   i],
                                                                        chunk[0 : ds_image_size[1]-1, ds_image_size[0]-1,     i]))
                                        cropped_elements = np.unique(border_pixels)[1:]
                                        
                                        for j in cropped_elements:
                                            partial = (chunk[:,:,i]==j).sum()
                                            total = (image[:,:,i]==j).sum()
                                            if partial/total < occlusion_percentage[i]:
                                                chunk[chunk[:,:,i]==j] = 0
                                        
                                        # for j in cropped_elements:
                                        #     chunk[:,:,i][chunk[:,:,i]==j] = 0

                                if set_type == 'qemscan' and calc_props:
                                    minerals_proportions.append(calculate_minerals_proportions(chunk, unified_labels))
                                
                                if set_type == 'data' and chunk.shape[-1] > 3:
                                    if 'pp' not in save_nodes:
                                        chunk = chunk[:, :, 3:]
                                    if 'px' not in save_nodes:
                                        chunk = chunk[:, :, :3]
                                
                                if set_type == 'ec':
                                    if yolo or yolo_seg:
                                        save_anot_yolo_format(chunk, output_dir, i_chunk, prefix = prefix, suffix = suffix,
                                                             save_bbox = yolo, save_seg = yolo_seg)

                                write(chunk, output_dir, set_extension, i_chunk, channel_first, as_volume, final_size,
                                      is_segment = (set_type != 'data'), prefix = prefix, suffix = suffix,
                                      compact_rgb = compact_rgb)

                                if show:
                                    view[set_type][yi : yi + ds_image_size[1], xi : xi + ds_image_size[0]] = \
                                        (view[set_type][yi : yi + ds_image_size[1], xi : xi + ds_image_size[0]] + view[set_type].max() + 1)//2
                    
                    else:
                        discarded += 1
    
            n_valid = n_chunks - discarded
            print('Done.', n_valid, 'valid,', discarded, 'discarded (background).')
            
            if show:
                n_cols = len(view.keys())
                ax = plt.subplots(1, n_cols)[1]
                for i, set_type in enumerate(view.keys()):
                    if set_type == 'data':
                        view[set_type] = view[set_type][:, :, :3]
                    imshow_mini(view[set_type], ax[i] if n_cols > 1 else ax, '', unified_labels)
                plt.show()
            
            if calc_props:
                group_proportions = calculate_general_proportions(minerals_proportions)
                proportion_weights.append(n_valid)

                plot_proportions(group_proportions, unified_labels, gen_nc_file_name(group, depth))
                dataset_proportions.append(group_proportions)
    
    if calc_props:
        plot_proportions(calculate_general_proportions(dataset_proportions, weights = proportion_weights), unified_labels, 'DATASET PROPORTIONS ' + list(groups.keys())[0] + '-' + list(groups.keys())[-1]) 

In [None]:
groups = {
    'A': ['5228.45', '5230.5', '5232.1', '5235.5', '5242.95', '5246.95', '5247.7'],
    'B': ['5129.8', '5163.8', '5187.6', '5208', '5218.2'],
    'C': ['5565', '5705', '5732.5', '5758.5', '5766.4', '5813.2', '5823.3', '5838.2'],
    'D': ['6148.6', '6150', '6166', '6178.8', '6189', '6199.5', '6202', '6212', '6224', '6234', '6277', '6293.5', '6295'] # por engano, 6212 não foi incluída antes (quando o modelo principal foi treinado) \ 
        + ['6164.5'], # o PP tem resolução baixa e o QEMSCAN está errado. Usar se precisar do PX apenas
    'F': ['4938.3', '4941.8'] + ['4964.45'], # o último não tem QEMSCAN válido
    'G': ['5300.5', '5304'] + ['5268.5'], # o último não tem QEMSCAN válido
    'H': ['5615.8', '5629.45', '5630.8'],
    'I': ['5573.7', '5595.1', '5646.9'],
    'K': ['5427', '5368'],
    'L': ['5350.1', '5439.05', '5486.75'],
    'M': ['5229.75', '5235.75'] + ['5168.05', '5168.05t', '5219.38', '5236.95'], # os últimos não têm QEMSCAN válido; todas as imagens exceto a 5229.75 pertencem ao lote 2
    'P': ['5389.75'],
    'Q': ['5503.25', '5505.05'],
    'R': ['5182'],
    
    # new data
    'AR': ['6374.95', '6375.65', '6376.00', '6376.65', '6376.95', '6378.05', '6379.05', '6379.80', '6380.15', '6380.45', '6380.80', '6381.90'],
    'AS': ['5672.00', '5675.70'],
    'FL': ['5400.25', '5401.00', '5404.65'], # '5400.55' tem resolução baixa
    'LB': ['5472.20', '5477.70', '5481.65'],
    'MR': ['5561.45'], # não tem QEMSCAN válido
    'SA': ['6292.50', '6312.00'],
    'SC': ['6328.00', '6328.00t', '6340.50', '6340.50t', '6342.20', '6342.20t', '6344.90', '6344.90t', '6346.60', '6346.60t', '6353.00', '6353.00t', '6355.00', '6355.00t', '6363.00', '6363.00t', '6364.50', '6364.50t', '6374.00t', '6390.00', '6390.00t', '6398.50', '6398.50t'],
    'SL': ['5174.50'],
    'YB': ['4822.00', '4846.80'],
    
    #'Siliciclastics': ['01', '02', '03', '04', '05', '06', '07']
    
    # Bug: 5607.35 (no SOI), 5609.20 (no SOI), 5608.85 (no SOI), 5608.60 (no SOI), 5609.45 (no SOI)
    #'Processed_Poro': ['5611.80', '5769.20', '5623.10', '5765.50', '5631.55', '5628.25', '5610.65', '5617.05', '5662.30', '5795.05', '5603.45', '5623.65', '5749.30', '5790.00', '5627.25', '5720.40', '5621.05', '5604.00', '5759.00', '5612.10', '5643.60', '5629.90', '5711.10', '5610.45', '5631.10', '5811.50', '5804.70', '5607.95', '5881.20', '5636.40', '5690.00', '5862.00', '5634.55', '5603.10', '5637.45', '5844.80', '5651.10', '5684.20', '5780.15', '5851.10', '5613.20', '5622.05', '5613.50', '5633.80', '5707.30', '5616.15', '5611.05', '5666.40', '5635.50', '5620.15', '5705.10', '5602.45', '5634.90', '5632.30', '5871.00', '5632.00', '5636.65', '5627.95', '5714.40', '5832.70', '5610.35', '5726.30', '5790.65', '5621.75', '5612.45', '5619.00', '5800.05', '5739.50', '5639.30', '5602.05', '5716.70', '5775.25', '5802.00', '5754.00', '5618.35', '5785.15', '5614.55', '5632.95', '5699.50', '5822.10', '5635.75', '5607.35', '5609.20', '5608.85', '5608.60', '5609.45']
}

In [None]:
groups = {
    'A': ['5230.5'],
    'B': ['5218.2'],
    'C': ['5565', '5705'],
    'D': ['6166'] # por engano, 6212 não foi incluída antes (quando o modelo principal foi treinado) \ 
        + ['6164.5'], # o PP tem resolução baixa e o QEMSCAN está errado. Usar se precisar do PX apenas
    'F': ['4941.8'], # o último não tem QEMSCAN válido
    'G': ['5268.5'], # o último não tem QEMSCAN válido
    'H': ['5630.8'],
    'I': ['5573.7', '5595.1', '5646.9'],
    'K': ['5427', '5368'],
    'L': ['5350.1', '5439.05', '5486.75'],
    'M': ['5229.75'], # os últimos não têm QEMSCAN válido; todas as imagens exceto a 5229.75 pertencem ao lote 2
    'Q': ['5503.25'],
    
    # new data
    'SA': ['6312.00'],
    'SC': ['6328.00t', '6340.50t', '6342.20t', '6344.90t', '6346.60', '6346.60t', '6353.00', '6353.00t'],
    
    #'Siliciclastics': ['01', '02', '03', '04', '05', '06', '07']
    
    # Bug: 5607.35 (no SOI), 5609.20 (no SOI), 5608.85 (no SOI), 5608.60 (no SOI), 5609.45 (no SOI)
    #'Processed_Poro': ['5611.80', '5769.20', '5623.10', '5765.50', '5631.55', '5628.25', '5610.65', '5617.05', '5662.30', '5795.05', '5603.45', '5623.65', '5749.30', '5790.00', '5627.25', '5720.40', '5621.05', '5604.00', '5759.00', '5612.10', '5643.60', '5629.90', '5711.10', '5610.45', '5631.10', '5811.50', '5804.70', '5607.95', '5881.20', '5636.40', '5690.00', '5862.00', '5634.55', '5603.10', '5637.45', '5844.80', '5651.10', '5684.20', '5780.15', '5851.10', '5613.20', '5622.05', '5613.50', '5633.80', '5707.30', '5616.15', '5611.05', '5666.40', '5635.50', '5620.15', '5705.10', '5602.45', '5634.90', '5632.30', '5871.00', '5632.00', '5636.65', '5627.95', '5714.40', '5832.70', '5610.35', '5726.30', '5790.65', '5621.75', '5612.45', '5619.00', '5800.05', '5739.50', '5639.30', '5602.05', '5716.70', '5775.25', '5802.00', '5754.00', '5618.35', '5785.15', '5614.55', '5632.95', '5699.50', '5822.10', '5635.75', '5607.35', '5609.20', '5608.85', '5608.60', '5609.45']
}

In [None]:
groups = {
    'A': ['5230.5'] + ['5228.45', '5232.1', '5246.95'],
    'B': ['5129.8', '5163.8', '5187.6', '5208', '5218.2'],
    'C': ['5565', '5705', '5732.5', '5758.5', '5766.4', '5813.2', '5823.3', '5838.2'],
    'D': ['6150', '6166', '6178.8', '6199.5', '6202', '6212', '6224', '6234', '6277', '6293.5', '6295'] # por engano, 6212 não foi incluída antes (quando o modelo principal foi treinado) \ 
        + ['6164.5'] + ['6189'],
    'F': ['4941.8'] + ['4938.3', '4964.45'], # o último não tem QEMSCAN válido
    'G': ['5268.5'] + ['5300.5', '5304.00'], # o último não tem QEMSCAN válido
    'H': ['5615.8', '5629.45', '5630.8'],
    'I': ['5573.7', '5595.1', '5646.9'],
    'K': ['5427', '5368'],
    'L': ['5350.1', '5439.05', '5486.75'],
    'M': ['5229.75'] + ['5235.75', '5236.95'], # os últimos não têm QEMSCAN válido; todas as imagens exceto a 5229.75 pertencem ao lote 2
    'Q': ['5503.25'],
    
    # new data
    'AS': ['5672.00', '5675.70'],
    'SA': ['6312.00'] + ['6292.50'],
    'SC': ['6328.00t', '6340.50t', '6342.20t', '6344.90t', '6346.60t', '6353.00t'], # + ['6346.60t', '6353.00t']
    
    'AR': ['6376.65'],
    'FL': ['5400.25', '5401.00', '5404.65'],
    'LB': ['5472.20', '5477.70', '5481.65'],
    'SL': ['5174.50'],
    'YB': ['4822.00', '4846.80']
    
    #'Siliciclastics': ['01', '02', '03', '04', '05', '06', '07']
    
    # Bug: 5607.35 (no SOI), 5609.20 (no SOI), 5608.85 (no SOI), 5608.60 (no SOI), 5609.45 (no SOI)
    #'Processed_Poro': ['5611.80', '5769.20', '5623.10', '5765.50', '5631.55', '5628.25', '5610.65', '5617.05', '5662.30', '5795.05', '5603.45', '5623.65', '5749.30', '5790.00', '5627.25', '5720.40', '5621.05', '5604.00', '5759.00', '5612.10', '5643.60', '5629.90', '5711.10', '5610.45', '5631.10', '5811.50', '5804.70', '5607.95', '5881.20', '5636.40', '5690.00', '5862.00', '5634.55', '5603.10', '5637.45', '5844.80', '5651.10', '5684.20', '5780.15', '5851.10', '5613.20', '5622.05', '5613.50', '5633.80', '5707.30', '5616.15', '5611.05', '5666.40', '5635.50', '5620.15', '5705.10', '5602.45', '5634.90', '5632.30', '5871.00', '5632.00', '5636.65', '5627.95', '5714.40', '5832.70', '5610.35', '5726.30', '5790.65', '5621.75', '5612.45', '5619.00', '5800.05', '5739.50', '5639.30', '5602.05', '5716.70', '5775.25', '5802.00', '5754.00', '5618.35', '5785.15', '5614.55', '5632.95', '5699.50', '5822.10', '5635.75', '5607.35', '5609.20', '5608.85', '5608.60', '5609.45']
}

In [None]:
# groups = {
#     'Siliciclastics': [f'numero_{i}' for i in range(5, 25, 2)] + ['01', '02', '03', '04', '05', '06', '07'] # numero_1 e numero_3 não têm QEMSCAN
# }

groups = {
    'D': ['6189'],
    'A': ['5228.45', '5232.1', '5246.95'],
    'AR': ['6376.65'],
    'F': ['4938.3', '4964.45'],
    'FL': ['5400.25', '5401.00', '5404.65'],
    'G': ['5300.5', '5304.00'],
    'LB': ['5472.20', '5477.70', '5481.65'],
    'M': ['5235.75', '5236.95'],
    #'Q': ['5505.05'],
    'SA': ['6292.50'],
    'SL': ['5174.50'],
    'YB': ['4822.00', '4846.80']
}
sum = 0
for key in groups:
    sum += len(groups[key])
sum

In [None]:
project = 'qemscan'
dataset_dir = None

if project == 'qemscan':
    dataset_dir = 'D:\\Annotated' # os.path.join(os.sep, 'petrobr', 'parceirosbr', 'smartseg', 'datasets', 'qemscan')
    from_nc = True
elif project == 'poreseg':
    dataset_dir = os.path.join(os.sep, 'petrobr', 'parceirosbr', 'smartseg', 'datasets', 'poreseg')
    from_nc = False
elif project == 'elementos_construtores':
    dataset_dir = os.path.join(os.sep, 'petrobr', 'parceirosbr', 'smartseg', 'datasets', 'elementos_construtores')
    from_nc = True

# ATENÇÃO: O arquivo 'monailabel_labels.csv' tem o índice dos minerais iniciado em 50. Isso se deve ao fato de que o app testado
# tinha 50 classes (com as excedentes consideradas para a feature de adição de novas classes dinamicamente). Futuramente,
# considerar inserir em código o índice inicial.

unified_labels_file = os.path.join(dataset_dir, 'labels.csv') #os.path.join(dataset_dir, 'unified_labels.csv')

In [None]:
'''
@groups (dict(str, list(str))): dicionário em que as chaves descrevem os grupos (e.g. poços) de imagens a partir das quais o
dataset será gerado e os valores são listas de imagens contidas em cada grupo. As imagens de cada grupo devem estar em um
diretório com o nome do grupo e devem ser nomeadas conforme consta nas listas;

@data_dir (str): diretório contendo os dados descritos em @groups;

@from_nc (bool): Default: True; Booleano que indica se as imagens descritas em @groups têm o formato .nc. Caso falso,
espera-se que se apresentem em formato .nrrd, em que cada nó (lâmina, segmentação, SOI, etc.) seja um arquivo distinto;

@do_qemscan_unification (bool): Default: False; Caso verdadeiro, percorre o dataset extraindo todas as ocorrências distintas de
minerais, atribuindo um rótulo numérico único para cada e salvando a informação no arquivo descrito em @unified_labels_file.
Visa mitigar o fato de que o mesmo mineral consta com diferentes rótulos em diferentes imagens;

@do_ec_unification (bool): Default: False; Similar ao @do_qemscan_unification, mas referente a elementos construtores ao invés
de minerais;

@unified_labels_file (str): Default: 'unified_labels.csv'; Arquivo no qual será registrada a unificação de rótulos descrita em
@do_qemscan_unification e @do_ec_unification. Será usado na geração do dataset. Cria o arquivo, caso não exista;

@sequential_label_indexes (bool): Default: False; Caso verdadeiro e @do_qemscan_unification e @do_ec_unification sejam
verdadeiros, os elementos construtores e minerais são listados em @unified_labels_file com índices sequenciais, a partir de 1.
Caso contrário, cada um dos tipos (elementos construtores / minerais) tem a própria enumeração, iniciada em 1;

@ds_image_size (int): Default: 32; tamanho lateral dos recortes;

@extension (str): Default: 'nii.gz'. Extensão dos arquivos contendo os recortes;

@channel_first (bool): Default: False; Os recortes são salvos no formato CHW(D), caso verdadeiro, ou HWC(D), caso contrário;

@preserve_channels: [deprectated]

@as_volume (bool): Default: False; Caso verdadeiro, adiciona uma quarta dimensão ao recorte, representando profundidade 1;

@show (bool): Default: False; Caso verdadeiro, exibe os nós PP, PX, rótulos (QEMSCAN e elementos construtores) e SOI da imagem,
bem como sua área útil ao considerar o SOI como máscara;

@save_nodes (list(str)): Default: None; Lista dos nós que serão incluídos nos recortes salvos. Suporta:
    'pp';
    'px';
    'qemscan';
    'ec' [elementos construtores];
    [] ou None, caso não seja necessário salvar (indicado em casos de visualização (@show = True));

@calc_props (bool): Default: True; Caso verdadeiro, calcula e exibe as proporções de cada fase mineral na área útil de cada
imagem e, por fim, no dataset completo;

@bg_tresh (int): [deprecated]

@shrank (bool): Default: False; adiciona o sufixo '_shrank' nos diretórios gerados para o dataset final, para identificação
caso se esteja trabalhando com imagens com o efeito shrink ativado;

@max_zero_rate_thresh (float): Default: 1; porporção (0.0 - 1.0) de área não-útil aceitável para que o recorte seja salvo;

** Parâmetros úteis para o projeto elementos_construtores com MONAI Label:

@missing_nodes_allowed (bool): Default: False; Caso verdadeiro, deixa de ser obrigatório que as imagens tenham os nós
especificados. Útil para construção de um dataset a ser rotulado. Suporta a nomenclatura de nós adotada em @save_nodes;

@final_size (int): Default: None; Se especificado, as laterais dos recortes serão redimensionadas para este valor antes que
ele seja salvo;

@single_output_dir (bool): Default: False; Caso verdadeiro, salva todos os recortes em um mesmo diretório. Caso contrário,
salva-os separados por grupo (vide @groups);

@save_randomized (bool): Default: False; Caso verdadeiro, adiciona um prefixo aleatório aos arquivos contendo os recortes
salvos. Útil para treinamento via MONAI Label, visto que o módulo separa os conjuntos de treino e validação alfabeticamente.

@compact_rgb (bool): Default: False; Caso verdadeiro, salva a informação RGB do recorte não como 3 canais distintos, mas como
1 único canal do tipo RGB, formato suportado pelo visualizador do 3D Slicer.

@crop_soi_area (bool): Default: True; Caso verdadeiro, descarta a área não-útil ao redor do SOI antes de gerar os recortes.
Caso contrário, considera as imagens completamente.

@split_ec_instances (bool): Default: False; Caso falso, os elementos construtores são separados apenas por classe. Todos os
elementos construtores da imagem são salvos em um únicoa arquivo, e todas as instâncias de um mesmo elemento têm o mesmo valor
(correspondente ao seu índice em @unified_labels_file). Caso contrário, o arquivo final é salvo no formato (H, W, N), sendo
N a quantidade de elementos construtores em @unified_labels_file. O elemento de índice k é salvo no canal k-1 desta nova
dimensão, e cada instância tem um valor diferente.
'''

'''
# ** For mineralogy
gen_dataset(groups, dataset_dir, from_nc = from_nc, do_qemscan_unification = True, do_ec_unification = False,
            unified_labels_file = unified_labels_file, extension = 'nii.gz', channel_first = False, as_volume = False,
            initial_labels = None,
            ds_image_size = 1000, show = True, save_nodes = [], calc_props = False, bg_thresh = 0, max_zero_rate_thresh = 0.3,
            missing_nodes_allowed = ['ec'], final_size = None, single_output_dir = False, save_randomized = False,
            compact_rgb = False, crop_soi_area = True, split_ec_instances = False, yolo = False)


# ** For EC's
gen_dataset(groups, os.path.join(dataset_dir, 'elementos_construtores', 'dataset_sdumont'), from_nc = from_nc, do_qemscan_unification = False, do_ec_unification = False,
            unified_labels_file = unified_labels_file, extension = 'seg.nrrd', channel_first = False, as_volume = False,
            initial_labels = None,
            ds_image_size = 2048, show = False, save_nodes = ['px', 'ec'], calc_props = False, bg_thresh = 0, max_zero_rate_thresh = 1,
            missing_nodes_allowed = ['pp', 'qemscan', 'ec'], final_size = None, single_output_dir = True, save_randomized = True,
            compact_rgb = False, crop_soi_area = False, split_ec_instances = True, yolo = False)
'''

#groups = {
#    'SC': ['6353.00t']
#}

# ** For EC's w/ yolo
gen_dataset(groups, dataset_dir, from_nc = from_nc, do_qemscan_unification = False, do_ec_unification = False,
            unified_labels_file = unified_labels_file, extension = 'png', channel_first = False, as_volume = False,
            initial_labels = None,
            ds_image_size = 2048, show = False, save_nodes = ['px', 'ec'], calc_props = False, bg_thresh = 0, max_zero_rate_thresh = 1,
            missing_nodes_allowed = ['pp', 'qemscan', 'ec'], final_size = None, single_output_dir = False, save_randomized = True,
            compact_rgb = False, crop_soi_area = False, split_ec_instances = True, yolo = True, yolo_seg = True,
            occlusion_percentage = [0.5, 0.1, 0.8, 0.5, 0.1, 0.3, 0.5, 0.1, 0.1, 0.3, 0.5, 0.5, 0.5, 0.8])

In [None]:
import torch
m = torch.load('C:\\Users\\LTrace\\Desktop\\slicerltrace\\src\\ltrace\\ltrace\\' + \
    'assets\\trained_models\\ThinSectionEnv\\petrobras_complete_u_net.pth')#, map_location = 'cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
m.keys()