<a href="https://colab.research.google.com/github/iamnotwhale/BU-Net_Pytorch_Implementation/blob/heoneyzi/BUnet_Dataset_modified_ipynb%EC%9D%98_fixed%EC%9D%98_%EC%82%AC%EB%B3%B8%EC%9D%98_%EC%82%AC%EB%B3%B8%EC%9D%98_%EC%82%AC%EB%B3%B8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch

import os
import nibabel as nib
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch

In [3]:
import os
import numpy as np
import nibabel as nib
import torch
from torch.utils.data import Dataset

class Custom2DBraTSDataset(Dataset):
    def __init__(self, data_dir, modality):
        self.data_dir = data_dir
        self.modality = modality
        self.patient_ids = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]

        # Initialize lists to store slices
        self.images = []
        self.labels = []

        # Iterate through patients and load slices
        for patient_id in self.patient_ids:
            patient_path = os.path.join(self.data_dir, patient_id)

            # Load image and label volumes
            image = nib.load(os.path.join(patient_path, f'{patient_id}_{self.modality}.nii.gz')).get_fdata()
            label = nib.load(os.path.join(patient_path, f'{patient_id}_seg.nii.gz')).get_fdata()

            # Append all slices to the list
            for slice_idx in range(image.shape[2] // 2 - 20, image.shape[2] // 2 + 20):
                image_slice = image[:, :, slice_idx]
                label_slice = label[:, :, slice_idx]

                # Convert to torch tensor and add channel dimension for image
                image_tensor = torch.tensor(image_slice, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
                # rgb_image_tensor = torch.cat((image_tensor, image_tensor, image_tensor), dim=0)
                label_tensor = torch.tensor(label_slice, dtype=torch.long)

                self.images.append(image_tensor)
                self.labels.append(label_tensor)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label

    def __len__(self):
        return len(self.images)

In [4]:
data_dir = '/content/drive/MyDrive/BraTS_2018_Train_LGG' # put the directory of data
modalities = ['t1', 't1ce', 't2', 'flair']

In [5]:
t1_dataset = Custom2DBraTSDataset(data_dir=data_dir, modality='t1')
t2_dataset = Custom2DBraTSDataset(data_dir=data_dir, modality='t2')
t1ce_dataset = Custom2DBraTSDataset(data_dir=data_dir, modality='t1ce')
flair_dataset = Custom2DBraTSDataset(data_dir=data_dir, modality='flair')

In [7]:
from torch.utils.data import Dataset, DataLoader, ConcatDataset
# 모든 데이터셋을 하나로 합치기
combined_dataset = ConcatDataset([t1_dataset, t2_dataset, t1ce_dataset, flair_dataset])

# 데이터셋 길이 확인 및 DataLoader 설정
print(f"Combined dataset length: {len(combined_dataset)}")



Combined dataset length: 1120


In [8]:
from torch.utils.data import random_split

# 전체 데이터셋 길이
dataset_size = len(combined_dataset)

# 훈련 데이터셋과 검증 데이터셋 사이즈 결정 (예: 75% 훈련, 25% 검증)
train_size = int(0.75 * dataset_size)
val_size = dataset_size - train_size

# 데이터셋을 무작위로 훈련 및 검증 세트로 분할
train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size])

# 각 데이터셋에 대한 DataLoader 설정
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False)

# 확인용 출력
print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Training dataset size: 840
Validation dataset size: 280


In [9]:
for images, labels in train_dataloader:
    print(images.shape, images)  # 출력: [batch_size, 1, D, H, W]
    print(labels.shape, labels)  # 출력: [batch_size, D, H, W]
    break

torch.Size([4, 1, 240, 240]) tensor([[[[-0.4231, -0.4231, -0.4231,  ..., -0.4231, -0.4231, -0.4231],
          [-0.4231, -0.4231, -0.4231,  ..., -0.4231, -0.4231, -0.4231],
          [-0.4231, -0.4231, -0.4231,  ..., -0.4231, -0.4231, -0.4231],
          ...,
          [-0.4231, -0.4231, -0.4231,  ..., -0.4231, -0.4231, -0.4231],
          [-0.4231, -0.4231, -0.4231,  ..., -0.4231, -0.4231, -0.4231],
          [-0.4231, -0.4231, -0.4231,  ..., -0.4231, -0.4231, -0.4231]]],


        [[[-0.4765, -0.4765, -0.4765,  ..., -0.4765, -0.4765, -0.4765],
          [-0.4765, -0.4765, -0.4765,  ..., -0.4765, -0.4765, -0.4765],
          [-0.4765, -0.4765, -0.4765,  ..., -0.4765, -0.4765, -0.4765],
          ...,
          [-0.4765, -0.4765, -0.4765,  ..., -0.4765, -0.4765, -0.4765],
          [-0.4765, -0.4765, -0.4765,  ..., -0.4765, -0.4765, -0.4765],
          [-0.4765, -0.4765, -0.4765,  ..., -0.4765, -0.4765, -0.4765]]],


        [[[-0.4424, -0.4424, -0.4424,  ..., -0.4424, -0.4424, -0.4424

In [10]:
import torch
import torch.nn as nn
import numpy as np
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class WC_Block(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(WC_Block, self).__init__()

        self.split_conv_x1_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(15, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x1_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=(1, 15)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x2_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(1, 15)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x2_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=(15, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.conv_sum = nn.Conv2d(2* out_channels, out_channels, 3, padding=1)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        split_conv_x1 = self.split_conv_x1_1(x)
        split_conv_x1 = self.split_conv_x1_2(split_conv_x1)
        split_conv_x2 = self.split_conv_x2_1(x)
        split_conv_x2 = self.split_conv_x2_2(split_conv_x2)
        x = torch.cat([split_conv_x1, split_conv_x2],dim=1)
        x = self.conv_sum(x)
        x = self.batch_norm(x)
        x = self.relu(x)

        return x

class BU_net(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(BU_net, self).__init__()
        self.conv1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.WC = WC_Block(512, 1024)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
        )

        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.iconv4 = DoubleConv(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.iconv3 = DoubleConv(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.iconv2 = DoubleConv(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.iconv1 = DoubleConv(128, 64)

        self.outconv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)

        WC_block = self.WC(pool4)
        bottleneck = self.bottleneck(WC_block)

        upconv4 = self.upconv4(bottleneck)
        upconv4 = F.interpolate(upconv4, size=conv4.shape[2:4])
        cat4 = torch.cat((upconv4, conv4), dim=1)
        iconv4 = self.iconv4(cat4)
        upconv3 = self.upconv3(iconv4)
        cat3 = torch.cat((upconv3, conv3), dim=1)
        iconv3 = self.iconv3(cat3)
        upconv2 = self.upconv2(iconv3)
        cat2 = torch.cat((upconv2, conv2), dim=1)
        iconv2 = self.iconv2(cat2)
        upconv1 = self.upconv1(iconv2)
        cat1 = torch.cat((upconv1, conv1), dim=1)
        iconv1 = self.iconv1(cat1)

        out = self.outconv(iconv1)
        out = self.softmax(out)
        return out

In [11]:
import torch
from torch.nn import functional as F

def compute_class_weight(target):
    n, H, W = target.size()
    class_weights = torch.zeros(n, H, W).to(target.device)
    max_value = int(target.max().item())
    for i in range(max_value + 1):
        mask = (target == i).float()
        class_weight = 1.0 / (mask.sum() + 1e-6)
        class_weights += mask * class_weight
    return class_weights

def Dice_Loss_Coefficient(pred, target, smooth=1.):
    n, c, H, W = pred.shape  # pred의 차원 정보를 가져옵니다.

    # target을 pred의 크기인 H와 W로 리사이징
    #target_resized = F.interpolate(target.float().unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
    #target_resized = target_resized.unsqueeze(1).expand(n, c, H, W)  # [n, c, H, W]로 확장

    # 클래스 가중치 계산을 위해 리사이징된 target 사용
    target = target.long()
    if target.dim() == 4:
          target = target.squeeze(1)
    target = target.expand(n, H, W)
    weights = compute_class_weight(target)
    #weights = compute_class_weight(target_resized[:, 0, :, :])  # 원래 target 대신 리사이징된 target 사용
    weights = weights.unsqueeze(1).expand(n, c, H, W)  # [n, c, H, W]로 확장
    target = target.unsqueeze(1).expand(n, c, H, W)
    intersection = (pred * target * weights).sum(dim=2).sum(dim=2)
    union = (weights * pred).sum(dim=2).sum(dim=2) + (weights * target).sum(dim=2).sum(dim=2)

    loss = (1 - ((2. * intersection + smooth) / (union + smooth)))

    return loss.mean()

class Weighted_Cross_Entropy_Loss(torch.nn.Module):

    def __init__(self):
        super(Weighted_Cross_Entropy_Loss, self).__init__()

    def forward(self, pred, target):
      n, c, H, W = pred.shape

    # 클래스 레이블 검증 (0부터 c-1까지)
      if (target.max() >= c) or (target.min() < 0):
          #print("target tensor contains out-of-range values")
          target.clamp_(0, c-1)  # 잘못된 값 조정
      if target.dim() == 4:
          target = target.squeeze(1)

      #target = F.interpolate(target.float().unsqueeze(1), size=(H, W), mode='nearest').unsqueeze(1)
      target = target.long()

      target = target.expand(n, H, W)
      weights = compute_class_weight(target)
      target = target.unsqueeze(1)
      logp = F.log_softmax(pred, dim=1)
      logp = torch.gather(logp, 1, target)
      weighted_logp = (logp * weights.unsqueeze(1)).view(n, -1)
      weighted_loss = weighted_logp.sum(1) / weights.view(n, -1).sum(1)
      weighted_loss = -weighted_loss.mean()

      return weighted_loss


class BU_Net_Loss(torch.nn.Module):
    def __init__(self):
        super(BU_Net_Loss, self).__init__()
        self.cross_entropy_loss = Weighted_Cross_Entropy_Loss()

    def forward(self, pred, target):
        wce_loss = self.cross_entropy_loss(pred, target)
        dice_loss = Dice_Loss_Coefficient(pred, target)
        total_loss = wce_loss + dice_loss
        return total_loss


In [12]:
from tqdm import tqdm

In [13]:
def train_model(trainloader, model, optimizer, device):
    model.train()
    loss_criterion = BU_Net_Loss()  # Instantiate the loss object here
    loss_criterion.to(device)       # Ensure the loss model is on the correct device
    for i, (inputs, labels) in tqdm(enumerate(trainloader), total=len(trainloader)):
        inputs = inputs.to(device)
        labels = labels.to(device=device, dtype=torch.int64)
        inputs = inputs.float()

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = loss_criterion(outputs, labels)  # Correct usage of the loss function
        loss.backward()
        optimizer.step()

In [14]:

def accuracy_check(label, pred):
    ims = [label, pred]
    np_ims = []
    for item in ims:
        item = np.array(item)
        np_ims.append(item)
    compare = np.equal(np_ims[0], np_ims[1])
    accuracy = np.sum(compare)
    return accuracy / len(np_ims[0].flatten())

def accuracy_check_for_batch(labels, preds, batch_size):
    total_acc = 0
    for i in range(batch_size):
        total_acc += accuracy_check(labels[i], preds[i])
    return total_acc/batch_size

In [15]:
def get_loss_train(model, trainloader, loss, device):

    model.eval()
    total_acc = 0
    total_loss = 0
    for batch, (inputs, labels) in tqdm(enumerate(trainloader), total = len(trainloader)):
        with torch.no_grad():
            inputs = inputs.to(device)
            labels = labels.to(device = device, dtype = torch.long)
            inputs = inputs.float()


            outputs = model(inputs)
            loss = BU_Net_Loss()
            loss_val = loss(outputs, labels)
            outputs = np.transpose(outputs.cpu(), (0,2,3,1))
            preds = torch.argmax(outputs, dim=3).float()
            acc = accuracy_check_for_batch(labels.cpu(), preds.cpu(), inputs.size()[0])
            total_acc += acc
            total_loss += loss_val.cpu().item()
    return total_acc/(batch+1), total_loss/(batch+1)

In [16]:
def val_model(model, valloader, loss, device):

    total_val_loss = 0
    total_val_acc = 0
    n=0

    for batch, (inputs, labels) in tqdm(enumerate(valloader), total = len(valloader)):
        with torch.no_grad():

            inputs = inputs.to(device)
            labels = labels.to(device=device, dtype=torch.int64)

            outputs = model(inputs)
            loss = BU_Net_Loss()
            loss_value = loss(outputs, labels)

            outputs = np.transpose(outputs.cpu(), (0, 2, 3, 1))
            preds = torch.argmax(outputs, dim=3).float()

            acc = accuracy_check_for_batch(labels.cpu(), preds.cpu(), inputs.size()[0])
            total_val_acc += acc
            total_val_loss += loss_value.cpu().item()



    return total_val_acc/(batch+1), total_val_loss/(batch+1)

In [17]:
batch_size = 16
learning_rate = 0.01
momentum = 0.9
epochs = 5

# 모델 초기화
model = BU_net(1, 4)
print(model)

criterion = BU_Net_Loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)
history = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[]}
val_loader = val_dataloader

BU_net(
  (conv1): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )


In [18]:
print("Training")
for epoch in range(epochs):
    train_model(train_dataloader, model, optimizer, device)
    train_acc, train_loss = get_loss_train(model, train_dataloader, BU_Net_Loss, device)
    print("epoch", epoch + 1, "train loss : ", train_loss, "train acc : ", train_acc)
    val_loader = val_dataloader
    val_acc, val_loss = val_model(model, val_loader, BU_Net_Loss, device)
    print("epoch", epoch + 1, "val loss : ", val_loss, "val acc : ", val_acc)

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'./{str(epoch)}.pth')

print('Finish Training')

Training


  0%|          | 1/210 [00:29<1:43:11, 29.63s/it]


KeyboardInterrupt: 