## 准备数据集

**下载数据集并解压**

In [None]:
!wget https://dataset-bj.cdn.bcebos.com/%E5%8C%BB%E7%96%97%E6%AF%94%E8%B5%9B/GOALS2022-Train.zip
!wget https://dataset-bj.cdn.bcebos.com/%E5%8C%BB%E7%96%97%E6%AF%94%E8%B5%9B/GOALS2022-Validation.zip

!unzip -oq GOALS2022-Train.zip -d GOALS2022-Train
!unzip -oq GOALS2022-Validation.zip -d GOALS2022-Validation

In [None]:
!rm GOALS2022-Train.zip
!rm GOALS2022-Validation.zip

**图像另存为单通道灰度图，标签数值重映射至0-3**

In [None]:
!python preprocess.py

**增加水平翻转验证集**

In [None]:
import cv2
import os

name_list = ['0076.png', '0063.png', '0054.png', '0075.png', '0065.png',
             '0066.png', '0011.png', '0090.png', '0068.png', '0003.png']
for file_name in name_list:
    image = cv2.imread(os.path.join('GOALS2022-Train/Train/Image', file_name), cv2.IMREAD_GRAYSCALE)
    label = cv2.imread(os.path.join('GOALS2022-Train/Train/Layer_Masks', file_name), cv2.IMREAD_GRAYSCALE)

    image = cv2.flip(image, flipCode=1)
    label = cv2.flip(label, flipCode=1)

    cv2.imwrite(os.path.join('GOALS2022-Train/Train/Image', '_' + file_name), image)
    cv2.imwrite(os.path.join('GOALS2022-Train/Train/Layer_Masks', '_' + file_name), label)

**生成脉络膜区域（cls 3）**

In [None]:
import cv2
import numpy as np
import glob
from tqdm import tqdm

!cp -r GOALS2022-Train GOALS2022-Train_cls3
mask_list = glob.glob('GOALS2022-Train_cls3/Train/Layer_Masks/*.png')
for path in tqdm(mask_list):
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    mask[mask != 3] = 0
    mask[mask == 3] = 1
    cv2.imwrite(path, mask)

**生成GCIPL区域（cls 2）**

In [None]:
import cv2
import numpy as np
import glob
from tqdm import tqdm

!cp -r GOALS2022-Train GOALS2022-Train_cls2
mask_list = glob.glob('GOALS2022-Train_cls2/Train/Layer_Masks/*.png')
for path in tqdm(mask_list):
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    mask[mask != 2] = 0
    mask[mask == 2] = 1
    cv2.imwrite(path, mask)

**验证标签合法性**

In [None]:
import cv2
import numpy as np

print(np.unique(cv2.imread('GOALS2022-Train/Train/Layer_Masks/0003.png')))
print(np.unique(cv2.imread('GOALS2022-Train/Train/Layer_Masks/_0003.png')))
print(np.unique(cv2.imread('GOALS2022-Train_cls2/Train/Layer_Masks/0003.png')))
print(np.unique(cv2.imread('GOALS2022-Train_cls3/Train/Layer_Masks/0003.png')))

## 图像分类

In [None]:
!pip install paddlex==2.1.0

**模型训练**

In [None]:
!python train_cls.py

**模型预测**

In [None]:
!python infer_cls.py

## 图像分割

### 训练阶段

**统计均值/标准差**

In [None]:
!pip install paddlex==1.3.11

In [None]:
import paddlex as pdx

train_analysis = pdx.datasets.analysis.Seg(
    data_dir='GOALS2022-Train/Train',
    file_list='split_lists/seg_holdout/train.txt',
    label_list='split_lists/seg_holdout/labels.txt')

train_analysis.analysis()

In [None]:
!python train_seg_cls123.py

In [None]:
!python train_seg_cls3.py

**统计均值/标准差**

In [None]:
import cv2
import glob
import os
import numpy as np


def get_file_list(dataset_root, train_file_path):
    file_list = []
    with open(train_file_path, mode='r') as f:
        for line in f:
            items = line.strip().split()
            image_path = os.path.join(dataset_root, items[0])
            label_path = os.path.join(dataset_root, items[1])
            file_list.append([image_path, label_path])
    return file_list


def get_mean_std(dataset_root, train_file_path, only_gt=False):
    max_val, min_val = 0, 255
    mean, std = 0, 0

    file_list = get_file_list(dataset_root, train_file_path)
    for i in range(len(file_list)):
        image = cv2.imread(file_list[i][0], cv2.IMREAD_GRAYSCALE)

        if only_gt:
            label = cv2.imread(file_list[i][1], cv2.IMREAD_GRAYSCALE)
            image[label == 0] = 0
            _mean, _std = cv2.meanStdDev(image, mask=label)
            mean += _mean
            std += _std
        else:
            mean += image.mean()
            std += image.std()
        max_val = max(max_val, image.max())
        min_val = min(min_val, image.min())

    mean /= len(file_list)
    std /= len(file_list)
    mean /= max_val - min_val
    std /= max_val - min_val

    return mean, std


print(get_mean_std(
    dataset_root='GOALS2022-Train_cls2/Train',
    train_file_path='split_lists/seg_cv/full.txt',
    only_gt=False))
print(get_mean_std(
    dataset_root='GOALS2022-Train_cls2/Train',
    train_file_path='split_lists/seg_cv/full.txt',
    only_gt=True))

In [None]:
!python train_seg_cls2.py

### 预测阶段

In [2]:
!python infer_seg_cls123.py
!python infer_seg_cls2.py
!python infer_seg_cls3.py
!python infer_seg_merge.py

**结果文件打包**

In [None]:
!zip -rq results.zip results