In [9]:

import matplotlib.pyplot as plt
from config import Config
from patient_data_reader import PatientReader
import os
import time
import numpy as np
import pandas as pd


import torch

In [6]:
# Number of units in the hidden (recurrent) layer
N_HIDDEN = 200
# Number of training sequences in each batch


# All gradients above this will be clipped
GRAD_CLIP = 100
# How often should we check the output?
EPOCH_SIZE = 100
# Number of epochs to train the net
num_epochs = 6

MAX_LENGTH = 300

In [14]:
def prepare_data(seqs, labels, vocabsize, maxlen=None):
    """Create the matrices from the datasets.

    This pad each sequence to the same lenght: the lenght of the
    longuest sequence or maxlen.

    if maxlen is set, we will cut all sequence to this maximum
    lenght.

    This swap the axis!
    """
    # x: a list of sentences
    lengths = [len(s) for s in seqs]

    eventSeq = []

    for seq in seqs:
        t = []
        for visit in seq:
            t.extend(visit)
        eventSeq.append(t)
    eventLengths = [len(s) for s in eventSeq]

    if maxlen is not None:
        new_seqs = []
        new_lengths = []
        new_labels = []
        for l, s, la in zip(lengths, seqs, labels):
            if l < maxlen:
                new_seqs.append(s)
                new_lengths.append(l)
                new_labels.append(la)
            else:
                new_seqs.append(s[:maxlen])
                new_lengths.append(maxlen)
                new_labels.append(la[:maxlen])
        lengths = new_lengths
        seqs = new_seqs
        labels = new_labels

        if len(lengths) < 1:
            return None, None, None

    n_samples = len(seqs)
    maxlen = np.max(lengths)

    x = np.zeros((n_samples, maxlen, vocabsize)).astype('int64')
    x_mask = np.zeros((n_samples, maxlen)).astype(float)
    y = np.ones((n_samples, maxlen)).astype(float)
    for idx, s in enumerate(seqs):
        x_mask[idx, :lengths[idx]] = 1
        for j, sj in enumerate(s):
            for tsj in sj:
                x[idx, j, tsj - 1] = 1
    for idx, t in enumerate(labels):
        y[idx, :lengths[idx]] = t
        # if lengths[idx] < maxlen:
        #     y[idx,lengths[idx]:] = t[-1]

    return x, x_mask, y, lengths, eventLengths

In [15]:
FLAGS = Config()
data_sets = PatientReader(FLAGS)

X_raw_data, Y_raw_data = data_sets.get_data_from_type("train")
trainingAdmiSeqs, trainingMask, trainingLabels, trainingLengths, ltr = prepare_data(X_raw_data, Y_raw_data,
                                                                                    vocabsize=619,
                                                                                    maxlen=MAX_LENGTH)
Num_Samples, MAX_LENGTH, N_VOCAB = trainingAdmiSeqs.shape

X_valid_data, Y_valid_data = data_sets.get_data_from_type("valid")
validAdmiSeqs, validMask, validLabels, validLengths, lval = prepare_data(X_valid_data, Y_valid_data, vocabsize=619,
                                                                         maxlen=MAX_LENGTH)

X_test_data, Y_test_data = data_sets.get_data_from_type("test")
test_admiSeqs, test_mask, test_labels, testLengths, ltes = prepare_data(X_test_data, Y_test_data, vocabsize=619,
                                                                        maxlen=MAX_LENGTH)
alllength = sum(trainingLengths) + sum(validLengths) + sum(testLengths)
print(alllength)
eventNum = sum(ltr) + sum(lval) + sum(ltes)
print(eventNum)

 [*] load resource\vocab.pkl
 [*] load resource/X_train.pkl
 [*] load resource/Y_train.pkl
 [*] load resource/X_valid.pkl
 [*] load resource/Y_valid.pkl
 [*] load resource/X_test.pkl
 [*] load resource/Y_test.pkl
vocabulary size: 619
number of training documents: 2000
number of validation documents: 500
number of testing documents: 500
239887
685482


In [18]:
X_raw_data

[[[406, 273, 485],
  [1, 488, 490, 486],
  [426],
  [442, 307, 203],
  [490, 472, 486, 12],
  [439],
  [411],
  [380],
  [406, 489, 440, 485, 406, 489, 440, 485],
  [488, 469, 487, 486],
  [406, 475, 440, 479],
  [359, 444],
  [400],
  [80, 428, 444],
  [48],
  [458],
  [473],
  [368, 484, 472, 485],
  [458, 424],
  [473],
  [472],
  [481, 402, 139],
  [490],
  [487, 486, 411],
  [444],
  [423, 446],
  [427],
  [127],
  [490, 362, 489, 470],
  [417],
  [380],
  [475, 311],
  [411],
  [469, 440],
  [427],
  [391],
  [1, 488, 486, 459, 462, 1, 488, 486, 459, 462],
  [488, 469, 481, 440, 488, 469, 481, 440],
  [359, 444, 364, 131],
  [489, 446, 478, 468, 350],
  [406, 487, 348, 485],
  [406, 488, 440, 479],
  [476, 490],
  [490],
  [395, 473, 470],
  [380],
  [424],
  [482, 489, 475, 470],
  [434, 472, 305],
  [385, 426, 335, 424, 414, 471],
  [411],
  [481],
  [481, 489],
  [487, 490, 470],
  [472],
  [488],
  [488, 469, 487, 486, 488, 469, 487, 486],
  [385],
  [434],
  [400],
  [488, 7