In [2]:
import torch
from torch import nn
from torchvision import transforms
from torchvision import models
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
from monai.losses import DiceLoss
from PIL import Image
from PIL import ImageFilter
from PIL import ImageOps
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from copy import copy
from copy import deepcopy
import cv2
import sys
import os
import pandas as pd
import numpy as np
import random
from pathlib import Path
import math
from SA_UNet import SA_UNet
# from evaluation import *

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # dead kernel for matplotlib

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [4]:
metadata = pd.read_csv('../doby_meta.csv')
metadata.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7064 entries, 0 to 7063
Data columns (total 7 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   subject_id      7064 non-null   int64 
 1   study_id        7064 non-null   int64 
 2   dicom_id        7064 non-null   object
 3   DicomPath       7064 non-null   object
 4   edema_severity  7064 non-null   int64 
 5   normal          7064 non-null   int64 
 6   CHF             7064 non-null   bool  
dtypes: bool(1), int64(4), object(2)
memory usage: 338.2+ KB


In [5]:
metadata = metadata[metadata['subject_id'] < 16000000]
metadata.info()

<class 'pandas.core.frame.DataFrame'>
Index: 4198 entries, 0 to 4197
Data columns (total 7 columns):
 #   Column          Non-Null Count  Dtype 
---  ------          --------------  ----- 
 0   subject_id      4198 non-null   int64 
 1   study_id        4198 non-null   int64 
 2   dicom_id        4198 non-null   object
 3   DicomPath       4198 non-null   object
 4   edema_severity  4198 non-null   int64 
 5   normal          4198 non-null   int64 
 6   CHF             4198 non-null   bool  
dtypes: bool(1), int64(4), object(2)
memory usage: 233.7+ KB


In [6]:
SEG_BASE_PATH = '../chest-x-ray-dataset-with-lung-segmentation-1.0.0/chest-x-ray-dataset-with-lung-segmentation-1.0.0'
ORIG_BASE_PATH = '../physionet.org/files/mimic-cxr-jpg/2.0.0'

In [7]:
TRAIN_TRANSFORMS = transforms.Compose([
    transforms.ToTensor()
])

TARGET_TRANSFORMS = transforms.Compose([
    transforms.ToTensor()
])

In [8]:
class Dataset(Dataset):
    def __init__(self, metadata, mask_base_path, orig_base_path, transform=None, target_transform=None):
        self.metadata = metadata
        self.base_path = Path(orig_base_path)
        self.mask_path = Path(mask_base_path)
        
        self.transform = transform
        self.target_transform = target_transform
        
    def __getitem__(self, idx):
        detail_path = self.metadata.loc[idx, 'DicomPath']
        y_path = self.mask_path / Path(str(detail_path[:-4]) + '-mask.jpg')
        x_path = self.base_path / Path(str(detail_path))

        y = Image.open(str(y_path))
        x = cv2.imread(str(x_path), cv2.IMREAD_GRAYSCALE)
        x = cv2.resize(x, dsize=(224, 224), interpolation=cv2.INTER_AREA)

        del detail_path
        del y_path
        del x_path

        # 이것 때문인지는 모르겠는데 ToTensor로 Normalize가 안 된다.
        # y = np.where(y > 128, 255, 0) # 128 threshold
        # y = np.where(y > 128, 255, 0) / 255.
        
        if self.transform:
            x = self.transform(x)

        if self.target_transform:
            y = self.target_transform(y)

        return x, y

    def __len__(self):
        return self.metadata['normal'].count()

In [9]:
ds = Dataset(metadata, 
             SEG_BASE_PATH, 
             ORIG_BASE_PATH,
             transform=TRAIN_TRANSFORMS,
             target_transform=TARGET_TRANSFORMS)

In [10]:
ds_size = len(ds)
train_size = int(ds_size * 0.8)
test_size = ds_size - train_size
train_ds, test_ds = random_split(ds, [train_size, test_size], generator=torch.manual_seed(42))

In [11]:
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=8, shuffle=True)

In [12]:
model = SA_UNet(channel=1, block_size_per=10) # feature size 10%
model.to(device)

SA_UNet(
  (e1): EncoderBlock(
    (convBlk): ConvBlock(
      (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (e2): EncoderBlock(
    (convBlk): ConvBlock(
      (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
schedular = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
loss_fn = DiceLoss(reduction='mean').to(device)

In [14]:
def train(model, data_loader, optimizer, epoch):
    model.train()
    running_loss = 0
    running_acc = 0
    n_data = 0

    for batch_idx, (batch, target) in enumerate(data_loader, start=1):
        batch, target = batch.to(device), target.to(device)

        optimizer.zero_grad()

        output = model(batch)
        output_flat = output.view(output.size(0), -1)
        target_flat = target.view(target.size(0), -1)

        loss = loss_fn(output_flat, target_flat)
        running_loss += loss.item()

        output_flat = (output_flat >= torch.FloatTensor([0.5]).to(device)).type(torch.float32)
        correct = (output_flat == target_flat).sum().item()
        running_acc += correct

        loss.backward()
        optimizer.step()

        # Pixel 별 Accuracy
        n_data += len(batch)
        output_pixels = (output_flat.shape[1])
        
        print(f'\rTrain Epoch: {epoch} [{n_data}/{len(data_loader.dataset)} ({100 * batch_idx / len(data_loader):.2f}%)]  Accuracy: {100*running_acc/(n_data*output_pixels):.2f}%  Loss: {running_loss/batch_idx:.4f}', end='')

In [15]:
def test(model, data_loader):
    model.eval()
    test_acc = 0
    test_loss = 0
    n_data = 0
    TP, FP, TN, FN = 0, 0, 0, 0
    output_pixels = 0
    
    with torch.no_grad():
        for batch, target in data_loader:
            batch, target = batch.to(device), target.to(device)

            output = model(batch)

            output_flat = output.view(output.size(0), -1)
            target_flat = target.view(target.size(0), -1)

            loss = loss_fn(output_flat, target_flat)
            test_loss += loss.item()

            output_flat = (output_flat >= torch.FloatTensor([0.5]).to(device)).type(torch.float32)

            correct = (output_flat == target_flat).sum().item()
            test_acc += correct

            TP += ((output_flat == target_flat) & (target_flat == 1)).sum().item()
            FP += ((output_flat != target_flat) & (target_flat == 0)).sum().item()
            TN += ((output_flat == target_flat) & (target_flat == 0)).sum().item()
            FN += ((output_flat != target_flat) & (target_flat == 1)).sum().item()
            
            n_data += len(batch)
            output_pixels = (output_flat.shape[1])
            
            print(f'\rTest set: [{100*n_data/len(data_loader.dataset):.2f}%]', end='')
    
    test_acc = 100 * test_acc / (len(data_loader.dataset)*output_pixels)
    test_loss = test_loss / len(data_loader)
    
    print(f'\rTest set: Accuracy: {test_acc:.2f}%  Loss: {test_loss:.4f}')

    return test_acc, test_loss, TP, FP, TN, FN

In [16]:
def getMetric(TP, FP, TN, FN):
    # base case: divide by zero
    TP = 0.1 if TP == 0 else TP
    FP = 0.1 if FP == 0 else FP
    TN = 0.1 if TN == 0 else TN
    FN = 0.1 if FN == 0 else FN
    
    sensitivity = TP/(TP+FN)
    specificity = TN/(TN+FP)
    precision = TP/(TP+FP)
    recall = TP/(TP+FN)
    f1_score = 2*precision*recall/(precision+recall)
    mcc = ((TP*TN)-(FP*FN))/math.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN))
    
    return sensitivity, specificity, f1_score, mcc

In [17]:
accs = []
losses = []
best_acc = 0
best_f1 = 0

best_acc_model = None
best_acc_model_state = None
best_f1_model = None
best_f1_model_state = None

In [18]:
for epoch in range(1, 50+1):
    train(model, train_dl, optimizer, epoch)
    
    print()
    
    acc, loss, tp, fp, tn, fn = test(model, test_dl)
    sensitivity, specificity, f1_score, mcc = getMetric(tp, fp, tn, fn)
    print(f'TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}')
    print(f'Sensitivity: {sensitivity:.4f}, Specificity: {specificity:.4f}\nF1-Score: {f1_score:.4f}, MCC: {mcc:.4f}')

    if acc > best_acc:
        best_acc = acc
        best_acc_model = deepcopy(model)
        best_acc_model_state = deepcopy(model.state_dict())

    if f1_score > best_f1:
        best_f1 = f1_score
        best_f1_model = deepcopy(model)
        best_f1_model_state = deepcopy(model.state_dict())
        
    schedular.step(loss)
    accs.append(acc)
    losses.append(loss)

    print('================================================================')

RuntimeError: The size of tensor a (224) must match the size of tensor b (226) at non-singleton dimension 3

In [None]:
accs_np = np.array(accs)
losses_np = np.array(losses)
np.save('./sa_unet_acc.npy', accs_np)
np.save('./sa_unet_loss.npy', losses_np)

torch.save(best_acc_model, './sa_unet_best_acc_model.pt')
torch.save(best_acc_model_state, './sa_unet_best_acc_model_state.pt')
torch.save(best_f1_model, './sa_unet_best_f1_model.pt')
torch.save(best_f1_model_state, './sa_unet_best_f1_model_state.pt')