# Register by hand

Describes a sequence of steps how to find someone in the global embedding database, and associate their user_id to the idx in the embedding database

In [None]:
import torch
import numpy as np
from PIL import Image
from utils import send_query, send_large_query

In [None]:
embedding_database = torch.load('ALIGNED_EMBEDDING_DATABASE.pth')
all_paths = embedding_database['paths']
all_embs = embedding_database['embeddings'].cuda()

In [None]:
def addUser(name):
    print('Adding user:', name)
    REQUEST_QUERY = """
        SELECT registration_log_ID, name, email, shibboleth 
        FROM registration_log
        WHERE action="button-press" AND name = "{name}"
    """.format(name=name)

    result = send_query(REQUEST_QUERY)
    if len(result) < 1:
        raise RuntimeError('User by name %s could not be found'%name)
    #TODO: By default everybody will get the reglogID for card_ID since the table doesn't exist so far
    INSERT_QUERY = """
        INSERT INTO user(name, email, shibboleth, card_ID)
        VALUES("{name}", "{email}", "{shibboleth}", {registration_log_ID});
    """.format(**result[-1])
    send_query(INSERT_QUERY)

def getMandacskoIDs(**query_params):
    print('Images assigned from Mandacsko_log')
    SQL_QUERY = '''
        SELECT aligned_ID FROM aligned JOIN (
            SELECT name, timestamp-{minusoffset} as start, timestamp+{plusoffset} as end 
            FROM Mandacsko_log WHERE gate = "Forgóvilla jobb (kintről) BE" AND name = "{name}") 
        ON aligned.timestamp BETWEEN start AND end
        ORDER BY aligned_ID;
    '''
    SQL_QUERY = SQL_QUERY.format(**query_params)

    query_result = send_query(SQL_QUERY, verbose=False)
    aligned_IDs = [int(q['aligned_ID'])-1 for q in query_result]    
    print('Found %4d matches' % len(aligned_IDs))
    return aligned_IDs


def getRegIDs(**query_params):
    print('Images assigned from registration_log')
    SQL_QUERY = '''
        SELECT aligned_ID FROM aligned JOIN (
            SELECT name, timestamp-{minusoffset} as start, timestamp+{plusoffset} as end 
            FROM registration_log WHERE name = "{name}") 
        ON aligned.timestamp BETWEEN start AND end
        ORDER BY aligned_ID;
    '''
    SQL_QUERY = SQL_QUERY.format(**query_params)

    query_result = send_query(SQL_QUERY, verbose=False)
    aligned_IDs = [int(q['aligned_ID'])-1 for q in query_result]        
    print('Found %4d matches' % len(aligned_IDs))
    return aligned_IDs
    
    
def getAlignedIDs(**query_params):
    print('Images assigned from aligned images')
    SQL_QUERY = '''
        SELECT aligned_ID FROM aligned 
        JOIN user ON aligned.user_ID = user.user_ID
        WHERE user.name = "{name}"
        ORDER BY aligned_ID;
    '''
    SQL_QUERY = SQL_QUERY.format(**query_params)

    query_result = send_query(SQL_QUERY, verbose=False)
    aligned_IDs = [int(q['aligned_ID'])-1 for q in query_result]        
    print('Found %4d matches' % len(aligned_IDs))
    return aligned_IDs
    
    
def getKnownIDs(name, offset=1000):
    query_params = {
        'minusoffset': offset,
        'plusoffset': offset,
        'name': name
    }
    
    result = send_query('SELECT user_ID FROM user WHERE name = "%s"'%name)
    if len(result) > 1:
        raise RuntimeWarning(
            'Name "%s" is ambigious, found %d instances'%(name, len(result)))
    if len(result) == 0:
        print('User is not yet in the _user_ table... trying to add:')
        addUser(name)
        result = send_query('SELECT user_ID FROM user WHERE name = "%s"'%name)
    
    user_ID = int(result[-1]['user_ID'])
    
    # use set to avoid collision
    aligned_IDs = set()
    aligned_IDs.update(set(getMandacskoIDs(**query_params)))
    aligned_IDs.update(set(getRegIDs(**query_params)))
    aligned_IDs.update(set(getAlignedIDs(**query_params)))

    
    return user_ID, list(aligned_IDs)

def updateAligned(user_ID, aligned_IDs, verbose=False):
    FULL_QUERY = ''
    for idx in aligned_IDs:
        query_params = {
            'user_id': user_ID,
            'idx': idx + 1
        }

        UPDATE_SQL = """
            UPDATE aligned
            SET user_ID = {user_id}
            WHERE aligned_ID = {idx};
        """
        UPDATE_SQL = UPDATE_SQL.format(**query_params)
        FULL_QUERY += UPDATE_SQL
    
    send_query(FULL_QUERY, verbose=verbose)
    if verbose:
        total_count = send_query(
            'SELECT COUNT(*) FROM aligned WHERE user_ID = %d'%user_ID)[0]['COUNT(*)']
        print('# total images assigned to user_ID %d: %3d'%(user_ID, int(total_count)))
        return total_count

In [None]:
def getKclosest(embedding_query, k=-1, embedding_db=all_embs):
    anchor_embedding = embedding_query
    distance = ((embedding_db-anchor_embedding)**2).mean(-1)
    
    sorted_distance, idxs = torch.sort(distance)
    return embedding_db[idxs[:k]], idxs[:k], sorted_distance[:k]


def getMeanof(idxs):
    #selected_embs = torch.stack([all_embs[idx] for idx in idxs])
    selected_embs = all_embs[idxs]
    mean_emb = torch.mean(selected_embs, dim=0)
    return mean_emb
    

def ProximitySearch(startIDs, proximity):
    newCandidates = list(startIDs)
    distances = [0. for _ in range(len(newCandidates))]
    try:
        for i in startIDs:
            _, candidateIndices, candidateDistances = getKclosest(all_embs[i], proximity)
            for i, d in zip(candidateIndices.data, candidateDistances.data):
                if i in newCandidates:
                    continue
                newCandidates.append(i)
                distances.append(d)
    except KeyboardInterrupt:
        print('Manually terminated')
    return list(newCandidates)

def ProximityMean(start_embedding, proximity, BFS=True, earlyExit=True):
    mean_emb = start_embedding
    Visited = []
    _, candidate_indices, _ = getKclosest(mean_emb, proximity)
    toVisit = [i.data[0] for i in candidate_indices]
    try:
        while len(toVisit) > 0:
            if BFS:
                # pop(0) -> Breadth first search
                i = toVisit.pop(0)
            else:
                # pop(-1) -> Depth first search
                i = toVisit.pop(-1)
            if i in Visited:
                continue
            Visited.append(i)
            new_mean_emb = (mean_emb + all_embs[i]) / float(len(Visited))
            _, candidate_indices, _ = getKclosest(new_mean_emb, proximity)
            new_candidates = [i.data[0] for i in candidate_indices]
            if not BFS:
                new_candidates.reverse()
            if not earlyExit:
                toVisit += new_candidates
            mean_emb = new_mean_emb

            print(len(Visited), i)    
    except KeyboardInterrupt:
        print('Manually terminated')
    return Visited


def BatchMeanSearch(startIDs, batch_size, unique=False, rounds=10):
    """
    Update mean with N new samples
    """
    Visited = list(startIDs)
    
    for r in range(rounds):
        print('Round %d: len %d, unique %d' % (r, len(Visited), len(set(Visited))))
        _, candidate_indices, _ = getKclosest(getMeanof(Visited), batch_size)
        for cid in candidate_indices.data:
            if unique and cid not in Visited:
                Visited.append(cid)
            if not unique:
                Visited.append(cid)
    
    Visited = list(set(Visited))
    _, inner_idxs, _ = getKclosest(
        getMeanof(Visited), k=len(Visited), embedding_db=all_embs[Visited])
    
    sorted_candidate_indices = [Visited[i] for i in inner_idxs.data]
    return sorted_candidate_indices

# AUTOMATION

In [None]:
names = [q['name'] for q in send_query('SELECT name FROM registration_log')]

In [None]:
for i, name in enumerate(names[102:]):
    exists = len(send_query('SELECT * FROM user WHERE name="%s"'%name, verbose=False)) > 0
    
    try:
        user_ID, aligned_IDs = getKnownIDs(name, offset=5000)
    except RuntimeError as e:
        print('!!!!', e)
        continue
    print('  Name: %s \nUser ID: %d, \nnumber of associated images:%d'%\
          (name, user_ID, len(aligned_IDs)))
    if len(aligned_IDs) < 1:
        print('  No associated image could be found for %s'%name)
        continue
    if not exists:
        print('  EXTENDING Image database with BatchMeanSearch')
        extended_IDs = BatchMeanSearch(aligned_IDs, len(aligned_IDs), rounds=100)
        print('  Extended with %d new images'%len(set(extended_IDs)-set(aligned_IDs)))
        print('  Update database')
        updateAligned(user_ID, extended_IDs, verbose=False)
    print('[%3d/%3d] DONE! ---------------------------------\n\n\n\n'%(i, len(names)))