In [107]:
import os
import torch
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import numpy as np
import random
from pathlib import Path
from scipy import interpolate
from sklearn.preprocessing import normalize

# make x and y be a function of t
# t is the total distance traveled


def get_strokes(path):
    root = ET.parse(path).getroot()
    strokes = root[1]

    #print("Strokes", len(strokes))

    stroke_list = []
    min_time = float(strokes[0][0].attrib["time"])
    last_time = 0
    stroke_delay = 0 # time between strokes
    start_end_strokes = []
    
    for stroke in strokes:
        x_coords = []
        y_coords = []
        time_list = []
        
        for i, point in enumerate(stroke):
            #print("Points", len(strokes))
            x,y, time = point.attrib["x"], point.attrib["y"], point.attrib["time"]
            x_coords.append(int(x))
            y_coords.append(-int(y))
            
            if i==0: # no time passes between strokes!
                min_time += float(time)-min_time-last_time-.001
                start_end_strokes.append((last_time, float(time)-min_time))

            next_time = float(time)-min_time

            if time_list and next_time == time_list[-1]:
                next_time += .001
                assert next_time > time_list[-1]
            
                
            # No repeated times
            if time_list and next_time <= time_list[-1]:
                next_time = time_list[-1] + .001
            
                
            time_list.append(next_time)
        last_time=time_list[-1]
        stroke_list.append({"x":x_coords, "y":y_coords, "time":time_list})
    return stroke_list, start_end_strokes

def convert_strokes(stroke_list):
    x, y, time = [], [], []
    [x.extend(key["x"]) for key in stroke_list]
    [y.extend(key["y"]) for key in stroke_list]
    [time.extend(key["time"]) for key in stroke_list]
    return np.array(x), np.array(y), np.array(time)

def process(time):
    total_time = np.max(time) - np.min(time)


In [113]:
def normalize(my_array):
    return ((my_array-np.min(my_array))/(np.max(my_array)-np.min(my_array))-.5)*2

def get_gts(path, instances = 50):
    stroke_list, start_end_strokes = get_strokes(path)
    x,y,time = convert_strokes(stroke_list)

    # find dead timezones
    # make x and y independently a function of t
    time_continuum = np.linspace(np.min(time), np.max(time), instances)
    x_func = interpolate.interp1d(time, x)
    y_func = interpolate.interp1d(time, y)

    for i,t in enumerate(time_continuum):
        for lower, upper in start_end_strokes:
            if t < lower:
                break
            if t > lower and t < upper:
                t = lower if abs(t-lower)<abs(t-upper) else upper
                time_continuum[i] = t
                break
    return normalize(x_func(time_continuum)), normalize(y_func(time_continuum))


In [114]:
path = Path("/media/taylor/Data/Linux/Github/simple_hwr/data/prepare_online_data/lines-xml/a01-000u-06.xml")

get_gts(path)

(array([-1.        , -0.98820401, -0.98802347, -0.98380521, -0.90070996,
        -0.91806482, -0.85585469, -0.82920331, -0.76193384, -0.76760189,
        -0.71580077, -0.68877856, -0.65237018, -0.60204692, -0.57112971,
        -0.49682716, -0.45418593, -0.46321773, -0.4126078 , -0.38319895,
        -0.37819815, -0.3048843 , -0.13492077, -0.10410237, -0.05857741,
        -0.07008551, -0.00229394,  0.2045342 ,  0.23823755,  0.27013675,
         0.27025507,  0.3361882 ,  0.38948364,  0.35854569,  0.61867353,
         0.60765093,  0.60678239,  0.67839028,  0.73014626,  0.72818961,
         0.79927907,  0.82621772,  0.85996926,  0.87382559,  0.87029167,
         0.87664101,  0.9204167 ,  0.9712212 ,  0.99606231,  1.        ]),
 array([ 0.39154597, -0.30345296, -0.15864581,  0.39136734,  0.38231689,
        -0.20020641,  0.3481594 , -0.01983259,  0.31860651, -0.27620735,
         0.92576053, -0.21151947, -0.06019738,  0.88052814, -0.21280956,
        -0.2203119 ,  0.1991148 , -0.28453832,  0

array([-1.        , -0.33333333,  0.33333333,  1.        ])