In [6]:
import math
import os
import sys
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as pat
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

from utils.mylogger import NewLogger
from utils.torchblock import Block
from utils.json_utils import Json_Config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

In [7]:
test_mode = True

json_config = Json_Config('config.json')
c = json_config.export_config()

logger = None
logger = NewLogger(test=test_mode, result_path='./log/', model_name=c.model_name)
json_config.log_and_save(logger)
logger.debug('dir_path: %s' % logger.dir_path)

2020-04-13 16:55:54 - ==== config ====
2020-04-13 16:55:54 - model_name simple_model
2020-04-13 16:55:54 - num_epochs 100
2020-04-13 16:55:54 - batch_size 10
2020-04-13 16:55:54 - loss_weight 10.0
2020-04-13 16:55:54 - lr 0.0002
2020-04-13 16:55:54 - ngpu 4
2020-04-13 16:55:54 - num_workers 4
2020-04-13 16:55:54 - uaph_path ./data/ctu/uaph.csv
2020-04-13 16:55:54 - data_path ./data/ctu/ctu_csv/
2020-04-13 16:55:54 - uaph_threshold 7.05
2020-04-13 16:55:54 - dir_path: ./log/test


In [None]:
device = torch.device("cuda:0" if (torch.cuda.is_available() and c.ngpu > 0) else "cpu")
torch.backends.cudnn.benchmark = True

manualSeed = 999
random.seed(manualSeed)
torch.manual_seed(manualSeed)
logger.debug('seed: %s' % manualSeed)

In [None]:
# CTG Dataset

class CTGDataset(data.Dataset):
    def __init__(self, fhr_path_list, toco_path_list, uaph_path, uaph_threshold):
        self.fhr_path_list = fhr_path_list
        self.toco_path_list = toco_path_list
        self.uaph = pd.read_csv(uaph_path, header=None).values.reshape([-1])
        self.uaph_threshold = uaph_threshold

    def __len__(self):
        return len(self.fhr_path_list)

    def __getitem__(self, index):
        fhr_data = self.path2tensor(self.fhr_path_list[index])
        toco_data = self.path2tensor(self.toco_path_list[index])
        
        if self.uaph[index] < self.uaph_threshold:
            label = torch.tensor([0., 1.])
        else:
            label = torch.tensor([1., 0.])
        
        uaph = self.uaph[index]
        
        return fhr_data, toco_data, label, uaph
    
    @staticmethod
    def path2tensor(path):
        np_data = pd.read_csv(path, header=None).values.reshape([-1])
        np_data = np_data[-14400:]
        nn_data = torch.from_numpy(np_data).to(torch.float).reshape(1, -1)
        return nn_data

In [None]:
fhr_path_list = [os.path.join(c.data_path, '%s_fhr.csv' % i) for i in range(1, 552 + 1)]
toco_path_list = [os.path.join(c.data_path, '%s_toco.csv' % i) for i in range(1, 552 + 1)]

all_dataset = CTGDataset(fhr_path_list, toco_path_list, c.uaph_path, c.uaph_threshold)
all_size = len(all_dataset)

train_size = int(all_size * 0.8)
val_size = all_size - train_size

train_dataset, val_dataset = data.random_split(all_dataset, [train_size, val_size])

dataloader = {
    'train': data.DataLoader(train_dataset, batch_size=c.batch_size, 
                             shuffle=True, num_workers=c.num_workers),
    'val': data.DataLoader(val_dataset, batch_size=c.batch_size, 
                           shuffle=True, num_workers=c.num_workers),
}

logger.debug('train_size:%s, val_size:%s' % (train_size, val_size))

In [None]:
def calc_statistics(pred, truth):
    """
    Args:
        pred(torch.tensor): 0 or 1
        truth(torch.tensor): 0 or 1
    
    Returns:
        tp, fp, fn, tn, tpr, tnr, prec, acc, f1
    """
    
    tp = int(torch.sum((pred == 1) & (truth == 1)))   
    fp = int(torch.sum((pred == 1) & (truth == 0)))
    fn = int(torch.sum((pred == 0) & (truth == 1))) 
    tn = int(torch.sum((pred == 0) & (truth == 0)))
    
    tpr = tp / (tp + fn) if tp + fn > 0 else -1
    tnr = tn / (fp + tn) if fp + tn > 0 else -1
    acc = (tp + tn) / (tp + fp + fn + tn) if tp + fp + fn + tn > 0 else -1
    prec = tp / (tp + fp) if tp + fp > 0 else -1

    try:
        f1 = (2 * prec * tpr) / (tpr + prec)
    except:
        f1 = -1
        
    return tp, fp, fn, tn, tpr, tnr, acc, prec, f1

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.seq = nn.Sequential(
            nn.MaxPool1d(4, stride=4, padding=0, dilation=1),
            
            nn.Conv1d(1, 10, kernel_size=15, padding=7, stride=1),
            nn.BatchNorm1d(10),
            nn.ReLU(),
            nn.Conv1d(10, 10, kernel_size=15, padding=7, stride=1),
            nn.BatchNorm1d(10),
            nn.ReLU(),
            nn.MaxPool1d(2, stride=2, padding=0, dilation=1),
            
            nn.Conv1d(10, 20, kernel_size=15, padding=7, stride=1),
            nn.BatchNorm1d(20),
            nn.ReLU(),
            nn.MaxPool1d(2, stride=2, padding=0, dilation=1),
            
            nn.Conv1d(20, 40, kernel_size=15, padding=7, stride=1),
            nn.BatchNorm1d(40),
            nn.ReLU(),
            nn.MaxPool1d(2, stride=2, padding=0, dilation=1),
            
            nn.Flatten(),
            
            nn.Linear(40 * 450, 60),
            nn.Linear(60, 2),
            nn.Softmax()
        )

    def forward(self, x):
        return self.seq(x)

In [None]:
net = Net().to(device)
net.apply(weights_init)
weight = torch.tensor([1.0, c.loss_weight]).to(device)

criterion = nn.BCELoss(reduction='mean', weight=weight)
optimizer = optim.Adam(net.parameters(), lr=c.lr, betas=(0.8, 0.999))

logger.debug(net)

net(iter(dataloader['train']).next()[0].to(device))

In [None]:
loss_list = {'train': [None], 'val': []}
tpr_list= {'train': [None], 'val': []}
tnr_list= {'train': [None], 'val': []}

for epoch in range(1, c.num_epochs+1):
    
    for phase in ['train', 'val']:
        
        phase_loss = 0
        
        epoch_pred_cls = torch.tensor([]).to(torch.long)
        epoch_true_cls = torch.tensor([]).to(torch.long)
        
        if phase == 'train':
            net.train()
        else:
            net.eval()
        
        if (epoch == 1) & (phase == 'train'):
            continue
        
        for i, batch in enumerate(tqdm(dataloader[phase]), 0):
            inputs = batch[0].to(device)
            labels = batch[2].to(device)
            
            optimizer.zero_grad()
            
            with torch.set_grad_enabled(phase == 'train'):
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                
                _, pred_cls = torch.max(outputs, 1)
                _, true_cls = torch.max(labels, 1)
                
                epoch_pred_cls = torch.cat([epoch_pred_cls, pred_cls.to('cpu')])
                epoch_true_cls = torch.cat([epoch_true_cls, true_cls.to('cpu')])
                
                phase_loss += loss.item() * inputs.size(0)
            
        phase_loss /= len(dataloader[phase].dataset)
        
        logger.debug('epoch: {}, phase: {}, loss: {:.4f}'.format(
                     epoch, phase, phase_loss))
        tp, fp, fn, tn, tpr, tnr, acc, prec, f1 = \
            calc_statistics(epoch_pred_cls, epoch_true_cls)
        
        logger.debug('[tp: %s fn: %s] [fp: %s tn: %s]' % (tp, fn, fp, tn))
        logger.debug('tpr: {:.3f}, tnr: {:.3f}'.format(tpr, tnr))
        logger.debug('acc: {:.3f}, prec: {:.3f}, f1: {:.3f}'.format(acc, prec, f1))
    
        loss_list[phase] += [phase_loss]
        tpr_list[phase] += [tpr]
        tnr_list[phase] += [tnr]

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(loss_list['train'][1:], label='train')
plt.plot(loss_list['val'][1:], label='val')
plt.legend()
plt.title('train: {:.4f} val: {:.4f}'.format(
            loss_list['train'][-1], loss_list['val'][-1]))
plt.savefig(os.path.join(logger.dir_path, 'loss.png'))

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(tpr_list['train'][1:], label='train_tpr')
plt.plot(tpr_list['val'][1:], label='val_tpr')
plt.plot(tnr_list['train'][1:], label='train_tnr')
plt.plot(tnr_list['val'][1:], label='val_tnr')

plt.legend()
plt.title('train_tpr: {:.4f}, val_tpr: {:.4f}, train_tnr: {:.4f}, val_tnr: {:.4f},'.format(
            tpr_list['train'][-1], tpr_list['val'][-1], 
            tnr_list['train'][-1], tnr_list['val'][-1]))
plt.savefig(os.path.join(logger.dir_path, 'tpr_tnr.png'))

In [None]:
def viewer(index, dataset, min_length):
    """
    visualize fetal heart rate
    
    Args:
        index (int):
        dataset (list):
        min_length (int): visualizing duration(min) in an image
    """
    
    fhr = dataset[index][0].reshape(-1)
    fhr = np.where(fhr == 0.0, np.nan, fhr)
    
    fig = plt.figure(figsize=(15, 12))
    
    mpl.rcParams['axes.xmargin'] = 0
    mpl.rcParams['axes.ymargin'] = 0
    
    for i in range(3):
        ax = plt.subplot(3, 1, i+1)

        rec = pat.Rectangle(xy=(0, 110), width=min_length*240, height=50,
                            color="whitesmoke", alpha=1)
        ax.add_patch(rec)

        for j in range(min_length):
           ax.axvline(x=j*240, color="gray")

        start = (i - 3) * min_length
        end = (i - 2) * min_length

        ax.plot(fhr[start*240: end*240-1])
        
        uaph = dataset[index][3]

        ax.set_title('UApH {:.2f}, {} ~ {} min'.format(uaph, start, end))
        
        ax.set_xlim([0, min_length*240])
        ax.set_ylim([40,220])
        
        plt.xticks(color="None")

# viewer(0, all_dataset, 20)