# Using 2D CNNs and Image Preprocessing to read ECG graphs and perform PVC diagnosis

By Lee Lip Tong

In [None]:
import cv2
import math
import numpy as np
import pickle
import os
import pytesseract
import os
import cv2
from PIL import Image
import pandas as pd 
import keras
from keras.models import Sequential
from keras.layers import Reshape
from keras.layers import Dense, Activation, Flatten, Conv2D, Dropout,MaxPool2D, ELU, BatchNormalization
import pickle
from keras.utils import to_categorical

## Data Preprocessing
Functions for data preprocessing

In [None]:
#process image takes an image and splits it into slices which are positioned around each blue line
def process_image(filepath):
    img = cv2.imread(filepath)
    graph = isolate_graph(img)
    graph = cv2.blur(graph,(3,3))
    ret,graph = cv2.threshold(graph,240,255,cv2.THRESH_BINARY)
    pulselines = get_pulse_lines(img)
    print(len(pulselines))
    labels = get_pulse_strings(img,pulselines)
    top, bottom, number = extract_horizontal_lines(img)
    scalingfactor=(bottom-top)/(number-1)
    left, right, hnum = get_precalculated_values()
    hinterval=(right-left)/(hnum-1)
    lrcrop=0.4/0.2
    rrcrop=0.6/0.2
    sliced_image_width = int((lrcrop+rrcrop)*hinterval)
    filename=filepath[6:-4]
    filenamelist=[filename]*len(pulselines)
    idlist=range(0,len(pulselines))
    if os.path.isdir('train_processed/{}'.format(filename))!= True:
        os.mkdir('train_processed/{}'.format(filename))
    with open('text_processed/{}.txt'.format(filename), 'wb') as file:
        pickle.dump([pulselines,labels,filenamelist,idlist], file)
    for i, x in enumerate(pulselines):
        leftcrop = int(x - lrcrop*hinterval)
        rightcrop= int(x+ rrcrop*hinterval)
        blankimage=np.zeros((img.shape[0],sliced_image_width), np.uint8)
        blankimage=cv2.bitwise_not(blankimage)
        if leftcrop >= 0 and rightcrop <= right:
            cropimage=graph[0:img.shape[0], leftcrop:rightcrop]
        if leftcrop < 0 :
            cropimage=blankimage
            cropimage[0:img.shape[0],-leftcrop:(rightcrop-leftcrop)]=graph[0:img.shape[0],0:rightcrop]
        if rightcrop > right:
            cropimage=blankimage
            cropimage[0:img.shape[0],0:(right-leftcrop)]=graph[0:img.shape[0],leftcrop:right]

        cropimage=cv2.resize(cropimage,(128,128))
        
        cv2.imwrite('train_processed/{}/{}.png'.format(filename,i), cropimage)
    
    return pulselines, labels, filenamelist, idlist
        
def process_test_image(filepath):
    img = cv2.imread(filepath)
    graph = isolate_graph(img)
    graph = cv2.blur(graph,(3,3))
    ret,graph = cv2.threshold(graph,240,255,cv2.THRESH_BINARY)
    pulselines = get_pulse_lines(img)
    print(len(pulselines))
    labels = get_pulse_strings(img,pulselines)
    top, bottom, number = extract_horizontal_lines(img)
    scalingfactor=(bottom-top)/(number-1)
    left, right, hnum = get_precalculated_values()
    hinterval=(right-left)/(hnum-1)
    lrcrop=0.4/0.2
    rrcrop=0.6/0.2
    sliced_image_width = int((lrcrop+rrcrop)*hinterval)
    filename=filepath[5:-4]
    filenamelist=[filename]*len(pulselines)
    idlist=range(0,len(pulselines))
    if os.path.isdir('test_processed/{}'.format(filename))!= True:
        os.mkdir('test_processed/{}'.format(filename))
    with open('test_text_processed/{}.txt'.format(filename), 'wb') as file:
        pickle.dump([pulselines,labels,filenamelist,idlist], file)
    for i, x in enumerate(pulselines):
        leftcrop = int(x - lrcrop*hinterval)
        rightcrop= int(x+ rrcrop*hinterval)
        blankimage=np.zeros((img.shape[0],sliced_image_width), np.uint8)
        blankimage=cv2.bitwise_not(blankimage)
        if leftcrop >= 0 and rightcrop <= right:
            cropimage=graph[0:img.shape[0], leftcrop:rightcrop]
        if leftcrop < 0 :
            cropimage=blankimage
            cropimage[0:img.shape[0],-leftcrop:(rightcrop-leftcrop)]=graph[0:img.shape[0],0:rightcrop]
        if rightcrop > right:
            cropimage=blankimage
            cropimage[0:img.shape[0],0:(right-leftcrop)]=graph[0:img.shape[0],leftcrop:right]

        cropimage=cv2.resize(cropimage,(128,128))
        
        cv2.imwrite('test_processed/{}/{}.png'.format(filename,i), cropimage)
    
    return pulselines, labels, filenamelist, idlist
        
        
#Extraction function for grid lines
def extract_horizontal_lines(image):
    lower = np.array([148, 148, 255], dtype = "uint8")
    upper = np.array([195, 195, 255], dtype = "uint8")
    mask = cv2.inRange(image, lower, upper)
    img2 = cv2.bitwise_and(image, image, mask = mask)
    gray = cv2.cvtColor(img2,cv2.COLOR_BGR2GRAY)
    blur = cv2.blur(gray,(12,12))
    edges = cv2.Canny(blur, 80, 120)
    lines = cv2.HoughLinesP(edges, 1, math.pi/2, 2, None, 300, 20)
    blankimage = np.zeros((image.shape[0],image.shape[1],1), np.uint8)
    for liner in lines:
        for line in liner:
            pt1 = (0,line[1])
            pt2 = (image.shape[1],line[1])
            cv2.line(blankimage, pt1, pt2, (255), 15)       

    edges = cv2.Canny(blankimage, 80, 120)
    lines = cv2.HoughLinesP(edges, 1, math.pi/2, 2, None, 300, 20)

    linelist = []
    for line in lines:
        for lm in line:
            linelist.append(lm[1])
    linelist.sort()
    linecentre=[sum(linelist[i:i+2])//2 for i in range(0,len(linelist),2)]

    return linecentre[0], linecentre[-1], len(linecentre)

def get_precalculated_values():
    return 33, 7487, 101

#Extraction function for grid lines
def extract_vertical_lines(image):
    lower = np.array([148, 148, 255], dtype = "uint8")
    upper = np.array([195, 195, 255], dtype = "uint8")
    mask = cv2.inRange(image, lower, upper)
    img2 = cv2.bitwise_and(image, image, mask = mask)
    gray = cv2.cvtColor(img2,cv2.COLOR_BGR2GRAY)
    blur = cv2.blur(gray,(5,5))


    edges = cv2.Canny(blur, 80, 120)
    lines = cv2.HoughLinesP(edges, 1, math.pi/2, 2, None, 300, 20)
    blankimage = np.zeros((image.shape[0],image.shape[1],1), np.uint8)
    for liner in lines:
        for line in liner:
            pt1 = (line[0],0)
            pt2 = (line[0],image.shape[0])
            cv2.line(blankimage, pt1, pt2, (255), 5)       

    edges = cv2.Canny(blankimage, 80, 120)
    lines = cv2.HoughLinesP(edges, 1, math.pi/2, 2, None, 300, 20)


    linelist = []
    for line in lines:
        for lm in line:
            linelist.append(lm[0])
    linelist.sort()
    linecentre=[sum(linelist[i:i+2])//2 for i in range(0,len(linelist),2)]
    
    return linecentre[0], linecentre[-1], len(linecentre)

#Removes gridlines and any text from the graph
def isolate_graph(image):
    lower = np.array([0, 0, 0], dtype = "uint8")
    upper = np.array([70, 70, 70], dtype = "uint8")
    mask = cv2.inRange(image, lower, upper)
    img2 = cv2.bitwise_and(image,image, mask = mask)
    ret,gray = cv2.threshold(mask,0,255,cv2.THRESH_BINARY)
    gray=cv2.bitwise_not(gray)
    return gray

#Gets the pixel location of blue/green pulse lines
def get_pulse_lines(img):
    lower = np.array([146, 113, 146], dtype = "uint8")
    upper = np.array([255, 255, 148], dtype = "uint8")
    mask = cv2.inRange(img, lower, upper)
    img2 = cv2.bitwise_and(img, img, mask = mask)
    gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 80, 120)
    lines = cv2.HoughLinesP(edges, 1, math.pi/2, 2, None, 30, 1);
    blankimage = np.zeros((img.shape[0],img.shape[1],1), np.uint8)
    for liner in lines:
        for line in liner:
            pt1 = (line[0],0)
            pt2 = (line[2],img.shape[1])
            cv2.line(blankimage, pt1, pt2, (255), 1)       
    edges = cv2.Canny(blankimage, 80,120)
    lines = cv2.HoughLinesP(edges, 1, math.pi/2, 2, None, 100, 20)
    blankimage = np.zeros((img.shape[0],img.shape[1],1), np.uint8)
    for liner in lines:
        for line in liner:
            pt1 = (line[0],0)
            pt2 = (line[2],img.shape[1])
            cv2.line(blankimage, pt1, pt2, (255), 10)       
    linelist = []
    for line in lines:
        for lm in line:
            linelist.append(lm[0])
    linelist.sort()
    linecentre=[sum(linelist[i:i+2])//2 for i in range(0,len(linelist),2)]
    return linecentre

#Gets the letters within the pulse strings
def get_pulse_strings(image, pulselines):
    cropimage=image[0:90,0:image.shape[1] ]
    lower = np.array([0, 0, 0], dtype = "uint8")
    upper = np.array([255, 128, 0], dtype = "uint8")
    mask = cv2.inRange(cropimage, lower, upper)
    img2 = cv2.bitwise_and(cropimage, cropimage, mask = mask)
    gray=cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
    gray=cv2.bitwise_not(gray)
    ret,gray = cv2.threshold(gray,240,255,cv2.THRESH_BINARY)
    text=[]

    for i,line in enumerate(pulselines):
        current=cv2.Canny(gray[10:80, line-15:line+15],100,200)
        text.append(pytesseract.image_to_string(current,lang='eng',config='-c tessedit_char_whitelist=NVP-+RA -psm 10'))
    return text

#converts pixel distance to time
def pxtotime(px):
    left,right,number = get_precalculated_values()
    interval=(right-left)/(number-1)
    time=(px-left)/interval*0.2
    return round(time,3)

### Preprocess data

Run this once to preprocess training and testing data

In [None]:
traindir="train/"
trainlist= os.listdir(traindir)
pulselist=[]
pulsetext=[]
pulseloc=[]
pulseid=[]
for i,file in enumerate(trainlist):
    plist, ptext, ploc, pid = process_image(traindir+file)
    pulselist+=plist
    pulsetext+=ptext
    pulseloc+=ploc
    pulseid+=pid
    print('Processed Images for {}, {} / {}'.format(file, i+1,len(trainlist)))
    
with open('data_processed.txt', 'wb') as file:
        pickle.dump([pulselist, pulsetext, pulseloc, pulseid], file)
        
        
traindir="test/"
trainlist= os.listdir(traindir)
pulselist=[]
pulsetext=[]
pulseloc=[]
pulseid=[]
for i,file in enumerate(trainlist):
    print('opening {}, {} / {}'.format(file, i+1,len(trainlist)))
    plist, ptext, ploc, pid = process_test_image(traindir+file)
    pulselist+=plist
    pulsetext+=ptext
    pulseloc+=ploc
    pulseid+=pid
    print('Processed Images for {}, {} / {}'.format(file, i+1,len(trainlist)))
    
with open('test_data_processed.txt', 'wb') as file:
        pickle.dump([pulselist, pulsetext, pulseloc, pulseid], file)

## Model: Deep 2D CNN for classifying ECG beats
This deep layered CNN is an adaptation of VGG16 as prescribed in this paper: https://arxiv.org/pdf/1804.06812.pdf
Some code was adapted from: https://github.com/ankur219/ECG-Arrhythmia-classification/blob/master/model.py

In [None]:
IMAGE_SIZE=[128,128]

model = Sequential()

model.add(Conv2D(64, (3,3),strides = (1,1), input_shape = IMAGE_SIZE +[3],kernel_initializer='glorot_uniform'))
model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(Conv2D(64, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(MaxPool2D(pool_size=(2, 2), strides= (2,2)))

model.add(Conv2D(128, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(BatchNormalization())

model.add(Conv2D(128, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(BatchNormalization())

model.add(MaxPool2D(pool_size=(2, 2), strides= (2,2)))

model.add(Conv2D(256, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(BatchNormalization())

model.add(Conv2D(256, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(BatchNormalization())

model.add(MaxPool2D(pool_size=(2, 2), strides= (2,2)))

model.add(Flatten())

model.add(Dense(2048))

model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(Dropout(0.5))

model.add(Dense(2, activation='softmax'))

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

## Training the model

We train the model for 100 Epochs and observe 5 different measures: accuracy, precision, recall, PPV and specifity

PPV or positive prediction value and specifity are the most important and we will decide to continue or stop training depending on their scores

I attained around 98.7% PPV after training for 80 epochs

In [None]:
epochs=50
batch_size=64

with open('data_processed.txt', 'rb') as file:
        data = pickle.load(file)

train_data = data[0]
print(len(train_data))
test_name = data[2][-1025:-1]
test_number = data[3][-1025:-1]
test_label = data[1][-1025:-1]
ytestlabels=[]
for x in test_label:
        ytestlabels+=[int(x=='V')]
y_test_binary = to_categorical(ytestlabels, num_classes=2)
test_image=[]
for i in range(0,len(test_name)):
        test_image.append(cv2.imread('train_processed/{}/{}.png'.format(test_name[i],test_number[i])))    
test_image=np.array(test_image)

number_of_batches= int(len(data[0][0:-1025]) / batch_size)


for epoch in range(epochs):

    for index in range(number_of_batches):
        #print('Loading data for Batch: {}/{}'.format(index+1, number_of_batches))
        start = index*batch_size
        end = start+batch_size
        image_number=data[3][start:end]
        image_name=data[2][start:end]
        ylabels=[]
        image_batch=[]
        for x in data[1][start:end]:
            ylabels+=[int(x=='V')]
        y_binary = to_categorical(ylabels, num_classes=2)
        for i in range(0,len(image_name)):
            image_batch.append(cv2.imread('train_processed/{}/{}.png'.format(image_name[i],image_number[i])))
        image_batch=np.array(image_batch)
        loss = model.train_on_batch(image_batch,y_binary)

    print('Loss for Epoch {}: {}'.format(epoch,loss))
    yhat=model.predict(test_image)
    Vtruth=y_test_binary[:,1]
    Vresult=yhat[:,1]

    correctscore=0
    falsenegative=0
    falsepositives=0
    truepositives=0
    truenegatives=0
    for i in range(len(Vtruth)):
        if Vtruth[i]==1 and Vresult[i]==1:
            truepositives+=1
        if Vtruth[i]==1 and Vresult[i]==0:
            falsenegative+=1
        if Vtruth[i]==0 and Vresult[i]==1:
            falsepositives+=1
        if Vtruth[i]==0 and Vresult[i]==0:
            truenegatives+=1
    accuracy=(truepositives+truenegatives)/(truepositives+truenegatives+falsepositives+falsenegative)
    precision=truepositives/(truepositives+truenegatives)
    recall=truepositives/(truepositives+falsenegative)
    ppv=truepositives/(truepositives+falsepositives)
    specifity=truenegatives/(falsepositives+truenegatives)
    print('Validation Accuracy: {}, Precision :{}, Recall: {}, PPV:{}, specifity:{}'.format(accuracy,precision,recall,ppv,specifity))
    if epoch %2 == 0:
        model.save_weights('epoch s {}.h5'.format(epoch))
    

## Predictions using trained model

After training we then feed the testing data into our model to check its predictions

In [None]:
with open('test_data_processed.txt', 'rb') as file:
        testdata = pickle.load(file)

train_data = testdata[0]
print(len(train_data))
test_name = testdata[2]
test_number = testdata[3]
test_label = testdata[1]
ytestlabels=[]
for x in test_label:
        ytestlabels+=[int(x=='V')]
y_test_binary = to_categorical(ytestlabels, num_classes=2)
test_image=[]
for i in range(0,len(test_name)):
        test_image.append(cv2.imread('test_processed/{}/{}.png'.format(test_name[i],test_number[i])))    
test_image=np.array(test_image)

yhattest=model.predict(test_image)

predictions=yhattest[:,1]
predictions=predictions.astype(int)

with open('predictions.txt', 'wb') as file:
        pickle.dump(predictions,file)

In [None]:
with open('test_data_processed.txt', 'rb') as file:
        testdata=pickle.load(file)

In [None]:
tplist=testdata[0]
tploc=testdata[2]
tpid=testdata[3]
csvData=[['filename','Location of V beat']]

for i in range(0,len(testdata[0])):
    if predictions[i]==1:
        csvData+=[['{}.png'.format(tploc[i]), pxtotime(tplist[i])]]

### Export to CSV

In [None]:
import csv
with open('Vbeat Predictions.csv', 'w') as csvFile:
    writer = csv.writer(csvFile)
    writer.writerows(csvData)
csvFile.close()