## Notebook for manually labeling during a SSAL experiment

Allows a user to specify the desired database and retrieves that database's active learning table (data needing labels from an oracle). It then loops through all of the active learning selections, looking to see if a label already exists in the master labeling csv, if it does, it uses that label and moves to the next image. If a label is not in the master csv, it flashes the image on the screen and gets input regarding the label. When all images have a label, the newly aquired labels are saved to the master csv, and the database gets updated. Upon finishing, the active learning table should be cleared out and empty.

In [1]:
import psycopg2 as pg
import pandas.io.sql as psql
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
from skimage.transform import resize
from IPython.display import clear_output
import time
import sys

from psycopg2.extensions import register_adapter, AsIs
def addapt_numpy_float64(numpy_float64):
    return AsIs(numpy_float64)
def addapt_numpy_int32(numpy_int32):
    return AsIs(numpy_int32)

def addapt_numpy_int64(numpy_int64):
    return AsIs(numpy_int64)
register_adapter(np.float64, addapt_numpy_float64)
register_adapter(np.int32, addapt_numpy_int32)
register_adapter(np.int64, addapt_numpy_int64)

title_dict = {"0":"0 - Closed Forest", "1":"1 - Woodland", "2":"2 - Shrubland/Thicket", "3":"3 - Dwarf Shrubland", "4":"4 - Herbaceous Veg", "5":"5 - Barren", "6":"6 - Wetland", "7":"7 - Open Water", "8":"8 - Cultivated Land", "9":"9 - Urban", "x":"x","f":"f", "couldnt load":"couldnt load"}

In [2]:
#updates a database's observation with a new label, using the image name as a primary key
def handle_observation_db(cursor, official_label, image_name):
    
    #set official label in database
    sql_update_official_label = """UPDATE observation SET official_label = %s, official_label_source = 'hackathon', partition = 'train', al_rank = NULL WHERE image_name =%s;"""
    cursor.execute(sql_update_official_label, (official_label, image_name))

#updates the master labeled data dataframe with a new official label
def update_local_label_df(df, image_name, official_label):
    
    new_row = pd.DataFrame(columns = ["image_name","citizen_label","official_label"], data = [[image_name,np.nan,official_label]])
    
    return df.append(new_row)
    
#saves the local master labeled dataframe to a csv with a given name    
def save_local_label_df(df, name):
    
    try:
        df.reset_index(drop = True, inplace = True)
        df.to_csv(name)
        
        return True
        
    except PermissionError:
        print("found permission error, you probably have the file open, waiting 10 seconds")
        
        return False
    
    except Exception as e:
        
        print("could not save, found exception", e)
        
        
        return False
        
        
    
        
        

In [1]:
#gets user input, checks it's an input we know how to handle
def userInput():
    
    #get user input (0-0, save, stop)
    user_input = input()
    
    valid_inputs = ['0','1','2','3','4','5','6','7','8','9','10',"back"]
    
    while(user_input not in valid_inputs):
        
        print("invalid input try again")
        
        user_input = input()
   
    return user_input

In [4]:
#get all images in ActiveLearning table

#connect to database
connection = pg.connect("host=blank dbname=blank user=preeti password=blank")

#path to master labeling csv
labeled_data_path = "./data/labeled_data_globe.csv"
labeled_images = pd.read_csv(labeled_data_path, index_col = 0)


sql = """SELECT * FROM observation WHERE al_rank is NOT NULL"""

al_df =  psql.read_sql(sql, connection)


#checks every 30s to see if there are images needing labels in the active learning table
while(al_df.shape[0] == 0):
    
    print("al queue is empty")
    time.sleep(30)
    al_df =  psql.read_sql(sql, connection)
    

#for each image, check if label already exists in labeled_data.csv
for index, row in al_df.iterrows():
    
    cursor = connection.cursor()
    
    image_name = row["image_name"]
   
     
    #if label exists, update database with label, remove from active_learning queue  
    if(image_name in labeled_images.image_name.values):
        
        
        print(image_name,"has a label")
        #get official label from local csv
        official_label = labeled_images[labeled_images["image_name"] == image_name].official_label.values[0]
        
        #set official label in database
        handle_observation_db(cursor, official_label, image_name)
          
        cursor.close()
        
        print("official label updated from local csv")
        
    #if image doesn't have a label yet, ask a human for one    
    else:
        #print(image_name,"does not have label yet, asking for oracle label")
        
        image_dir = "./data/images/all_images/"
        image_path = image_dir + image_name
        
        try:
            img = plt.imread(image_path)
            img = resize(img, (900,900))
        except:
            print("couldnt load",image_path)
            continue


        #plot image
        plt.ion()
        plt.figure(figsize = (12,8))
        plt.axis("off")

        plt.imshow(img)

        #plot key
        plt.plot([],[],label = "Assign new label \n0 - Closed Forest\n1 - Woodland\n2 - Shrubland/Thicket\n3 - Dwarf Shrubland/Thickect\n4 - Herbaceous Veg\n5 - Barren\n6 - Wetland\n7 - Open Water\n8 - Cultivated Land\n9 - Urban\n\nenter - pass", color = "white")
        plt.legend(loc = 1, bbox_to_anchor = (1.85,1.0), fontsize = 14)

        #plot image name
        plt.text(400,875,image_name, color = "white", fontsize = 10)

        #show plot
        plt.show(block = False)    

        #get user input
        user_label = userInput()
    
        #clear output for next image
        clear_output(wait = False)
        
        if(user_label != ""):
            #update local label csv file
            labeled_images = update_local_label_df(labeled_images, image_name, user_label)
            #set official label in database
            handle_observation_db(cursor, user_label, image_name)
        else:
            print("skipped",image_name)

        cursor.close()

print("done with batch, commit changes to database and local csv file?")
ans = str(input())


if(ans == "yes"):
    
    while(save_local_label_df(labeled_images, labeled_data_path) == False):
        
        time.sleep(10)
    
    connection.commit()
    connection.close()
    
    print("saved local file and commited ")        


al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue is empty
al queue i

KeyboardInterrupt: 