# Customized Mechanical Turk API for R21 Survey #
**Author** Andrew Larkin <br>
**Principal Investigator** Perry Hystad <br>
**Date** April 30, 2020 <br>
**Summary** The class is a custom wrapper to abstract interactions with the mechanical turk boto3 API

In [35]:
# library created by Amazon for automated mechanical turk processes
import boto3
import pandas as ps
import itertools
import time

In [1]:
class MTWrapper:
    def __init__(self,accKey,accSecret,surveyFilepath=None,realSession=False):
        self.ACCESS_KEY = accKey
        self.ACCESS_SECRET = accSecret
        self.SURVEY_FILEPATH = surveyFilepath
        self.MTURK_REAL = 'https://mturk-requester.us-east-1.amazonaws.com'
        self.MTURK_SANDBOX = 'https://mturk-requester-sandbox.us-east-1.amazonaws.com'
        self.mtSession = self.setupConnection(realSession)
        self.MAX_TOTAL_PER_STATE = 500
        self.MIN_REQUEST_PER_STATE = 200
        self.stateIndex = 0
        self.setSurveyMeta()
        self.createRegionDict()
        
    def setSurveyMeta(self):
        self.surveyMeta = {
            'Title':'Compare Google Images (2 minutes)',
            'Description':'Compare and rank Google street view pictures',
            'Keywords':'image, picture,environment, health, ranking, compare',
            'Reward':'0.20',
            'Lifetime':172800,
            'AssignDuration':600,
            'ApprovalDelay':1440
        }
        
    def createRegionDict(self):
        regionMap = {}
        for state in ['CT','ME','NH','RI','VT','NJ','NY','PA','MA']:
            regionMap[state] = 'NorthEast'
        for state in ['IN','IL','MI','OH','WI','IA','KS','MN','MI','NE','ND','SD','MO']:
            regionMap[state] = 'MidWest'
        for state in ['DE','DC','FL','GA','MD','NC','SC','VA','WV','AL','KY','MS','TN','AR','LA','OK','TX']:
            regionMap[state] = 'South'
        for state in ['AZ','CO','ID','NM','MT','UT','NV','WY','AK','CA','HI','OR','WA','Other']:
            regionMap[state] = 'West'
        self.regionMap = regionMap
        
    def setupConnection(self,realSession):
        if realSession:
            return self.connectToMTReal()
        return self.connectToMTSandbox()
         
    def setupSurveysByState(self,numRecordsByState):
        self.calcNumRecordsToRequestByState(numRecordsByState)
        self.createQRs()
    
    def calcNumRecordsToRequestByState(self,numRecordsByState):
        sampSize = []
        for count in numRecordsByState['Count']:
            sampSize.append(min(self.MIN_REQUEST_PER_STATE, max(0,(self.MAX_TOTAL_PER_STATE - count))))
        numRecordsByState['sampSize'] = sampSize
        self.recordsToSampleByState = numRecordsByState
        
    def createQRs(self):
        QRs = []
        states = list(self.recordsToSampleByState['State'])
        sampSize = list(self.recordsToSampleByState['sampSize'])
        for index in range(len(states)):
            curState = states[index]
            curSize = sampSize[index]
            qr = [{
                'QualificationTypeId': '00000000000000000071',
                'Comparator': 'In',
                'LocaleValues': [
                {
                    'Country': 'US', 
                    'Subdivision': curState
                }],
            }]
            QRs.append(qr)
        self.stateQRs = QRs
        
    def hitNextState(self):
        if(self.stateIndex >= len(self.recordsToSampleByState['sampSize'])):
            print("""
                state index is at the end of the state list.  
                If you'd like to cycle through the states again, reset the index to 0
            """)
            return
        self.createHit()
        self.stateIndex+=1
        
    def createHit(self):
        if(self.SURVEY_FILEPATH == None):
            print("couldn't create HIT: survey filepath was not defined on class initialization")
            return
        sampSize = int(self.recordsToSampleByState['sampSize'][self.stateIndex])
        if(sampSize < 1):
            print(
                " couldn't create hit: n has already been reach for QR: %s " % (self.stateQRs[self.stateIndex])
                )

        question = open(file=self.SURVEY_FILEPATH,mode='r').read()
        new_hit = self.mturk.create_hit(
            Title = self.surveyMeta['Title'],
            Description = self.surveyMeta['Description'],
            Keywords = self.surveyMeta['Keywords'],
            Reward = self.surveyMeta['Reward'],
            MaxAssignments = sampSize,
            LifetimeInSeconds = self.surveyMeta['Lifetime'],
            AssignmentDurationInSeconds = self.surveyMeta['AssignDuration'],
            AutoApprovalDelayInSeconds = self.surveyMeta['ApprovalDelay'],
            Question = question,
            QualificationRequirements = self.stateQRs[self.stateIndex]
        )
        print("State Index: %i" % (self.stateIndex))
        print("state: " + str(self.recordsToSampleByState['State'][self.stateIndex]))
        print("A new HIT has been created: " + str(new_hit['HIT']['HITId']))  
        
    def createHitForState(stateName):
        tempStateIndex = self.recordsToSampleByState['State'].index(stateName)
        sampSize = int(self.recordsToSampleByState['sampSize'][tempStateIndex])
        if(sampSize < 1):
            print(
                " couldn't create hit: n has already been reach for QR: %s" % (self.stateQRs[self.stateIndex])
            )
        question = open(file=self.SURVEY_FILEPATH,mode='r').read()
        new_hit = self.mturk.create_hit(
            Title = self.surveyMeta['Title'],
            Description = self.surveyMeta['Description'],
            Keywords = self.surveyMeta['Keywords'],
            Reward = self.surveyMeta['Reward'],
            MaxAssignments = sampSize,
            LifetimeInSeconds = self.surveyMeta['Lifetime'],
            AssignmentDurationInSeconds = self.surveyMeta['AssignDuration'],
            AutoApprovalDelayInSeconds = self.surveyMeta['ApprovalDelay'],
            Question = question,
            QualificationRequirements = self.stateQRs[tempStateIndex]
        )
        print("State Index: " + str(self.recordsToSampleByState['State'][tempStateIndex]))
        print("state: " + str(self.recordsToSampleByState['State'][tempStateIndex]))
        print("A new HIT has been created: " + str(new_hit['HIT']['HITId']))
        
    def getAssignmentsForHit(self,HITIdVal):
        mTurkResponse = self.mturk.list_assignments_for_hit(HITId=HITIdVal)
        assignList = mTurkResponse['Assignments']
        while(mTurkResponse['NumResults'] >0):
            nextToken = mTurkResponse['NextToken']
            assignList += mTurkResponse['Assignments']
            mTurkResponse = self.mturk.list_assignments_for_hit(HITId=HITIdVal,NextToken=nextToken)
            time.sleep(0.1)
        return(assignList)
    
    def processSingleAssignment(self,assign):
        assignId = assign['AssignmentId']
        workId = assign['WorkerId']
        acceptTime = assign['AcceptTime']
        submitTime = assign['SubmitTime']
        processTime = (submitTime - acceptTime).seconds
        return([assignId,workId,acceptTime,submitTime,processTime])
    
    def processAssignmentsForHit(self,assignments):
        assignId, workerId, acceptTime,submitTime,workTime = [],[],[],[],[]
        for assignment in assignments:
            tempAssign,tempWork,tempAccept,tempSubmit,tempTime = self.processSingleAssignment(assignment)
            assignId.append(tempAssign)
            workerId.append(tempWork)
            acceptTime.append(tempAccept)
            submitTime.append(tempSubmit)
            workTime.append(tempTime)
        return([assignId,workerId,acceptTime,submitTime,workTime])
        
    def getHITLocale(self,hit):
        if(len(hit['QualificationRequirements']) ==0):
            return "NA"
        if(len(hit['QualificationRequirements'][0]['LocaleValues'])>1):
            return 'Other'
        return(hit['QualificationRequirements'][0]['LocaleValues'][0]['Subdivision'])
    
    def processHIT(self,hit):
        assignments = self.getAssignmentsForHit(hit['HITId'])
        locale=self.getHITLocale(hit)
        region = self.regionMap[locale]
        hidId = hit['HITId']
        assignId,workerId,acceptTime,submitTime,workTime = self.processAssignmentsForHit(assignments)
        locales = list(itertools.repeat(locale,len(assignId)))
        regions = list(itertools.repeat(region,len(assignId)))
        hitId = list(itertools.repeat(hidId,len(assignId)))
        df = ps.DataFrame(list(zip(assignId,workerId,acceptTime,submitTime,workTime,locales,regions,hitId)),
                         columns = ['mt_assign_id','mt_work_id','mt_accept','mt_submit','mt_elapsed',
                                   'mt_locale','mt_region','mt_hit_id'])
        return(df)
    
    def processHitsMeta(self,hits,maxHits):
        nHits = len(hits)
        print("nHits %i" % (nHits))
        dF = self.processHIT(hits[0])
        for index in range(1,min(nHits,maxHits)):
            tempDF = self.processHIT(hits[index])
            if(len(tempDF['mt_locale'])>0):
                dF = dF.append(tempDF)
        self.hitInfo = dF
    
    def setInfoAllHits(self,nHits=99999999):
        nProcessed = 0
        hitMeta = self.mturk.list_hits()
        hits = hitMeta['HITs']
        nextToken = hitMeta['NextToken']
        hitMeta = self.mturk.list_hits(NextToken = nextToken)
        while(hitMeta['NumResults']>0 and len(hits) < nHits):
            hits += hitMeta['HITs']
            nextToken = hitMeta['NextToken']
            hitMeta = self.mturk.list_hits(NextToken = nextToken)
            time.sleep(0.1)
        print("got hit meta")
        self.processHitsMeta(hits,nHits)
        self.hitInfo.drop_duplicates(inplace=True,keep="first")
    
    def getHitInfoAllHits(self):
        return self.hitInfo
    
    def connectToMTReal(self):
        self.mturk = boto3.client(
            'mturk',
            aws_access_key_id = self.ACCESS_KEY,
            aws_secret_access_key = self.ACCESS_SECRET,
            region_name='us-east-1',
            endpoint_url = self.MTURK_REAL
        )
        
    def connectToMTSandbox(self):
        self.mturk = boto3.client(
            'mturk',
            aws_access_key_id = self.ACCESS_KEY,
            aws_secret_access_key = self.ACCESS_SECRET,
            region_name='us-east-1',
            endpoint_url = self.MTURK_SANDBOX
        )
        
    def getBalance(self):
        return self.mturk.get_account_balance()['AvailableBalance']
        
    def setStateIndex(self,newStateIndex):
        self.stateIndex = newStateIndex
        print("sucessfully set state index to: " + str(self.stateIndex))
        
    def getStateIndex(self):
        return self.stateIndex
    
    def getCountsByState(self):
        return self.recordsToSampleByState