In [None]:
import pandas as ps
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sn

In [None]:
class TSAnalyzer:
    def __init__(self,outcomes):
        self.scoreDict = {}
        self.URBAN_MAP = {0:'Seattle',1:'Urban',2:'Suburb'}
        self.VIEWING_ANGLES = ['straight','side']
        self.OUTCOME_LABELS = outcomes
        
    def loadScores(self,scores,label):
        self.scoreDict[label] = scores
        print("completed loading scores for label %s" % (label))
        
    def getScores(self,label):
        return self.scoreDict[label]
    
    def getDescriptiveStatsAllLabels(self):
        descriptiveStatsArray = []
        for label in self.scoreDict.keys():
            curDict = deepcopy(self.scoreDict[label])
            curDict.drop(columns=['urban_cat'])
            descriptiveStatsArray.append(curDict.describe())
        return(descriptiveStatsArray)
    
    def getDescriptiveStatsOneViewingAngle(self,label,viewCat,includeLabel=False):
        catName = label + viewCat if includeLabel else viewCat
        curDict = deepcopy(self.scoreDict[label])
        curDict = curDict[curDict['angle']==viewCat]
        curDict = curDict[['n_' + label,'avg_mu_' + label,'avg_sigma_'+label]]
        curDict.columns = ['n_' + catName, 'avg_mu_' + catName, 'avg_sigma_' + catName]
        return curDict.describe()
    
    def getDescriptiveStatsByViewingAngleOneLabel(self,label,includeLabel=False):
        descrStats = self.getDescriptiveStatsOneViewingAngle(label,self.VIEWING_ANGLES[0],includeLabel)
        tempStats = self.getDescriptiveStatsOneViewingAngle(label,self.VIEWING_ANGLES[1],includeLabel)
        descrStats = ps.merge(descrStats,tempStats,left_index=True,right_index=True)
        return(descrStats)
    
    def getDescrStatsByViewingAngle(self,includeLabel=False):
        viewingCatStats = {}
        for label in self.scoreDict.keys():
            viewingCatStats[label] = self.getDescriptiveStatsByViewingAngleOneLabel(label,includeLabel)
        return(viewingCatStats)
    
    def getDescrStatsOneUrbanCat(self,label,urbanCat,includeLabel=False):
        catName = label + self.URBAN_MAP[urbanCat] if includeLabel else self.URBAN_MAP[urbanCat]
        curDict = deepcopy(self.scoreDict[label])
        curDict = curDict[curDict['urban_cat']==urbanCat]
        curDict = curDict[['n_' + label,'avg_mu_' + label, 'avg_sigma_' + label]]
        curDict.columns = ['n_' + catName, 'avg_mu_' + catName, 'avg_sigma_' + catName ]
        return curDict.describe()
    
    def getDescriptiveStatsByUrbanCatOneLabel(self,label,includeLabel=False):
        urbanCatKeys = list(self.URBAN_MAP.keys())
        descrStats = self.getDescrStatsOneUrbanCat(label,urbanCatKeys[0],includeLabel)
        for urbIndex in range(1,len(urbanCatKeys)):
            tempStats = self.getDescrStatsOneUrbanCat(label,urbanCatKeys[urbIndex],includeLabel)
            descrStats = ps.merge(descrStats,tempStats,left_index=True,right_index=True)
        return(descrStats)
    
    def getCorMatrix(self):
        labels = list(self.scoreDict.keys())
        corDict = None
        for label in labels:
            tempDict = deepcopy(self.scoreDict[label])
            tempDict = tempDict[[
                'n_' + label,
                'avg_mu_' + label,
                'avg_sigma_' + label,
                'abs_mu_' + label,
                'image_id'
            ]]
            if(labels.index(label) ==0):
                corDict = tempDict
            else:
                corDict = ps.merge(corDict, tempDict, how='inner', on=['image_id'])
        return(corDict)
    
    def getCorPlot(self,saveFilename=""): 
        labels = ['greenspace','beauty','relax','safe']
        corDict = None
        for label in labels:
            tempDict = deepcopy(self.scoreDict[label])
            tempDict = tempDict[[
                'avg_mu_' + label,
                'abs_mu_' + label,
                'avg_sigma_' + label,
                'n_' + label,
                'image_id'
            ]]
            #tempDict.columns = [label,'image_id']
            if(labels.index(label) ==0):
                corDict = tempDict
            else:
                corDict = ps.merge(corDict, tempDict, how='inner', on=['image_id'])
        orderedDict = ps.DataFrame()
        for label in labels:
            orderedDict['n ' + label] = corDict['n_' + label] 
        for label in labels:
            orderedDict['$\sigma$ ' + label] = corDict['avg_sigma_' + label]
        for label in labels:
            orderedDict['$\mu$ ' + label] = corDict['avg_mu_' + label]
        for label in labels:
            orderedDict['intensity ' + label] = corDict['abs_mu_' + label]
        
        corMatrix = orderedDict.corr()
        # Generate the annotation
        annot = np.asarray(corMatrix)
        annot = np.round(annot,2)
        annot[abs(annot) <0.1] = 0.
        annot = annot.astype('str')
        annot[annot=='0.0']=''

        fig, ax = plt.subplots(figsize=(15,10)) 
        sn.heatmap(corMatrix,cmap="PRGn",annot=annot, fmt='')
        if(len(saveFilename)>0):
            plt.savefig(saveFilename)
        plt.show()
        return(corDict)
    
    def createPlot(self,axs,label,curDict,fig):
        curPlot = axs.scatter(
            curDict['avg_mu_' + label],
            curDict['avg_sigma_' + label],
            s=curDict['n_' +label]*2,
            c=curDict['abs_mu_' + label],
            alpha=0.5,
            label=label
        )
        curPlot.set_clim(0,15)
        axs.title.set_text(label)
        axs.title.set_fontsize(16)
        divider = make_axes_locatable(axs)
        cax = divider.append_axes('right',size='5%',pad=0.05)
        axs.set_xlim(9,40)
        axs.set_ylim(1.5,6.0)
        axs.set_xlabel(r'$\mu$',fontsize=14)
        axs.set_ylabel(r'$\sigma$',fontsize=14)
        axs.tick_params(axis='both',which='major',labelsize=12)
        curPlot.set_cmap("winter")
        fig.colorbar(curPlot,cax=cax,orientation='vertical',label='$\mu$ intensity')
        l1 = plt.scatter([],[], s=10, edgecolors='none',color='gray')
        l2 = plt.scatter([],[], s=20, edgecolors='none',color='gray')
        l3 = plt.scatter([],[], s=40, edgecolors='none',color='gray')

        labels = ["5", "10", "20"]
        leg = plt.legend([l1, l2, l3], labels, ncol=3, frameon=True, fontsize=10,
        handlelength=2, bbox_to_anchor=(-11.,1), borderpad = 1,
        handletextpad=1, title='n votes', scatterpoints = 1)
        
    def plotSumstats(self,saveFile=""):
        fig, axs = plt.subplots(2,2,figsize=(18,14),dpi=100)
        colIndex,rowIndex = 0,0
        for curIndex in range(len(self.OUTCOME_LABELS)):
            curLabel = self.OUTCOME_LABELS[curIndex]
            curDict = self.scoreDict[curLabel]
            #curDict['abs_mu_' + curLabel] = (curDict['avg_mu_' + curLabel]-25).abs()
            if(curIndex%2==0 and curIndex>0):
                rowIndex+=1
            colIndex = curIndex%2
            curAxs = axs[rowIndex,colIndex]
            self.createPlot(curAxs,curLabel,curDict,fig)
        if(len(saveFile)>0):
            plt.savefig(saveFile)#format='eps')
        plt.show()
             
    def getTSOneImageOneLabel(self,imageId,label):
        curDict = self.scoreDict[label]
        return curDict[curDict['image_id']==imageId]
    
    # whether to include label in output descriptive stats dictionary
    def getDescriptiveStatsByUrbanCat(self,columnVals,includeLabel=False):
        urbanCatStats = {}
        for label in self.scoreDict.keys():
            urbanCatStats[label] = self.getDescriptiveStatsByUrbanCatOneLabel(label,includeLabel)
        return urbanCatStats
    
    def writeCSV(self,csvFilepath):
        labels = list(self.scoreDict.keys())
        labelVals = self.scoreDict[labels[0]]
        for index in range(1,len(labels)):
            tempDict = deepcopy(self.scoreDict[labels[index]])
            tempDict = tempDict.drop(columns=['lat','lon','angle','urban_cat'])
            labelVals = ps.merge(labelVals,tempDict,on='image_id',how='inner')
        labelVals.to_csv(csvFilepath,index=False)
        print("finished writing true skill scores to csv file at filepath %s" % (csvFilepath))
    
    def calcAvgTSOneLabel(self,label):
        curDict = self.scoreDict[label]
        onesDivider = np.ones((len(curDict['strong_sigma_' + label]),1))
        strongWeight = 1/np.asarray(curDict['strong_sigma_' + label])
        modWeight = 1/np.asarray(curDict['mod_sigma_' + label])
        slightWeight = 1/np.asarray(curDict['slight_sigma_' + label])
        sumWeight = strongWeight + modWeight + slightWeight
        strongMult = np.multiply(strongWeight,np.asarray(curDict['strong_mu_' + label]))
        modMult = np.multiply(modWeight,np.asarray(curDict['mod_mu_' + label]))
        slightMult = np.multiply(slightWeight,np.asarray(curDict['slight_mu_' + label]))
        sumMult = strongMult + modMult + slightMult
        weightedScores = np.divide(sumMult,sumWeight)
        avgSigma = (curDict['strong_sigma_' + label] + curDict['mod_sigma_' + label] + curDict['slight_sigma_' + label])/3
        curDict['avg_mu_' + label] = weightedScores
        curDict['avg_sigma_' + label] = avgSigma
        curDict['abs_mu_' + label] = abs(weightedScores-25)
        self.scoreDict[label] = curDict
        
    def calcAvgTSAllLabels(self):
        for label in self.scoreDict.keys():
            self.calcAvgTSOneLabel(label)
        print("finished calculating avg ts scores for all labels")
    