# SMSNet Ensemble Method

Note: This notebook needs to be moved to the parent directory in order to execute. (imports)

In [None]:
from    src.network         import  build_network
from    src.parameters      import  MSDNetParameters, TUNetParameters, TUNet3PlusParameters
from    src.seg_utils       import  train_val_split, train_segmentation
from    src.tiled_dataset   import  TiledDataset
import  torch
import  torch.nn        as      nn
import  torch.optim     as      optim
from    torchvision     import  transforms
from    src.utils           import  create_directory

import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataloader import default_collate
from dlsia.core.train_scripts import segmentation_metrics
import yaml

In [None]:
# Open the YAML file for all parameters
yaml_path = 'example_yamls/example_smsnet_ensemble.yaml'
with open(yaml_path, 'r') as file:
    # Load parameters
    parameters = yaml.safe_load(file)

In [None]:
model_parameters = parameters['model_parameters']
model_parameters

In [None]:
dataset = TiledDataset(
        data_tiled_uri=parameters['data_tiled_uri'],
        data_tiled_api_key=parameters['data_tiled_api_key'],
        mask_tiled_uri=parameters['mask_tiled_uri'],
        mask_tiled_api_key=parameters['mask_tiled_api_key'],
        qlty_window=model_parameters['qlty_window'],
        qlty_step=model_parameters['qlty_step'],
        qlty_border=model_parameters['qlty_border'],
        transform=transforms.ToTensor()
        )

In [None]:
dataset.mask_client

In [None]:
dataset.mask_client_one_up.metadata

In [None]:
import matplotlib.pyplot as plt
for n in range(3):
    plt.imshow(dataset[n])
    plt.colorbar()
    plt.show()

In [None]:
dataset.mask_client

In [None]:
def custom_collate(batch):
    elem = batch[0]
    print(f'elem type: {type(elem)}')
    first_data = elem[0]
    print(f'first_data_size: {first_data.shape}')
    if isinstance(elem, tuple) and elem[0].ndim == 4:
        data, mask = zip(*batch)
        concated_data = torch.cat(data, dim=0) # concat on the first dim without introducing another dim -> keep in the 4d realm
        concated_mask = torch.cat(mask, dim=0)
        print(f'concated_data shape: {concated_data.shape}')
        print(f'concated_mask shape: {concated_mask.shape}')
        return concated_data, concated_mask
    elif isinstance(elem, torch.Tensor) and elem.ndim == 4:
        print(f'batch size: {len(batch)}')
        concated_data = torch.cat(batch, dim=0) # concat on the first dim without introducing another dim -> keep in the 4d realm
        print(f'concated_data shape: {concated_data.shape}')
        return concated_data
    else:  # Fall back to `default_collate` as suggested by PyTorch documentation
        return default_collate(batch)

def train_val_split(dataset, parameters):
    '''
    This funnction splits the given tiled_dataset object into the train set and val set using torch's built in random_split function.

    Caution: the random_split does not taken class balance into account. Future upgrades for that direction would requrie sampler from torch.
    '''

    # Set Dataloader parameters (Note: we randomly shuffle the training set upon each pass)
    train_loader_params = {'batch_size': parameters['batch_size_train'],
                        'shuffle': parameters['shuffle_train']}
    val_loader_params = {'batch_size': parameters['batch_size_val'],
                        'shuffle': parameters['shuffle_val']}

    # Build Dataloaders
    val_pct = parameters['val_pct']
    val_size = int(val_pct*len(dataset))
    print(f'length of dataset: {len(dataset)}')
    print(f'length of val_size: {val_size}')
    if len(dataset) == 1:
        train_loader = DataLoader(dataset, **train_loader_params, collate_fn=custom_collate)
        val_loader = None
    elif val_size == 0:
        train_size = len(dataset) - 1
        train_data, val_data = random_split(dataset, [train_size, 1])
        print(f'train_data size: {len(train_data)}')
        train_loader = DataLoader(train_data, **train_loader_params, collate_fn=custom_collate)
        val_loader = DataLoader(val_data, **val_loader_params, collate_fn=custom_collate)
    else:
        train_size = len(dataset) - val_size
        train_data, val_data = random_split(dataset, [train_size, val_size])
        train_loader = DataLoader(train_data, **train_loader_params, collate_fn=custom_collate)
        val_loader = DataLoader(val_data, **val_loader_params, collate_fn=custom_collate)
    return train_loader, val_loader

In [None]:
train_loader, val_loader = train_val_split(dataset, model_parameters)

In [None]:
from dlsia.core.networks import sms3d, smsnet
from dlsia.core import helpers

def construct_2dsms_ensembler(n_networks,
                              in_channels,
                              out_channels,
                           layers,
                           alpha = 0.0,
                           gamma = 0.0,
                           hidden_channels = None,
                           dilation_choices = [1,2,3,4],
                           P_IL = 0.995,
                           P_LO = 0.995,
                           P_IO = True,
                           parameter_bounds = None,
                           max_trial=100,
                           network_type="Regression",
                           parameter_counts_only = False
                           ):

    networks = []

    layer_probabilities = {
        'LL_alpha': alpha,
        'LL_gamma': gamma,
        'LL_max_degree': layers,
        'LL_min_degree': 1,
        'IL': P_IL,
        'LO': P_LO,
        'IO': P_IO,
    }


    if parameter_counts_only:
        assert parameter_bounds is None

    if hidden_channels is None:
        hidden_channels = [ 3*out_channels ]

    for _ in range(n_networks):
        ok = False
        count = 0
        while not ok:
            count += 1
            this_net = smsnet.random_SMS_network(in_channels=in_channels,
                                                    out_channels=out_channels,
                                                    layers=layers,
                                                    dilation_choices=dilation_choices,
                                                    hidden_out_channels=hidden_channels,
                                                    layer_probabilities=layer_probabilities,
                                                    sizing_settings=None,
                                                    dilation_mode="Edges",
                                                    network_type=network_type,
                                                    )
            pcount = helpers.count_parameters(this_net)
            if parameter_bounds is not None:
                if pcount > min(parameter_bounds):
                    if pcount < max(parameter_bounds):
                        ok = True
                        networks.append(this_net)
                if count > max_trial:
                    print("Could not generate network, check bounds")
            else:
                ok = True
                if parameter_counts_only:
                    networks.append(pcount)
                else:
                    networks.append(this_net)
    return networks


In [None]:
net_ensemble = construct_2dsms_ensembler(
                              n_networks=3,
                              in_channels=1,
                              out_channels=3,
                              layers=5,#<USER CHOICE, LIMIT FROM 5 to say 20>,
                              alpha = 0.0, #KEEP AS IS
                              gamma = 0.0, # KEEP AS IS
                              hidden_channels = None, #<USER CHOICE, LIMIT FROM 3 to 20>
                              dilation_choices = [1,2,3,4,5],
                              parameter_bounds = None, # LEAVE AS IS
                              max_trial=10,
                              network_type="Classification", # SET TO "Classification" for segmentation task
                              parameter_counts_only = False, # LEAVE AS IS
                           )

In [None]:
len(net_ensemble)

In [None]:
# Define criterion and optimizer
criterion = getattr(nn, model_parameters['criterion'])
criterion = criterion(weight=model_parameters['weights'],
                        ignore_index=-1, 
                        size_average=None
                        )    
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

for idx, net in enumerate(net_ensemble):
    optimizer = getattr(optim, model_parameters['optimizer'])
    optimizer = optimizer(net.parameters(), lr = model_parameters['learning_rate'])
    net, results = train_segmentation(
        net,
        train_loader,
        val_loader,
        model_parameters['num_epochs'],
        criterion,
        optimizer,
        device,
        savepath=parameters['save_path'],
        saveevery=None,
        scheduler=None,
        show=0,
        use_amp=False,
        clip_value=None
    )
    # Save network parameters
    model_params_path = f"{parameters['save_path']}/{parameters['uid']}_SMSNet{idx}.pt"
    net.save_network_parameters(model_params_path)


# Inference

In [None]:
from dlsia.core.networks.baggins import model_baggin
from dlsia.core.networks.smsnet import SMSNetwork_from_file
from    qlty.qlty2D         import  NCYXQuilt

In [None]:
dataset = TiledDataset(
        data_tiled_uri=parameters['data_tiled_uri'],
        mask_idx=parameters['mask_idx'], # Keeping this for a quick inference for now, in future this will be out with updates from app.
        data_tiled_api_key=parameters['data_tiled_api_key'],
        qlty_window=model_parameters['qlty_window'],
        qlty_step=model_parameters['qlty_step'],
        qlty_border=model_parameters['qlty_border'],
        transform=transforms.ToTensor()
        )

# Set Dataloader parameters (Note: we randomly shuffle the training set upon each pass)
inference_loader_params = {'batch_size': model_parameters['batch_size_inference'],
                            'shuffle': model_parameters['shuffle_inference']}
# Build Dataloaders
inference_loader = DataLoader(dataset, **inference_loader_params, collate_fn=custom_collate)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
filenames = ['results/models/uid0013_SMSNet0.pt', 'results/models/uid0013_SMSNet1.pt', 'results/models/uid0013_SMSNet2.pt']
list_of_smsnet = []
for network in filenames:
    list_of_smsnet.append(SMSNetwork_from_file(network))

In [None]:
def segment(net, device, inference_loader, qlty_object):

    patch_preds = [] # store results for patches
    for batch in inference_loader:
        with torch.no_grad():
            # Necessary data recasting
            batch = batch.type(torch.FloatTensor)
            batch = batch.to(device)
            # Input passed through networks here
            mean_map, std_map = net(batch,device=device,return_std=True)
            patch_preds.append(mean_map)
    
    patch_preds = torch.concat(patch_preds)
    stitched_result, weights = qlty_object.stitch(patch_preds)
    # Individual output passed through argmax to get predictions
    seg = torch.argmax(stitched_result.cpu(), dim=1).numpy()
    print(f'Result array shape: {seg.shape}')
    print(f'Result array type: {type(seg)}')
    return seg

In [None]:
ensemble = model_baggin(models=list_of_smsnet, model_type='classification')
# mean_map, std_map = ensemble(input_tensor, device=device,return_std=True)
qlty_object = NCYXQuilt(X=dataset.data_client.shape[-1], 
                        Y=dataset.data_client.shape[-2],
                        window = (model_parameters['qlty_window'], model_parameters['qlty_window']),
                        step = (model_parameters['qlty_step'], model_parameters['qlty_step']),
                        border = (model_parameters['qlty_border'], model_parameters['qlty_border'])
                            )
seg = segment(ensemble, device, inference_loader, qlty_object)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
plt.imshow(seg[0])
plt.show()

In [None]:
np.unique(seg)