In [1]:
from __future__ import print_function
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import wfdb
import time
import random
from sklearn.preprocessing import minmax_scale
import sys
import pandas
from torch.utils.tensorboard import SummaryWriter

In [2]:
class StaffIIIDataset(Dataset):
    
    def __init__(self, record_path, excel_path, channel, train=True, length=10000, transform=None):
        
        self.train = train
        self.length = length
        
        with open(record_path) as fp:  
            self.lines = fp.readlines()
                    
        if self.train:
            self.lines = self.lines[:int(0.7*len(self.lines))]
        else:
            self.lines = self.lines[int(0.7*len(self.lines)):int(0.8*len(self.lines))]
            
        self.df = pandas.read_excel(excel_path)
        self.labels = self.df[u'Unnamed: 28'][9:].as_matrix()
        
        # channel
        self.channel = channel
        
    def __getitem__(self, index):
        
        # extract patient_id from file_name
        patient_id = int(self.lines[index][5:8])
        
        if patient_id == 28 or patient_id == 67:
            # train
            index = 0
            patient_id = int(self.lines[index][5:8])
        if  patient_id == 78 or patient_id == 103:
            # val
            index = 105
            patient_id = int(self.lines[index][5:8])
        
        file_name = self.lines[index][:-1]
        data, _ = wfdb.rdsamp("staff-iii-database-1.0.0/" + str(file_name))
        data = np.array(data) # (300000, 9)
        data = data[:self.length, :]
        
        # extract relevant channels
        data = data[:, self.channel]
        
        if self.labels[patient_id] != 'no':
            y = 0.9
        else:
            y = 0.1
        
        return data, y
    
    def __len__(self):
        
        if self.train:
            return int(0.7*519)
        else:
            return int(0.1*519)
        

In [8]:
batch_size = 100

train_dataset = StaffIIIDataset(record_path='staff-iii-database-1.0.0/RECORDS',
                                excel_path='staff-iii-database-1.0.0/STAFF-III-Database-Annotations.xlsx',
                                channel=0)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          num_workers=1)

val_dataset = StaffIIIDataset(record_path='staff-iii-database-1.0.0/RECORDS',
                                excel_path='staff-iii-database-1.0.0/STAFF-III-Database-Annotations.xlsx',
                                channel=0,
                                train=False)

val_loader = DataLoader(dataset=val_dataset,
                        batch_size=batch_size,
                        num_workers=1)



In [5]:
class ConvNetQuake(nn.Module):
    def __init__(self):
        super(ConvNetQuake, self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=151, stride=1)
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=10, kernel_size=45, stride=1)
        self.conv3 = nn.Conv1d(in_channels=10, out_channels=10, kernel_size=20, stride=1)
        self.conv4 = nn.Conv1d(in_channels=10, out_channels=10, kernel_size=10, stride=1)
        self.linear1 = nn.Linear(320, 30)
        self.linear2 = nn.Linear(30, 10)
        self.linear3 = nn.Linear(10, 2)
        self.sigmoid = nn.Sigmoid()
        self.bn1 = nn.BatchNorm1d(3)
        self.bn2 = nn.BatchNorm1d(10)
        self.bn3 = nn.BatchNorm1d(10)
        self.bn4 = nn.BatchNorm1d(10)
        self.mp1 = nn.MaxPool1d(6, stride=2, padding=2)
        self.mp2 = nn.MaxPool1d(20, stride=2, padding=9)
        self.mp3 = nn.MaxPool1d(20, stride=2, padding=9)
        self.mp4 = nn.MaxPool1d(20, stride=2, padding=9)
        self.drop1 = nn.Dropout(p=0.25)
        self.drop2 = nn.Dropout(p=0.5)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        
        x = F.relu(self.bn1((self.conv1(x))))
        x = self.mp1(x)
        x = F.relu(self.bn2((self.conv2(x))))
        x = self.mp2(x)
        x = F.pad(self.conv3(x), (9, 10), "constant", 0)
        x = F.relu(self.bn3(x))
        x = self.mp3(x)
        x = F.pad(self.conv4(x), (4, 5), "constant", 0)
        x = F.relu(self.bn4(x))
        x = self.mp4(x)
        x = self.drop1(x)
        x = torch.reshape(x, (batch_size, -1))
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.drop2(x)
        x = self.linear3(x)

        return x

In [17]:
def train(model, device, train_loader, optimizer, epoch, val_loader, writer, iteration):
    
    model.train()
    print_every = 10
    iteration_ = iteration
    
    for batch_idx, (data, y) in enumerate(train_loader):
        
        data = data.cuda()
        y = y.cuda()
        
        optimizer.zero_grad()
    
        y_pred = model(data)
        
        loss = criterion(y_pred, y)

        loss.backward()
        optimizer.step()
        
        if batch_idx%print_every == 0 and batch_idx != 0:
            
            iteration_ += 1
            writer.add_scalar('Loss/train', loss, iteration_)
            
            # validate
            with torch.no_grad():

                # test_set
                
                for batch_idx, (data_val, y_val) in enumerate(val_loader):
                    a = 1
                    break
                
                y_pred_val = model(data_val)
                
                count = 0
                acc = 0
                for num in y_pred_val:
                    if int(round(num)) == int(round(y_val[count])):
                        acc += 1
                    count += 1

                writer.add_scalar('Accuracy/val', acc, iteration_)

                # train_set
                count = 0
                acc = 0
                for num in y_pred:
                    if int(round(num)) == int(round(y[count])):
                        acc += 1
                    count += 1

                writer.add_scalar('Accuracy/train', acc, iteration_)
    
    return iteration_

In [18]:
channels = {0: "v1", 
            1: "v2",
            2: "v3",
            3: "v4",
            4: "v5",
            5: "v6",
            6: "i",
            7: "ii",
            8: "iii"}

channel_1 = 0

model = ConvNetQuake()
# model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-4)
criterion = nn.BCELoss()

writer = SummaryWriter('/home/jenxime/mi_detection/runs_staff/runs_record_' + str(channels[channel_1]))

for epoch in range(1, 10):
    print("Train Epoch: ", epoch)
    iteration = train(model_, device, train_loader, optimizer, epoch, val_loader, writer, iteration)

NameError: name 'SummaryWriter' is not defined

In [None]:
# dataset size

In [2]:
with open('staff-iii-database-1.0.0/RECORDS') as fp:  
    lines = fp.readlines()

In [3]:
len(lines)

520

In [4]:
lines[0]

'data/001a\n'

In [7]:
bad_idx = ['data/089d']

In [36]:
data_len = []
for index in range(len(lines)):
    file_name = lines[index][:-1]
    # check file_name
    if file_name in bad_idx:
        index = np.random.randint(0, len(lines))
        continue

    data, _ = wfdb.rdsamp("staff-iii-database-1.0.0/" + str(file_name))
    
    j = 0
    for i in range(0, 300000, 10000):
        data_snippet = np.array(data[i:i+10000])
        
        # if not np.isnan(np.sum(data_snippet)):
        #     j += 1
            
        data_snippet = torch.from_numpy(data_snippet)
        if not torch.any(torch.isnan(data_snippet)):
            j += 1
    
    # data_len.append(len(data))
    data_len.append(10000*j)

In [32]:
np.array(data_len).mean()

299210.01926782273

In [33]:
np.array(data_len).std()

5617.223016042727

In [35]:
min(data_len)

190000