In [3]:
import sys
import os
import logging
import re
import xlrd
from time import time as ttime

from PIL import Image

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np

from PhotonDataset import transform, PhotonDataset

logging.getLogger("PIL").setLevel(logging.ERROR)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging

In [4]:
# training data location
BASE_PATH="/mnt/ossdata/"

ALL="ALL"

# exclude data folders
EXCLUDE_PROJS=[]

WEEK2DAY = {
    "W0" : "D0",
    "W1" : "D7",
    "W2" : "D14",
    "W3" : "D21",
    "W4" : "D28",
}

INSPECT2ID = {
    "Stratum_corneum" : 0,
    "DEJunction" : 1,
    "ELCOR" : 2,
}

IDS2INSPECT = {
    0 : "Stratum_corneum",
    1 : "DEJunction",
    2 : "ELCOR",
}

IMAGE_SIZE=224

# the window size for frame index 
PRECISION_WINDOWS=3

DATALOAD_MARGIN=200

DATASPLIT_RATIO=[0.8, 0.1, 0.1]

In [5]:
# Retrieve the parent folder path of the .tif file from the image paths in an Excel sheet, 
# because the paths in the Excel are pointing to the merged file. We need to access the AF and SHG files
def extractTIFPath(imgUrl, projName, personId, day):
    subPaths = imgUrl.split("\\")
    if not subPaths[-1].endswith(".tif"):
        logger.error(f"@ERROR:extractTIFPath: It is not tif file: {projName}, {personId}, {day}, {imgUrl}")
        return None, None
    try:
        personIndex = subPaths.index(personId)
    except Error:
        logger.error(f"@ERROR:extractTIFPath:imgUrl is invalid: {projName}, {personId}, {day}, {imgUrl}")
        return None, None
    path = BASE_PATH + projName + "/" + projName + "-双光子原图/" + day + "/原始/" + "/".join(subPaths[personIndex:-1])
    fileKey = projName + "_" + personId + "_" + day + "_" + subPaths[-2]
    return path, fileKey

# retrieve the pathes of AF and SHG files
def extractAFAndSHGFile(dir):
    try:
        return [dir + "/" + x for x in os.listdir(dir) if ("AF_Color" in x or "SHG_Color" in x) and x.endswith(".tif")]
    except FileNotFoundError:
        logger.error(f"@ERROR:extractAFAndSHGFile: dir does not exists: {dir}")
        return None

['/mnt/ossdata//C230534/C230534-双光子原图/D0/原始/RD0002/三维快速扫描/095345/ZStackFast_AF_Color_20231120_095624257.tif', '/mnt/ossdata//C230534/C230534-双光子原图/D0/原始/RD0002/三维快速扫描/095345/ZStackFast_SHG_Color_20231120_095624257.tif']


In [None]:
def processSingleInspectSheet(picInfo, table, inspectId, projName, personId, day):
    if table.nrows <= 1:
        # empty file
        return

    frame_start_index = -1
    frame_end_index = -1
    image_url_index = -1
    for i in range(table.row_len(0)):
        value = table.cell_value(0, i)
        if "起始帧" in value:
            frame_start_index = i
        if "结束帧" in value:
            frame_end_index = i
        if "地址" in value:
            image_url_index = i

    if frame_start_index == -1 \
        or frame_end_index == -1 \
        or image_url_index == -1:
        return

    for i in range(1, table.nrows):
        try:
            frame_start = int(table.cell_value(i, frame_start_index))
            frame_end = int(table.cell_value(i, frame_end_index))
            image_url = table.cell_value(i, image_url_index)
        except:
            continue
        urlPath, fileKey = extractTIFPath(image_url, projName, personId, day)
        if urlPath is None:
            logger.error(f"@szh:processSingleInspectSheet:invalid url path: {urlPath}")
            continue
        if fileKey not in picInfo:
            picInfo[fileKey] = {}

        picInfo[fileKey][inspectId] = (frame_start, frame_end)
        if "path" not in picInfo[fileKey]:
            picInfo[fileKey]["path"] = urlPath
        elif urlPath != picInfo[fileKey]['path']:
            logger.error(f"@ERROR:processSingleInspectSheet: tifPath duplicate: {projName}, {personId}, {day}, {fileKey}, {urlPath}, {picInfo[fileKey]['path']}")

def processAllInspectSheet(picInfo, table, projName, personId, day):
    inspectId = -1
    num_index = -1
    frame_start_index = -1
    frame_end_index = -1
    image_url_index = -1
    for rowIndex in range(table.nrows):
        row = table.row(rowIndex)
        frame_start = -1
        frame_end = -1
        image_url = ''
        for colIndex in range(len(row)):
            cellValue = table.cell_value(rowIndex, colIndex)
            if "" == cellValue:
                continue

            for tmpInspectName in INSPECT2ID:
                if tmpInspectName in str(cellValue).upper():
                    inspectId = INSPECT2ID[tmpInspectName]
                    num_index = -1
                    frame_start_index = -1
                    frame_end_index = -1
                    image_url_index = -1
                    break

            if isinstance(cellValue, str):
                if "编号" in cellValue:
                    num_index = colIndex
                    continue
                if "起始帧" in cellValue:
                    frame_start_index = colIndex
                    continue
                if "结束帧" in cellValue:
                    frame_end_index = colIndex
                    continue
                if "地址" in cellValue:
                    image_url_index = colIndex
                    continue

            if colIndex == num_index:
                try:
                    int(cellValue)
                except:
                    break

            if colIndex == frame_start_index:
                try:
                    frame_start = int(cellValue)
                except:
                    break

            if colIndex == frame_end_index:
                try:
                    frame_end = int(cellValue)
                except:
                    break

            if colIndex == image_url_index:
                if len(cellValue.strip()) == 0:
                    break
                image_url = cellValue.strip()

        if inspectId >= 0 and frame_start >= 0 and frame_end >= 0 and len(image_url) > 0:
            urlPath, fileKey = extractTIFPath(image_url, projName, personId, day)
            if fileKey not in picInfo:
                picInfo[fileKey] = {}
            if inspectId in picInfo[fileKey]:
                frame_start = min(frame_start, picInfo[fileKey][inspectId][0])
                frame_end = max(frame_end, picInfo[fileKey][inspectId][1])
            picInfo[fileKey][inspectId] = (frame_start, frame_end)
            if "path" not in picInfo[fileKey]:
                picInfo[fileKey]["path"] = urlPath
            elif urlPath != picInfo[fileKey]['path']:
                logger.error(f"@ERROR:processSingleInspectSheet: tifPath duplicate: {projName}, {personId}, {day}, {fileKey}, {urlPath}, {picInfo[fileKey]['path']}")

        frame_start = -1
        frame_end = -1
        image_url = -1

def processExcel(picInfo, path, fileName, projName):
    logger.info(f"@szh:processExcel {path} {fileName} {projName}")
    try:
        personId, day, _ = re.split("-|\.", fileName)
    except:
        groups = re.search(r'([a-zA-Z]+\d+)([a-zA-Z]\d+[a-zA-Z]*)\.xls', fileName)
        personId, day = groups[1], groups[2]

    if day in WEEK2DAY:
        day = WEEK2DAY[day]

    excel = xlrd.open_workbook(path + "/" + fileName)
    sheetNames = excel.sheet_names()
    for sheetName in sheetNames:
        if ALL in sheetName.upper():
            table = excel.sheet_by_name(sheetName)
            processAllInspectSheet(picInfo, table, projName, personId, day)
            continue

        inspectId = -1;
        for tmpInspectName in INSPECT2ID:
            if tmpInspectName in sheetName.upper():
                inspectId = INSPECT2ID[tmpInspectName]
                break
        if inspectId == -1:
            continue

        table = excel.sheet_by_name(sheetName)
        processSingleInspectSheet(picInfo, table, inspectId, projName, personId, day)

In [None]:
proj_dirs = os.listdir(BASE_PATH)
picInfo = {}
for proj_dir in proj_dirs:
    if not os.path.isdir(BASE_PATH + proj_dir):
        logger.error(f"@ERROR:skip root project file: {proj_dir}")
        continue

    if not proj_dir.startswith("C"):
        logger.error(f"@ERROR:skip root project directory: {proj_dir}")
        continue

    if proj_dir in EXCLUDE_PROJS:
        logger.info(f"@szh:skip exclude project {proj_dir}")
        continue

    logger.info(f"@szh: start to process proj: {proj_dir}")
    data_dir = BASE_PATH + proj_dir + "/" + proj_dir + "-双光子数据"
    excelFileNames = os.listdir(data_dir)

    for excelFileName in excelFileNames:
        processExcel(picInfo, data_dir, excelFileName, proj_dir)

print(picInfo)
print(len(picInfo))

In [10]:
def contructDataSet(picInfo, transform, groupNums=None):
    features, rLabels, rWinLabels = [], [], []
    for seq, (key, valueInfo) in zip(range(len(picInfo)), picInfo.items()):
        if groupNums is not None and (seq < groupNums[0] or seq >= groupNums[1]):
            continue

        logger.info(f"@szh: process picInfo: {key}")
        tifFiles = extractAFAndSHGFile(valueInfo['path'])
        if tifFiles is None:
            continue

        labels = [valueInfo.get(x, (-1, -1)) for x in range(len(IDS2INSPECT))]

        if len(tifFiles) > 2 or len(tifFiles) == 0:
            logger.info(f"The tifFiles is empty!")
            continue

        imgs = [Image.open(x) for x in tifFiles]
        if len(imgs) == 2:
            if imgs[0].n_frames != imgs[1].n_frames:
                logger.error(f"@szh:The frame numbers does not match: {filePathes}")
                continue
        try:
            readPos = 1
            while True:
                frameTensors = [np.array(x.copy().resize((IMAGE_SIZE, IMAGE_SIZE), Image.Resampling.BICUBIC).getdata(), dtype='uint8') for x in imgs]
                for frameTensor in frameTensors:
                    if frameTensor.shape != (IMAGE_SIZE * IMAGE_SIZE, 3):
                        logger.error(f"@szh:The file size is not 512 * 512 * 3:  {filePathes}, {frameTensor.size()}")
                        break
                frameTensors = [np.reshape(x, (IMAGE_SIZE, IMAGE_SIZE, 3)) for x in frameTensors]
                labelGroup = []
                winLabelGroup = []
                # 1: feature appears；0：feature disappears；-1：no label data
                # (1, 1):feature appearing key frame; (1,0): feature disappearing key frame; (-1, *): normal frame
                for i, (x, y) in zip(range(len(labels)), labels):
                    if x == -1 and y == -1:
                        winLabelGroup.append(-1)
                    elif (readPos <= x + PRECISION_WINDOWS \
                        and readPos >= x)  \
                        or (readPos >= y - PRECISION_WINDOWS \
                        and readPos <= y):
                        winLabelGroup.append(1)
                    elif (readPos >= x - PRECISION_WINDOWS \
                        and readPos < x)  \
                        or (readPos <= y + PRECISION_WINDOWS \
                        and readPos > y):
                        winLabelGroup.append(0)
                    else:
                        winLabelGroup.append(-1)

                    if readPos >= x and readPos <= y:
                        labelGroup.append(1)
                    elif x == -1 and y == -1:
                        labelGroup.append(-1)
                    else:
                        labelGroup.append(0)

                labelGroup = torch.FloatTensor(labelGroup)
                winLabelGroup = torch.FloatTensor(winLabelGroup)

                if len(frameTensors) > 1:
                    tensor = sum(frameTensors)
                else:
                    tensor = frameTensors[0]

                for x in imgs:
                    x.seek(readPos)
                readPos += 1

                tensor = transform(tensor)

                features.append(tensor)
                rLabels.append(labelGroup)
                rWinLabels.append(winLabelGroup)
        except EOFError:
            for x in imgs:
                x.close()

    return features, rLabels, rWinLabels

In [None]:
def saveDataLoader(picInfo, start, end):

    features, labels, winLabels = contructDataSet(picInfo, transform, groupNums=(start, end))
    raw_dataset = PhotonDataset(features, labels, winLabels)
    if raw_dataset.__len__() == 0:
        print("skip")
        return
    train_dataset, val_dataset, test_dataset = random_split(raw_dataset, DATASPLIT_RATIO)

    train_dataLoader = DataLoader(dataset=train_dataset, batch_size=75, shuffle=True, num_workers=0)
    val_dataLoader = DataLoader(dataset=val_dataset, batch_size=75, shuffle=True, num_workers=0)
    test_dataLoader = DataLoader(dataset=test_dataset, batch_size=75, shuffle=True, num_workers=0)

    print("Data margin:", start, end)
    print("Training size:", len(train_dataLoader))
    print("Validation size:", len(val_dataLoader))
    print("Testing size:", len(test_dataLoader))

    torch.save(train_dataLoader, BASE_PATH + "/dataloader/train-{}-{}.pth".format(start, end))
    torch.save(val_dataLoader, BASE_PATH + "/dataloader/validation-{}-{}.pth".format(start, end))
    torch.save(test_dataLoader, BASE_PATH + "/dataloader/test-{}-{}.pth".format(start, end))

dataNum = len(picInfo)
start = 0
for end in range(start + DATALOAD_MARGIN, dataNum, DATALOAD_MARGIN):
    saveDataLoader(picInfo, start, end)
    start = end
saveDataLoader(picInfo, start, dataNum)