In [None]:
import torch
from torch.utils.data import DataLoader, RandomSampler
import os
import time
import gc
import sys
import logging
import argparse
import numpy as np
import pickle
import SimpleITK as sitk
import yaml

sys.setrecursionlimit(10000)
# sys.path.append("..")
# sys.path.append("/home/wangc/now/pure/ALinAirway/func")  # 根据实际情况调整路径

from func.load_dataset import airway_dataset
from func.model_arch_e0_org_channel import SegAirwayModel
from func.loss_func import (
    dice_loss_weights,
    dice_accuracy,
    dice_loss_power_weights,
)


from func.model_run import semantic_segment_crop_and_cat
from func.post_process import post_process, add_broken_parts_to_the_result
from func.detect_tree import tree_detection
from func.ulti import get_df_of_line_of_centerline, load_obj
from func.eval_use_func import (
    load_many_CT_img,
    get_metrics,
    get_the_skeleton_and_center_nearby_dict,
)


def update_dataset_paths(dataset, old_prefix, new_prefix):
    for key, value in dataset.items():
        # 去除指定前缀，并添加新的前缀
        value["image"] = new_prefix + value["image"].replace(old_prefix, "")
        value["label"] = new_prefix + value["label"].replace(old_prefix, "")
    return dataset


def load_config(config_path):
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config



config = load_config("../config.yaml")
exact09_img_path = config["exact09"]["img_path"]
lidc_img_path = config["lidc"]["img_path"]
exact09_label_path = config["exact09"]["label_path"]
lidc_label_path = config["lidc"]["label_path"]




# Configuration
need_resume = True
learning_rate = 1e-5
max_epoch = 50
freq_switch_of_train_mode_high_low_generation = 1
num_samples_of_each_epoch = 20000
batch_size = 1
train_file_format = ".nii.gz"
crop_size = (32, 128, 128)
windowMin_CT_img_HU = -1000
windowMax_CT_img_HU = 600
model_save_freq = 1
num_workers = 4



data_info_path = '/home/wangc/now/pure/saved_objs/for_128_objs/data_dict_org.pkl'
# Load dataset
dataset_info_org = load_obj(data_info_path)
if config["is_change_prefix"]["is_change"]:
    old_prefix = config["is_change_prefix"]["old_prefix"]
    new_prefix = config["is_change_prefix"]["new_prefix"]
    dataset_info_org = update_dataset_paths(
        dataset_info_org, old_prefix, new_prefix
    )

train_dataset_org = airway_dataset(dataset_info_org)
train_dataset_org.set_para(
    file_format=train_file_format,
    crop_size=crop_size,
    windowMin=windowMin_CT_img_HU,
    windowMax=windowMax_CT_img_HU,
    need_tensor_output=True,
    need_transform=True,
)



In [None]:
sampler_of_airways_org = RandomSampler(
        train_dataset_org,
        num_samples=min(num_samples_of_each_epoch, len(train_dataset_org)),
        replacement=True,
)
dataset_loader = DataLoader(
    train_dataset_org,
    batch_size=batch_size,
    sampler=sampler_of_airways_org,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=(num_workers > 1),
)

len_dataset_loader = len(dataset_loader)
for ith_batch, batch in enumerate(dataset_loader):
    img_input = batch["image"].float()
    groundtruth_foreground = batch["label"].float()
    print(groundtruth_foreground.shape)
    print(1.0 in groundtruth_foreground)
    
