In [1]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import cohen_kappa_score
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ExponentialLR

from sklearn.model_selection import train_test_split
import monai
from PIL import Image
from monai.losses.dice import DiceLoss
from torchvision.transforms.functional import to_pil_image,affine
from monai.transforms import Rand2DElastic

MLflow support for Python 3.6 is deprecated and will be dropped in an upcoming release. At that point, existing Python 3.6 workflows that use MLflow will continue to work without modification, but Python 3.6 users will no longer get access to the latest MLflow features and bugfixes. We recommend that you upgrade to Python 3.7 or newer.


In [2]:
### 设置参数
images_file = '../GOALS2022-Train/Train/Image'  # 训练图像路径
gt_file = '../GOALS2022-Train/Train/Layer_Masks'
image_size = 800 # 输入图像统一尺寸
image_size2 = 1120
val_ratio = 0.3  # 训练/验证图像划分比例
batch_size = 4 # 批大小
iters = 10000 # 训练迭代次数
optimizer_type = 'adam' # 优化器, 可自行使用其他优化器，如SGD, RMSprop,...
num_workers = 8 # 数据加载处理器个数
init_lr = 1e-3 # 初始学习率

summary_dir = './logs'
torch.backends.cudnn.benchmark = True
print('cuda',torch.cuda.is_available())
print('gpu number',torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(torch.cuda.get_device_name(i))
summaryWriter = SummaryWriter(summary_dir)

cuda True
gpu number 1
NVIDIA RTX A6000


In [3]:
### 从数据文件夹中加载眼底图像，提取相应的金标准，生成训练样本
class OCTDataset(Dataset):
    def __init__(self, image_file, gt_path=None, filelists=None,  mode='train'):
        super(OCTDataset, self).__init__()
        self.mode = mode
        self.image_path = image_file
        image_idxs = os.listdir(self.image_path) # 0001.png,
        self.gt_path = gt_path
        self.file_list = [image_idxs[i] for i in range(len(image_idxs))]        
        if filelists is not None:
            self.file_list = [item for item in self.file_list if item in filelists] 
    
    def transform(self,img, mask):
        (d, t, sc, sh) = transforms.RandomAffine.get_params(degrees=(-20, 20), translate=(0.2, 0.2),
                                                            scale_ranges=(0.8, 1.2), shears=(-20, 20),
                                                            img_size=img.shape)
        img = affine(to_pil_image(img), angle=d, translate=t, scale=sc, shear=sh)
        mask = affine(to_pil_image(mask), angle=d, translate=t, scale=sc, shear=sh)

        return (np.array(img), np.array(mask))
   
    def __getitem__(self, idx):
        real_index = self.file_list[idx]
        img_path = os.path.join(self.image_path, real_index)
        img = cv2.imread(img_path,cv2.IMREAD_GRAYSCALE) 
        #img = img[:,:,np.newaxis]
        #print(img.shape)
        h,w = img.shape # (800, 1100, 3)     
        img = cv2.resize(img,(image_size2, image_size))
        #img = img[:,:,np.newaxis]
        #print(img.shape)
        
        if self.mode == 'train' or self.mode == 'val':
            gt_tmp_path = os.path.join(self.gt_path, real_index)
            gt_img = cv2.imread(gt_tmp_path,cv2.IMREAD_GRAYSCALE)

            ### 像素值为0的是RNFL(类别 0)，像素值为80的是GCIPL(类别 1)，像素值为160的是脉络膜(类别 2)，像素值为255的是其他（类别3）。
            gt_img[gt_img == 0] = 3
            gt_img[gt_img == 80] = 1
            gt_img[gt_img == 160] = 2
            gt_img[gt_img == 255] = 0
            
            gt_img = cv2.resize(gt_img,(image_size2, image_size),interpolation = cv2.INTER_NEAREST)
            # gt_img = gt_img[:,:,1]
            #print('gt shape', gt_img.shape)
            
        
        
        if self.mode == 'train':
            img, gt_img = self.transform(img, gt_img)
            #print('aug = ',img.shape)
            #print('aug = ',gt_img.shape)
            deform = Rand2DElastic(
                prob=0.4,
                spacing=(30, 30),
                magnitude_range=(5, 6),
                rotate_range=(np.pi / 4,),
                scale_range=(0.2, 0.2),
                translate_range=(100, 100),
                padding_mode="zeros",
                device="cpu")
            
            deform.set_random_state(seed=23)
            img = deform(img[np.newaxis,:,:], (image_size, image_size2), mode="bilinear")
            deform.set_random_state(seed=23)
            gt_img = deform(gt_img[np.newaxis,:,:], (image_size, image_size2), mode="nearest")
            #print('aug = ',img.shape)
            #print('aug = ',gt_img.shape)
            
            
        if self.mode == 'train' or self.mode == 'val':
            #gt_img = gt_img[:,:,np.newaxis]
            #gt_img = gt_img.transpose(2,0,1)
            #print('gt shape', gt_img.shape)
            if gt_img.shape[0] != 1:
                gt_img = gt_img[np.newaxis,:,:]
            gt_img = torch.from_numpy(gt_img)
            #print('gt shape', gt_img.shape)
            
        #print(img.shape)
        if img.shape[0] != 1: 
            img = img[np.newaxis,:,:]
        img = torch.from_numpy(img)
        
        
        # print(img.shape)
        # img = img_re.astype(np.float32)
        
        
        if self.mode == 'test':
            ### 在测试过程中，加载数据返回眼底图像，数据名称，原始图像的高度和宽度
            return img, real_index, h, w
        
        if self.mode == 'train' or self.mode == 'val':
            ###在训练过程中，加载数据返回眼底图像及其相应的金标准           
            return img, gt_img

    def __len__(self):
        return len(self.file_list)

In [4]:
model = monai.networks.nets.VNet(in_channels=1, out_channels=4,spatial_dims=2)
test_file = '../GOALS2022-Validation/GOALS2022-Validation/Image'  # 测试图像路径
best_model_path = "/home/liyihao/OVH/home/yihao/GOALS/src/Ensemble/v8_model623_1.0667138594488792.pth"
model.load_state_dict(torch.load(best_model_path))
model.cuda()
model.eval()

VNet(
  (in_tr): InputTransition(
    (act_function): ELU(alpha=1.0, inplace=True)
    (conv_block): Convolution(
      (conv): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (adn): ADN(
        (N): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (down_tr32): DownTransition(
    (down_conv): Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_function1): ELU(alpha=1.0, inplace=True)
    (act_function2): ELU(alpha=1.0, inplace=True)
    (ops): Sequential(
      (0): LUConv(
        (act_function): ELU(alpha=1.0, inplace=True)
        (conv_block): Convolution(
          (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
          (adn): ADN(
            (N): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        

In [7]:
model2 = monai.networks.nets.VNet(in_channels=1, out_channels=4,spatial_dims=2)
#test_file = '../GOALS2022-Validation/GOALS2022-Validation/Image'  # 测试图像路径
best_model_path = "/home/liyihao/OVH/home/yihao/GOALS/src/Ensemble/v12_model805_1.0609809725631845.pth"
model2.load_state_dict(torch.load(best_model_path))
model2.cuda()
model2.eval()

VNet(
  (in_tr): InputTransition(
    (act_function): ELU(alpha=1.0, inplace=True)
    (conv_block): Convolution(
      (conv): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (adn): ADN(
        (N): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (down_tr32): DownTransition(
    (down_conv): Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_function1): ELU(alpha=1.0, inplace=True)
    (act_function2): ELU(alpha=1.0, inplace=True)
    (ops): Sequential(
      (0): LUConv(
        (act_function): ELU(alpha=1.0, inplace=True)
        (conv_block): Convolution(
          (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
          (adn): ADN(
            (N): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        

In [8]:
model3 = monai.networks.nets.VNet(in_channels=1, out_channels=4,spatial_dims=2)
#test_file = '../GOALS2022-Validation/GOALS2022-Validation/Image'  # 测试图像路径
best_model_path = "/home/liyihao/OVH/home/yihao/GOALS/src/Ensemble/v10_model665_1.0598934205860422.pth"
model3.load_state_dict(torch.load(best_model_path))
model3.cuda()
model3.eval()

VNet(
  (in_tr): InputTransition(
    (act_function): ELU(alpha=1.0, inplace=True)
    (conv_block): Convolution(
      (conv): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (adn): ADN(
        (N): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (down_tr32): DownTransition(
    (down_conv): Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_function1): ELU(alpha=1.0, inplace=True)
    (act_function2): ELU(alpha=1.0, inplace=True)
    (ops): Sequential(
      (0): LUConv(
        (act_function): ELU(alpha=1.0, inplace=True)
        (conv_block): Convolution(
          (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
          (adn): ADN(
            (N): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        

In [9]:
model4 = monai.networks.nets.VNet(in_channels=1, out_channels=4,spatial_dims=2)
#test_file = '../GOALS2022-Validation/GOALS2022-Validation/Image'  # 测试图像路径
best_model_path = "/home/liyihao/OVH/home/yihao/GOALS/src/Ensemble/v11_model891_1.0616440244186967.pth"
model4.load_state_dict(torch.load(best_model_path))
model4.cuda()
model4.eval()

VNet(
  (in_tr): InputTransition(
    (act_function): ELU(alpha=1.0, inplace=True)
    (conv_block): Convolution(
      (conv): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (adn): ADN(
        (N): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (down_tr32): DownTransition(
    (down_conv): Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_function1): ELU(alpha=1.0, inplace=True)
    (act_function2): ELU(alpha=1.0, inplace=True)
    (ops): Sequential(
      (0): LUConv(
        (act_function): ELU(alpha=1.0, inplace=True)
        (conv_block): Convolution(
          (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
          (adn): ADN(
            (N): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        

In [11]:
model5 = monai.networks.nets.VNet(in_channels=1, out_channels=4,spatial_dims=2)
#test_file = '../GOALS2022-Validation/GOALS2022-Validation/Image'  # 测试图像路径
best_model_path = "/home/liyihao/OVH/home/yihao/GOALS/src/Ensemble/v6_model563_1.072693606666316.pth"
model5.load_state_dict(torch.load(best_model_path))
model5.cuda()
model5.eval()

VNet(
  (in_tr): InputTransition(
    (act_function): ELU(alpha=1.0, inplace=True)
    (conv_block): Convolution(
      (conv): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
      (adn): ADN(
        (N): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (down_tr32): DownTransition(
    (down_conv): Conv2d(16, 32, kernel_size=(2, 2), stride=(2, 2), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act_function1): ELU(alpha=1.0, inplace=True)
    (act_function2): ELU(alpha=1.0, inplace=True)
    (ops): Sequential(
      (0): LUConv(
        (act_function): ELU(alpha=1.0, inplace=True)
        (conv_block): Convolution(
          (conv): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
          (adn): ADN(
            (N): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        

In [12]:


test_dataset = OCTDataset(image_file = test_file, 
                            mode='test')
                            
img, real_index, h, w = test_dataset.__getitem__(0)
print(img.shape)
print(real_index, h, w)

for img, idx, h, w in test_dataset:
    # print(idx)
    img = img.unsqueeze(0).float()
    img = img.cuda()
    #print(img.shape)
    logits = model(img)
    logits2 = model2(img)
    logits3 = model3(img)
    logits4 = model4(img)
    logits5 = model5(img)
    logits = logits.detach().cpu().numpy()
    logits2 = logits2.detach().cpu().numpy()
    logits3 = logits3.detach().cpu().numpy()
    logits4 = logits4.detach().cpu().numpy()
    logits5 = logits5.detach().cpu().numpy()
    
    #logits_ensm = (logits + logits2 + logits3 + logits4 + logits5)/5
    logits_ensm = (0.3*logits + 0.2*logits5 + 0.2*logits2 + 0.15*logits3 + 0.15*logits4)
    
    #print(logits.shape)
    #print(logits)
    #print(abc)
    pred_img = logits_ensm.argmax(1)
    pred_gray = np.squeeze(pred_img, axis=0)
    pred_gray = pred_gray.astype('float32')
    #print(np.unique(pred_gray))
    # print(pred_gray.shape)
    pred_gray[pred_gray == 0.] = 255
    pred_gray[pred_gray == 1.] = 80
    pred_gray[pred_gray == 2.] = 160
    pred_gray[pred_gray == 3.] = 0
    
    #print(pred_gray)
    pred_ = cv2.resize(pred_gray, (w, h),interpolation = cv2.INTER_NEAREST)
    #print(pred_gray.shape)
    print(np.unique(pred_))
    #print(abc)
    filepath = './VNET_ensemble5/Layer_Segmentations/'
    folder = os.path.exists(filepath)
    if not folder:
        # 判断是否存在文件夹如果不存在则创建为文件夹
        os.makedirs(filepath)
    cv2.imwrite('./VNET_ensemble5/Layer_Segmentations/'+idx, pred_)


torch.Size([1, 800, 1120])
0200.png 800 1100
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80. 160. 255.]
[  0.  80