In [2]:
!pip install einops
!pip install nibabel 
!pip install SimpleITK

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import os
import random
import pandas as pd
import csv
import sys
import time
import logging
from tqdm import tqdm
import argparse
import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from torchvision import transforms
import torchvision.transforms.functional as ff
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import nibabel as nib
import SimpleITK as sitk
from einops import rearrange

In [4]:
device = torch.device("cuda"if torch.cuda.is_available() else "cpu")

In [5]:
! nvidia-smi

Tue Mar 28 13:15:55 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    24W / 300W |      2MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
device

device(type='cuda')

In [7]:
def read_files(root):
  train_path_list = os.listdir(root)[1:]
  train_path_list.sort()

  list_img = []
  list_target = []
  for i in range(len(train_path_list)):
    path = os.path.join(root, train_path_list[i])
    patient_path_list = os.listdir(path)
    patient_path_list.sort()

    for i in range(len(patient_path_list)):
      if 'frame' in patient_path_list[i] and 'gt' not in patient_path_list[i]:
        list_img.append(patient_path_list[i])
      elif 'frame' in patient_path_list[i] and 'gt' in patient_path_list[i]:
        list_target.append(patient_path_list[i])

  new_patient_path = train_path_list * 2
  new_patient_path.sort()
  return list_img, list_target, new_patient_path

In [9]:
list_img, list_target, new_patient_path = read_files('/content/drive/MyDrive/database/training')

with open('/content/drive/MyDrive/database/train_name/image.csv', 'w') as f:
  csv_writer = csv.writer(f)
  csv_writer.writerow(['path'])
  for i in range(len(list_img)):
    
    path = os.path.join('/content/drive/MyDrive/database/training', new_patient_path[i], list_img[i])
    csv_writer.writerow([path])
  f.close()

with open('/content/drive/MyDrive/database/train_name/target.csv', 'w') as f:
  csv_writer = csv.writer(f)
  csv_writer.writerow(['path'])
  for i in range(len(list_target)):
    
    path = os.path.join('/content/drive/MyDrive/database/training', new_patient_path[i], list_img[i])
    csv_writer.writerow([path])
  f.close()

In [10]:
list_img, list_target, new_patient_path = read_files('/content/drive/MyDrive/database/testing')

with open('/content/drive/MyDrive/database/test_name/img.csv', 'w') as f:
  csv_writer = csv.writer(f)
  csv_writer.writerow(['path'])
  for i in range(len(list_img)):
    
    path = os.path.join('/content/drive/MyDrive/database/testing', new_patient_path[i], list_img[i])
    csv_writer.writerow([path])
  f.close()

with open('/content/drive/MyDrive/database/test_name/target.csv', 'w') as f:
  csv_writer = csv.writer(f)
  csv_writer.writerow(['path'])
  for i in range(len(list_target)):
    
    path = os.path.join('/content/drive/MyDrive/database/testing', new_patient_path[i], list_img[i])
    csv_writer.writerow([path])
  f.close()

In [6]:
img_path_img_train = pd.read_csv('/content/drive/MyDrive/database/train_name/image.csv')
target_path_target_train = pd.read_csv('/content/drive/MyDrive/database/train_name/target.csv')

img_path_img_test = pd.read_csv('/content/drive/MyDrive/database/test_name/img.csv')
target_path_target_test = pd.read_csv('/content/drive/MyDrive/database/test_name/target.csv')

In [7]:
def read_files(img_path, target_path):
  img_num = []
  img_list = []
  target_list = []
  for i in range(len(img_path)):
    img = sitk.ReadImage(img_path['path'][i])
    img = sitk.GetArrayFromImage(img)
    img = img[:, None, :, :]
    for i in range(img.shape[0]):
      img_i = img[i]
      img_list.append(img_i)
    img_num.append(img.shape[0])
    
    target = sitk.ReadImage(target_path['path'][i])
    target = sitk.GetArrayFromImage(target)
    target = target[:, None, :, :]
    for i in range(target.shape[0]):
      target_i = target[i]
      target_list.append(target_i)
  
  return img_list, target_list, img_num

In [8]:
img_list_train, target_list_train, img_num_train = read_files(img_path_img_train, target_path_target_train)
img_list_test, target_list_test, img_num_test = read_files(img_path_img_train, target_path_target_train)

In [9]:
class ACDC(Dataset):

    def __init__(self, img_list, target_list, img_num, crop_size = 256, padding_size = 100):

      self.img_list = img_list
      self.target_list = target_list
      self.img_num = img_num
      self.crop_size = crop_size
      self.padding_size = padding_size
        
    def __len__(self):
      
      return sum(self.img_num)

    def __getitem__(self, idx):
      
      img = torch.tensor(self.img_list[idx], dtype=torch.float32)
      target = torch.tensor(self.target_list[idx], dtype=torch.float32)

      img, target = self.center_crop(img, target, self.crop_size, self.padding_size)
      
      return img, target
    
    def center_crop(self, img, target, crop_size, padding_size):
      img = F.pad(img, pad=(self.padding_size, self.padding_size, self.padding_size, self.padding_size), mode='constant', value=0)
      target = F.pad(target, pad=(self.padding_size, self.padding_size, self.padding_size, self.padding_size), mode='constant', value=0)

      img = ff.center_crop(img, crop_size)
      target = ff.center_crop(target, crop_size)
      

      return img, target

In [10]:
ACDC_train = ACDC(img_list_train, target_list_train, img_num_train)
ACDC_test = ACDC(img_list_test, target_list_test, img_num_test)

In [12]:
train_data = DataLoader(ACDC_train, batch_size=2, shuffle=True)
test_data = DataLoader(ACDC_test, batch_size=2, shuffle=True)

In [13]:
class conv_block(nn.Module):
    """
    Convolution Block 
    """
    def __init__(self, in_ch, out_ch):
        super(conv_block, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):

        x = self.conv(x)
        return x


class up_conv(nn.Module):
    """
    Up Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x


class Attention_block(nn.Module):
    """
    Attention Block
    """

    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()

        self.W_g = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        out = x * psi
        return out


class AttU_Net(nn.Module):
    def __init__(self, img_ch=1, output_ch=4):
        super(AttU_Net, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(img_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32)
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0)


    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)

        x4 = self.Att5(g=d5, x=e4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=e3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=e2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=e1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        return out


class Attention(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention, self).__init__()

        self.q = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.k = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, g, x):
      
      q = self.q(g)
      k = self.k(x)

      B, C, H, W = q.shape
      q = rearrange(q, 'b c h w -> b (h w) c')
      k = rearrange(k, 'b c h w -> b (h w) c')
      x = rearrange(x, 'b c h w -> b (h w) c')

      attn = torch.matmul(q, k.transpose(-1, -2)) * ((H * W) ** -0.5)
      attn = self.softmax(attn)

      out = torch.matmul(attn, x)
      out = rearrange(out, 'b (h w) c -> b c h w', h=H, w=W)

      return out


class ReAttU_Net(nn.Module):
    def __init__(self, img_ch=1, output_ch=4):
        super(ReAttU_Net, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(img_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Att5 = Attention(F_g=filters[3], F_l=filters[3], F_int=filters[2])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32)
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0)


    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)

        d5 = self.Up5(e5)

        x4 = self.Att5(g=d5, x=e4)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        x3 = self.Att4(g=d4, x=e3)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(g=d3, x=e2)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(g=d2, x=e1)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        out = self.Conv(d2)

        return out


In [19]:
model = ReAttU_Net(img_ch=1, output_ch=10)
data = torch.rand(size=[1, 1, 224, 224])
print(model(data).shape)

torch.Size([1, 10, 224, 224])


In [20]:
model1 = AttU_Net(img_ch=1, output_ch=10)
data = torch.rand(size=[1, 1, 224, 224])
print(model1(data).shape)

torch.Size([1, 10, 224, 224])


In [14]:
def make_one_hot(input, num_classes):

  shape = np.array(input.shape)
  shape[1] = num_classes
  shape = tuple(shape) 
  result = torch.zeros(shape)
  result = result.scatter_(1, input, 1)

  return result

class DiceLoss(nn.Module):
  def __init__(self):
    super(DiceLoss, self).__init__()


  def forward(self, pred, target):
    target = make_one_hot(target, 4)
    num = target.size(1)
    smooth = 1

    m1 = pred.view(num, -1)
    m2 = target.view(num, -1)

    intersection = (m1 * m2)

    score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
    dice_loss = 1 - score.sum() / num

    return dice_loss

def dice_coeff(pred, target):
  target = make_one_hot(target, 4)
  smooth = 1.
  num = pred.size(1)
  m1 = pred.view(num, -1)
  m2 = target.view(num, -1)
  intersection = (m1 * m2).sum()
  return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)

In [18]:
model1 = AttU_Net()
model1 = model1.to(device)
model2 = ReAttU_Net().to(device)

In [25]:
def train(model):
  model.train()
  optimizer = optim.Adam(model.parameters(), lr=1e-4)
  loss_function = DiceLoss()

  train_loss = 0

  for epoch in range(100):
    for idx, (image, label) in enumerate(train_data):

      image = image.to(device)
      label = label.to(device)
      print(image.is_cuda,label.is_cuda)

      pred = model(image)
      print(pred.is_cuda)
      loss = loss_function(pred, label)
      
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      train_loss += loss.item()
    
    dice_coeff = dice_coeff(pred, label)

    print('Epoch %d/%d, Loss=%f, Dice_score=%f'.format(epoch, 100, train_loss / len(train_data), dice_coeff))

In [26]:
train(model1)

True True
True


RuntimeError: ignored

In [None]:

for idx, (image, label) in enumerate(train_data):
  print(image.shape)
  break


torch.Size([100, 1, 256, 256])


In [None]:
def get_logger(logdir):
  if not os.path.exists(logdir):
    os.makedirs(logdir)
  logname = f'run-{time.strftime('%Y-%m-%d-%H-%M')}.log'
  log_file = os.path.join(logdir, logname)

  logger = logging.getLogger('train')
  logger.setLevel(logging.INFO)
  formatter = logging.Formatter('%(asctime)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S')

  stream_handler = logging.StreamHandler(sys.stdout)
  stream_handler.setFormatter(formatter)
  logger.addHandler(stream_handler)


  file_handler = logging.FileHandler(log_file)
  file_handler.setFormatter(formatter)
  logger.addHandler(file_handler)

  return logger

In [None]:
def get_argparser():
  parser = argparse.ArgumentParser()

  # 日志
  parser.add_argument('--tensorboard events', type=str, default='/content/drive/MyDrive/events', help='path to tensorboard events')
  parser.add_argument('--logs', type=str, default='/content/drive/MyDrive/logs', help='path to train logs')

  # 数据
  parser.add_argument('--data_root', type=str, default='/content/drive/MyDrive/database/training', help='path to Dataset')
  parser.add_argument('--crop_size', type=int, default=154, help='Crop size of dataset')

  #模型
  parser.add_argument('--model', type=str, default='AttU_Net', choices=['AttU_Net', 'ReAttU_Net'], help='model name')

  # 训练参数
  parser.add_argument('--batchsize', type=int, default=100, help='batch_size (default: 100)')
  parser.add_argument('--total epoch', type=int, default=500, help='epoch number (default: 500)')
  parser.add_argument('--optimizer', type=str, default='adam', choice=['sgd', 'adam'], help='choose optimizer')
  parser.add_argument('lr', type=float, default=0.01, help='learning rate (default: 0.01)')
  parser.add_argument('lr_policy', type=str, default='poly', choices=['poly', 'step', 'multi_step', 'exponential', 'cosine', 'lambda', 'onecycle'], help='learning rate scheduler policy')
  parser.add_argument("--weight_decay", type=float, default=1e-4, help='weight decay (default: 1e-4)')
  parser.add_argument("--random_seed", type=int, default=1, help='random seed (default: 1)')

  #损失函数
  parser.add_argument("--loss_type", type=str, default='cross_entropy', help='loss type (default: False)')

  # 显卡
  parser.add_argument('--gpu_id', type=str, default='0', help='GPU ID')

  # 权重
  parser.add_argument('--ckpt', default='/content/drive/MyDrive/weights', type=str, help='restore from ckeckpoint')
  parser.add_argument('--resume', action='store_true', default=False)

  return parser

opts = get_argparser().parse_args()

# 日志
train_logger = get_logger(opts.logs)

writer = SummaryWriter(log_dir=opts.tensorboard_events)

def _log_stats_train(results, epoch):
  tag_value = {'train_loss': results['train_loss'],
          'dice_score': results['dice_score']}
  for tag, value in tag_value.items():
    writer.add_scalar(tag, value, epoch)

os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_logger.info(f'Conf | Device {device}')

def get_dataset(opts):
  root = os.path.join(opts.data_root)
  

In [None]:
if __name__ == '__main__':
  main()