In [1]:
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split, KFold
import matplotlib.pyplot as plt
import torchvision.transforms as trans
# from dataset_functions.task1_dataset import OCTDataset
from dataset_functions.transforms import img_transforms

In [4]:
a = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
print(a)
print(a.shape)

tensor([[[[0.4850]],

         [[0.4560]],

         [[0.4060]]]])
torch.Size([1, 3, 1, 1])


In [2]:
# 设置参数
image_file = '../datasets/Train/Image_aug2' 
gt_file = '../datasets/Train/Layer_Masks_aug2'
layer_file = '../datasets/Train/Layer_show_aug2'
# image_size = 256 # 统一输入图像尺寸
val_ratio = 0.2 # 验证/训练图像划分比例
batch_size = 8


filelists = os.listdir(image_file)
train_filelists, val_filelists = train_test_split(filelists, test_size = val_ratio,random_state = 42)
print("Total Nums: {}, train: {}, val: {}".format(len(filelists), len(train_filelists), len(val_filelists)))
print(val_filelists)

Total Nums: 200, train: 160, val: 40
['0062_2.png', '0025_2.png', '0065_2.png', '0009_2.png', '0045_2.png', '0071_1.png', '0057_2.png', '0037_2.png', '0064_2.png', '0088_2.png', '0066_2.png', '0035_1.png', '0072_2.png', '0049_1.png', '0030_1.png', '0012_2.png', '0097_2.png', '0052_2.png', '0090_1.png', '0067_2.png', '0018_1.png', '0074_1.png', '0068_2.png', '0019_1.png', '0085_2.png', '0063_1.png', '0035_2.png', '0027_2.png', '0080_1.png', '0059_2.png', '0019_2.png', '0036_1.png', '0032_1.png', '0093_2.png', '0033_1.png', '0028_1.png', '0055_1.png', '0066_1.png', '0013_1.png', '0060_2.png']


In [4]:
class OCTDataset(Dataset):
    def __init__(self, image_transforms, image_dir, filelists=None, gt_dir=None, mode='train'):
        super(OCTDataset, self).__init__()
        self.image_transforms = image_transforms
        self.image_dir = image_dir
        self.filelists = filelists
        if self.filelists == None:
            self.filelists = os.listdir(self.image_dir)
        self.gt_dir = gt_dir
        self.mode = mode

    def __getitem__(self, idx):
        img_index = self.filelists[idx]
        img_path = os.path.join(self.image_dir, img_index)
        img = cv2.imread(img_path)
        h, w, c = img.shape

        gt_img = cv2.imread(os.path.join(self.gt_dir, img_index))
        gt_img = gt_img[:,:,1]
        gt_img = gt_img.astype(np.uint8)
        # gt_img = torch.from_numpy(gt_img)
        
        if self.mode == 'train' or self.mode == 'val':            
            # 像素值为0的是RNFL(类别 0)，像素值为80的是GCIPL(类别 1)，像素值为160的是脉络膜(类别 2)，像素值为255的是其他（类别3）
            gt_img[gt_img == 80] = 1
            gt_img[gt_img == 160] = 2
            gt_img[gt_img == 255] = 3
        # single_label_list = []
        # for c in range(4):
        #     single_label = (gt_img==c)
        #     single_label_list.append(single_label)
        # labels_one_hot = torch.stack(tuple(single_label_list), axis = 0)
        # labels_one_hot = labels_one_hot.numpy()
        # # labels_one_hot = np.transpose(labels_one_hot,[1,2,0]) 
        # labels_one_hot = labels_one_hot.astype(np.uint8)
            
            # gt_img = cv2.resize(gt_img, (800, 800))
            # gt_img = torch.from_numpy(gt_img)
            # gt_img = torch.squeeze(gt_img, 0)
            # gt_img = gt_img[:,:,1] # 取一个通道
        
        if self.image_transforms is not None:
            out = self.image_transforms(image=img,mask=gt_img)
            img, gt_img = out['image'], out['mask']

        if self.mode == 'train' or self.mode == 'val':            
            return img.float(), gt_img.long() # img, labels_one_hot 

        if self.mode == 'test':
            return img.float(), img_index, h, w

    def __len__(self):
        return len(self.filelists)


img_train_transforms = img_transforms(applied_types='train')
img_val_transforms = img_transforms(New_size = (800,800), applied_types='val')
val_dataset = OCTDataset(image_transforms=img_val_transforms, image_dir=image_file, filelists=val_filelists, gt_dir=gt_file, mode='val')
# val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

img, gt_img = val_dataset.__getitem__(0)
print(torch.unique(gt_img))
print(gt_img.shape)
print(gt_img)# 输入：image_path为数据集路径，filelists为此次调用OCTDataset的图像列表（训练或者验证的图像列表），gt_path为分割的ground truth




tensor([0, 1, 2, 3])
torch.Size([800, 800])
tensor([[3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3],
        ...,
        [3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3],
        [3, 3, 3,  ..., 3, 3, 3]])


  "blur_limit and sigma_limit minimum value can not be both equal to 0. "


In [5]:
def get_layer_label(input, layer):
    all_1 = np.ones((800,800))
    layer_diff = layer - input
    layer_diff[layer_diff != 0] = 1
    layer = layer_diff * layer
    location1 = np.all((layer == [0,0,1]), axis=2)
    location2 = np.all((layer == [0,1,0]), axis=2)
    location3 = np.all((layer == [1,0,0]), axis=2)
    location4 = np.all((layer == [1,1,0]), axis=2)
    location5 = np.all((layer == [1,0,1]), axis=2)
    layer_label1 = all_1 * location1
    # layer_label1 = np.repeat(layer_label1,repeats = 3,axis = 2)
    layer_label2 = all_1 * location2
    layer_label3 = all_1 * location3
    layer_label4 = all_1 * location4
    layer_label5 = all_1 * location5
    layer_label = np.stack((layer_label1, layer_label2, layer_label3, layer_label4, layer_label5), axis=2) 
    return layer_label # 800*800*5

class OCTDataset_layermask(Dataset):
    def __init__(self, image_transforms, image_path, layer_path, filelists=None, gt_path=None, mode='train'):
        super(OCTDataset_layermask, self).__init__()
        self.image_transforms = image_transforms
        self.image_path = image_path
        self.layer_path = layer_path
        self.filelists = filelists
        if self.filelists == None:
            self.filelists = os.listdir(self.image_path)
        self.gt_path = gt_path
        self.mode = mode

    def __getitem__(self, idx):
        img_index = self.filelists[idx]
        img_path = os.path.join(self.image_path, img_index)
        img = cv2.imread(img_path)
        h, w, c = img.shape

        # img = cv2.resize(img, (self.image_size, self.image_size))

        if self.mode == 'train' or self.mode == 'val':      
            gt_img = cv2.imread(os.path.join(self.gt_path, img_index))
            layer_img = cv2.imread(os.path.join(self.layer_path, img_index))
            # if len(gt_img.shape==3):
            gt_img = gt_img[:,:,0] # 取一个通道
            gt_img = gt_img.astype(np.uint8)      
            # 像素值为0的是RNFL(类别 0)，像素值为80的是GCIPL(类别 1)，像素值为160的是脉络膜(类别 2)，像素值为255的是其他（类别3）
            gt_img[gt_img == 80] = 1
            gt_img[gt_img == 160] = 2
            gt_img[gt_img == 255] = 3

            layer_label = get_layer_label(img/255, layer_img/255).astype(np.uint8)
            layer_label = torch.from_numpy(layer_label)


        if self.image_transforms is not None:
            if self.mode == 'test':
                out = self.image_transforms(image=img)
                img = out['image']
            else:
                out = self.image_transforms(image=img,mask=gt_img)
                img, gt_img = out['image'], out['mask']

        img = img/255

        if self.mode == 'train' or self.mode == 'val':             
            return img.float(), gt_img.long(), layer_label.long()

        if self.mode == 'test':
            return img.float(), img_index, h, w

    def __len__(self):
        return len(self.filelists)


img_train_transforms = img_transforms(applied_types='train')
img_val_transforms = img_transforms(New_size = (800,800), applied_types='val')
val_dataset = OCTDataset_layermask(image_transforms=img_val_transforms, image_path=image_file,
                                layer_path=layer_file, filelists=val_filelists, gt_path=gt_file, mode='val')
# val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

img, gt_img, layer_label = val_dataset.__getitem__(0)
# print(torch.unique(gt_img))
print(layer_label.shape)
print(layer_label)

torch.Size([800, 800, 5])
tensor([[[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         ...,
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         ...,
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         ...,
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        ...,

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         ...,
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         ...,
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         ...,
        