In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os, os.path
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,models
from tqdm import tqdm_notebook as tqdm

In [6]:
def imagetensor(imagedir):
  X = np.zeros((301, 1200, 15000))
  for i, im in tqdm(enumerate(os.listdir(imagedir))):
    image= Image.open(imagedir + im)
    image= image.convert('L')
    image= np.array(image, dtype= float)/255
    X[i] = image
    print(i)
  return X

In [None]:
ecg_afib_numpy = imagetensor("/content/drive/MyDrive/ECG_PLOT_IMAGES/AFIB/")

In [None]:
ecg_control_numpy = imagetensor("/content/drive/MyDrive/ECG_PLOT_IMAGES/CONTROL/")

In [11]:
class ResidualBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride=1,kernel_size=3,padding=1,bias=False):
        super(ResidualBlock,self).__init__()
        self.cnn1 =nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )
        self.cnn2 = nn.Sequential(
            nn.Conv2d(out_channels,out_channels,kernel_size,1,padding,bias=False),
            nn.BatchNorm2d(out_channels)
        )
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride,bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Sequential()
            
    def forward(self,x):
        residual = x
        x = self.cnn1(x)
        x = self.cnn2(x)
        x += self.shortcut(residual)
        x = nn.ReLU(True)(x)
        return x

In [38]:
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18,self).__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(1,64,kernel_size=2,stride=2,padding=3,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        
        self.block2 = nn.Sequential(
            nn.MaxPool2d(1,1),
            ResidualBlock(64,64),
            ResidualBlock(64,64,2)
        )
        
        self.block3 = nn.Sequential(
            ResidualBlock(64,128),
            ResidualBlock(128,128,2)
        )
        
        self.block4 = nn.Sequential(
            ResidualBlock(128,256),
            ResidualBlock(256,256,2)
        )
        self.block5 = nn.Sequential(
            ResidualBlock(256,512),
            ResidualBlock(512,512,2)
        )
        
        self.avgpool = nn.AvgPool2d(2)

        self.fc3 = nn.Linear(512,7)
        self.fc4 = nn.Linear(7, 1)

        self.m = nn.Sigmoid()
        
    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.avgpool(x)
        x = x.view(x.size(0),-1)
        x = self.fc3(x)
        x = self.fc4(x)
        out = self.m(x)
        return out

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet18().to(device)


In [None]:
%pip install torchsummary
from torchsummary import summary
summary(model, (1, 64, 64))

In [13]:
# ~~~~~~~~~~~~~~~ TRAIN MODEL ~~~~~~~~~~~~~~~
def train_model(X_train, X_rem, y_train, y_rem, X_valid, X_test, y_valid, y_test, n_classes, model_name):
  class_weight={}

  model = ResNet18().to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=0.1, decay=0.5, )
  #scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-4, max_lr=0.05)
  criterion = nn.CrossEntropyLoss()
  batch_size=32

  model.summary()
  model.compile(loss=criterion,  optimizer=optimizer,  metrics=['accuracy'])

  earlyStopCallback = EarlyStopping(monitor='val_loss', min_delta=0, patience=9,  mode='auto')
  saveBestCallback = ModelCheckpoint(model_name+'weights_only_checkpoint.h5',monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True, mode='auto', period=1)
  reduceLR =ReduceLROnPlateau(monitor = 'val_loss',factor = 0.5,patience = 3,verbose=1,min_lr = 0.00001)
  history = model.fit(X_train, y_train,validation_data=(X_valid, y_valid),epochs=20, batch_size=batch_size, verbose=1, 
                      callbacks=[saveBestCallback,earlyStopCallback,reduceLR]) #class_weight=class_weight

  
  return model

# ~~~~~~~~~~~~~~~ CALCULATE TEST ACCURACY ~~~~~~~~~~~~~~~
def get_test_acc(model, X_test, y_test):
  score = model.evaluate(X_test, y_test, verbose = 0) 

  print('Test loss:', score[0]) 
  print('Test accuracy:', score[1])
  print()
  print("For reference, Attia model reports 83.3% test accuracy.")
  return score[1]

# ~~~~~~~~~~~~~~~ SAVE MODEL ~~~~~~~~~~~~~~~
def save_model(model, name):
  model.save(name)
  print("Saved model to ", name)