In [16]:
import csv
import os

RESULTS_DIR = '/Users/jarridr/repos/decagon/finals-new'
INIT_DATA_SET_PROPORTION = 0.2

DATA_SET_ID_IDX   = 0
EPOCH_IDX         = 1
LOSS_IDX          = 2
LATENCY_IDX       = 3
EVALUATED_ALL_IDX = 4
EDGE_TYPE_IDX     = 5
AUROC_IDX         = 7
AUPRC_IDX         = 8
APK_IDX           = 9

class TrainingJobResults:
    def __init__(
        self, 
        activeLearningPolicy, 
        edgeTypePredicted, 
        dataSetProportion, 
        aurocVals, 
        auprcVals
    ):
        self.activeLearningPolicy = activeLearningPolicy 
        self.edgeTypePredicted = edgeTypePredicted
        self.dataSetProportion = dataSetProportion
        self.aurocVals = aurocVals
        self.auprcVals = auprcVals
        
    def combine(self, other):
        self.aurocVals.extend(other.aurocVals)
        self.auprcVals.extend(other.auprcVals)
        
class DataSetInformation:
    OkayPolicies = set(['RandomMasking', 'Greedy'])
    OkayEdgeTypes = set(['Neutropenia', 'Hyperglycaemia', 'Anosmia'])
    
    def __init__(
        self,
        activeLearningPolicy, 
        edgeTypePredicted,
        dataSetProportion, 
    ):  
        if activeLearningPolicy not in DataSetInformation.OkayPolicies:
            raise ValueError
            
        if edgeTypePredicted not in DataSetInformation.OkayEdgeTypes:
            raise ValueError
        
        self.activeLearningPolicy = activeLearningPolicy 
        self.edgeTypePredicted = edgeTypePredicted
        self.dataSetProportion = dataSetProportion
        
    def __hash__(self):
        return hash((
            self.activeLearningPolicy,
            self.edgeTypePredicted,
            self.dataSetProportion
        ))

def parseAll():    
    result = {}
    
    for trainingJobResultsDict in map(parseFile, os.listdir(RESULTS_DIR)):
        for trainingJobResults in trainingJobResultsDict.values():
            activeLearnPol = trainingJobResults.activeLearningPolicy
            if activeLearnPol not in result:
                result[activeLearnPol] = {}

            edgeTypePredicted = trainingJobResults.edgeTypePredicted
            if edgeTypePredicted not in result[activeLearnPol]:
                result[activeLearnPol][edgeTypePredicted] = {}

            dataSetProportion = trainingJobResults.dataSetProportion
            if dataSetProportion not in result[activeLearnPol][edgeTypePredicted]:
                result[activeLearnPol][edgeTypePredicted][dataSetProportion] = trainingJobResults
            else:
                result[activeLearnPol][edgeTypePredicted][dataSetProportion].combine(trainingJobResults)

    return result
        

def parseFile(rawFilename: str) -> TrainingJobResults:
    result = {}
    f = open('%s/%s' % (RESULTS_DIR, rawFilename))
    reader = csv.reader(f)

    # Skip the header
    try:
        next(reader)
    except:
        import pdb; pdb.set_trace()

    for iteration in reader:
        # Only use data for the first epoch
        if iteration[EPOCH_IDX] != '1':
            continue
        
        try:
            dataSetInformation = parseDataSetId(iteration[DATA_SET_ID_IDX])
        
            if dataSetInformation not in result:
                result[dataSetInformation] = TrainingJobResults(
                    dataSetInformation.activeLearningPolicy,
                    dataSetInformation.edgeTypePredicted,
                    dataSetInformation.dataSetProportion,
                    aurocVals=[],
                    auprcVals=[]
                )
                
            if float(iteration[AUROC_IDX]) < 0 or float(iteration[AUROC_IDX]) >= 1:
                import pdb; pdb.set_trace()

            if float(iteration[AUPRC_IDX]) < 0 or float(iteration[AUPRC_IDX]) >= 1:
                import pdb; pdb.set_trace()

            result[dataSetInformation].aurocVals.append(float(iteration[AUROC_IDX]))
            result[dataSetInformation].auprcVals.append(float(iteration[AUPRC_IDX]))
            
        except:
            continue
    
    f.close()
    
    return result
        
def parseDataSetId(dataSetId: str) -> DataSetInformation:
    activeLearningPolicyStartIdx = 0
    activeLearningPolicyEndIdx = dataSetId.find('ActiveLearner')
    
    edgeTypeStartIdx = dataSetId.find('DataSet') + len('DataSet')
    edgeTypeEndIdx = dataSetId.find('AdjMtx')
    
    activeLearningPolicyIterNum = int(dataSetId[-1])
    dataSetProportion = (20 + min(80, 2 ** activeLearningPolicyIterNum)) / 100
    
    return DataSetInformation(
        dataSetId[activeLearningPolicyStartIdx:activeLearningPolicyEndIdx],
        dataSetId[edgeTypeStartIdx:edgeTypeEndIdx],
        dataSetProportion
    )
    

In [21]:
# Plot the data!
from typing import Dict, Iterable
import seaborn as sns
import pandas as pd
import numpy as np

TrainDataResults = Dict[str, Dict[str, Dict[float, TrainingJobResults]]]

def getResultsDict(
    policyName: str,
    proportionToRes: Dict[float, TrainingJobResults], 
    plotAuprc
) -> Dict[str, Iterable]:
    def metricExtractor(results: Iterable[TrainingJobResults], plotAuprc: bool):
        preAttrName = 'auprcVals' if plotAuprc else 'aurocVals' 
        metrics = list(map(lambda x: np.mean(getattr(x, preAttrName)), results))
        
        # Slice at -4 to exclude the "Vals" substring from attrName
        return preAttrName[:-4], metrics

    attrName, metrics = metricExtractor(proportionToRes.values(), plotAuprc)
    
    return {
        'LearningPolicy': [policyName for _ in proportionToRes],
        'DataSetProportion': [dataSetProportion for dataSetProportion in proportionToRes.keys()],
        attrName: metrics
    }

def trainResultsAsDF(
    trainResults: TrainDataResults, 
    edgeType: str, 
    plotAuprc: bool
) -> pd.DataFrame:
    randomResultsDict = getResultsDict(
        'RandomMasking', 
        trainResults['RandomMasking'][edgeType], 
        plotAuprc
    )
    
    greedyResultsDict = getResultsDict(
        'Greedy', 
        trainResults['Greedy'][edgeType], 
        plotAuprc
    )
    
    randResultsDf   = pd.DataFrame(randomResultsDict)
    greedyResultsDf = pd.DataFrame(greedyResultsDict)
    
    return pd.concat([randResultsDf, greedyResultsDf])

def plotData(trainResults: TrainDataResults, edgeType: str, plotAuprc: bool) -> None:
    dataFrame = trainResultsAsDF(trainResults, edgeType, plotAuprc)
    print(dataFrame.sort_values(['LearningPolicy', 'DataSetProportion']))
    
    #yKey = 'auprc' if plotAuprc else 'auroc'
    #sns.lineplot(
   #     x='DataSetProportion',
    #    y=yKey,
     #   hue='variable',
      #  data=pd.melt(dataFrame, ['LearningPolicy']),
       # estimator=None,
        #style='choice'
    #)
    

In [17]:
allData: TrainDataResults = parseAll()

In [24]:
plotData(allData, 'Hyperglycaemia', plotAuprc=True)

  LearningPolicy  DataSetProportion     auprc
2         Greedy               0.21  0.524115
3         Greedy               0.22  0.522429
0         Greedy               0.24  0.498671
1         Greedy               0.28  0.502927
4         Greedy               0.36  0.504180
7         Greedy               0.52  0.525956
5         Greedy               0.84  0.489854
6         Greedy               1.00  0.510412
1  RandomMasking               0.21  0.485395
5  RandomMasking               0.22  0.480674
3  RandomMasking               0.24  0.512411
6  RandomMasking               0.28  0.479880
7  RandomMasking               0.36  0.525609
0  RandomMasking               0.52  0.506830
2  RandomMasking               0.84  0.527471
4  RandomMasking               1.00  0.556734


In [23]:
plotData(allData, 'Anosmia', plotAuprc=False)

  LearningPolicy  DataSetProportion      auroc
3         Greedy               0.21   0.487506
2         Greedy               0.22   0.447600
0         Greedy               0.24   0.471239
1         Greedy               0.28   0.497573
4         Greedy               0.36   0.513097
7         Greedy               0.52   0.463504
5         Greedy               0.84   0.452798
6         Greedy               1.00   0.527050
5  RandomMasking               0.21  10.538297
3  RandomMasking               0.22   0.523455
6  RandomMasking               0.24   0.506028
7  RandomMasking               0.28   1.286772
2  RandomMasking               0.36   0.518148
0  RandomMasking               0.52   0.473112
1  RandomMasking               0.84   0.581148
4  RandomMasking               1.00   0.458353


In [25]:
parseAll()

{'RandomMasking': {'Anosmia': {0.52: <__main__.TrainingJobResults at 0x7fb7f10e6dd0>,
   0.84: <__main__.TrainingJobResults at 0x7fb800c28210>,
   0.36: <__main__.TrainingJobResults at 0x7fb830e61610>,
   0.22: <__main__.TrainingJobResults at 0x7fb7f0e05dd0>,
   1.0: <__main__.TrainingJobResults at 0x7fb7d935cbd0>,
   0.21: <__main__.TrainingJobResults at 0x7fb830e32110>,
   0.24: <__main__.TrainingJobResults at 0x7fb7f0b4f210>,
   0.28: <__main__.TrainingJobResults at 0x7fb7d0947910>},
  'Neutropenia': {0.28: <__main__.TrainingJobResults at 0x7fb7d0e8f690>,
   0.36: <__main__.TrainingJobResults at 0x7fb7d1462910>,
   0.24: <__main__.TrainingJobResults at 0x7fb7f1727a10>,
   0.22: <__main__.TrainingJobResults at 0x7fb7d935ca50>,
   0.21: <__main__.TrainingJobResults at 0x7fb7d93b2dd0>},
  'Hyperglycaemia': {0.52: <__main__.TrainingJobResults at 0x7fb7d0dfd790>,
   0.21: <__main__.TrainingJobResults at 0x7fb7d0bb0250>,
   0.84: <__main__.TrainingJobResults at 0x7fb830e37390>,
   0.24: <