In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pysida.lib import get_deformation_from_pair, Pair, Defor

In [2]:
class Datasets:
    OutputClass = None
    data = None

    def __init__(self, data=None):
        self.data = {}
        if data is not None:
            self.data = data

    def __iter__(self):
        """Iterator that yields data for each unique pair id"""
        if not self.data:
            return
        unique_ids = self.data['points']['id'].unique()
        for pair_id in unique_ids:
            return_data = {}
            for name in self.data:
                pair_data = self.data[name][self.data[name]['id'] == pair_id]
                for column in pair_data:
                    return_data[column] = pair_data[column].values
            return_data['id'] = pair_id
            yield self.OutputClass(**self.adapt_data(return_data))

    def adapt_data(self, data):
        return data

    def save(self, filepath):
        mode = 'w'
        for name in self.data:
            self.data[name].to_hdf(filepath, key=name, mode=mode)
            mode = 'a'

    def load(self, filepath, read_names=()):
        if not read_names:
            with pd.HDFStore(filepath, mode='r') as store:
                read_names = list(store.keys())
        for name in read_names:
            self.data[name] = pd.read_hdf(filepath, key=name)


class Pairs(Datasets):
    OutputClass = Pair

    def append(self, x0, y0, x1, y1, d0, d1, t, a, p, g):
        if 'points' not in self.data:
            pair_id = 0
        else:
            pair_id = self.data['points'].id.iloc[-1] + 1
            
        point_ids = (np.ones(len(x0)) * pair_id).astype(int)
        mesh_ids = (np.ones(len(a)) * pair_id).astype(int)

        new_points = pd.DataFrame({
            'id': point_ids.astype(int),
            'x0': x0.astype(np.float32),
            'y0': y0.astype(np.float32),
            'x1': x1.astype(np.float32),
            'y1': y1.astype(np.float32)
            })
        new_dates = pd.DataFrame({'id': [pair_id], 'd0': [d0], 'd1': [d1]})
        new_meshes = pd.DataFrame({
            'id': mesh_ids.astype(int),
            'a': a.astype(np.float32),
            'p': p.astype(np.float32),
            'g': g.astype(np.float32),
            't0': t[:, 0].astype(int),
            't1': t[:, 1].astype(int),
            't2': t[:, 2].astype(int)
        })

        if not self.data:
            self.data['points'] = new_points
            self.data['dates'] = new_dates
            self.data['meshes'] = new_meshes
        else:
            self.data['points'] = pd.concat([self.data['points'], new_points], ignore_index=True)
            self.data['dates'] = pd.concat([self.data['dates'], new_dates], ignore_index=True)
            self.data['meshes'] = pd.concat([self.data['meshes'], new_meshes], ignore_index=True)

    def adapt_data(self, data):
        data['t'] = np.column_stack((data['t0'], data['t1'], data['t2']))
        del data['t0'], data['t1'], data['t2']
        data['d0'] = pd.Timestamp(data['d0'][0]).to_pydatetime()
        data['d1'] = pd.Timestamp(data['d1'][0]).to_pydatetime()
        return data

class Defors(Datasets):
    OutputClass = Defor

    def append(self, pid, e1, e2, e3, ux, uy, vx, vy):
        defor_ids = (np.ones(len(e1)) * pid).astype(int)
        new_defors = pd.DataFrame({
            'id': defor_ids.astype(int),
            'e1': e1.astype(np.float32),
            'e2': e2.astype(np.float32),
            'e3': e3.astype(np.float32),
            'ux': ux.astype(np.float32),
            'uy': uy.astype(np.float32),
            'vx': vx.astype(np.float32),
            'vy': vy.astype(np.float32),
        })
        if 'defor' not in self.data:
            self.data['defor'] = new_defors
        else:
            if pid in self.data['defor'].id.values:
                self_data_defor = self.data['defor'][self.data['defor'].id != pid]
            else:
                self_data_defor = self.data['defor']
            self.data['defor'] = pd.concat([self_data_defor, new_defors], ignore_index=True)


In [3]:
np_pairs = np.load('../test/nex10_pairs.npz', allow_pickle=True)['pairs']
len(np_pairs)

524

In [4]:
pairs = Pairs()
for npp in np_pairs:
    pairs.append(npp.x0, npp.y0, npp.x1, npp.y1, npp.d0, npp.d1, npp.t, npp.a, npp.p, npp.g)

pairs.save('../test/data10.hdf5')

In [5]:
defors = Defors(pairs.data)
for i, p in enumerate(pairs):
    d = get_deformation_from_pair(p)
    defors.append(p.id, d.e1, d.e2, d.e3, d.ux, d.uy, d.vx, d.vy)

defors.save('../test/data10.hdf5')

In [6]:
pairs.data['points']

Unnamed: 0,id,x0,y0,x1,y1
0,0,-1.061115e+06,9.766455e+05,-1.063044e+06,9.771837e+05
1,0,-1.055900e+06,9.865853e+05,-1.057807e+06,9.871111e+05
2,0,-1.024585e+06,8.101170e+05,-1.027000e+06,8.107462e+05
3,0,-9.882484e+05,9.528082e+05,-9.902786e+05,9.532601e+05
4,0,-1.052271e+06,8.847198e+05,-1.054428e+06,8.854434e+05
...,...,...,...,...,...
274430,523,-5.422636e+05,1.290151e+06,-5.422516e+05,1.290167e+06
274431,523,-5.244672e+05,7.983267e+05,-5.286411e+05,8.013352e+05
274432,523,-4.048753e+05,1.400487e+06,-4.033238e+05,1.408767e+06
274433,523,-4.738903e+05,8.561779e+05,-4.768609e+05,8.598497e+05
