In [1]:
%matplotlib inline
%timeit

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import os
import time
import pickle
import datetime
import re
import pyActigraphy

import sys
sys.path.append('/home/ngrav/project/')
from wearables.scripts import utils as wearutils
from wearables.scripts import data as weardata
from wearables.scripts import train as weartrain
from wearables.scripts import eval_ as weareval
from wearables.scripts import model as wearmodels
from wearables.scripts import DTW as weardtw

import torch
import torch.nn as nn
import torch.nn.functional as F

from scipy.spatial.distance import pdist, squareform
import fastdtw
import umap

plt.rc('font', size = 9)
plt.rc('font', family='sans serif')
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['legend.frameon']=False
plt.rcParams['axes.grid']=False
plt.rcParams['legend.markerscale']=1
plt.rcParams['savefig.dpi'] = 600
sns.set_style("ticks")

In [2]:
# load data
mfp = '/home/ngrav/project/wearables/model_zoo'
bst_modelpkl = os.path.join(mfp, '213-itv52_InceptionTime_GA5.pkl')
bst_trainer = os.path.join(mfp, 'trainer_itv52_InceptionTime_GA5.pkl')

pfp = '/home/ngrav/project/wearables/results/'
md = pd.read_csv(os.path.join(pfp, 'md_v52_220124.csv'), index_col=0)

In [5]:
def merge_out2md(md, bst_trainerfp, bst_modelfp, return_embeds=True, out_file=None, verbose=False):
    def loadpkl(file):
        with open(file, 'rb') as f:
            data = pickle.load(f)
            f.close()
        return data
    trainer = loadpkl(bst_trainerfp)
    if verbose:
        total_t = time.time()
    dt = pd.DataFrame()
    if return_embeds:
        embeds = pd.DataFrame()
    for split in ['train', 'test']: # omit val since test_pids contain val_pids (val is a subset of test)
        if verbose:
            tic = time.time()
            print('Starting inference for {} set...'.format(split))

        evaluation = weareval.eval_trained(trainer, split=split,
                                           modelpkl=bst_modelfp,
                                           two_outputs=True)
        dt = dt.append(pd.DataFrame({
            'y':evaluation.y.numpy(), 'yhat':evaluation.yhat.numpy(), 
            'split':[split]*evaluation.y.shape[0],
            'error':(evaluation.yhat - evaluation.y).numpy()
        }, index=evaluation.id))
        if return_embeds:
            embeds = embeds.append(pd.DataFrame(evaluation.out2.numpy(), index=evaluation.id))
        if verbose:
            print('  inference for {} set done in {:.0f}-s\t{:.2f}-min elapsed'.format(split, time.time()-tic, (time.time()-total_t)/60))
    md = md.merge(dt, left_index=True, right_index=True, how='left')
    if out_file is not None:
        md.to_csv(out_file)
    if return_embeds:
        if out_file is not None:
            embeds.to_csv(os.path.join(os.path.split(out_file)[0], 'embeds_v522_220124.csv'))
        return md, embeds
    else:
        return md

In [6]:
md, embeds = merge_out2md(md, bst_trainer, bst_modelpkl, out_file=os.path.join(pfp, 'md_v522_220124.csv'), verbose=True)

Starting inference for train set...


  return F.conv1d(input, weight, bias, self.stride,
  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


  inference for train set done in 555-s	9.25-min elapsed
Starting inference for test set...


  return F.mse_loss(input, target, reduction=self.reduction)


  inference for test set done in 235-s	13.16-min elapsed


In [7]:
md.head()

Unnamed: 0_level_0,record_id,age_enroll,marital,gestage_by,insur,ethnicity,race,bmi_1vis,prior_ptb_all,fullterm_births,...,visit_num,PQSI,KPAS,EpworthSS,Edinburgh,Pre-term birth,y,yhat,split,error
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1024_10,1024,23.0,0.0,0.0,5.0,0.0,1.0,21.378954,0.0,1.0,...,1,4.0,4.571429,0.0,0.0,False,10.0,15.463044,train,5.463044
2180_12,2180,34.0,1.0,1.0,3.0,0.0,0.0,31.73264,0.0,0.0,...,1,10.0,9.05368,5.0,6.0,False,12.0,11.77713,train,-0.22287
2148_11,2148,34.0,1.0,2.0,3.0,0.0,0.0,22.406605,0.0,1.0,...,1,11.0,10.266234,9.0,7.0,False,11.0,18.814798,train,7.814798
1062_8,1062,31.0,0.0,1.0,5.0,0.0,1.0,40.10627,0.0,0.0,...,1,8.0,9.493074,3.0,5.0,False,8.0,14.42829,train,6.42829
1306_7,1306,25.0,0.0,3.0,5.0,0.0,1.0,24.55567,0.0,0.0,...,1,10.0,8.88961,9.0,9.0,False,7.0,16.641426,train,9.641426


In [8]:
md.index.duplicated().sum()

0

In [9]:
md.shape

(2463, 130)

In [10]:
mdchk = pd.read_csv(os.path.join(pfp, 'md_v522_220124.csv'), index_col=0)

In [11]:
mdchk

Unnamed: 0_level_0,record_id,age_enroll,marital,gestage_by,insur,ethnicity,race,bmi_1vis,prior_ptb_all,fullterm_births,...,visit_num,PQSI,KPAS,EpworthSS,Edinburgh,Pre-term birth,y,yhat,split,error
unique_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1024_10,1024,23.0,0.0,0.0,5.0,0.0,1.0,21.378954,0.0,1.0,...,1,4.0,4.571429,0.0,0.0,False,10.0,15.463044,train,5.463044
2180_12,2180,34.0,1.0,1.0,3.0,0.0,0.0,31.732640,0.0,0.0,...,1,10.0,9.053680,5.0,6.0,False,12.0,11.777130,train,-0.222870
2148_11,2148,34.0,1.0,2.0,3.0,0.0,0.0,22.406605,0.0,1.0,...,1,11.0,10.266234,9.0,7.0,False,11.0,18.814798,train,7.814798
1062_8,1062,31.0,0.0,1.0,5.0,0.0,1.0,40.106270,0.0,0.0,...,1,8.0,9.493074,3.0,5.0,False,8.0,14.428290,train,6.428290
1306_7,1306,25.0,0.0,3.0,5.0,0.0,1.0,24.555670,0.0,0.0,...,1,10.0,8.889610,9.0,9.0,False,7.0,16.641426,train,9.641426
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2344_32,2344,26.0,1.0,3.0,3.0,0.0,0.0,25.969303,0.0,0.0,...,3,11.0,11.025541,2.0,0.0,False,32.0,25.704956,train,-6.295044
2167_33,2167,31.0,1.0,2.0,3.0,0.0,0.0,29.672672,0.0,1.0,...,3,6.0,8.044805,11.0,0.0,False,33.0,20.061865,train,-12.938135
1484_33,1484,28.0,1.0,2.0,1.0,0.0,1.0,30.598898,1.0,0.0,...,3,7.0,12.809957,11.0,2.0,True,33.0,24.141296,test,-8.858704
2330_34,2330,32.0,1.0,2.0,3.0,0.0,0.0,24.046602,0.0,1.0,...,3,8.0,10.271429,6.0,0.0,False,34.0,22.923199,test,-11.076801
