# Baseline Diversity-based Ensemble Selection

This demo provides the baseline diversity-based ensemble selection examples on CIFAR-10 and ImageNet.

In [None]:
import os
import time
import timeit
import numpy as np

import torch
from itertools import combinations

from EnsembleBench.frameworks.pytorchUtility import (
    calAccuracy,
    calAveragePredictionVectorAccuracy,
    calNegativeSamplesSet,
    calDisagreementSamplesNoGroundTruth,
    filterModelsFixed,
)

from EnsembleBench.groupMetrics import (
    calAllDiversityMetrics,
)
from EnsembleBench.teamSelection import (
    getNTeamStatistics,
)

%load_ext autoreload
%autoreload 2

## Dataset Configurations

You can download the extracted predictions for CIFAR-10 and ImageNet from the following Google Drive folder.
https://drive.google.com/drive/folders/18rEcjSpMSy-XN2bUQ3PfsBppwb874B8q?usp=sharing

In [None]:
# simply use the extracted prediction results to calculate the diversity scores and perform ensemble selection

dataset = 'cifar10'
diversityMetricsList = ['CK', 'QS', 'BD', 'FK', 'KW', 'GD']

if dataset == 'cifar10':
    predictionDir = './cifar10/prediction'
    models = ['densenet-L190-k40', 'densenetbc-100-12', 'resnext8x64d', 'wrn-28-10-drop', 'vgg19_bn', 
              'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110']
    maxModel = 0
    maxModelAcc = 96.68
    targetAcc = 96.33 # accuracy of entire ensemble
elif dataset == 'imagenet':
    predictionDir = './imagenet/prediction'
    models = np.array(['AlexNet', 'DenseNet', 'EfficientNetb0', 'ResNeXt50', 'Inception3', 'ResNet152', 'ResNet18', 'SqueezeNet', 'VGG16', 'VGG19bn'])
    maxModel = 5
    maxModelAcc = 78.25
    targetAcc = 79.82 # accuracy of entire ensemble

else:
    raise Exception("Dataset not support!")

suffix = '.pt'

# Perform Baseline Diversity-based Ensemble Selection

In [None]:
# load prediction vectors
labelVectorsList = list()
predictionVectorsList = list()
tmpAccList = list()
for m in models:
    predictionPath = os.path.join(predictionDir, m+suffix)
    prediction = torch.load(predictionPath)
    predictionVectors = prediction['predictionVectors']
    predictionVectorsList.append(torch.nn.functional.softmax(predictionVectors, dim=-1).cpu())
    labelVectors = prediction['labelVectors']
    labelVectorsList.append(labelVectors.cpu())
    tmpAccList.append(calAccuracy(predictionVectors, labelVectors)[0].cpu())
    print(tmpAccList[-1])

minAcc = np.min(tmpAccList)
avgAcc = np.mean(tmpAccList)
maxAcc = np.max(tmpAccList)

In [None]:
# obtain negative samples for any base models
sampleID, sampleTarget, predictions, predVectors = calDisagreementSamplesNoGroundTruth(
    predictionVectorsList, labelVectorsList[0]
)

sampleID = np.array(sampleID)
sampleTarget = np.array(sampleTarget)
predictions = np.array(predictions)
predVectors = np.array([np.array([np.array(pp) for pp in p]) for p in predVectors])

# settings for the diversity score calculation
nModels = len(predictions[0])
modelIdx = list(range(nModels))

In [None]:
# calculate diversity scores for ensemble teams
np.random.seed(0)
crossValidation = True
crossValidationTimes = 3
nRandomSamples = 100

teamSizeList = list()
teamList = list()
diversityScoresList = list()

startTime = timeit.default_timer()
for n in range(2, nModels+1):
    comb = combinations(modelIdx, n)
    for selectedModels in list(comb):
        teamSampleID, teamSampleTarget, teamPredictions, teamPredVectors = filterModelsFixed(sampleID, sampleTarget, predictions, predVectors, selectedModels) 
        
        if len(teamPredictions) == 0:
            print("negative sample not found")
            continue
        
        if crossValidation:
            tmpMetrics = list()   
            for _ in range(crossValidationTimes):
                randomIdx = np.random.choice(np.arange(teamPredictions.shape[0]), nRandomSamples)
                tmpMetrics.append(calAllDiversityMetrics(teamPredictions[randomIdx], teamSampleTarget[randomIdx], diversityMetricsList))
            tmpMetrics = np.mean(np.array(tmpMetrics), axis=0)
        else:
            tmpMetrics = np.array(calAllDiversityMetrics(teamPredictions, teamSampleTarget, diversityMetricsList))
        
        diversityScoresList.append(tmpMetrics)                                  
        teamSizeList.append(n)
        teamList.append(selectedModels)
endTime = timeit.default_timer()
print("Time: ", endTime-startTime)

diversityScoresList = np.array(diversityScoresList)
teamSizeList = np.array(teamSizeList)
teamList = np.array(teamList, dtype=object)


In [None]:
# perform mean-threshold based ensemble selection (baseline approach)
QMetrics = {}
QMetricsThreshold = {}
teamSelectedQAllDict = {}

for j, dm in enumerate(diversityMetricsList):
    QMetricsThreshold[dm] = np.mean(diversityScoresList[..., j])

print("Diversity threshold: ", QMetricsThreshold)

for i, t in enumerate(teamList):
    teamName = "".join(map(str, t))
    for j, dm in enumerate(diversityMetricsList):
        QMetricsDM = QMetrics.get(dm, {})
        QMetricsDM[teamName] = diversityScoresList[i][j]
        QMetrics[dm] = QMetricsDM
        if QMetricsDM[teamName] > round(QMetricsThreshold[dm], 3):
            teamSelectedQAllSet = teamSelectedQAllDict.get(dm, set())
            teamSelectedQAllSet.add(teamName)
            teamSelectedQAllDict[dm] = teamSelectedQAllSet

In [None]:
# Evaluate ensemble selection results

# Calculate the team accuracy (optional)
# team -> accuracy map
# model -> team
import timeit
teamAccuracyDict = dict()
startTime = timeit.default_timer()
for n in range(2, len(models)+1):
    comb = combinations(list(range(len(models))), n)
    for selectedModels in list(comb):
        tmpAccuracy = calAveragePredictionVectorAccuracy(predictionVectorsList, labelVectorsList[0], modelsList=selectedModels)[0].cpu().item()
        teamName = "".join(map(str, selectedModels))
        teamAccuracyDict[teamName] = tmpAccuracy
endTime = timeit.default_timer()
print("Accuracy Calculation Time (s): ", endTime-startTime)

# statistics for different diversity metrics
for dm in diversityMetricsList:
    print(dm, getNTeamStatistics(list(teamSelectedQAllDict[dm]), teamAccuracyDict,
                                 minAcc, avgAcc, maxAcc, tmpAccList))