In [1]:
import os
import h5py
import numpy as np
import cv2 as cv
from tqdm import tqdm
from matplotlib import pyplot as plt

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

# Device Configuration
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= "4"

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
def load_mat(filepath, img_size):
    mat = h5py.File(filepath, "r")
    label = mat['cjdata']['label'].__array__()
    img = mat['cjdata']['image'].__array__()
    img = cv.resize(img, (img_size, img_size))
    img = img/img.max()
    mask = mat['cjdata']['tumorMask'].__array__()
    mask = cv.resize(mask, (img_size, img_size))
    mat.close()
    return img, label, mask

class MriDataset(Dataset):
    
    def __init__(self, root_dir, img_size=256, split_ratio=0.9, mode="train"):
        self.dir = root_dir
        self.img_size = img_size
        file_list = [file for file in os.listdir(root_dir) if file.split(".")[-1] == "mat"]
        self.file_list = file_list[:int(len(file_list)*split_ratio)] if mode == "train" else file_list[int(len(file_list)*split_ratio):]

    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        img, label, mask = load_mat(os.path.join(self.dir, self.file_list[idx]), self.img_size)
        return img[np.newaxis], label[0], mask

In [3]:
class Conv_Block(nn.Module):
    '''(Conv, ReLU) * 2'''
    def __init__(self, in_ch, out_ch, pool=None):
        super(Conv_Block, self).__init__()
        layers = [nn.Conv2d(in_ch, out_ch, 3, padding=1),
                  nn.ReLU(inplace=True),
                  nn.Conv2d(out_ch, out_ch, 3, padding=1),
                  nn.ReLU(inplace=True)]
        
        if pool:
            layers.insert(0, nn.MaxPool2d(2, 2))
        
        self.conv = nn.Sequential(*layers)
            

    def forward(self, x):
        x = self.conv(x)
        return x


class Upconv_Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Upconv_Block, self).__init__()

        self.upconv = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)
        
        self.conv = Conv_Block(in_ch, out_ch)

    def forward(self, x1, x2):
        # x1 : unpooled feature
        # x2 : encoder feature
        x1 = self.upconv(x1)
        x1 = nn.UpsamplingBilinear2d(x2.size()[2:])(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

class Build_UNet(nn.Module):
    def __init__(self, input_channel=3, num_classes=5):
        super(Build_UNet, self).__init__()
        self.conv1 = Conv_Block(input_channel, 64)
        self.conv2 = Conv_Block(64, 128, pool=True)
        self.conv3 = Conv_Block(128, 256, pool=True)
        self.conv4 = Conv_Block(256, 512, pool=True)
        self.conv5 = Conv_Block(512, 1024, pool=True)
        
        self.unconv4 = Upconv_Block(1024, 512)
        self.unconv3 = Upconv_Block(512, 256)
        self.unconv2 = Upconv_Block(256, 128)
        self.unconv1 = Upconv_Block(128, 64)
        
        self.prediction = nn.Conv2d(64, num_classes, 1)
        
    def forward(self, x):
        en1 = self.conv1(x) #/2
        en2 = self.conv2(en1) #/4
        en3 = self.conv3(en2) #/8
        en4 = self.conv4(en3) #/16
        en5 = self.conv5(en4) 
        
        de4 = self.unconv4(en5, en4) # /8
        de3 = self.unconv3(de4, en3) # /4
        de2 = self.unconv2(de3, en2) # /2
        de1 = self.unconv1(de2, en1) # /1
        
        output = self.prediction(de1)
        return output

In [4]:
epochs=100
batch_size=2

data_root = "../mri_example/data"
mat_list = os.listdir(data_root)
train_mriset = MriDataset(data_root, mode="train")
val_mriset = MriDataset(data_root, mode="validation")
train_loader = DataLoader(train_mriset, batch_size=batch_size)
val_loader = DataLoader(val_mriset, batch_size=batch_size)

In [5]:
net = Build_UNet(input_channel=1, num_classes=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-4)

In [6]:
def train_epoch(epochs, epoch, dataloader, model, optimizer, criterion, device):
    """
    Train 1 epoch 
    """
    model.train()
    
    performance_dict = {
        "epoch": epoch+1
    }

    summ = {
        "loss": 0
    }

    # Training 1 Epoch
    with tqdm(total=len(dataloader)) as t:
        t.set_description(f'[{epoch+1}/{epochs}]')
        
        # Iteration step
        for i, (batch_img, batch_lab, batch_mask) in enumerate(dataloader):
            
            X = batch_img.type(torch.float).to(device)
            Y = batch_mask.type(torch.float).to(device)
            
            predictions = net.forward(X)

            # Calculate Loss
            loss = criterion(predictions.squeeze(dim=1), Y)
            summ["loss"] += loss.item()

            # Train & Update model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix({key: f"{val/(i+1):05.3f}"for key, val in summ.items()})
            t.update()
    
    performance_dict.update({key: val/(i+1) for key, val in summ.items()})
    return performance_dict

def eval_epoch(epochs, epoch, dataloader, model, optimizer, criterion, device):
    """
    Train 1 epoch 
    """
    model.eval()
    
    performance_dict = {
        "epoch": epoch+1
    }

    summ = {
        "loss_val": 0
    }

    # Training 1 Epoch
    with tqdm(total=len(dataloader)) as t:
        t.set_description(f'[{epoch+1}/{epochs}]')
        
        # Iteration step
        with torch.no_grad():
            for i, (batch_img, batch_lab, batch_mask) in enumerate(dataloader):
                
                X = batch_img.type(torch.float).to(device)
                Y = batch_mask.type(torch.float).to(device)

                predictions = net.forward(X)

                # Calculate Loss
                loss = criterion(predictions.squeeze(dim=1), Y)
                summ["loss_val"] += loss.item()

                t.set_postfix({key: f"{val/(i+1):05.3f}"for key, val in summ.items()})
                t.update()
        
        performance_dict.update({key: val/(i+1) for key, val in summ.items()})
    return performance_dict

In [7]:
for epoch in range(epochs):
    metrics_summary = train_epoch(epochs, epoch, train_loader, net, optimizer, criterion, device)
    metrics_summary.update(eval_epoch(epochs, epoch, val_loader, net, optimizer, criterion, device))

    metrics_string = " ; ".join(f"{key}: {value:05.3f}" for key, value in metrics_summary.items())
    print(f"[{epoch+1}/{epochs}] Performance: {metrics_string}")
    # avg_loss = 0
    # avg_acc = 0
    # total_batch = len(train_mriset) // batch_size
    # for i, (batch_img, batch_lab, batch_mask) in enumerate(train_loader):
    #     X = batch_img.type(torch.float).to(device)
    #     Y = batch_mask.type(torch.float).to(device)

    #     optimizer.zero_grad()

    #     y_pred = net.forward(X)

    #     loss = criterion(y_pred.squeeze(dim=1), Y)
        
    #     loss.backward()
    #     optimizer.step()
    #     avg_loss += loss.item()

    #     if (i+1)%20 == 0 :
    #         print("Epoch : ", epoch+1, "Iteration : ", i+1, " Loss : ", loss.item())

    # with torch.no_grad():
    #     val_loss = 0
    #     total = 0
    #     correct = 0
    #     for (batch_img, batch_lab, batch_mask) in val_loader:
    #         X = batch_img.type(torch.float).to(device)
    #         Y = batch_mask.type(torch.float).to(device)
    #         y_pred = net(X)
    #         val_loss += criterion(y_pred.squeeze(dim=1), Y)
    #         _, predicted = torch.max(y_pred.data, 1)
    #         total += Y.size(0)
    #     val_loss /= total

    # print("Epoch : ", epoch+1, " Loss : ", (avg_loss/total_batch), " Val Loss : ", val_loss.item())
    # num_plot=4
    # shuffle_idx = np.random.choice(mridataset.len, num_plot, replace=False)
    # In = X.cpu().numpy()[shuffle_idx].transpose(0, 2, 3, 1)
    # predicted = predicted.cpu().numpy()[shuffle_idx]
    # plt.figure(figsize=(10, 4))
    # for i in range(num_plot):
    #     plt.subplot(2, num_plot, i+1)
    #     plt.imshow(In[i])
    #     plt.axis("off")
    #     plt.subplot(2, num_plot, i+1+num_plot)
    #     plt.imshow(predicted[i], cmap='gray')
    # plt.show()

print("Training Done !")

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
[1/100]: 100%|██████████| 1379/1379 [01:12<00:00, 19.11it/s, loss=0.080]
[1/100]: 100%|██████████| 154/154 [00:04<00:00, 33.81it/s, loss_val=0.063]


[1/100] Performance: epoch: 1.000 ; loss: 0.080 ; loss_val: 0.063


[2/100]: 100%|██████████| 1379/1379 [01:17<00:00, 17.88it/s, loss=0.061]
[2/100]: 100%|██████████| 154/154 [00:04<00:00, 32.19it/s, loss_val=0.058]


[2/100] Performance: epoch: 2.000 ; loss: 0.061 ; loss_val: 0.058


[3/100]: 100%|██████████| 1379/1379 [01:18<00:00, 17.61it/s, loss=0.058]
[3/100]: 100%|██████████| 154/154 [00:04<00:00, 32.62it/s, loss_val=0.056]


[3/100] Performance: epoch: 3.000 ; loss: 0.058 ; loss_val: 0.056


[4/100]: 100%|██████████| 1379/1379 [01:18<00:00, 17.49it/s, loss=0.057]
[4/100]: 100%|██████████| 154/154 [00:04<00:00, 32.61it/s, loss_val=0.056]


[4/100] Performance: epoch: 4.000 ; loss: 0.057 ; loss_val: 0.056


[5/100]: 100%|██████████| 1379/1379 [01:18<00:00, 17.48it/s, loss=0.056]
[5/100]: 100%|██████████| 154/154 [00:04<00:00, 32.83it/s, loss_val=0.053]


[5/100] Performance: epoch: 5.000 ; loss: 0.056 ; loss_val: 0.053


[6/100]:  18%|█▊        | 252/1379 [00:14<01:04, 17.50it/s, loss=0.055]


KeyboardInterrupt: 