In [1]:
import math
import random
from trueskill import Rating, quality_1vs1, rate_1vs1
import pandas as ps
import psycopg2
import sqlalchemy
import os
import globalConstants as gConst

In [4]:
class Outcomes:
    def __init__(self,slight,moderate,strong):
        self.slight = slight
        self.mod = moderate
        self.strong = strong;

In [5]:
class MultiCatTS:
    
    def __init__(self,label,imgId,connection,reset):
        self.connection = connection
        self.loadSQLData(label,imgId,reset)
        
    def setValues(self,sqlData,label,tableName,imgId):
        self.strongTS = Rating(mu=sqlData[0],sigma=sqlData[1])
        self.modTS = Rating(mu=sqlData[2],sigma=sqlData[3])
        self.slightTS = Rating(mu=sqlData[4],sigma=sqlData[5])
        self.n = sqlData[6]
        self.avgTSMean = self.calcAvgTSMean()
        self.avgTSSigma = self.calcAvgTSSigma()
        self.label = label
        self.tableName = tableName
        self.image_id = imgId
        
    def setValuesRest(self,sqlData,label,tableName,imgId):
        # heres the fix
        self.strongTS = Rating()
        self.modTS = Rating()
        self.slightTS = Rating()
        self.n = 0 #sqlData[6]
        self.avgTSMean = self.calcAvgTSMean()
        self.avgTSSigma = self.calcAvgTSSigma()
        self.label = label
        self.tableName = tableName
        self.image_id = imgId

        
    def initializeMTTSTableRecord(self,tablename,imgId):
        nameString = 'INSERT INTO ' + tablename + ' (image_id,strong_mu,strong_sigma,mod_mu,mod_sigma,slight_mu,slight_sigma,n_sampled) '
        valueString = 'VALUES (%s,%s,%s,%s,%s,%s,%s,%s)'
        insertString = nameString + valueString
        try:
            cur = self.connection.cursor()
            cur.execute(
                insertString, 
                (
                    imgId,
                    '25.0',
                    '8.333',
                    '25.0',
                    '8.333',
                    '25.0',
                    '8.333',
                    '0'
                )
            )
            self.connection.commit()
        except Exception as e:
            print(str(e))
        
    def loadSQLData(self,label,image_id,reset):
        try:
            cur = self.connection.cursor()
            tableName = "mt_tskill_" + str(label)
            prefix = "SELECT strong_mu, strong_sigma, mod_mu, mod_sigma, slight_mu, slight_sigma,n_sampled "
            midfix = " FROM " + str(tableName)
            postfix = " where image_id = '" + image_id + "'"
            cur.execute(prefix + midfix + postfix)
            rows = cur.fetchall()
            if(len(rows)==0):
                self.initializeMTTSTableRecord(tableName,image_id)
                cur.execute(prefix + midfix + postfix)
                rows = cur.fetchall()
            if(reset):
                self.setValuesRest(rows[0],label,tableName,image_id)
            else:
                self.setValues(rows[0],label,tableName,image_id)
        except Exception as e:
            print(str(e))
        finally:
            cur.close()
        return(rows)
        
    def calcAvgTSMean(self):
        return((self.slightTS.mu + self.modTS.mu + self.strongTS.mu)/3.0)
            
    def calcAvgTSSigma(self):
        return((self.slightTS.sigma + self.modTS.sigma + self.strongTS.sigma)/3.0)
            
    def calcMultiCatGame(self,competitorTS):
        slightTieProb = quality_1vs1(self.slightTS, competitorTS.slightTS)
        modTieProb = quality_1vs1(self.modTS,competitorTS.modTS)
        strongTieProb = quality_1vs1(self.strongTS,competitorTS.strongTS)
        return(round((slightTieProb + modTieProb + strongTieProb)/0.3,1))
            
    def updateOutcomes(self,outcome,TSP1,TSP2):
        if(outcome == "lose"):
            newp2,newp1 = rate_1vs1(TSP2,TSP1)
            return(newp1,newp2)
        if(outcome =="win"):
            return(rate_1vs1(TSP1, TSP2))
        return(rate_1vs1(TSP1, TSP2, drawn=True))
    
    def updateSQLData(self):
        preString = "UPDATE " + str(self.tableName) + " SET "
        varString = "strong_mu=%s,strong_sigma=%s,mod_mu=%s,mod_sigma=%s,slight_mu=%s,slight_sigma=%s,n_sampled=%s "
        whereString = "WHERE image_id = %s"
        insertString = preString + varString + whereString
        varArray = [self.strongTS.mu,self.strongTS.sigma,self.modTS.mu,self.modTS.sigma,self.slightTS.mu,self.slightTS.sigma,self.n,self.image_id]
        try:
            cur = self.connection.cursor()
            cur.execute(insertString, (varArray))
            self.connection.commit()
            cur.close()
            return(True)
        except Exception as e:
            print(str(e))
            cur.close()
            return(False)
    
    def updateScores(self,outcomes,competitor,updateSQL=False):
        #print("updating scores")
        self.slightTS,competitor.slightTS = self.updateOutcomes(outcomes.slight,self.slightTS,competitor.slightTS)
        self.modTS,competitor.modTS = self.updateOutcomes(outcomes.mod,self.modTS,competitor.modTS)
        self.strongTS,competitor.strongTS = self.updateOutcomes(outcomes.strong,self.strongTS,competitor.strongTS)
        self.avgTSSigma = self.calcAvgTSSigma()
        competitor.avgTSSigma = competitor.calcAvgTSSigma()
        self.n +=1
        competitor.n +=1
        if(updateSQL): 
            self.updateSQLData()
            competitor.updateSQLData()

In [6]:
class SVImage:
    
    def __init__(self,imgId,labels,http,connection,reset):
        self.connection = connection
        self.imgId = imgId
        self.http = http
        self.createTSDict(labels,imgId,reset)
        self.selectWeight = {}
        self.initializeSelectWeights(labels)
    
    def createTSDict(self,labels,imgId,reset):
        self.tsDict = {}
        for labelIndex in range(len(labels)):
            self.tsDict[labels[labelIndex]] = MultiCatTS(labels[labelIndex],imgId,self.connection,reset)
            
    def getTS(self,label):
        return(self.tsDict[label])
    
    def initializeSelectWeights(self,labels):
        for label in labels:
            self.selectWeight[label] = math.pow(self.tsDict[label].avgTSSigma,1.5)
            
    def updateSelectWeights(self,label):
        self.selectWeight[label] = math.pow(self.tsDict[label].avgTSSigma,1.5)

In [7]:
class SVImageSet:
    
    def __init__(self,inputDF,labels,connection,reset):
        self.connection = connection
        self.numImgs = len(inputDF['imgId'])
        self.imgArray = self.createImageArray(inputDF,labels,reset)
        self.startIndeces = self.createWeightedIndecesAllLabels(labels)
        self.labels = labels
        self.imgIds = self.getImgIds()
        
    def getImgIds(self):
        imgIds = []
        for img in self.imgArray:
            imgIds.append(img.imgId)
        return(imgIds)
    
    def getImgFromId(self,imgId):
        return(self.imgArray[self.imgIds.index(imgId)])
    
    def getImgIndexFromId(self,imgId):
        return(self.imgIds.index(imgId))
        
    def createImgRecord(self,currRecord,labels,reset):
        imgId = currRecord['imgId']
        angleCat = currRecord['angleCat']
        httpAddr = currRecord['http']
        return(SVImage(imgId,labels,httpAddr,self.connection,reset))
    
    def createImageArray(self,inputDF,labels,reset):
        imgArray = []
        for recordIndex in range(self.numImgs):
            currRecord = inputDF.iloc[recordIndex]
            imgArray.append(self.createImgRecord(currRecord,labels,reset))
        return(imgArray)
    
    def createdWeightedIndecesOneLabel(self,label):
        samplingIndeces = [0]
        for imgIndex in range(self.numImgs):
            weightToAdd = int(round(self.imgArray[imgIndex].selectWeight[label],1)*10)
            samplingIndeces.append(samplingIndeces[len(samplingIndeces)-1] + weightToAdd)
        return(samplingIndeces)
    
    def createWeightedIndecesAllLabels(self,labels):
        samplingDict = {}
        for label in labels:
            samplingDict[label] = self.createdWeightedIndecesOneLabel(label)
        return(samplingDict)
    
    def updateWeightedIndecesOneLabel(self,label):
        self.startIndeces[label] = self.createdWeightedIndecesOneLabel(label)
    
    def isWithinRange(self,currIndex,selectNumber,indexArray):
        if(indexArray[currIndex] <= selectNumber):
            if(indexArray[currIndex+1] >selectNumber):
                return True
            if(currIndex+2 == len(indexArray) and indexArray[currIndex+1] == selectNumber):
                return True
        return False
        
    def findImgInIndex(self,randSelectWeight,indexArray):
        currIndex = int(len(indexArray)/2)
        while not(self.isWithinRange(currIndex,randSelectWeight,indexArray)):
            if(indexArray[currIndex] < randSelectWeight):
                currIndex +=1
            else:
                currIndex-=1
        return(currIndex)
    
    def sampleImageOnce(self,label):
        randSelectWeight = random.randint(0,self.startIndeces[label][len(self.startIndeces[label])-1])
        imgIndex = self.findImgInIndex(randSelectWeight,self.startIndeces[label])
        return([self.imgArray[imgIndex],imgIndex])
    
    def updateImage(self,updatedImage,imgIndex,label):
        self.imgArray[imgIndex] = updatedImage
        self.updatedWeightedIndeces(imgIndex,label)
    
    def updatedWeightedIndeces(self,updatedImgIndex,label):
        numImgs = len(self.imgArray)
        for imgIndex in range(updatedImgIndex,numImgs):
            weightToAdd = int(round(self.imgArray[imgIndex].selectWeight[label],1)*10)
            self.startIndeces[label][imgIndex+1] = self.startIndeces[label][imgIndex] + weightToAdd
            
    def randomlySampleImage(self,previouslySampledIndeces,label):
        validSample = False
        sampledImg = -1
        sampledIndex = -1
        while not (validSample):
            sampledImg,sampledIndex = self.sampleImageOnce(label)
            if(sampledIndex not in previouslySampledIndeces):
                validSample = True        
        return([sampledImg,sampledIndex])
    
    def createComparisonWeights(self,sampledImgIndex,label):
        refTS = self.imgArray[sampledImgIndex].tsDict[label]
        comparisonWeights = [0]
        for imageIndex in range(self.numImgs):
            candidateImg = self.imgArray[imageIndex]
            weightToAdd = int(round(candidateImg.tsDict[label].calcMultiCatGame(refTS),1)*10)
            comparisonWeights.append(comparisonWeights[imageIndex] + weightToAdd)
        return(comparisonWeights)
    
    def getComparisonImage(self,sampledImgIndex,previouslySampledIndeces,label):
        imgIndex = -1
        comparisonWeights =  self.createComparisonWeights(sampledImgIndex,label)
        while(imgIndex <0 or imgIndex in ([sampledImgIndex] + previouslySampledIndeces)):
            randSelectWeight = random.randint(0,comparisonWeights[len(comparisonWeights)-1])
            imgIndex = self.findImgInIndex(randSelectWeight,comparisonWeights)
        compareImg = self.imgArray[imgIndex]
        return([compareImg,imgIndex])
    
    def randomlySelectImgComparisonPair(self,previouslySelectedIndeces,label):
        sampledImage, sampledIndex = self.randomlySampleImage(previouslySelectedIndeces,label)
        compareImg, compareIndex = self.getComparisonImage(sampledIndex,previouslySelectedIndeces,label)
        previouslySelectedIndeces += [sampledIndex,compareIndex]
        return([sampledIndex,compareIndex,previouslySelectedIndeces]) 
    
    def randomlySelectImgComparisonSet(self,sampleSize,label):
        previouslySelectedIndeces = []
        origImgs = []
        compareImgs = []
        for index in range(sampleSize):
            sampledIndex,compIndex,prevSelectedIndeces = self.randomlySelectImgComparisonPair(previouslySelectedIndeces,label)
            origImgs.append(sampledIndex)
            compareImgs.append(compIndex)
        leftSet,rightSet = self.shuffleOrigCompareImgs(origImgs,compareImgs)
        return([leftSet,rightSet])
    
    def shuffleOrigCompareImgs(self,origImgs,compareImgs):
        leftSet, rightSet = [],[]
        newOrder = list(range(len(origImgs)))
        random.shuffle(newOrder)
        for index in range(len(origImgs)):
            isLeft = random.randint(0,1)
            if(isLeft==0):
                leftSet.append(origImgs[newOrder[index]])
                rightSet.append(compareImgs[newOrder[index]])
            else:
                leftSet.append(compareImgs[newOrder[index]])
                rightSet.append(origImgs[newOrder[index]])
        return([leftSet,rightSet])


In [8]:
class CarImageSet:
    
    def __init__(self,connection):
        df = {}
        self.connection = connection
        self.carDF = self.getCarImageData()
        
        
    def convertSQLToDF(self,sqlRows):
        car_cat, image_id, view_cat,http = [],[],[],[]
        numObs = len(sqlRows)
        indexVals = list(range(len(sqlRows)))
        for currIndex in range(numObs):
            car_cat.append(sqlRows[currIndex][0])
            image_id.append(sqlRows[currIndex][1])
            view_cat.append(sqlRows[currIndex][2])
            http.append(sqlRows[currIndex][3])
        psDF = ps.DataFrame(indexVals)
        psDF['imgId'] = image_id
        psDF['car_cat'] = car_cat
        psDF['view_cat'] = view_cat
        psDF['http'] = http
        return(psDF)
    
    def getCarCatByImgId(self,imgId):
        carImgRecord = self.carDF[self.carDF['imgId']==imgId]
        return(carImgRecord.iloc[0]['car_cat'])
        
    def getCarImageData(self):
        queryString = 'SELECT car_cat,a.image_id,view_cat,http FROM image_meta as a join car_images as b on a.image_id = b.image_id'
        try:
            cur = self.connection.cursor()
            cur.execute(queryString)
            rows = cur.fetchall()
            return(self.convertSQLToDF(rows))
        except Exception as e:
            print(str(e))
            cur.close()
        return(None)
    
    def randomlySampleCarSet(self,view_cat):
        dataDict = {}
        matchingAngles = self.carDF[self.carDF['view_cat']==view_cat]
        withCars = matchingAngles[matchingAngles['car_cat']==True]
        withCarsImg = withCars.sample(frac=1).iloc[0]
        noCars = matchingAngles[matchingAngles['car_cat']==False]
        noCarsImg = noCars.sample(frac=1).iloc[0]
        # randomly choose which image goes on the left and right
        if(random.randint(0,1)==0):
            dataDict = {
                'label':'cars',
                'idLeft':withCarsImg['imgId'],
                'idRight':noCarsImg['imgId'],
                'httpLeft':withCarsImg['http'],
                'httpRight':noCarsImg['http']
            }
        else:
            dataDict = {
                'label':'cars',
                'idLeft':noCarsImg['imgId'],
                'idRight':withCarsImg['imgId'],
                'httpLeft':noCarsImg['http'],
                'httpRight':withCarsImg['http']
            }
        return(dataDict)

In [2]:
class MTGame:
    
    def __init__(self,db,reset):
        self.setupConnection(db,gConst.PGUSER,gConst.PGPWORD)
        if(reset):
            self.resetGame(gConst.OUTCOME_LABELS)
        self.loadSQLDatasets(gConst.OUTCOME_LABELS,reset)
        
    def setupConnection(self,db,user,pw,port=5432):
        self.connection = psycopg2.connect(
            user=user,
            password=pw,
            host="localhost",
            port=port,database=db
        ) 
        
    def loadSQLDatasets(self,labels,reset):
        straightDF = self.loadSQLData("straight")
        sideDF = self.loadSQLData("side")
        self.straightImgSet = SVImageSet(straightDF,labels,self.connection,reset)
        self.sideImgSet = SVImageSet(sideDF,labels,self.connection,reset)
        self.carImgSet = CarImageSet(self.connection)
        self.labels = labels
        
    def resetMTScores(self,tableName):
        preString = "UPDATE " + tableName + " SET "
        varString = "strong_mu=%s,strong_sigma=%s,mod_mu=%s,mod_sigma=%s,slight_mu=%s,slight_sigma=%s,n_sampled=%s "
        whereString = "WHERE n_sampled >=0"
        insertString = preString + varString + whereString
        varArray = [25,8.33,25,8.33,25,8.33,0]
        try:
            cur = self.connection.cursor()
            cur.execute(insertString, (varArray))
            self.connection.commit()
            cur.close()
            return(True)
        except Exception as e:
            print(str(e))
            cur.close()
            return(False)
            
        
    def resetGame(self,labels):
        for label in labels:
            self.resetMTScores("MT_tskill_" + label)

            
    def convertSQLRowsToPandas(self,sqlRows):
        imgIds, http, angleCat = [],[],[]
        indexVals = list(range(len(sqlRows)))
        for row in sqlRows:
            imgIds.append(row[0])
            http.append(row[1])
            angleCat.append(row[2])
        psDF = ps.DataFrame(indexVals)
        psDF['imgId'] = imgIds
        psDF['http'] = http
        psDF['angleCat'] = angleCat
        return(psDF)

    def loadSQLData(self,direction):
        df = self.loadSQLMeta(direction)
        return(df)
    
    def loadSQLMeta(self,direction):
        rows = []
        try:
            cur = self.connection.cursor()
            cur.execute("SELECT image_id, http, view_cat FROM image_meta where view_cat='" + str(direction) + "'")
            rows = cur.fetchall()
        except Exception as e:
            print(str(e))
        finally:
            cur.close()
        pandasDF = self.convertSQLRowsToPandas(rows)
        return(pandasDF)
        
    def randomlySampleOneLabel(self,sampleSize,label):
        leftSetSide,rightSetSide = self.sideImgSet.randomlySelectImgComparisonSet(sampleSize,label)
        leftSetStraight,rightSetStraight = self.straightImgSet.randomlySelectImgComparisonSet(sampleSize,label)
        return([leftSetSide,rightSetSide,leftSetStraight,rightSetStraight])
    
    def insertCarCompare(self,randomList,imgIdLeft,imgIdRight,httpLeft,httpRight,view_cat):
        # get image id, car_cat where the viewcat matches the rest of the survey.  
        # Need to join car_images and image_meta
        carSample = self.carImgSet.randomlySampleCarSet(view_cat)
        randomList.append('cars')
        imgIdLeft.append(carSample['idLeft'])
        imgIdRight.append(carSample['idRight'])
        httpLeft.append(carSample['httpLeft'])
        httpRight.append(carSample['httpRight'])
        return([randomList,imgIdLeft,imgIdRight,httpLeft,httpRight])
    
    
    # reformat to fit MT survey, and shuffle label order
    def createSingleMTSet(self,imageSet,leftSide,rightSide,sampleIndex,view_cat):
        tempRandomOrder = random.sample(self.labels,len(self.labels))
        imgIdLeft, imgIdRight, httpLeft, httpRight,randomList = [],[],[],[],[]
        carIndex = random.randint(0,len(tempRandomOrder))
        randomList = []
        for currIndex in range(len(tempRandomOrder)+1):
            if(currIndex == carIndex):
                randomList,imgIdLeft,imgIdRight,httpLeft,httpRight = self.insertCarCompare(
                    randomList,imgIdLeft,imgIdRight,httpLeft,httpRight,view_cat
                )
            else:
                if(currIndex > carIndex):
                    currLabel = tempRandomOrder[currIndex-1]
                else:
                    currLabel = tempRandomOrder[currIndex]
                labelIndex = self.labels.index(currLabel)
                randomList.append(currLabel)
                imgLeft = imageSet.imgArray[leftSide[labelIndex][sampleIndex]]
                imgRight = imageSet.imgArray[rightSide[labelIndex][sampleIndex]]
                imgIdLeft.append(imgLeft.imgId)
                httpLeft.append(imgLeft.http)
                imgIdRight.append(imgRight.imgId)
                httpRight.append(imgRight.http)
        mtSet = {}
        mtSet['labels'] = randomList
        mtSet['leftIds'] = imgIdLeft
        mtSet['rightIds'] = imgIdRight
        mtSet['httpLeft'] = httpLeft 
        mtSet['httpRight'] = httpRight
        return(mtSet)
                    
    def randomlySampleAllLabels(self,sampleSize,sampleCatInt):
        leftSetSide, rightSetSide,leftSetStraight,rightSetStraight,mtSets,mtArray = [],[],[],[],[],[]
        for label in self.labels:
            tempLeftSide,tempRightSide,tempLeftStraight,tempRightStraight = self.randomlySampleOneLabel(1,label)
            leftSetSide.append(tempLeftSide)
            rightSetSide.append(tempRightSide)
            leftSetStraight.append(tempLeftStraight)
            rightSetStraight.append(tempRightStraight)
        for imgIndex in range(sampleSize):
            if(sampleCatInt==1):
                mtArray.append(self.createSingleMTSet(self.straightImgSet,leftSetStraight,rightSetStraight,imgIndex,'straight'))
            else:
                mtArray.append(self.createSingleMTSet(self.sideImgSet,leftSetSide,rightSetSide,imgIndex,'side'))
        return(mtArray)
    
    def updateGame(self,label,leftImageId,rightImageId,gameResults,angle):
        imgSet = self.straightImgSet
        if(angle=='side'):
            imgSet = self.sideImgSet
        leftImg = imgSet.getImgFromId(leftImageId)
        rightImg = imgSet.getImgFromId(rightImageId)
        leftImgTS = leftImg.getTS(label)
        rightImgTS = rightImg.getTS(label)
        leftImgTS.updateScores(gameResults,rightImgTS,updateSQL=True)
        leftImg.updateSelectWeights(label)
        rightImg.updateSelectWeights(label)
        imgSet.updateWeightedIndecesOneLabel(label)
        #rightImgTS.updateScores(gameResults,rightImgTS,updateSQL=True)
        
    def insertMTRecord(self,assign_id,jsonRecord):
        nameString = 'INSERT INTO mt_votes (assign_id'
        valString = 'VALUES (%s'
        valArray = [assign_id]
        for label in jsonRecord['labels']:
            outcomeLabel = self.translateLabels([label],toSurvey=False)[0][0]
            currLabelIndex = jsonRecord['labels'].index(label)
            nameString = nameString + ',l_img_' + outcomeLabel + ',r_img_' + outcomeLabel + ',' + outcomeLabel + '_vote'
            valString += ',%s,%s,%s'
            valArray += [
                jsonRecord['idsLeft'][currLabelIndex],
                jsonRecord['idsRight'][currLabelIndex],
                jsonRecord['votes'][currLabelIndex]
            ]
        nameString += ') '
        valString += ')'
        insertString = nameString + valString
        if(self.connection.closed>0):
            try:
                self.connection = psycopg2.connect(self.DATABASE_URL,sslmode="require")
            except Exception as e:
                "couldn't connect to database"
                return
        try:
            cur = self.connection.cursor()
            cur.execute(
                insertString, 
                (
                    valArray
                )
            )
            self.connection.commit()
            cur.close()
            return(True)
        except Exception as e:
            print(str(e))
            cur.close()
            self.connection.rollback()
            return(False)
        
    def insertMTToImageRecord(self,assign_id,imageid):
        insertString = 'INSERT INTO mt_to_image (assign_id,image_id) VALUES (%s,%s)'
        if(self.connection.closed>0):
            try:
                self.connection = psycopg2.connect(DATABASE_URL,sslmode="require")
            except Exception as e:
                "couldn't connect to database"
                return
        try:
            cur = self.connection.cursor()
            cur.execute(
                insertString, 
                (
                    assign_id,
                    imageid
                )
            )
            self.connection.commit()
            print("sucessfully commited insertion")
            print(str(imageid))
            cur.close()
        except Exception as e:
            print(str(e))
            self.connection.rollback()
            cur.close()
        
    def translateLabels(self,inputLabels,toSurvey=True):
        LABEL_TRANSLATION = ['a higher quality of nature','more relaxing','safer','more beautiful']
        LABEL_GRAMMAR = ['has','is','is','is']
        QA_GRAMMAR = ['has']
        QA_LABEL = ['cars']
        QA_TRANSLATION = ['more cars']
        tempLabels = self.labels + QA_LABEL
        tempTranslation = LABEL_TRANSLATION + QA_TRANSLATION
        tempGrammar = LABEL_GRAMMAR + QA_GRAMMAR
        translatedLabels = []
        grammar = []
        for label in inputLabels:
            if(toSurvey):
                translatedLabels.append(tempTranslation[tempLabels.index(label)])
                grammar.append(tempGrammar[tempLabels.index(label)])
            else:
                translatedLabels.append(tempLabels[tempTranslation.index(label)])        
        return([translatedLabels,grammar])
    