In [None]:
# Pip 安装命令
!pip install segmentation-models-pytorch
# !pip install lightning
!pip install wandb -U
!pip install monai

In [None]:
# 导入的库
import IPython
import albumentations as A
import monai
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch
# from fastai.losses import *
import torchmetrics
import wandb
from IPython.display import display
from albumentations.pytorch import ToTensorV2
from pytorch_lightning import LightningDataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.tuner.tuning import Tuner
from torch.utils.data import Dataset, DataLoader

# !pip install ipywidgets
# !pip install albumentations
# !pip install nibabel

IPython.display.clear_output()

print("Envirionment Set Up.")

# Dataset and Augment Setting

In [None]:
# 定义数据增强
train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(),
    #     A.RandomBrightnessContrast(p=0.2),
    A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.3, alpha_affine=120 * 0.2),
    A.RandomSizedCrop(min_max_height=(128, 256), height=256, width=256, p=0.5),
    # A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(256, 256),
    # A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])


def adjust_window(image, window_center, window_width):
    """
    调整CT图像的窗宽窗位。
    :param image: 输入的图像数组。
    :param window_center: 窗位（WC）。
    :param window_width: 窗宽（WW）。
    :return: 调整窗宽窗位后的图像。
    """
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    windowed_img = np.clip(image, img_min, img_max)
    # print(windowed_img.dtype) # NOW its float64
    return windowed_img


class MultipleImageDataset(Dataset):
    def __init__(self, image_paths, label_paths, transform=None):
        """
        image_paths: 图像文件路径列表
        label_paths: 标签文件路径列表
        transform: 应用于图像和标签的转换操作
        """
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.transform = transform

    def __len__(self):
        # 假设图像和标签列表长度相等
        return len(self.image_paths)

    def __getitem__(self, idx):  # dataloader获取每个数据都会用到这个函数，所以你应当在这里实现你需要的
        # 预处理等等步骤
        image = (np.load(self.image_paths[idx]))['arr_0']
        #         print(image.shape)
        label = (np.load(self.label_paths[idx]))['arr_0']
        #         print(label.shape)

        image = adjust_window(image, window_center=40, window_width=400)

        if self.transform:
            #             image = image.astype(np.float32)
            augmented = self.transform(image=image, mask=label)
            image = augmented['image']
            #             print("image aug")
            label = augmented['mask']
            image = image.float()
            label = label.long()

        label = label.long()
        image = (image - image.min()) / (image.max() - image.min())
        return image.float(), label.long()


######################################################################################################################
class MOADataModule(LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int = 16):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_transform = train_transform
        self.val_transform = val_transform

    def setup(self, stage=None):
        image_dir = os.path.join(self.data_dir, 'image_npz')  # 注意这里路径的更正
        label_dir = os.path.join(self.data_dir, 'mask_npz')

        # 读取文件路径
        image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.npz')])
        label_files = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.npz')])

        # 划分训练集、验证集、测试集
        train_size = int(0.8 * len(image_files))
        val_size = int(0.1 * len(image_files))

        self.train_image_paths = image_files[:train_size]
        self.val_image_paths = image_files[train_size:train_size + val_size]
        self.test_image_paths = image_files[train_size + val_size:]

        self.train_label_paths = label_files[:train_size]
        self.val_label_paths = label_files[train_size:train_size + val_size]
        self.test_label_paths = label_files[train_size + val_size:]

    def train_dataloader(self):
        train_dataset = MultipleImageDataset(self.train_image_paths, self.train_label_paths,
                                             transform=self.train_transform)
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        val_dataset = MultipleImageDataset(self.val_image_paths, self.val_label_paths, transform=self.val_transform)
        return DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, pin_memory=True)

    def test_dataloader(self):
        test_dataset = MultipleImageDataset(self.test_image_paths, self.test_label_paths, transform=self.val_transform)
        return DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, pin_memory=True)


In [ ]:
# 定义数据集和数据加载器
data_dir = '/kaggle/input/rawniidataset/SMU_Dataset'


In [ ]:
def predict_and_log_images(num_samples=2):
    # 假设 test_loader 和 model 已经定义好了，并且 model 已经移动到了适当的设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    test_loader = data_module.test_dataloader()

    # 生成随机索引
    indices = torch.randperm(len(test_loader.dataset))[:num_samples]
    # 调整subplot的大小
    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))  # 每个样本显示3张图（原图、真实掩码、预测掩码）

    for i, idx in enumerate(indices):
        image, mask = test_loader.dataset[idx]
        image = image.unsqueeze(0).to(device)  # 添加batch维度并移动到设备
        mask = mask.squeeze()  # 移除batch维度（如果有的话）

        with torch.no_grad():
            pred = model(image)
            prediction = torch.argmax(pred, dim=1).cpu()  # 获取预测类别并移回CPU

        # 显示原始图像
        axs[i, 0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
        axs[i, 0].set_title(f'Original Image {i + 1}')
        axs[i, 0].axis('off')

        # 显示Ground Truth
        axs[i, 1].imshow(mask.cpu().numpy(), cmap='gray')
        axs[i, 1].set_title(f'True Mask {i + 1}')
        axs[i, 1].axis('off')

        # 显示预测掩码
        axs[i, 2].imshow(prediction[0].numpy(), cmap='gray')
        axs[i, 2].set_title(f'Predicted Mask {i + 1}')
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.close(fig)  # 防止在notebook中显示图像
    return fig


class ValidationCallback(Callback):
    def on_validation_epoch_end(self, trainer, pl_module):
        print("Validation epoch ended. Executing custom actions...")  # myk
        # 每10个epoch执行一次
        if (trainer.current_epoch + 1) % 10 == 0:
            fig = predict_and_log_images(num_samples=2)
            # 用wandb记录图像，或进行其他操作
            wandb.log({"Validation Callback Predicted Images": wandb.Image(fig)})

# (Opt) Data Pre-Check

In [ ]:
import matplotlib.pyplot as plt
import numpy as np

# 初始化数据模块
data_module = MOADataModule(data_dir='/kaggle/input/rawniidataset/SMU_Dataset', batch_size=16)

# 设置数据模块（准备数据）
data_module.setup()


In [None]:
# 获取训练数据加载器
train_loader = data_module.train_dataloader()

# 从数据加载器中抽取一批数据
images, labels = next(iter(train_loader))

# 选择要展示的图像数量
num_images_to_show = 4

# 创建图表来展示图像和对应的掩码
fig, axs = plt.subplots(num_images_to_show, 3, figsize=(15, num_images_to_show * 5))

for i in range(num_images_to_show):
    img = images[i].squeeze().numpy()  # 假设图像和掩码都只有一个通道
    lbl = labels[i].squeeze().numpy()
    overlay = np.ma.masked_where(lbl == 0, lbl)

    axs[i, 0].imshow(img, cmap='gray')
    axs[i, 0].set_title('Image')
    axs[i, 0].axis('off')

    axs[i, 1].imshow(lbl, cmap='gray')
    axs[i, 1].set_title('Mask')
    axs[i, 1].axis('off')

    axs[i, 2].imshow(img, cmap='gray')
    axs[i, 2].imshow(overlay, cmap='autumn', alpha=0.5)
    axs[i, 2].set_title('Overlay')
    axs[i, 2].axis('off')

    print(f"Image shape: {img.shape}")
    print(f"Label shape: {lbl.shape}")
    print(f"Image M&m Value: {img.max(), img.min()}")
    print(f"Label Unique: {np.unique(lbl)}")

plt.tight_layout()
plt.show()
print('Show Time!')

# Model Setting

移除了class DoubleConv(nn.Module):

损失函数：
你使用了monai.losses.DiceCELoss ，一个结合了Dice Loss和 Cross-Entropy Loss 的混合损失函数。
这种混合能帮助模型更稳定地收敛，并平衡精确性和召回率。 
对于医学图像分割任务，这是一个合理的选择。 
除了 Dice Loss 外，医学图像中常使用的分割损失函数还有 Focal Loss，Tversky Loss，Generalized Dice Loss 等等。
您可以进行实验，看看不同的损失函数是否会对模型的性能产生影响。

评价指标： 你使用了 torchmetrics 库计算精度、微观精度、宏观精度、Dice系数、F1 score、Jaccard系数等指标。这些指标能全面评估分割模型的性能。

优化器和学习率调度器： 使用了 Adam 优化器和 CosineAnnealingLR 学习率调度器。Adam 是深度学习中常用的自适应优化器，余弦退火学习率能有效平稳地降低学习率并有助于模型收敛。

In [None]:
class UNetTestModel(pl.LightningModule, HyperparametersMixin):
    def __init__(
            self,
            encoder_name='resnet50',
            encoder_weights='imagenet',
            in_channels=1,
            classes=14,
            #         loss_fn=monai.losses.FocalLoss(use_softmax=True, to_onehot_y=True, include_background=False),
            loss_fn=monai.losses.DiceCELoss(softmax=True, lambda_dice=0.85, lambda_ce=0.15, to_onehot_y=True,
                                            include_background=True),
            #         loss_fn=monai.losses.DiceFocalLoss(softmax=True, lambda_dice=0.85, gamma=2.3, lambda_focal=0.15, to_onehot_y=True),

            #         loss_fn=monai.losses.DiceLoss( to_onehot_y=True),
            loss_function='DiceCELoss',
            learning_rate=3e-3,
    ):
        super().__init__()
        self.save_hyperparameters()
        ###################### model #########################
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            #             decoder_attention_type='scse',
        )
        ###################### loss ##########################
        self.loss_fn = loss_fn
        ###################### metrics ######################
        #         self.train_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=classes, average='micro', ignore_index=0)
        self.val_accuracy_MACRO = torchmetrics.classification.Accuracy(task="multiclass", num_classes=classes,
                                                                       average='macro', ignore_index=0)
        self.val_accuracy_micro = torchmetrics.classification.Accuracy(task="multiclass", num_classes=classes,
                                                                       average='micro', ignore_index=0)
        self.val_accuracy_classwise = torchmetrics.classification.Accuracy(task="multiclass", num_classes=classes,
                                                                           average='none',
                                                                           ignore_index=0)  # ver 1gy new
        self.Dice = torchmetrics.classification.Dice(multiclass=True, num_classes=classes, average='micro',
                                                     ignore_index=0)
        self.F1 = torchmetrics.classification.MulticlassF1Score(num_classes=classes, average="micro", ignore_index=0)
        self.Jaccard = torchmetrics.classification.MulticlassJaccardIndex(num_classes=classes, average="micro",
                                                                          ignore_index=0)

    # 定义前向传播
    def forward(self, x):
        return self.model(x)

    # 定义单个训练步的计算流程
    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.forward(images)

        loss = self.loss_fn(outputs, labels.unsqueeze(1))
        self.log('train_loss', loss, on_step=True, on_epoch=False, logger=True, prog_bar=True)

        return loss

    # 定义单个验证步的计算流程
    def validation_step(self, batch, batch_idx):
        print("val_step going")
        images, labels = batch
        outputs = self.forward(images)
        loss = self.loss_fn(outputs, labels.unsqueeze(1))
        preds = torch.argmax(outputs, dim=1)

        # print(labels_one_hot.shape)
        acc_micro = self.val_accuracy_micro(preds, labels)
        acc_MACRO = self.val_accuracy_MACRO(preds, labels)
        Dice = self.Dice(preds, labels)
        F1 = self.F1(preds, labels)
        Jaccard = self.Jaccard(preds, labels)
        acc = self.val_accuracy_classwise(preds, labels)

        self.log('val_loss', loss, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_accuracy_micro', acc_micro, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_accuracy_MACRO', acc_MACRO, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_F1', F1, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_Dice', Dice, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_Jaccard', Jaccard, on_step=True, on_epoch=False, logger=True, prog_bar=True)

        self.log('val_acc_4', acc[4], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_5', acc[5], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_10', acc[10], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_12', acc[12], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_13', acc[13], on_step=True, on_epoch=False, logger=True, prog_bar=True)

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure, **kwargs):
        # 调用优化器的step方法前执行自定义操作
        # 比如实现学习率热启动
        if self.trainer.global_step < 50:
            lr_scale = min(1.0, float(self.trainer.global_step + 1) / 50)
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * self.hparams.learning_rate
        # 调用优化器的step方法来更新模型参数
        optimizer.step(closure=optimizer_closure)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        #         scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9, verbose=True)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0.000001, last_epoch=-1)
        #         print(self.hparams.learning_rate)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch',  # 指定更新学习率的间隔单位为'epoch'
                'frequency': 1,  # 每个epoch更新一次学习率
            }
        }

# (Opt) Hyper Params Check

In [None]:
from torchinfo import summary

model = UNetTestModel()
print(summary(model, input_size=(1, 1, 256, 256)))


In [ ]:
print(model.hparams)
# def print_model_details(model, indent=0):
#     for name, child in model.named_children():
#         print(" " * indent, name, child)
#         print_model_details(child, indent+4)

# print_model_details(model)

In [None]:
# !pip install torchviz

# from torchviz import make_dot

# x = torch.randn(1, 1, 256, 256)  # 生成一个随机输入
# y = model(x)
# make_dot(y, params=dict(list(model.named_parameters()))).render("unet_model", format="png")


# (Opt) Pre-Train Test

In [None]:
data_module = MOADataModule(data_dir=data_dir, batch_size=1)
# avail test
model = UNetTestModel()
lr_monitor = LearningRateMonitor(logging_interval='step')
############################################## fastrun ###################################################
trainer = pl.Trainer(max_epochs=20,
                     fast_dev_run=True,
                     #                      callbacks=[lr_monitor, ValidationCallback()],
                     #                      check_val_every_n_epoch=10, 
                     )
trainer.fit(model, datamodule=data_module)

In [None]:
data_module = MOADataModule(data_dir=data_dir, batch_size=1)
# learnability test
wandb_logger_test = WandbLogger()
# # 初始化Trainer，设置overfit_batches来过拟合一小部分数据
# lr_monitor = LearningRateMonitor(logging_interval='step')
test_trainer = pl.Trainer(overfit_batches=1,
                          logger=wandb_logger_test,
                          #                           callbacks=[lr_monitor, ValidationCallback()], 
                          check_val_every_n_epoch=1)

# # 或者，使用10个批次的数据来过拟合
# test_trainer = pl.Trainer(overfit_batches=10, logger=wandb_logger_test)


### type your api¶

In [ ]:
wandb_logger = WandbLogger(project='lightning_logs')

In [ ]:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_key")

wandb.login(key=secret_value_0)
# notes = f'steeldefectdetection:{path}'
wandb.init(
    project="lightning_logs",
    name="run11",
    entity="team-mykcs"
)

In [ ]:
# 运行训练
test_trainer.fit(model, datamodule=data_module)


In [ ]:
wandb.finish()

# Sweeps

In [None]:
sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'},
    'parameters': {
        'learning_rate': {
            'min': 0.0001,
            'max': 0.1},
        'batch_size': {
            'values': [16, 32, 64]}
    }
}


In [ ]:
wandb_logger = WandbLogger()
sweep_id = wandb.sweep(sweep_config, project="test-sweeps11")
lr_monitor = LearningRateMonitor(logging_interval='step')


def sweep_train():
    with wandb.init() as run:
        config = wandb.config
        model = UNetTestModel(
            learning_rate=config.learning_rate,
            #             batch_size=config.batch_size,
            # 其他参数...
        )
        trainer = pl.Trainer(
            max_epochs=150,
            callbacks=[lr_monitor, ValidationCallback()],
            logger=wandb_logger,
            #             check_val_every_n_epoch=10,
            # 其他设置...
        )
        data_module = MOADataModule(data_dir=data_dir, batch_size=config.batch_size)
        trainer.fit(model, datamodule=data_module)


wandb.agent(sweep_id, function=sweep_train, count=10)


# Lightning Style Formal Test

In [ ]:
lr_monitor = LearningRateMonitor(logging_interval='step')
# 假设你已经定义了 LiTSDataModule
data_module = MOADataModule(data_dir=data_dir, batch_size=32)
# 初始化模型和训练器
model = UNetTestModel()

# wandb_logger = WandbLogger(project="SMU MOA", name="ResUNetPP50_monaiDiceCELoss_Max150")
wandb_logger = WandbLogger(project="UNet_Compare", name="ResUNet_Max250_DiceCELoss_Baseline")
# wandb_logger = WandbLogger()


trainer = Trainer(max_epochs=250,
                  #                      fast_dev_run=True, 
                  logger=wandb_logger,
                  callbacks=[lr_monitor, ValidationCallback()],
                  #                   callbacks=[lr_monitor],
                  log_every_n_steps=1,
                  check_val_every_n_epoch=1,
                  #                   precision='16-mixed',
                  )

In [ ]:
# 创建 Tuner 对象并运行学习率查找
# tuner = Tuner(trainer)
# lr_finder = tuner.lr_find(model, datamodule=data_module)

# # 可视化找到的学习率
# fig = lr_finder.plot(suggest=True)
# display(fig)

# # 将建议的学习率设置给模型
# new_lr = lr_finder.suggestion()
# model.hparams.learning_rate = new_lr


# 注意：在此处，你不需要手动更新 DataLoader 的批量大小
# 因为 tuner.scale_batch_size 方法已经更新了 LiTSDataModule 中的 batch_size
# 你可以检查更新后的批量大小
# tuner.scale_batch_size(model, datamodule=data_module, mode="power")
# print(f"Updated batch size: {data_module.batch_size}")


In [ ]:
# 使用更新后的学习率继续训练
trainer.fit(model, datamodule=data_module)

In [ ]:
##################################################### EVA ###################################################
# 确保模型处于评估模式
model.eval()
print("eval activated")

In [None]:
def predict_and_log_images(num_samples=2):
    # 假设 test_loader 和 model 已经定义好了，并且 model 已经移动到了适当的设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    test_loader = data_module.test_dataloader()

    # 生成随机索引
    indices = torch.randperm(len(test_loader.dataset))[:num_samples]
    # 调整subplot的大小
    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))  # 每个样本显示3张图（原图、真实掩码、预测掩码）

    for i, idx in enumerate(indices):
        image, mask = test_loader.dataset[idx]
        image = image.unsqueeze(0).to(device)  # 添加batch维度并移动到设备
        mask = mask.squeeze()  # 移除batch维度（如果有的话）

        with torch.no_grad():
            pred = model(image)
            prediction = torch.argmax(pred, dim=1).cpu()  # 获取预测类别并移回CPU

        # 显示原始图像
        axs[i, 0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
        axs[i, 0].set_title(f'Original Image {i + 1}')
        axs[i, 0].axis('off')

        # 显示Ground Truth
        axs[i, 1].imshow(mask.cpu().numpy(), cmap='viridis')
        axs[i, 1].set_title(f'True Mask {i + 1}')
        axs[i, 1].axis('off')

        # 显示预测掩码
        axs[i, 2].imshow(prediction[0].numpy(), cmap='viridis')
        axs[i, 2].set_title(f'Predicted Mask {i + 1}')
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.close(fig)  # 防止在notebook中显示图像
    return fig


In [ ]:
# 循环十次，每次都记录图像
for _ in range(10):
    fig = predict_and_log_images(num_samples=2)
    wandb.log({"Predicted Images": wandb.Image(fig)})


In [None]:
def predict_and_log_images(num_samples=2):
    # 假设 test_loader 和 model 已经定义好了，并且 model 已经移动到了适当的设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    test_loader = data_module.test_dataloader()

    # 生成随机索引
    indices = torch.randperm(len(test_loader.dataset))[:num_samples]
    # 调整subplot的大小
    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))  # 每个样本显示3张图（原图、真实掩码、预测掩码）

    for i, idx in enumerate(indices):
        image, mask = test_loader.dataset[idx]
        image = image.unsqueeze(0).to(device)  # 添加batch维度并移动到设备
        mask = mask.squeeze()  # 移除batch维度（如果有的话）

        with torch.no_grad():
            pred = model(image)
            prediction = torch.argmax(pred, dim=1).cpu()  # 获取预测类别并移回CPU

        # 显示原始图像
        axs[i, 0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
        axs[i, 0].set_title(f'Original Image {i + 1}')
        axs[i, 0].axis('off')

        # 显示Ground Truth
        axs[i, 1].imshow(mask.cpu().numpy(), cmap='viridis')
        axs[i, 1].set_title(f'True Mask {i + 1}')
        axs[i, 1].axis('off')

        # 显示预测掩码
        axs[i, 2].imshow(prediction[0].numpy(), cmap='viridis')
        axs[i, 2].set_title(f'Predicted Mask {i + 1}')
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.close(fig)  # 防止在notebook中显示图像
    return fig


# 循环十次，每次都记录图像
for _ in range(10):
    fig = predict_and_log_images(num_samples=2)
    wandb.log({"Predicted Images": wandb.Image(fig)})

In [None]:
trainer.validate(dataloaders=data_module.val_dataloader())

In [None]:
import os

# 指定目录路径
directory_path = '/kaggle/working/UNet_Compare/eau6br38/checkpoints/'

# 检查目录是否存在
if os.path.exists(directory_path):
    # 列出目录中的文件
    files = os.listdir(directory_path)
    print("Files in directory:")
    for file in files:
        print(file)
else:
    print(f"Directory {directory_path} does not exist.")


In [None]:
# 创建一个新的Artifact，指定其类型为'model'和Artifact的名称
artifact = wandb.Artifact('ResUNet_Max250_DiceCELoss_Baseline', type='model')
artifact.add_file('/kaggle/working/UNet_Compare/eau6br38/checkpoints/epoch=249-step=8250.ckpt')

# 保存Artifact到wandb
wandb.log_artifact(artifact)

In [None]:
artifact_dir = WandbLogger.download_artifact(artifact="mykcs/UNet_Compare/ResUNet_Max250_DiceCELoss_Baseline:v0")

In [None]:
model_ckpt = UNetTestModel.load_from_checkpoint(
    '/kaggle/working/artifacts/ResUNet_Max250_DiceCELoss_Baseline:v0/epoch=249-step=8250.ckpt')
# 确保模型处于评估模式
model_ckpt.eval()
print("model_ckpt eval activated")




In [ ]:
def predict_and_log_images(num_samples=2):
    # 假设 test_loader 和 model 已经定义好了，并且 model 已经移动到了适当的设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_ckpt.to(device)
    test_loader = data_module.test_dataloader()

    # 生成随机索引
    indices = torch.randperm(len(test_loader.dataset))[:num_samples]
    # 调整subplot的大小
    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))  # 每个样本显示3张图（原图、真实掩码、预测掩码）

    for i, idx in enumerate(indices):
        image, mask = test_loader.dataset[idx]
        image = image.unsqueeze(0).to(device)  # 添加batch维度并移动到设备
        mask = mask.squeeze()  # 移除batch维度（如果有的话）

        with torch.no_grad():
            pred = model_ckpt(image)
            prediction = torch.argmax(pred, dim=1).cpu()  # 获取预测类别并移回CPU

        # 显示原始图像
        axs[i, 0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
        axs[i, 0].set_title(f'Original Image {i + 1}')
        axs[i, 0].axis('off')

        # 显示Ground Truth
        axs[i, 1].imshow(mask.cpu().numpy(), cmap='viridis')
        axs[i, 1].set_title(f'True Mask {i + 1}')
        axs[i, 1].axis('off')

        # 显示预测掩码
        axs[i, 2].imshow(prediction[0].numpy(), cmap='viridis')
        axs[i, 2].set_title(f'Predicted Mask {i + 1}')
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.close(fig)  # 防止在notebook中显示图像
    return fig

# 循环十次，每次都记录图像
for _ in range(10):
    fig = predict_and_log_images(num_samples=2)
    wandb.log({"Predicted Images by Chcekpoint": wandb.Image(fig)})

# Finally we go finetune on aug dataset

In [None]:
wandb.login()

In [None]:
############################# SET KEY VALUE RIGHT (arr_0 to image or mask)################################
# 定义数据增强
train_transform_for_aug = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(),
    #     A.RandomBrightnessContrast(p=0.2),
    A.ElasticTransform(p=0.2, alpha=120, sigma=120 * 0.3, alpha_affine=120 * 0.2),
    A.RandomSizedCrop(min_max_height=(128, 256), height=256, width=256, p=0.2),
    # A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])

val_transform_for_aug = A.Compose([
    A.Resize(256, 256),
    # A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])


def adjust_window(image, window_center, window_width):
    """
    调整CT图像的窗宽窗位。
    :param image: 输入的图像数组。
    :param window_center: 窗位（WC）。
    :param window_width: 窗宽（WW）。
    :return: 调整窗宽窗位后的图像。
    """
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    windowed_img = np.clip(image, img_min, img_max)
    # print(windowed_img.dtype) # NOW its float64
    return windowed_img


class MultipleImageDataset(Dataset):
    def __init__(self, image_paths, label_paths, transform=None):
        """
        image_paths: 图像文件路径列表
        label_paths: 标签文件路径列表
        transform: 应用于图像和标签的转换操作
        """
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.transform = transform

    def __len__(self):
        # 假设图像和标签列表长度相等
        return len(self.image_paths)

    def __getitem__(self, idx):  # dataloader获取每个数据都会用到这个函数，所以你应当在这里实现你需要的
        # 预处理等等步骤
        image = (np.load(self.image_paths[idx]))['image']
        #         print(image.shape)
        label = (np.load(self.label_paths[idx]))['mask']
        #         print(label.shape)

        image = adjust_window(image, window_center=40, window_width=400)

        if self.transform:
            #             image = image.astype(np.float32)
            augmented = self.transform(image=image, mask=label)
            image = augmented['image']
            #             print("image aug")
            label = augmented['mask']
            image = image.float()
            label = label.long()

        label = label.long()
        image = (image - image.min()) / (image.max() - image.min())
        return image.float(), label.long()


######################################################################################################################
class MOADataModule(LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int = 16):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_transform = train_transform_for_aug
        self.val_transform = val_transform_for_aug

    def setup(self, stage=None):
        image_dir = os.path.join(self.data_dir, 'image_npz')  # 注意这里路径的更正
        label_dir = os.path.join(self.data_dir, 'mask_npz')

        # 读取文件路径
        image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.npz')])
        label_files = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.npz')])

        # 划分训练集、验证集、测试集
        train_size = int(0.9 * len(image_files))
        val_size = int(0.08 * len(image_files))

        self.train_image_paths = image_files[:train_size]
        self.val_image_paths = image_files[train_size:train_size + val_size]
        self.test_image_paths = image_files[train_size + val_size:]

        self.train_label_paths = label_files[:train_size]
        self.val_label_paths = label_files[train_size:train_size + val_size]
        self.test_label_paths = label_files[train_size + val_size:]

    def train_dataloader(self):
        train_dataset = MultipleImageDataset(self.train_image_paths, self.train_label_paths,
                                             transform=self.train_transform)
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        val_dataset = MultipleImageDataset(self.val_image_paths, self.val_label_paths, transform=self.val_transform)
        return DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, pin_memory=True)

    def test_dataloader(self):
        test_dataset = MultipleImageDataset(self.test_image_paths, self.test_label_paths, transform=self.val_transform)
        return DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, pin_memory=True)


######################################################################################################################
class UNetTestModel(pl.LightningModule, HyperparametersMixin):
    def __init__(
            self,
            encoder_name='resnet50',
            encoder_weights='imagenet',
            in_channels=1,
            classes=14,
            #         loss_fn=monai.losses.FocalLoss(use_softmax=True, to_onehot_y=True, include_background=False),
            #         loss_fn=monai.losses.DiceCELoss(softmax=True, lambda_dice=0.85, lambda_ce=0.15, to_onehot_y=True, include_background=True),
            loss_fn=monai.losses.DiceFocalLoss(softmax=True, lambda_dice=0.6, gamma=2.3, lambda_focal=0.4,
                                               to_onehot_y=True),

            #         loss_fn=monai.losses.DiceLoss( to_onehot_y=True),
            loss_function='DiceCELoss',
            learning_rate=3e-3,
    ):
        super().__init__()
        self.save_hyperparameters()
        ###################### model #########################
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            #             decoder_attention_type='scse',
        )
        ###################### loss ##########################
        self.loss_fn = loss_fn
        ###################### metrics ######################
        #         self.train_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=classes, average='micro', ignore_index=0)
        self.val_accuracy_MACRO = torchmetrics.classification.Accuracy(task="multiclass", num_classes=classes,
                                                                       average='macro', ignore_index=0)
        self.val_accuracy_micro = torchmetrics.classification.Accuracy(task="multiclass", num_classes=classes,
                                                                       average='micro', ignore_index=0)
        self.val_accuracy_classwise = torchmetrics.classification.Accuracy(task="multiclass", num_classes=classes,
                                                                           average='none', ignore_index=0)
        self.Dice = torchmetrics.classification.Dice(multiclass=True, num_classes=classes, average='micro',
                                                     ignore_index=0)
        self.F1 = torchmetrics.classification.MulticlassF1Score(num_classes=classes, average="micro", ignore_index=0)
        self.Jaccard = torchmetrics.classification.MulticlassJaccardIndex(num_classes=classes, average="micro",
                                                                          ignore_index=0)

    # 定义前向传播
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.forward(images)
        loss = self.loss_fn(outputs, labels.unsqueeze(1))
        self.log('train_loss', loss, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.forward(images)
        loss = self.loss_fn(outputs, labels.unsqueeze(1))
        preds = torch.argmax(outputs, dim=1)

        #         print(labels_one_hot.shape)
        acc_micro = self.val_accuracy_micro(preds, labels)
        acc_MACRO = self.val_accuracy_MACRO(preds, labels)
        Dice = self.Dice(preds, labels)
        F1 = self.F1(preds, labels)
        Jaccard = self.Jaccard(preds, labels)
        acc = self.val_accuracy_classwise(preds, labels)

        self.log('val_loss', loss, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_accuracy_micro', acc_micro, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_accuracy_MACRO', acc_MACRO, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_F1', F1, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_Dice', Dice, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_Jaccard', Jaccard, on_step=True, on_epoch=False, logger=True, prog_bar=True)

        self.log('val_acc_4', acc[4], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_5', acc[5], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_10', acc[10], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_12', acc[12], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_13', acc[13], on_step=True, on_epoch=False, logger=True, prog_bar=True)

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure, **kwargs):
        # 调用优化器的step方法前执行自定义操作
        # 比如实现学习率热启动
        if self.trainer.global_step < 100:
            lr_scale = min(1.0, float(self.trainer.global_step + 1) / 100)
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * self.hparams.learning_rate
        # 调用优化器的step方法来更新模型参数
        optimizer.step(closure=optimizer_closure)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        #         scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9, verbose=True)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=0.000001, last_epoch=-1)
        #         print(self.hparams.learning_rate)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch',  # 指定更新学习率的间隔单位为'epoch'
                'frequency': 1,  # 每个epoch更新一次学习率
            }
        }

################################################# SET UP ################################################
# 初始化数据模块
data_module_ckpt = MOADataModule(data_dir='/kaggle/input/aug-dataset-for-fine-tune/AUG_Dataset', batch_size=1)

# 设置数据模块（准备数据）
data_module_ckpt.setup()


################################################# CALL BACK #############################################
def predict_and_log_images(num_samples=2):
    # 假设 test_loader 和 model 已经定义好了，并且 model 已经移动到了适当的设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_ckpt.to(device)
    valid_loader = data_module_ckpt.val_dataloader()

    # 生成随机索引
    indices = torch.randperm(len(valid_loader.dataset))[:num_samples]
    # 调整subplot的大小
    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))  # 每个样本显示3张图（原图、真实掩码、预测掩码）

    for i, idx in enumerate(indices):
        image, mask = valid_loader.dataset[idx]
        image = image.unsqueeze(0).to(device)  # 添加batch维度并移动到设备
        mask = mask.squeeze()  # 移除batch维度（如果有的话）

        with torch.no_grad():
            pred = model_ckpt(image)
            prediction = torch.argmax(pred, dim=1).cpu()  # 获取预测类别并移回CPU

        # 显示原始图像
        axs[i, 0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
        axs[i, 0].set_title(f'Original Image {i + 1}')
        axs[i, 0].axis('off')

        # 显示Ground Truth
        axs[i, 1].imshow(mask.cpu().numpy(), cmap='viridis')
        axs[i, 1].set_title(f'True Mask {i + 1}')
        axs[i, 1].axis('off')

        # 显示预测掩码
        axs[i, 2].imshow(prediction[0].numpy(), cmap='viridis')
        axs[i, 2].set_title(f'Predicted Mask {i + 1}')
        axs[i, 2].axis('off')

    plt.tight_layout()
    plt.close(fig)  # 防止在notebook中显示图像
    return fig


class ValidationCallback(Callback):
    def on_validation_epoch_end(self, trainer, pl_module):
        # 每2个epoch执行一次
        if (trainer.current_epoch + 1) % 2 == 0:
            fig = predict_and_log_images(num_samples=2)
            # 用wandb记录图像，或进行其他操作
            wandb.log({"Validation Callback Predicted Images": wandb.Image(fig)})

In [None]:
artifact_dir = WandbLogger.download_artifact(artifact="southern/UNet Compare/ResUNet_Max250_DiceCELoss_Baseline:v0")

In [None]:
model_ckpt = UNetTestModel.load_from_checkpoint(
    '/kaggle/working/artifacts/ResUNet_Max250_DiceCELoss_Baseline:v0/epoch=249-step=8250.ckpt')

In [None]:
# 初始化数据模块
data_module_ckpt = MOADataModule(data_dir='/kaggle/input/aug-dataset-for-fine-tune/AUG_Dataset', batch_size=1)

# 设置数据模块（准备数据）
data_module_ckpt.setup()

# 获取训练数据加载器
train_loader_ckpt = data_module_ckpt.train_dataloader()
valid_loader_ckpt = data_module_ckpt.val_dataloader()

len(train_loader_ckpt), len(valid_loader_ckpt)

In [None]:
# 初始化数据模块
data_module_ckpt = MOADataModule(data_dir='/kaggle/input/aug-dataset-for-fine-tune/AUG_Dataset', batch_size=16)

# 设置数据模块（准备数据）
data_module_ckpt.setup()

# 获取训练数据加载器
train_loader_ckpt = data_module_ckpt.train_dataloader()

# 从数据加载器中抽取一批数据
images, labels = next(iter(train_loader_ckpt))

# 选择要展示的图像数量
num_images_to_show = 4

# 创建图表来展示图像和对应的掩码
fig, axs = plt.subplots(num_images_to_show, 3, figsize=(15, num_images_to_show * 5))

for i in range(num_images_to_show):
    img = images[i].squeeze().numpy()  # 假设图像和掩码都只有一个通道
    lbl = labels[i].squeeze().numpy()
    overlay = np.ma.masked_where(lbl == 0, lbl)

    axs[i, 0].imshow(img, cmap='gray')
    axs[i, 0].set_title('Image')
    axs[i, 0].axis('off')

    axs[i, 1].imshow(lbl, cmap='gray')
    axs[i, 1].set_title('Mask')
    axs[i, 1].axis('off')

    axs[i, 2].imshow(img, cmap='gray')
    axs[i, 2].imshow(overlay, cmap='autumn', alpha=0.5)
    axs[i, 2].set_title('Overlay')
    axs[i, 2].axis('off')

    print(f"Image shape: {img.shape}")
    print(f"Label shape: {lbl.shape}")
    print(f"Image M&m Value: {img.max(), img.min()}")
    print(f"Label Unique: {np.unique(lbl)}")

plt.tight_layout()
plt.show()
print('Show Time!')

In [None]:
lr_monitor = LearningRateMonitor(logging_interval='step')

data_module_ckpt = MOADataModule(data_dir='/kaggle/input/aug-dataset-for-fine-tune/AUG_Dataset', batch_size=1)

# wandb_logger = WandbLogger(project="SMU MOA", name="ResUNetPP50_monaiDiceCELoss_Max150")
wandb_logger = WandbLogger(project="UNet Compare", name="ResUNet_Max100_DiceFocalLoss_ckpt_bs&lrFinderUsed")
# wandb_logger = WandbLogger()


trainer_ckpt = Trainer(max_epochs=100,
                       #                      fast_dev_run=True, 
                       logger=wandb_logger,
                       callbacks=[lr_monitor, ValidationCallback()],
                       #                   callbacks=[lr_monitor],
                       log_every_n_steps=1,
                       check_val_every_n_epoch=1,
                       #                   precision='16-mixed',
                       )

# 创建 Tuner 对象并运行学习率查找
tuner = Tuner(trainer_ckpt)
lr_finder_ckpt = tuner.lr_find(model_ckpt, datamodule=data_module_ckpt)

# 可视化找到的学习率
fig = lr_finder_ckpt.plot(suggest=True)
display(fig)

# 将建议的学习率设置给模型
new_lr = lr_finder_ckpt.suggestion()
model_ckpt.hparams.learning_rate = new_lr

# 注意：在此处，你不需要手动更新 DataLoader 的批量大小
# 因为 tuner.scale_batch_size 方法已经更新了 LiTSDataModule 中的 batch_size
# 你可以检查更新后的批量大小
tuner.scale_batch_size(model_ckpt, datamodule=data_module_ckpt, mode="power")
print(f"Updated batch size: {data_module_ckpt.batch_size}")

# 使用更新后的学习率继续训练
trainer_ckpt.fit(model_ckpt, datamodule=data_module_ckpt)

In [None]:
wandb.finish()