In [36]:
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader



In [37]:
import os
os.listdir('/kaggle/input/ptb-xl-sorted-images/sorted_images')

['Other',
 'Conduction',
 'Arrhythmia',
 'Normal',
 'MI',
 'Morphology',
 'Hypertrophy']

In [38]:
class ECGDataset(Dataset):
    def __init__(self, file_paths, labels, transform = None, mode = 'random'):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
        self.mode = mode
    def __len__(self):
        return len(self.file_paths)
    def apply_thresholds(self, img):
        lottery = np.random.randint(3)
        if lottery == 0:
            _, th_binary = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY)
            return th_binary
        if lottery == 1:
            _, th_otsu = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
            return th_otsu
        if lottery == 2:
            th_adaptive = cv2.adaptiveThreshold(img, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 2)
            return th_adaptive
    def __getitem__(self, idx):
        img = cv2.imread(self.file_paths[idx], cv2.IMREAD_GRAYSCALE)
        label = self.labels[idx]
        imgs = [img, cv2.medianBlur(img, 3)]
        variants = []
        for base_img in imgs:
            variants.append(base_img)
            variants.append(self.apply_thresholds(base_img))
        if self.mode == 'random':
            chosen_img = variants[np.random.randint(len(variants))]
            final_img = chosen_img
        else:
            final_img = np.stack(variants, axis = 0)
        if self.transform:
            final_img = self.transform(final_img)
        else:
            type(final_img)
            if final_img.ndim == 2:
                final_img = torch.from_numpy(final_img).float().unsqueeze(0) / 255.0
            elif final_img.ndim == 3:
                final_img = torch.from_numpy(final_img).float().unsqueeze(1) / 255.0
        return final_img, label

In [39]:
data_root = '/kaggle/input/ptb-xl-sorted-images/sorted_images'
print(os.listdir(data_root))

['Other', 'Conduction', 'Arrhythmia', 'Normal', 'MI', 'Morphology', 'Hypertrophy']


In [40]:
for name in os.listdir(data_root):
    print([name])

['Other']
['Conduction']
['Arrhythmia']
['Normal']
['MI']
['Morphology']
['Hypertrophy']


In [41]:
import os
from collections import Counter

def get_image_paths_and_labels(data_root):
    folder_names = sorted([name for name in os.listdir(data_root) if name != 'Morphology'])
    label_names = ['Normal', 'Abnormal']
    label_to_idx = {label: idx for idx, label in enumerate(label_names)}

    normal_paths = []
    normal_labels = []
    abnormal_paths =  []
    abnormal_labels = []

    for label in folder_names:
        label_folder = os.path.join(data_root, label)
        print(label)
        print(label_folder)
        if label == 'Normal':
            for fname in os.listdir(label_folder):
                normal_paths.append(os.path.join(label_folder, fname))
                normal_labels.append(label_to_idx["Normal"])
        else:
            for fname in os.listdir(label_folder):
                abnormal_paths.append(os.path.join(label_folder, fname))
                abnormal_labels.append(label_to_idx['Abnormal'])

    print("Normal Images count: ", len(normal_paths))
    print("Abnormal Images Count", len(abnormal_paths))
    print("Normal Labels count: ", len(normal_paths))
    print("Abnormal Labels Count", len(abnormal_paths))
    file_paths = normal_paths + abnormal_paths
    labels = normal_labels + abnormal_labels
    print("File Paths Count: ", len(file_paths))
    print("File Labels Count", len(labels))
    print(set(labels))
    return file_paths, labels, label_to_idx

data_root = '/kaggle/input/ptb-xl-sorted-images/sorted_images'
file_paths, labels, label_to_idx = get_image_paths_and_labels(data_root)


Arrhythmia
/kaggle/input/ptb-xl-sorted-images/sorted_images/Arrhythmia
Conduction
/kaggle/input/ptb-xl-sorted-images/sorted_images/Conduction
Hypertrophy
/kaggle/input/ptb-xl-sorted-images/sorted_images/Hypertrophy
MI
/kaggle/input/ptb-xl-sorted-images/sorted_images/MI
Normal
/kaggle/input/ptb-xl-sorted-images/sorted_images/Normal
Other
/kaggle/input/ptb-xl-sorted-images/sorted_images/Other
Normal Images count:  8637
Abnormal Images Count 8260
Normal Labels count:  8637
Abnormal Labels Count 8260
File Paths Count:  16897
File Labels Count 16897
{0, 1}


In [42]:
# import os
# from collections import Counter

# def get_image_paths_and_labels(data_root):
#     label_names = sorted([name for name in os.listdir(data_root) if name not in ['Hypertrophy', 'Other', 'Morphology']])
#     label_names.append('Merged_hyper_other')
#     label_to_idx = {label: idx for idx, label in enumerate(label_names)}
#     print(label_names)
    
#     file_paths = []
#     labels = []
#     hyper_other_files = []
#     hyper_other_labels = []
    
#     for label in label_names:
#         print(label)
#         label_folder = os.path.join(data_root, label)
#         # print(label_folder)
#         # print('mho_start_if')
#         if label == 'Merged_hyper_other':
#             # print('mho_l_start')
#             # Fill merged class with images from Hypertrophy and Other folders
#             for merge_label in ['Hypertrophy', 'Other']:
#                 merge_folder = os.path.join(data_root, merge_label)
#                 if os.path.isdir(merge_folder):
#                     for fname in os.listdir(merge_folder):
#                         # print(fname)
#                         if fname.lower().endswith(('png', 'jpg', 'jpeg')):
#                             hyper_other_files.append(os.path.join(merge_folder, fname))
#                             hyper_other_labels.append(label_to_idx['Merged_hyper_other'])
#         elif os.path.isdir(label_folder):
#             for fname in os.listdir(label_folder):
#                 if fname.lower().endswith(('png', 'jpg', 'jpeg')):
#                     file_paths.append(os.path.join(label_folder, fname))
#                     labels.append(label_to_idx[label])
#     # print(hyper_other_files)
#     # print(hyper_other_labels)
#     file_paths += hyper_other_files
#     labels += hyper_other_labels
            
#     # print(set(file_paths))
#     label_count = Counter(labels)
#     print(label_to_idx)
#     print(f"final labels: {set(labels)}")
#     print(f"Count: {label_count}")
#     return file_paths, labels, label_to_idx

# data_root = '/kaggle/input/ptb-xl-sorted-images/sorted_images'
# file_paths, labels, label_to_idx = get_image_paths_and_labels(data_root)

# dataset = ECGDataset(file_paths, labels)
# loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [43]:
import torch.nn as nn
import torch.nn.functional as F

class SimpleECGCNN(nn.Module):
    def __init__(self, num_classes, resize_shape = (256, 256)):
        super(SimpleECGCNN, self).__init__()
        self.resize_shape = resize_shape
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        out_h, out_w = self.resize_shape[0] // 8, self.resize_shape[1] // 8
        self.fc1 = nn.Linear(128 * out_h * out_w, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.interpolate(x, size = self.resize_shape, mode = 'bilinear', align_corners = False)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

num_classes = 2
resize_shape = (256, 256)
model = SimpleECGCNN(num_classes, resize_shape)

In [44]:
import torch.optim as optim
from torch.utils.data import random_split

In [45]:
from sklearn.model_selection import train_test_split
train_paths, val_paths, train_labels, val_labels = train_test_split(
    file_paths, labels, test_size=0.2, stratify=labels, random_state=42
)

In [46]:
resize_shape = (256, 256)
batch_size = 256

train_dataset = ECGDataset(train_paths, train_labels)
val_dataset = ECGDataset(val_paths, val_labels)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


In [47]:
num_classes = len(label_to_idx)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleECGCNN(num_classes, resize_shape).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 1e-4)

In [48]:
import torch
torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device(type='cuda')

In [49]:
a=torch.cuda.FloatTensor()

  a=torch.cuda.FloatTensor()


In [None]:
import time

num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs} | ")
    model.train()
    running_loss = 0.0
    correct = total = 0

    epoch_start = time.time()

    for batch_idx, (imgs, lbls) in enumerate(train_loader):
        batch_start = time.time()
        # print(f"Batch {batch_idx+1}/{len(train_loader)} | ")

        imgs, lbls = imgs.to(device), lbls.to(device)
        data_load_time = time.time() - batch_start
        # print(f"DataLoad: {data_load_time:.4f}s | ")
        
        optimizer.zero_grad()

        forward_start = time.time()
        outputs = model(imgs)
        forward_time = time.time() - forward_start
        # print(f"Forward: {forward_time:.4f}s | ")

        loss_calc_start = time.time()
        loss = criterion(outputs, lbls)
        loss_calc_time = time.time() - loss_calc_start
        # print(f"Loss: {loss_calc_time:.4f}s | ")

        backward_start = time.time()
        loss.backward()
        backward_time = time.time() - backward_start
        # print(f"Backward: {backward_time:.4f}s | ")

        optimizer_step_start = time.time()
        optimizer.step()
        optimizer_step_time = time.time() - optimizer_step_start
        # print(f"OptStep: {optimizer_step_time:.4f}s")

        running_loss += loss.item() * imgs.size(0)
        _, predicted = outputs.max(1)
        total += lbls.size(0)
        correct += predicted.eq(lbls).sum().item()

        print(
            f"Epoch {epoch + 1} / {num_epochs} |"
            f"Batch {batch_idx+1}/{len(train_loader)} | "
            f"DataLoad: {data_load_time:.4f}s | "
            f"Forward: {forward_time:.4f}s | "
            f"Loss: {loss_calc_time:.4f}s | "
            f"Backward: {backward_time:.4f}s | "
            f"OptStep: {optimizer_step_time:.4f}s"
        )
    train_loss = running_loss / total
    train_acc = correct / total

    epoch_train_time = time.time() - epoch_start

    model.eval()
    val_running_loss = 0.0
    val_correct = val_total = 0
    with torch.no_grad():
        for val_batch_idx, (imgs, lbls) in enumerate(val_loader):
            batch_start = time.time()
            imgs, lbls = imgs.to(device), lbls.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, lbls)
            val_running_loss += loss.item() * imgs.size(0)
            _, predicted = outputs.max(1)
            val_total += lbls.size(0)
            val_correct += predicted.eq(lbls).sum().item()
            print(f"Val Batch {val_batch_idx+1}/{len(val_loader)} | Time: {time.time() - batch_start:.4f}s")
    val_loss = val_running_loss / val_total
    val_acc = val_correct / val_total

    print(
        f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}"
    )

Epoch 1/10 | 
Epoch 1 / 10 |Batch 1/53 | DataLoad: 0.6421s | Forward: 1.0188s | Loss: 0.0516s | Backward: 0.4192s | OptStep: 0.1695s
Epoch 1 / 10 |Batch 2/53 | DataLoad: 0.5788s | Forward: 0.0047s | Loss: 0.0020s | Backward: 0.0016s | OptStep: 0.0010s
Epoch 1 / 10 |Batch 3/53 | DataLoad: 0.5654s | Forward: 0.0038s | Loss: 0.0001s | Backward: 0.0016s | OptStep: 0.0010s
Epoch 1 / 10 |Batch 4/53 | DataLoad: 0.5905s | Forward: 0.0021s | Loss: 0.0001s | Backward: 0.0031s | OptStep: 0.0009s
Epoch 1 / 10 |Batch 5/53 | DataLoad: 0.5327s | Forward: 0.0017s | Loss: 0.0001s | Backward: 0.0015s | OptStep: 0.0010s
Epoch 1 / 10 |Batch 6/53 | DataLoad: 0.5563s | Forward: 0.0014s | Loss: 0.0001s | Backward: 0.0017s | OptStep: 0.0009s
Epoch 1 / 10 |Batch 7/53 | DataLoad: 0.5664s | Forward: 0.0013s | Loss: 0.0001s | Backward: 0.0015s | OptStep: 0.0010s
Epoch 1 / 10 |Batch 8/53 | DataLoad: 0.5484s | Forward: 0.0022s | Loss: 0.0002s | Backward: 0.0020s | OptStep: 0.0010s
Epoch 1 / 10 |Batch 9/53 | DataLoa

In [None]:
torch.save(model.state_dict(), 'ecgcnn_2.pth')

In [None]:
model.eval()