In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import math
from utils import *

In [5]:
def load_all_data(dir, filecnt=54):
    data = [None]*10
    t = np.zeros((10, filecnt), int)
    for p in range(10):
        d = [None]*filecnt
        for i in range(filecnt):
            d[i] = np.load(f"{dir}/{p}/{i+1}.npy")
            t[p,i] = d[i].shape[0]
        data[p] = d
    return data, t

def load_with_config(dir, config_path, load_nonvalid=False):
    config_arr = np.array(pd.read_csv(f"{config_path}", header=None, skiprows=1))
    data = [[] for i in range(10)]          #create empty 2d list : (10, unknown)
    for config in config_arr:
        #valid check
        if (not config[2]) and (not load_nonvalid):
            continue
        data[int(config[0])] += [np.load(f"{dir}/{int(config[0])}/{int(config[1])+1}.npy")[int(config[3]):int(config[4]),:]]
    
    return data, config_arr

def match_length(d, t:int):
    """
    return
        x : (N, times(t), 2) shape numpy array,
        y : (N)
    """
    N = sum([len(d[i]) for i in range(len(d))])
    x = np.zeros((N, t, 2), np.float64)
    y = np.zeros((N), np.float64)
    target_timepoints = np.linspace(0, 1, t)
    start_at = 0
    for r in range(len(d)):
        for c in range(len(d[r])):
            origin_timepoints = np.linspace(0, 1, d[r][c].shape[0])
            x[start_at + c, :, 0] = np.interp(target_timepoints, origin_timepoints, d[r][c][:,0])
            x[start_at + c, :, 1] = np.interp(target_timepoints, origin_timepoints, d[r][c][:,1])
        y[start_at:start_at + len(d[r])] = r
        start_at += len(d[r])
    return x, y 

def apply_normalize(d):
    for r in range(len(d)):
        for c in range(len(d[r])):
            channels = d[r][c].shape[1]
            min_vals = np.min(d[r][c][:, :], axis=0)
            max_vals = np.max(d[r][c][:, :], axis=0)
            min_max_diff = np.array([max_vals[j] - min_vals[j] for j in range(channels)])
            factor = min_max_diff / np.max(min_max_diff)
            for ch in range(channels):
                d[r][c][:, ch] = ((d[r][c][:, ch] - min_vals[ch])/min_max_diff[ch])*factor[ch]
    
def plot_data(d):
    plt.figure(figsize=(20,4))
    plt.subplot(2,1,1)
    gca = plt.gca()
    gca.plot(d[:,0])
    plt.subplot(2,1,2)
    gca = plt.gca()
    gca.plot(d[:,1])
    plt.tight_layout()
    plt.show()

def plot_data2(d, save=None, cmap_name="gist_rainbow", xlim1=None, ylim1=None, xlim2=None, ylim2=None, xlim3=None, ylim3=None):
    fig, axes = plt.subplot_mosaic("abbbb;acccc", figsize=(20,4))
    draw_gradation(d[:,0], d[:,1], axes["a"],cmap_name=cmap_name, xlim=xlim1, ylim=ylim1)
    draw_gradation(np.arange(d.shape[0]), d[:,0], axes["b"],cmap_name=cmap_name, xlim=xlim2, ylim=ylim2)
    draw_gradation(np.arange(d.shape[0]), d[:,1], axes["c"],cmap_name=cmap_name, xlim=xlim3, ylim=ylim3)

    plt.tight_layout()
    if save is None:
        plt.show()
    else:
        plt.savefig(save)
        plt.close()

def plot_data3(origin, syn, save, cmap_name="gist_rainbow"):
    fig, axes = plt.subplot_mosaic("abbbb;acccc;deeee;dffff", figsize=(20,8))
    draw_gradation(origin[:,0], origin[:,1], axes["a"],cmap_name=cmap_name, xlim=[0,1], ylim=[0,1])
    draw_gradation(np.arange(origin.shape[0]), origin[:,0], axes["b"],cmap_name=cmap_name, ylim=[0,1])
    draw_gradation(np.arange(origin.shape[0]), origin[:,1], axes["c"],cmap_name=cmap_name, ylim=[0,1])
    draw_gradation(syn[:,0], syn[:,1], axes["d"], cmap_name=cmap_name, xlim=[0,1], ylim=[0,1])
    draw_gradation(np.arange(syn.shape[0]), syn[:,0], axes["e"],cmap_name=cmap_name, ylim=[0,1])
    draw_gradation(np.arange(syn.shape[0]), syn[:,1], axes["f"],cmap_name=cmap_name, ylim=[0,1])
    axes["a"].set_title("Original")
    axes["d"].set_title("Synthetic")

    plt.tight_layout()
    if save is None:
        plt.show()
    else:
        plt.savefig(save)
        plt.close()

In [6]:
data, config = load_with_config("/home/user/workspace/research/eye-writing/self_data/", "/home/user/workspace/research/eye-writing/load_data_config.csv", load_nonvalid=False)
apply_normalize(data)
# plot_data2(data[9][53], xlim1=[0, 1], ylim1=[0,1], ylim2=[0,1], ylim3=[0,1])

In [7]:
# target_data = data
# save_dir = "/home/user/img"
# for r in range(len(target_data)):
#     target_pattern_config = config[config[:,0] == r, :]
#     indexes = target_pattern_config[target_pattern_config[:,2]==1, 1]
#     print(f"pattern {r}, length = {len(target_data[r])}, indexes len = {indexes.shape}")
#     for c in range(len(target_data[r])):
#         plot_data2(target_data[r][c], save=f"{save_dir}/Pattern_{r}_Index_{indexes[c]}.png", cmap_name="gist_rainbow", xlim1=[0, 1], ylim1=[0,1], ylim2=[0,1], ylim3=[0,1])

In [8]:
x, y = match_length(data, 200)
print(x.shape)
print(y.shape)
pattern = 8
x = x[y==pattern]
print(x.shape)


(505, 200, 2)
(505,)
(50, 200, 2)


In [9]:
# target_data = x
# save_dir = "/home/user/img"
# for i in range(len(target_data)):
#     plot_data2(target_data[i], f"{save_dir}/{i}.png", "gist_rainbow",  xlim1=[0, 1], ylim1=[0,1], ylim2=[0,1], ylim3=[0,1])

In [38]:
import tensorflow as tf
import tensorflow.keras as keras
from tqdm import trange

def get_model(input_shape, output_units, rnn_units, layer_cnt):
    inputs = keras.layers.Input(input_shape)
    x = inputs
    for i in range(layer_cnt):
        x = keras.layers.GRU(rnn_units, return_sequences=True)(x)
    outputs = keras.layers.Dense(output_units, activation="sigmoid")(x)
    return keras.Model(inputs, outputs)

def init_timegan(input_shape, units=24, layers=3):
    latent_code_shape = (input_shape[0], units)
    embedder = get_model(input_shape, units, units, layers)
    generator = get_model(input_shape, units, units, layers)
    supervisor = get_model(latent_code_shape, units, units, layers)
    recovery = get_model(latent_code_shape, input_shape[1], units, layers)
    discriminator = get_model(latent_code_shape, 1, units, layers)
    return {"embedder":embedder, "generator":generator, "supervisor":supervisor, "recovery":recovery, "discriminator":discriminator}
    
def timegan_train(embedder, generator, supervisor, recovery, discriminator, epochs, batch_size, learning_rate):
    mse = keras.losses.MeanSquaredError()
    bce = keras.losses.BinaryCrossentropy()
    mini_batch = lambda x, batch_size : x[np.random.permutation(x.shape[0])[:batch_size]]
    opt_autoencoder = keras.optimizers.Adam(learning_rate=learning_rate)
    opt_supervisor = keras.optimizers.Adam(learning_rate=learning_rate)
    opt_generator = keras.optimizers.Adam(learning_rate=learning_rate)
    opt_embedder = keras.optimizers.Adam(learning_rate=learning_rate)
    opt_discriminator = keras.optimizers.Adam(learning_rate=learning_rate)

    #train autoencoder
    print(f"train autoencoder")
    for _ in trange(epochs):
        batch = mini_batch(x, batch_size)
        with tf.GradientTape() as tape:
            y_true = embedder(batch)
            y_true = recovery(y_true)
            loss = 10 * tf.sqrt(mse(y_true, batch))
        var_list = embedder.trainable_variables + recovery.trainable_variables
        gradients = tape.gradient(loss, var_list)
        opt_autoencoder.apply_gradients(zip(gradients, var_list))

    #train supervisor
    print(f"train supervisor")
    for _ in trange(epochs):
        batch = mini_batch(x, batch_size)
        with tf.GradientTape() as tape:
            y_true = embedder(batch)
            y_pred = supervisor(y_true)
            loss = mse(y_true[:, 1:, :], y_pred[:, :-1, :])
        var_list = generator.trainable_variables + supervisor.trainable_variables
        gradients = tape.gradient(loss, var_list)
        opt_autoencoder.apply_gradients(zip(gradients, var_list))

    #joint train
    print(f"joint train")
    for _ in trange(epochs):
        for __ in range(2):
            batch = mini_batch(x, batch_size)
            random_vector = np.random.uniform(size=batch.shape)
            #train generator
            with tf.GradientTape() as tape:
                #supervised loss
                y_true = embedder(batch)
                y_pred = supervisor(y_true)
                supervised_loss = mse(y_true[:, 1:, :], y_pred[:, :-1, :])

                #unsupervised loss
                y_true = tf.ones((batch_size, x.shape[1], 1))
                y_pred = generator(random_vector)
                y_pred = supervisor(y_pred)
                y_pred = discriminator(y_pred)
                unsupervised_loss = bce(y_true, y_pred)

                #unsupervised loss - E
                y_true = tf.ones((batch_size, x.shape[1], 1))
                y_pred = generator(random_vector)
                y_pred = discriminator(y_pred)
                unsupervised_loss_e = bce(y_true, y_pred)

                #moment loss




In [37]:
timegan = init_timegan(x.shape[1:])
timegan_train(**timegan, epochs=100, batch_size=32, learning_rate=0.0005)

train autoencoder


  0%|          | 0/100 [00:00<?, ?it/s]

(32, 200, 2)


  2%|▏         | 2/100 [00:01<00:43,  2.24it/s]

(32, 200, 2)
(32, 200, 2)


  4%|▍         | 4/100 [00:01<00:23,  4.13it/s]

(32, 200, 2)
(32, 200, 2)


  6%|▌         | 6/100 [00:01<00:17,  5.48it/s]

(32, 200, 2)
(32, 200, 2)


  8%|▊         | 8/100 [00:01<00:14,  6.31it/s]

(32, 200, 2)
(32, 200, 2)


 10%|█         | 10/100 [00:02<00:13,  6.72it/s]

(32, 200, 2)
(32, 200, 2)


 12%|█▏        | 12/100 [00:02<00:12,  6.96it/s]

(32, 200, 2)
(32, 200, 2)


 14%|█▍        | 14/100 [00:02<00:12,  7.05it/s]

(32, 200, 2)
(32, 200, 2)


 16%|█▌        | 16/100 [00:02<00:12,  6.93it/s]

(32, 200, 2)
(32, 200, 2)


 18%|█▊        | 18/100 [00:03<00:11,  6.85it/s]

(32, 200, 2)
(32, 200, 2)


 20%|██        | 20/100 [00:03<00:11,  6.94it/s]

(32, 200, 2)
(32, 200, 2)


 22%|██▏       | 22/100 [00:03<00:11,  7.03it/s]

(32, 200, 2)
(32, 200, 2)


 24%|██▍       | 24/100 [00:04<00:10,  7.06it/s]

(32, 200, 2)
(32, 200, 2)


 26%|██▌       | 26/100 [00:04<00:10,  7.09it/s]

(32, 200, 2)
(32, 200, 2)


 28%|██▊       | 28/100 [00:04<00:10,  7.09it/s]

(32, 200, 2)
(32, 200, 2)


 30%|███       | 30/100 [00:05<00:10,  6.75it/s]

(32, 200, 2)
(32, 200, 2)


 32%|███▏      | 32/100 [00:05<00:10,  6.78it/s]

(32, 200, 2)
(32, 200, 2)


 34%|███▍      | 34/100 [00:05<00:09,  6.93it/s]

(32, 200, 2)
(32, 200, 2)


 36%|███▌      | 36/100 [00:05<00:09,  6.99it/s]

(32, 200, 2)
(32, 200, 2)


 38%|███▊      | 38/100 [00:06<00:09,  6.73it/s]

(32, 200, 2)
(32, 200, 2)


 40%|████      | 40/100 [00:06<00:08,  6.80it/s]

(32, 200, 2)
(32, 200, 2)


 42%|████▏     | 42/100 [00:06<00:08,  6.70it/s]

(32, 200, 2)
(32, 200, 2)


 44%|████▍     | 44/100 [00:07<00:08,  6.93it/s]

(32, 200, 2)
(32, 200, 2)


 46%|████▌     | 46/100 [00:07<00:07,  7.08it/s]

(32, 200, 2)
(32, 200, 2)


 48%|████▊     | 48/100 [00:07<00:07,  7.02it/s]

(32, 200, 2)
(32, 200, 2)


 50%|█████     | 50/100 [00:07<00:07,  7.04it/s]

(32, 200, 2)
(32, 200, 2)


 52%|█████▏    | 52/100 [00:08<00:06,  7.01it/s]

(32, 200, 2)
(32, 200, 2)


 54%|█████▍    | 54/100 [00:08<00:06,  7.05it/s]

(32, 200, 2)
(32, 200, 2)


 56%|█████▌    | 56/100 [00:08<00:06,  6.93it/s]

(32, 200, 2)
(32, 200, 2)


 58%|█████▊    | 58/100 [00:09<00:06,  6.94it/s]

(32, 200, 2)
(32, 200, 2)


 60%|██████    | 60/100 [00:09<00:05,  6.80it/s]

(32, 200, 2)
(32, 200, 2)


 62%|██████▏   | 62/100 [00:09<00:05,  6.73it/s]

(32, 200, 2)
(32, 200, 2)


 64%|██████▍   | 64/100 [00:10<00:05,  6.11it/s]

(32, 200, 2)
(32, 200, 2)


 66%|██████▌   | 66/100 [00:10<00:05,  6.37it/s]

(32, 200, 2)
(32, 200, 2)


 68%|██████▊   | 68/100 [00:10<00:04,  6.71it/s]

(32, 200, 2)
(32, 200, 2)


 70%|███████   | 70/100 [00:10<00:04,  6.99it/s]

(32, 200, 2)
(32, 200, 2)


 72%|███████▏  | 72/100 [00:11<00:04,  6.90it/s]

(32, 200, 2)
(32, 200, 2)


 74%|███████▍  | 74/100 [00:11<00:03,  6.95it/s]

(32, 200, 2)
(32, 200, 2)


 76%|███████▌  | 76/100 [00:11<00:03,  7.01it/s]

(32, 200, 2)
(32, 200, 2)


 78%|███████▊  | 78/100 [00:12<00:03,  6.99it/s]

(32, 200, 2)
(32, 200, 2)


 80%|████████  | 80/100 [00:12<00:02,  6.91it/s]

(32, 200, 2)
(32, 200, 2)


 82%|████████▏ | 82/100 [00:12<00:02,  6.98it/s]

(32, 200, 2)
(32, 200, 2)


 84%|████████▍ | 84/100 [00:12<00:02,  7.05it/s]

(32, 200, 2)
(32, 200, 2)


 86%|████████▌ | 86/100 [00:13<00:01,  7.12it/s]

(32, 200, 2)
(32, 200, 2)


 88%|████████▊ | 88/100 [00:13<00:01,  7.02it/s]

(32, 200, 2)
(32, 200, 2)


 90%|█████████ | 90/100 [00:13<00:01,  7.02it/s]

(32, 200, 2)
(32, 200, 2)


 92%|█████████▏| 92/100 [00:14<00:01,  7.00it/s]

(32, 200, 2)
(32, 200, 2)


 94%|█████████▍| 94/100 [00:14<00:00,  7.01it/s]

(32, 200, 2)
(32, 200, 2)


 96%|█████████▌| 96/100 [00:14<00:00,  7.00it/s]

(32, 200, 2)
(32, 200, 2)


 98%|█████████▊| 98/100 [00:14<00:00,  6.91it/s]

(32, 200, 2)
(32, 200, 2)


100%|██████████| 100/100 [00:15<00:00,  6.58it/s]

(32, 200, 2)





In [12]:
# target_data = synth_data
# save_dir = f"/home/user/syn{pattern}"
# for i in range(len(synth_data)):
#     plot_data2(synth_data[i], f"{save_dir}/{i}.png", "gist_rainbow",   xlim1=[0, 1], ylim1=[0,1], ylim2=[0,1], ylim3=[0,1])

In [13]:
# save_dir = f"/home/user/compare_images_{pattern}"
# for i in range(len(x)):
#     plot_data3(x[i], synth_data[i], f"{save_dir}/{i}.png", "winter_r")