In [None]:
import caffe
import wfdb
import matplotlib.pyplot as plt
import numpy as np
import math
import lmdb

#Heart rate calculation and record plotting
def peaks(x, peak_indices, fs, title, figsize=(20, 10), saveto=None):
    
    hrs = wfdb.processing.compute_hr(siglen=x.shape[0], peak_indices=peak_indices, fs=fs)
    
    return hrs

#Single beats extraction for record
def extract_beats(record, R_indices, peaks_hr, freq):
    
    flat_record = [item for sublist in record for item in sublist]
    ecg_beats = [];
    
    #Extract beats of different lengths depending on heart rate
    for i in range(0, len(R_indices)):
        hr = peaks_hr[R_indices[i]].item();
        if(math.isnan(hr)):
            hr = 70.0;
        samples_per_beat = int(freq*(60.0/hr))
        start = R_indices[i]-samples_per_beat/2;
        if(start<0):
            start = 0;
        end = start + samples_per_beat;
        ecg_beats.append(np.array(flat_record[start:end]));
    
    #Resample beats to get fixed input size for classification
    ecg_nparray = np.empty((len(ecg_beats),256L));
    for i in range(0,len(ecg_beats)):    
        ecg_beats[i], _ = wfdb.processing.resample_sig(x=ecg_beats[i],fs=len(ecg_beats[i]), fs_target=256);
        ecg_nparray[i] = ecg_beats[i];
            
    return ecg_beats;

def read_record(rec):
    # Load the wfdb record and the physical samples
    record = wfdb.rdsamp('./dataset/'+rec, channels=[0])
    annotation = wfdb.rdann('./dataset/'+rec, 'atr', summarize_labels=True)
    freq = record.fs
    sig = record.p_signals
    sig = wfdb.processing.normalize(x=sig, lb=0.0, ub=1.0)

    for idx, val in enumerate(sig):    
        record.p_signals[idx,0] = val

    peak_indices = annotation.sample;
    peaks_hr = peaks(x=record.p_signals, peak_indices=peak_indices, fs=record.fs, title="GQRS peaks on record "+rec);
    ecg_beats = extract_beats(record.p_signals, peak_indices, peaks_hr,freq);
    
    return ecg_beats, annotation;
    
beats = [];
annotations = [];
records = [100,101,102,103,104,105,106,107,108,109,111,112,113,114,115,116,117,118,119,121,122,123,124,
          200,201,202,203,205,207,208,209,210,212,213,214,215,217,219,220,221,222,223,228,230,231,232,233,234]

for i, rec in enumerate(records):
    rec_beats , annotation = read_record(str(rec));
    beats.append(rec_beats);
    annotations.append(annotation.symbol);
    file = open("./Text_records/Normalized/Rec_"+str(rec)+".txt", "w");
    for j in range(len(beats[i])):
        file.write(str(beats[i][j]) + "\t");
        file.write(annotations[i][j] + "\n");
    file.close();
    #print(annotation.contained_labels)

# Saving data into a lmdb dataset
caffe.set_device(0);
caffe.set_mode_gpu();

#blobinp = np.append(beats[0][: , np.newaxis, np.newaxis , :] , beats[3][: , np.newaxis, np.newaxis , :] , axis=0)
blobinp = beats[0][: , np.newaxis, np.newaxis , :]
print(blobinp.shape)

N = beats[0].shape[0];

X = blobinp;
Y = np.empty(N, dtype=int);
for i in range(0,N):
    Y[i] = {
        'N': 1,
        'L': 2,
        'R': 3,
        'a': 4,
        'V': 5,
        'F': 6,
        'J': 7,
        'A': 8,
        'S': 9,
        'E': 10,
        'j': 11,
        '/': 12,
        'Q': 13,
        '~': 14,
        '|': 16,
        '"': 22,
        '+': 28,
        '!': 31,
        '[': 32,
        ']': 33,
        'e': 34,
        'x': 37,
        'f': 38,
        }[annotations[0][i]]

map_size = X.nbytes * 8 * 5;
'''
env = lmdb.open('mylmdb', map_size=map_size);

N = X.shape[0];

with env.begin(write=True) as txn:
    #txn is a Transaction object
    for i in range(N):
        datum = caffe.proto.caffe_pb2.Datum()
        datum.channels = X.shape[1]
        datum.height = X.shape[2]
        datum.width = X.shape[3]
        datum.float_data.extend(X[i][0][0])
        #datum.data = X[i].tobytes()
        datum.label = int(Y[i])
        str_id = '{:08}'.format(i)

        # The encode is only essential in Python 3
        txn.put(str_id, datum.SerializeToString())
'''
#env = lmdb.open('mylmdb', readonly=True)
#with env.begin() as txn:
#    raw_datum = txn.get(b'00000000')

#datum = caffe.proto.caffe_pb2.Datum()
#datum.ParseFromString(raw_datum)

#flat_x = np.fromstring(datum.data, dtype=np.float64)

#x = flat_x.reshape(datum.channels, datum.height, datum.width)
#y = datum.label
#print(x);