# [MNIST digits stroke sequence data](https://edwin-de-jong.github.io/blog/mnist-sequence-data/　)
今回は、MNISTの手描き文字データセットを解析して、数字の筆順をピクセル座標の系列で表現したデータセットを使用する。
ここでの筆順は必ずしも正確な筆順とは限らないが、少なくとも画像中の数字の形状に沿った系列データとなっている。

In [1]:
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jr
import diffrax
import equinox as eqx

from tools._dataset.datasets import MNISTStrokeDataset
from tools._dataset.dataloader import dataloader_ununiformed_sequence



In [2]:
key = jr.PRNGKey(12)
dataset = MNISTStrokeDataset(dataset_size=64, mode_train=True, input_format='point_sequence', noise_ratio=0.2, interpolation='cubic', key=key)
ts, ys, coeffs, labels, in_size = dataset.make_dataset()

100%|██████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:06<00:00,  9.53it/s]


In [3]:
ts[0].shape

(46,)

In [4]:
ts[0]

Array([0.01421267, 0.01521267, 0.01621267, 0.01721267, 0.01821267,
       0.01921267, 0.02021267, 0.02121267, 0.02221267, 0.02321267,
       0.02421267, 0.02521267, 0.02621267, 0.02762689, 0.05166852,
       0.07313943, 0.09461034, 0.11241484, 0.11341484, 0.11441484,
       0.11582906, 0.11682906, 0.11782906, 0.11924328, 0.1206575 ,
       0.1216575 , 0.1226575 , 0.12407172, 0.12548593, 0.12648593,
       0.12748593, 0.12848593, 0.12948593, 0.13048594, 0.13148594,
       0.13248594, 0.13348594, 0.13448595, 0.13548595, 0.13648595,
       0.13748595, 0.13848595, 0.13948596, 0.14048596, 0.14148596,
       0.1689996 ], dtype=float32)

In [5]:
ys[0].shape

(46, 3)

In [6]:
ys[0]

Array([[ 1.42126707e-02,  1.10000000e+01,  9.00000000e+00],
       [ 1.52126709e-02,  1.20000000e+01,             nan],
       [ 1.62126701e-02,  1.30000000e+01,             nan],
       [ 1.72126703e-02,  1.30000000e+01,  8.00000000e+00],
       [ 1.82126705e-02,  1.40000000e+01,  8.00000000e+00],
       [ 1.92126706e-02,  1.50000000e+01,  8.00000000e+00],
       [ 2.02126708e-02,  1.60000000e+01,  8.00000000e+00],
       [ 2.12126710e-02,  1.70000000e+01,             nan],
       [ 2.22126711e-02,  1.70000000e+01,             nan],
       [ 2.32126713e-02,  1.80000000e+01,             nan],
       [ 2.42126714e-02,  1.90000000e+01,  7.00000000e+00],
       [ 2.52126716e-02,  2.00000000e+01,  7.00000000e+00],
       [ 2.62126718e-02,  2.10000000e+01,  7.00000000e+00],
       [ 2.76268851e-02,  2.20000000e+01,  6.00000000e+00],
       [ 5.16685173e-02, -1.00000000e+00, -1.00000000e+00],
       [ 7.31394291e-02,             nan,  9.00000000e+00],
       [ 9.46103409e-02, -1.00000000e+00

In [7]:
labels.shape

(64, 1)

In [8]:
labels[0]

Array([5.], dtype=float32)

In [9]:
in_size

3

In [10]:
for i in range(4):
    print(type(coeffs), len(coeffs[i]), coeffs[i][0].shape)

<class 'tuple'> 64 (45, 3)
<class 'tuple'> 64 (45, 3)
<class 'tuple'> 64 (45, 3)
<class 'tuple'> 64 (45, 3)


In [13]:
i = 0

sample_ys = ys[i]
sample_ts = ts[i]
sample_coeffs = (coeffs[0][i], coeffs[1][i], coeffs[2][i], coeffs[3][i])

interpolation = diffrax.CubicInterpolation(sample_ts, sample_coeffs)
values = jax.vmap(interpolation.evaluate)(sample_ts)

In [15]:
for t, y, v_interp in zip(sample_ts, sample_ys, values):
    print(f"{t}: {y} --> {v_interp}")

0.014212670736014843: [ 0.01421267 11.          9.        ] --> [0. 0. 0.]
0.015212670899927616: [ 0.01521267 12.                 nan] --> [ 0.01521267 12.          2.6666675 ]
0.016212670132517815: [ 0.01621267 13.                 nan] --> [ 0.01621267 12.999999    5.3333325 ]
0.017212670296430588: [ 0.01721267 13.          8.        ] --> [ 0.01721267 13.          8.        ]
0.01821267046034336: [ 0.01821267 14.          8.        ] --> [ 0.01821267 14.          8.        ]
0.019212670624256134: [ 0.01921267 15.          8.        ] --> [ 0.01921267 15.          8.        ]
0.020212670788168907: [ 0.02021267 16.          8.        ] --> [ 0.02021267 16.          8.        ]
0.02121267095208168: [ 0.02121267 17.                 nan] --> [ 0.02121267 17.          7.890625  ]
0.022212671115994453: [ 0.02221267 17.                 nan] --> [ 0.02221267 17.          7.6354165 ]
0.023212671279907227: [ 0.02321267 18.                 nan] --> [ 0.02321267 18.          7.3125    ]
0.0242126

In [11]:
def main(ts, ys, coeffs, labels, in_size):
    steps = 1
    batch_size = 10
    loader_key = jr.PRNGKey(12)

    def test_interpolation(data):
        ts, *coeffs = data
        control = diffrax.CubicInterpolation(ts, coeffs)
        y0 = control.evaluate(ts[0])
        delta = jax.vmap(control.evaluate)(ts[:-1], ts[1:])
        def _f(yi0, delta):
            yi1 = yi0 + delta
            return yi1, yi1
        _, ys = lax.scan(_f, y0, delta)
        ys = jnp.concatenate((y0[None, :], ys), axis=0)
        return ys

    eqx.clear_caches()
    jax.clear_caches()
    
    for step, data in zip(range(steps), dataloader_ununiformed_sequence(((ts, ys, *coeffs), labels), batch_size, key=loader_key)):
        _ts, _ys, *_coeffs, _labels = data
        print(_ts.shape, _labels.shape, len(_coeffs), _coeffs[1].shape)
        interp_ys = jax.vmap(test_interpolation)((_ts, *_coeffs))
        print(_labels[0][0])
        print(_ys.shape, interp_ys.shape)
        print(_ys[0])
        print(interp_ys[0])

In [12]:

main(ts, ys, coeffs, labels, in_size)

(10, 50) (10, 1) 4 (10, 49, 3)
4.0
(10, 50, 3) (10, 50, 3)
[[ 1.92353856e-02  1.70000000e+01  9.00000000e+00]
 [ 2.02353857e-02  1.70000000e+01             nan]
 [ 2.12353859e-02  1.70000000e+01             nan]
 [ 2.22353861e-02  1.60000000e+01  1.10000000e+01]
 [ 2.32353862e-02  1.60000000e+01  1.20000000e+01]
 [ 2.42353864e-02  1.60000000e+01  1.30000000e+01]
 [ 2.52353866e-02  1.60000000e+01  1.40000000e+01]
 [ 2.62353867e-02  1.60000000e+01             nan]
 [ 2.72353869e-02  1.50000000e+01             nan]
 [ 2.82353871e-02  1.50000000e+01             nan]
 [ 2.96496004e-02  1.40000000e+01  1.70000000e+01]
 [ 3.06496006e-02  1.40000000e+01  1.80000000e+01]
 [ 3.16496007e-02  1.40000000e+01  1.90000000e+01]
 [ 3.30638140e-02  1.30000000e+01  2.00000000e+01]
 [ 3.40638123e-02  1.30000000e+01  2.10000000e+01]
 [ 3.54780257e-02             nan  2.20000000e+01]
 [ 3.64780240e-02  1.20000000e+01             nan]
 [ 3.74780223e-02  1.10000000e+01             nan]
 [ 3.84780206e-02  1.10