In [None]:
import caffe
import wfdb
import matplotlib.pyplot as plt
import numpy as np
import math
import lmdb

#Heart rate calculation and record plotting
def peaks(x, peak_indices, fs, title, figsize=(20, 10), saveto=None):
    
    hrs = wfdb.processing.compute_hr(siglen=x.shape[0], peak_indices=peak_indices, fs=fs)
    fig, ax_left = plt.subplots(figsize=figsize)
    
    ax_left.plot(x, color='#3979f0', label='Signal')
    ax_left.plot(peak_indices, x[peak_indices], 'rx', marker='x', color='#8b0000', label='Peak', markersize=10)
    
    ax_left.set_title(title)
    ax_left.set_xlabel('Time (ms)')
    ax_left.set_ylabel('ECG (mV)', color='#3979f0')
    ax_left.tick_params('y', colors='#3979f0')
    plt.show()
    return hrs

#Single beats extraction for record
def extract_beats(record, R_indices, peaks_hr, freq):
    
    flat_record = [item for sublist in record for item in sublist]
    ecg_beats = [];
    
    #Extract beats of different lengths depending on heart rate
    for i in range(0, len(R_indices)):
        hr = peaks_hr[R_indices[i]].item();
        if(math.isnan(hr)):
            hr = 70.0;
        samples_per_beat = int(freq*(60.0/hr))
        start = R_indices[i]-samples_per_beat/2;
        if(start<0):
            start = 0;
        end = start + samples_per_beat;
        ecg_beats.append(np.array(flat_record[start:end]));
    
    #Resample beats to get fixed input size for classification
    ecg_nparray = np.empty((len(ecg_beats),250L));
    for i in range(0,len(ecg_beats)):    
        ecg_beats[i], _ = wfdb.processing.resample_sig(x=ecg_beats[i],fs=len(ecg_beats[i]), fs_target=250);
        ecg_nparray[i] = ecg_beats[i];
            
    return ecg_nparray;

def read_record(rec, t0=0, tf=30000):
    # Load the wfdb record and the physical samples
    record = wfdb.rdsamp('./dataset/'+rec, sampfrom=t0, sampto=tf, channels=[0])
    annotation = wfdb.rdann('./dataset/'+rec, 'atr', sampfrom=t0, sampto=tf,summarize_labels=True)
    freq = record.fs
    sig = record.p_signals
    sig = wfdb.processing.normalize(x=sig, lb=0.0, ub=1.0)

    for idx, val in enumerate(sig):    
        record.p_signals[idx,0] = val

    peak_indices = annotation.sample;
    peaks_hr = peaks(x=record.p_signals, peak_indices=peak_indices, fs=record.fs, title="GQRS peaks on record "+rec);
    ecg_beats = extract_beats(record.p_signals, peak_indices, peaks_hr,freq);
    
    return ecg_beats, annotation;
    
t0=0;
tf=600000;
beats = [];

for i in range(105,106):
    rec_beats , annotation = read_record(str(i), t0, tf)
    beats.append(rec_beats);
    #for j in range(0,len(beats[i-105])):
    #    plt.plot(beats[i-105][j]);
    #plt.title("Plotting ECG beats for record "+str(i));
    #plt.show();


# Saving data into a lmdb dataset
caffe.set_device(0);
caffe.set_mode_gpu();

#blobinp = np.append(beats[0][: , np.newaxis, np.newaxis , :] , beats[3][: , np.newaxis, np.newaxis , :] , axis=0)
blobinp = beats[0][: , np.newaxis, np.newaxis , :]
print(blobinp.shape)

N = beats[0].shape[0];

X = blobinp;
Y = np.empty(N, dtype=int);
for i in range(0,N):
    Y[i] = {
        'N': 0,
        'V': 1,
        'Q': 2,
        '+': 3,
        '~': 4,
        '|': 5,
        }[annotation.symbol[i]]

map_size = X.nbytes * 8 * 5

env = lmdb.open('mylmdb', map_size=map_size)

N = X.shape[0]

with env.begin(write=True) as txn:
    #txn is a Transaction object
    for i in range(N):
        datum = caffe.proto.caffe_pb2.Datum()
        datum.channels = X.shape[1]
        datum.height = X.shape[2]
        datum.width = X.shape[3]
        datum.float_data.extend(X[i][0][0])
        #datum.data = X[i].tobytes()
        datum.label = int(Y[i])
        str_id = '{:08}'.format(i)

        # The encode is only essential in Python 3
        txn.put(str_id, datum.SerializeToString())

#env = lmdb.open('mylmdb', readonly=True)
#with env.begin() as txn:
#    raw_datum = txn.get(b'00000000')

#datum = caffe.proto.caffe_pb2.Datum()
#datum.ParseFromString(raw_datum)

#flat_x = np.fromstring(datum.data, dtype=np.float64)

#x = flat_x.reshape(datum.channels, datum.height, datum.width)
#y = datum.label
#print(x);