In [None]:
import sqlite3
from pathlib import Path
from typing import List, Tuple, Union
import numpy as np
import enum

In [None]:
# path to databases
print("For which experiment would you like to reduce the data?")
experiment = int(input())
if experiment == 1:
    db_path = Path('E:/HumanA/Data/DataBase/HumanA_Exp1.db')
    #db_path = Path('E:/HumanA/Data/HumanA_Exp1_WorkingData.db')
elif experiment == 2:
    db_path = Path('E:/HumanA/Data/DataBase/HumanA_Exp2.db')
    #db_path = Path('E:/HumanA/Data/HumanA_Exp2_WorkingData.db')

# check if path exists
if not db_path or not db_path.exists():
    db_path = ':memory:'

# connect to database
connection=sqlite3.connect(db_path)
cr=connection.cursor()

In [None]:
class ValidityDatapoints(enum.Enum):
    VALID = 1
    ADJUSTED = 2
    INVALID = 3
    IRRELEVANT = 4

class AdditionalInfo(enum.Enum):
    AlgorithmStartPoint = 1
    EdgeCoordinatesReduced = 2
    EdgeToEdge = 3
    WasAmbigous = 4
    FirstDPofNode = 5
    LastDPofNode = 6
    Updated = 7
    SameNode = 8
    NeighbouringNode = 9
    ShortestDistance = 10

In [None]:
def getParticipants():
    """get all participantIds from the database

    Returns:
        tuple: list of participants
    """
    # select all participantIds and return them
    sql_instruction = """
    SELECT DISTINCT participantId FROM trials;
    """
    cr.execute(sql_instruction)
    participants = tuple(did[0] for did in cr.fetchall())
    return participants

def getTrialNrs(participant):
    """get all trialIds for the current participant

    Args:
        participant (int): current participant

    Returns:
        tuple: all trialIds 
    """

    sql_instruction = f"""
    SELECT DISTINCT id 
    FROM trials
    WHERE participantId = {participant};
    """

    cr.execute(sql_instruction)
    trialIdx = tuple(did[0] for did in cr.fetchall())
    return trialIdx

In [None]:
def getMissingElementDatapoints(trial):
    """get all datapoints from the database that do not have a node

    Args:
        trial (int): current trial

    Returns:
        list: all datapoints with missing nodes for this trial 
    """

    sql_instruction = f"""
    SELECT dataPoints_reduced.DatapointId, dataPoints_reduced.timeStampDataPointStart,dataPoints_reduced.graph_element_type, 
    dataPoints_reduced.node, dataPoints_reduced.AdditionalInfo
    FROM dataPoints_reduced
    WHERE dataPoints_reduced.trialId = {trial} AND dataPoints_reduced.graph_element_type IS NULL
    ORDER BY dataPoints_reduced.timeStampDataPointStart ASC
    ;
        """
    cr.execute(sql_instruction)
    data = cr.fetchall()
    return data

In [None]:
def getPreviousAndNextKnownDatapoints(trialId, datapointId):
    """get the previous and next node from previous and next datapoint of the current datapoint (from the same trial)

    Args:
        trialId (int): current trial
        datapointId(int): current datapoint

    Returns:
        previousNode (int): node of the previous datapoint
        nextNode (int): node of the next datapoint
    """
    
    
    sql_instruction = f"""SELECT datapointId, node 
        FROM dataPoints_reduced 
        WHERE datapointId IN (SELECT MAX(datapointId) 
            FROM dataPoints_reduced 
            WHERE (datapointId < {datapointId} AND TrialId = {trialId}))   
        OR datapointId IN (SELECT MIN(datapointId) 
            FROM dataPoints_reduced 
            WHERE (datapointId > {datapointId} AND TrialId = {trialId}))"""
    
    cr.execute(sql_instruction)
    data = cr.fetchall()
    if len(data) > 1:
        previousNode = data[0][1]
        nextNode = data[1][1]
    else:
        previousNode = None
        nextNode = None
    return previousNode,nextNode

In [None]:
def validTrialInDB(trial):
    """check if the validity of the trial is still undecided

    Args:
        trialId (int): current trial

    Returns:
        bool: true if no trial with undecided validity is found
    """
    
    
    sql_instruction = f"""SELECT id 
    FROM trials 
    WHERE id = {trial}   
    AND validFile IS NULL"""
    cr.execute(sql_instruction)
    trialInDB = cr.fetchall()
    if (trial,) in trialInDB:
        return False
    else: 
        return True

In [None]:
def updateDatapointInDB(datapointId,element, additionalInfo):
    """update the current datapoint in the database

    Args:
        datapointId (int): id of the datapoint
        element (int): the node that will be added to the datapoint
        additionalInfo (str): Enum from Class AdditionalInformation

    """
    
    validity = str(ValidityDatapoints(2).name)
    if isinstance(element, int):
        sql_instruction = f""" UPDATE dataPoints_reduced SET graph_element_type = 'Node', node = {element}, validDatapoint = '{validity}', 
                additionalInfo = '{additionalInfo}'
            WHERE datapointId = {datapointId} AND validDatapoint IS NULL"""
    elif isinstance(element,list):
        sql_instruction = f""" UPDATE dataPoints_reduced SET graph_element_type = 'Edge', edge_start = {element[0]},edge_end = {element[1]}, 
                validDatapoint = '{validity}', additionalInfo = '{additionalInfo}'
            WHERE datapointId = {datapointId} AND validDatapoint IS NULL"""
    cr.execute(sql_instruction)



In [None]:
def updateDatapointIrrelevant(datapointId,additionalInfo):
    """update the current datapoint in the database, if the datapoint itself is irrelevant for later analysis

    Args:
        datapointId (int): id of the datapoint
        element (int): the node that will be added to the datapoint
        additionalInfo (str): Enum from Class AdditionalInformation

    """
    
    validity = str(ValidityDatapoints(4).name)
    sql_instruction = f""" UPDATE dataPoints_reduced SET validDatapoint = '{validity}', additionalInfo = '{additionalInfo}'
    WHERE datapointId = {datapointId} AND validDatapoint IS NULL"""
    cr.execute(sql_instruction)

In [None]:
def datapointInReducedDP(dpId):
    """check if the datapoint id is already in the reduced datapoints table in the database

    Args:
        dpId (int): datapointId
    Returns:
        bool: true if the id is already in the reduced datapoints table, else false
    """
    
    sql_instruction = f"""SELECT * FROM dataPoints_reduced WHERE DatapointId = {dpId}"""
    cr.execute(sql_instruction)
    content = cr.fetchall()
    if content != []:
        return True
    else: 
        return False

In [None]:
def getPlaceholderDatapoints(trial, dp_id,datapointIds):
    """get placeholder datapoint for adding nodes between unconnected datapoints (nodes, that are not directly connected)

    Args:
        trial (int): current Trial Id
        dp_id (int): current datapointId
        datapointIds (tuple(int,int)): placerholder Ids
    Returns:
        datapoints (list(tuple)): all placeholder datapoints

    """
    
    isInDBIds = False
    for id in range(datapointIds[0], datapointIds[1]+1):
        if id != dp_id:
            isInDBIds = datapointInReducedDP(id)
        #if not datapointInReducedDP(id):
    if not isInDBIds:
        sql_instruction = f"""SELECT * FROM data_points WHERE (id BETWEEN {datapointIds[0]} AND {datapointIds[1]}) 
            AND trialId = {trial}"""
        cr.execute(sql_instruction)
        datapoints = cr.fetchall()
    else:    
            print("Ids are already in Database")
    return datapoints

In [None]:
def addPlaceholderDatapoint(datapoint,node):
    """add the placeholder datapoint to the database, and the node it is holding the place for

    Args:
        datapoint (list()): placeholder Datapoint
        node (int): current node

    """

    validity = str(ValidityDatapoints(2).name)
    additionalInfo = str(AdditionalInfo(10).name)
    values = str((datapoint[1],datapoint[0],datapoint[2],datapoint[3] ,datapoint[4],datapoint[5],datapoint[6], 
        'Node', node,validity, additionalInfo))
    sql_instruction = f"""INSERT INTO dataPoints_reduced (TrialId, DatapointId, timeStampDataPointStart, 
                            timeStampDataPointEnd, playerBodyPosition_x, playerBodyPosition_y, 
                            playerBodyPosition_z, graph_element_type, node, validDatapoint, additionalInfo)
                            VALUES {values}"""
    cr.execute(sql_instruction)

In [None]:
def getNodesNeighbours(nodes):
    """get all neighbours of a node

    Args:
        node (int): current node
    Returns:
        neighbours (list(tuple)): list of all neighbouring pairs (node,neighbour)
    """    
    
    if isinstance(nodes, int):
        sql_instruction = f"""SELECT * FROM node_neighbours WHERE FirstNode = {nodes} or SecondNode = {nodes}"""
    else:
        sql_instruction = f"""SELECT * FROM node_neighbours WHERE FirstNode IN {nodes} or SecondNode IN {nodes}"""
    cr.execute(sql_instruction)
    all_neighbours = cr.fetchall()
    neighbours = []

    if isinstance(nodes, int):
        neighbours = all_neighbours  
    else:
        for node in nodes:
            for neighbour in all_neighbours:
                if neighbour not in neighbours:
                    neighbours.append(neighbour)
    return neighbours

In [None]:
def findShortestPath(startNode, destinationNode, rec_depth = 0):
    """recursive method to find the shortest path between a startNode and a destinationNode, 
        stops if there is no connection found after 20 steps 
    Args:
        startNode (list(int)): list of starting nodes
        destinationNode (int): the destination node
    Returns:
        path (list(int)): path from the first starting node to the destination node (from node to node)
    """  
    
    rec_depth += 1

    neighbours_startNode = getNodesNeighbours(startNode)
    path = [item for item in neighbours_startNode if destinationNode in item]
    if path != []:
        neighbour = [node for node in path[0] if node != destinationNode]
        path = [neighbour[0],destinationNode]
    else:
        neighbours = [node for neighbours in neighbours_startNode for node in neighbours]
        neighbours = tuple([*set(neighbours)]) 
        if rec_depth <= 20:  
            path =  findShortestPath(neighbours, destinationNode, rec_depth)
            if path != []:
                neighbouringNode = path[0]
                if neighbouringNode != []:
                    previousConnection = [item for item in neighbours_startNode if neighbouringNode in item]
                    previousNode = [node for node in previousConnection[0] if node != neighbouringNode]
                    if previousNode[0] not in path:
                        path.insert(0,previousNode[0])
        else:
            path = []
    return list(path)

In [None]:
participants = getParticipants()
countEdgeToEdge = 0
for participant in participants:
    trials = getTrialNrs(participant)
    for trial in trials:
        if not validTrialInDB(trial):
            print("Participant: " +str(participant) + " Trial: " + str(trial))
            missingElementData = getMissingElementDatapoints(trial)
            for datapoint in missingElementData:
                
                dp_id = datapoint[0]
                previousNode, nextNode = getPreviousAndNextKnownDatapoints(trial,dp_id)
                if previousNode is not None and nextNode is not None:
                    if previousNode == nextNode:
                        additionalInfo = str(AdditionalInfo(8).name)
                        updateDatapointIrrelevant(dp_id,additionalInfo)

                    elif previousNode != nextNode:
                        shortestPath = findShortestPath(previousNode,nextNode)
                        if len(shortestPath) == 2:
                            additionalInfo = str(AdditionalInfo(9).name)
                            updateDatapointIrrelevant(dp_id,additionalInfo)
                            # Start node and Destination are neighbours

                        elif len(shortestPath) > 2:                 
                            print("StartNode: " + str(previousNode) + " EndNode: " + str(nextNode) )
                            print("Shortest Path: " + str(shortestPath))
                            datapointIds = (dp_id, (dp_id + len(shortestPath)-3))
                            placeh_datapoints = getPlaceholderDatapoints(trial,dp_id, datapointIds)
                            for node, datapoint in zip(shortestPath[1:-1], placeh_datapoints):

                                if datapoint[0] == dp_id:
                                    updateDatapointInDB(dp_id,node,str(AdditionalInfo(10).name))
                                else:
                                    addPlaceholderDatapoint(datapoint, node)

                            countEdgeToEdge += 1
                        elif len(shortestPath) == 0:
                            print("StartNode: " + str(previousNode) + " EndNode: " + str(nextNode) )
                            print("Could not find a connection between nodes with less than 20 steps")
        connection.commit()
print("All paths fixed")
print("Total number of multiple EdgeToEdge: " + str(countEdgeToEdge))
