# Late Fusion

To perform late fusion, we must either get the predictions and use them to sum the activations/get the votes or save individual activations before and load them.

It follows the steps:
* Evaluate the models and save the last activation layer and the prediction individually

    To do this step, one might use the test_segmentation function in the train_segmentation.py file

Sum of activations:
* Load the activation for each model to be ensembled
* Sum up the activations of each model
* Obtain the joint prediction by getting the highest activation
* Calculate the IoU metric
* Plot results

Majority voting:
* Load the predictions for each model to be ensembled
* Concatenate the predictions of each model
* Obtain the joint prediction through the mode 
* Calculate the IoU metric
* Plot results

In [None]:
from dataloader_seismic import *
from plots import *
from fcn import *
from aux_functions import *
from train_segmentation import *
from pathlib import Path
from torchinfo import summary

In [None]:
#Set the path to the saved predictions and activations
root = os.getcwd()
root = str(Path(root).parent)

loading_activation_path = os.path.join(root,f'cross_validation/saved_activations')
loading_activation_path

In [None]:
#Set the different shots and dataset
n_few_shot = 20
dataset= 'Parihaka_NZPM'       # 'F3_netherlands' 'Parihaka_NZPM'


#Define plotting colors by dataset        
if dataset=='F3_netherlands':
    colors = ["gold", "lawngreen", "lightseagreen", "orange", "blue", "sienna", "violet", "lightseagreen", "darkorange","red"]
elif dataset =='Parihaka_NZPM':
    colors = ['steelblue','darkturquoise', 'lightcoral', 'mediumseagreen','lavender','wheat']
cmap_seismic = LinearSegmentedColormap.from_list("mycmap", colors)

## Sum of Activations

In [None]:

iou_per_fold = []

for fold in range(0,5):

    #Model details
    args = {
        'dataset' :  dataset , # F3_netherlands  Parihaka_NZPM
        'task'      : 'segmentation', # segmentation
        'batch_size' : 1,
        'num_workers' : 4,
        'n_channels' : 1,
        'train_type' : 'fine_tune', # sup_ssl  few_shot fine_tune
    }


    #Define test set
    test_set = SeismicDataset(root=root, dataset_name=args['dataset'], split='test', 
                                task=args['task'],train_type=args['train_type'],n_few_shot=n_few_shot,
                                cross_val=True, fold=fold)

    dataloader_test = DataLoader(test_set, 
                                  batch_size=args['batch_size'],
                                  shuffle=False,num_workers=args['num_workers'])

    #Load predictions
    pretask = 'jigsaw'
    signature = f'cross_{pretask}_{n_few_shot}shot'
    name_model = f'{test_set.dataset_name}_{signature}_fold{fold}'
    loaded_jig = np.load(os.path.join(loading_activation_path,name_model)+'.npz') 

    pretask = 'rotation'
    signature = f'cross_{pretask}_{n_few_shot}shot'
    name_model = f'{test_set.dataset_name}_{signature}_fold{fold}'
    loaded_rot = np.load(os.path.join(loading_activation_path,name_model)+'.npz')

    pretask = 'frame_order'
    signature = f'cross_{pretask}_{n_few_shot}shot'
    name_model = f'{test_set.dataset_name}_{signature}_fold{fold}'
    loaded_frame = np.load(os.path.join(loading_activation_path,name_model)+'_pred.npz')

    #Rearrange the objects to np array 
    print(f'reading jigsaw')
    jig_array_activ = np.array(loaded_jig['activation'], dtype='float32')
    print(f'reading rotation')
    rot_array_activ = np.array(loaded_rot['activation'], dtype='float32')
    print(f'reading frame order')
    frame_array_activ = np.array(loaded_frame['activation'], dtype='float32')

    #Transform to torch tensor
    jig_tensor_activation = torch.from_numpy(jig_array_activ)
    rot_tensor_activation = torch.from_numpy(rot_array_activ)
    frame_tensor_activation = torch.from_numpy(frame_array_activ)

    #Get LATE FUSION sum of activations
    late_fusion_activ_sum = jig_tensor_activation + rot_tensor_activation  + frame_tensor_activation

    assert late_fusion_activ_sum.shape==jig_tensor_activation.shape, 'Wrong shapes'
    assert late_fusion_activ_sum.shape==rot_tensor_activation.shape, 'Wrong shapes'
    assert late_fusion_activ_sum.shape==frame_tensor_activation.shape, 'Wrong shapes'

    #Get the LATE FUSION PREDICTION
    print(f'getting late fusion prediction')
    late_pred_sum = late_fusion_activ_sum.max(1)[1] 


    #Compute IoU
    iou_epoch = []
    for idx, (section, test_labels, name) in enumerate(dataloader_test):
        if idx % 30 == 0:
            #if idx == 21:
            print(f'{idx}/{len(dataloader_test.dataset.sections)}')
        iou_mean, _ = evaluate_iouf1(late_pred_sum[idx], test_labels,dataloader_test.dataset.classes)
        iou_epoch.append(iou_mean)

        #Show some
        if idx % 100==0:
        #if idx == 21:

            fig, ax = plt.subplots(1, 5, figsize=(16, 9), constrained_layout=True)
            fig.suptitle(f'Model: {name_model}')

            ax[0].imshow(jig_tensor_preds[idx].squeeze(0),cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[0].set_yticks([])
            ax[0].set_xticks([])
            ax[0].set_title(f'Jigsaw -{idx}')

            ax[1].imshow(rot_tensor_preds[idx].squeeze(0),cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[1].set_yticks([])
            ax[1].set_xticks([])
            ax[1].set_title(f'Rotation -{idx}')

            ax[2].imshow(frame_tensor_preds[idx].squeeze(0),cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[2].set_yticks([])
            ax[2].set_xticks([])
            ax[2].set_title(f'Frame Order -{idx}')

            ax[3].imshow(late_pred_sum[idx],cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[3].set_yticks([])
            ax[3].set_xticks([])
            ax[3].set_title(f'Late Fusion -{idx}')

            ax[4].imshow(test_labels.squeeze(),cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[4].set_yticks([])
            ax[4].set_xticks([])
            ax[4].set_title(f'Mask')

            plt.show()

    print(f'Model: {name_model} - Mean IoU: {np.array(np.mean(iou_epoch))}')
    iou_per_fold.append(np.mean(iou_epoch))
    print(iou_per_fold)

In [None]:
iou_per_fold

In [None]:
for iou in iou_per_fold:
    print(iou)

## Majority voting 


In [None]:

iou_per_fold = []

for fold in range(0,5):

    #Model details
    args = {
        'dataset' :  dataset , # F3_netherlands  Parihaka_NZPM
        'task'      : 'segmentation', #  segmentation
        'batch_size' : 1,
        'num_workers' : 4,
        'n_channels' : 1,
        'train_type' : 'fine_tune', # sup_ssl  few_shot fine_tune
    }




    #Define test set
    test_set = SeismicDataset(root=root, dataset_name=args['dataset'], split='test', 
                                task=args['task'],train_type=args['train_type'],n_few_shot=n_few_shot,
                                cross_val=True, fold=fold)

    dataloader_test = DataLoader(test_set, 
                                  batch_size=args['batch_size'],
                                  shuffle=False,num_workers=args['num_workers'])

    #Load predictions
    pretask = 'jigsaw'
    signature = f'cross_{pretask}_{n_few_shot}shot'
    name_model = f'{test_set.dataset_name}_{signature}_fold{fold}'
    loaded_jig = np.load(os.path.join(loading_activation_path,name_model)+'.npz')

    pretask = 'rotation'
    signature = f'cross_{pretask}_{n_few_shot}shot'
    name_model = f'{test_set.dataset_name}_{signature}_fold{fold}'
    loaded_rot = np.load(os.path.join(loading_activation_path,name_model)+'.npz')
    
    pretask = 'frame_order'
    signature = f'cross_{pretask}_{n_few_shot}shot'
    name_model = f'{test_set.dataset_name}_{signature}_fold{fold}'
    loaded_frame = np.load(os.path.join(loading_activation_path,name_model)+'_pred.npz')

    #Rearrange the objects to np array 
    print(f'reading jigsaw')
    jig_array_preds = np.array(loaded_jig['preds'], dtype='float32')
    print(f'reading rotation')
    rot_array_preds = np.array(loaded_rot['preds'], dtype='float32')
    print(f'reading frame order')
    frame_array_preds = np.array(loaded_frame['preds'], dtype='float32')

    #Transform to torch tensor
    jig_tensor_preds = torch.from_numpy(jig_array_preds)
    rot_tensor_preds = torch.from_numpy(rot_array_preds)
    frame_tensor_preds = torch.from_numpy(frame_array_preds)

    #Get LATE FUSION by majority voting
    concat = torch.cat((frame_tensor_preds, rot_tensor_preds, jig_tensor_preds),dim=1)
    print('Obtaining Mode')
    (values, indices) = torch.mode(concat, dim=1, keepdim=True)
    modas = values.squeeze(1)

    #Compute IoU
    iou_epoch = []
    for idx, (section, test_labels, name) in enumerate(dataloader_test):
        if idx % 30 == 0:
            print(f'{idx}/{len(dataloader_test.dataset.sections)}')
        iou_mean, _ = evaluate_iouf1(modas[idx], test_labels,dataloader_test.dataset.classes)
        iou_epoch.append(iou_mean)



        #Show some
        if idx % 100==0:

            fig, ax = plt.subplots(1, 5, figsize=(16, 9), constrained_layout=True)
            fig.suptitle(f'Model: {name_model}')

            ax[0].imshow(jig_tensor_preds[idx].squeeze(0),cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[0].set_yticks([])
            ax[0].set_xticks([])
            ax[0].set_title(f'Jigsaw -{idx}')

            ax[1].imshow(rot_tensor_preds[idx].squeeze(0),cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[1].set_yticks([])
            ax[1].set_xticks([])
            ax[1].set_title(f'Rotation -{idx}')

            ax[2].imshow(frame_tensor_preds[idx].squeeze(0),cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[2].set_yticks([])
            ax[2].set_xticks([])
            ax[2].set_title(f'Frame Order -{idx}')
            
            ax[3].imshow(modas[idx],cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[3].set_yticks([])
            ax[3].set_xticks([])
            ax[3].set_title(f'Late Fusion - Majority Voting -{idx}')
            
            ax[4].imshow(test_labels.squeeze(),cmap=cmap_seismic,vmin=0,vmax=len(colors)-1)
            ax[4].set_yticks([])
            ax[4].set_xticks([])
            ax[4].set_title(f'Mask - {idx}')
            
            plt.show()

    print(f'Model: {name_model} - Mean IoU: {np.array(np.mean(iou_epoch))}')
    iou_per_fold.append(np.mean(iou_epoch))
    print(iou_per_fold)

In [None]:
iou_per_fold

In [None]:
for iou in iou_per_fold:
    print(iou)