In [5]:
class UNetTestModel(pl.LightningModule, HyperparametersMixin):
    def __init__(
        self,
        encoder_name='mit_b2',
        encoder_weights='imagenet',
        in_channels=3,
        classes=14,
#         loss_fn=monai.losses.FocalLoss(use_softmax=True, to_onehot_y=True, include_background=False),
#         loss_fn=DiceCELossWithKL(softmax=True, lambda_dice=0.85, lambda_ce=0.15, lambda_kl=2.0, to_onehot_y=True, include_background=True),
        loss_fn=monai.losses.DiceCELoss(softmax=True, lambda_dice=0.85, lambda_ce=0.15, to_onehot_y=True),

#         loss_fn=monai.losses.DiceLoss( to_onehot_y=True),
        loss_function='DiceCELossWithKL',
        learning_rate=3e-3,
        ):
        super().__init__()
        self.save_hyperparameters()
        ###################### model #########################
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
#             decoder_attention_type='scse',
        )
        ###################### loss ##########################
        self.loss_fn = loss_fn
        ###################### metrics ######################
#         self.train_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=classes, average='micro', ignore_index=0)
#         self.val_accuracy_MACRO = torchmetrics.classification.Accuracy(task="multiclass", num_classes=14, average='macro', ignore_index=0)
        self.val_accuracy = torchmetrics.classification.Accuracy(task="multiclass", num_classes=14, average='macro', ignore_index=0)
        self.val_accuracy_classwise = torchmetrics.classification.Accuracy(task="multiclass", num_classes=14, average='none', ignore_index=0)
        self.Dice = torchmetrics.classification.Dice(multiclass=True, num_classes=14, average='macro', ignore_index=0)
        self.F1 = torchmetrics.classification.MulticlassF1Score(num_classes=14, average="macro", ignore_index=0)
        self.Jaccard = torchmetrics.classification.MulticlassJaccardIndex(num_classes=14, average="macro", ignore_index=0)


    # 定义前向传播
    def forward(self, x):
        x = x.repeat(1,3,1,1)
        return self.model(x)


    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.forward(images)
        loss = self.loss_fn(outputs, labels.unsqueeze(1))
        self.log('train_loss', loss, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.forward(images)
        loss = self.loss_fn(outputs, labels.unsqueeze(1))
#         outputs = F.softmax(outputs, dim=1)  # 将logits转换为概率
#         preds = torch.argmax(outputs, dim=1)
#         labels = (F.one_hot(labels, num_classes=14)).permute(0, 3, 1, 2)
#         print(f"Images shape: {images.shape}")
#         print(f"Labels shape: {labels.shape}")
#         print(f"Outputs shape: {outputs.shape}")
#         print(f"Labels with unsqueezed dim shape: {labels.unsqueeze(1).shape}")
#         print(f"Predictions shape: {preds.shape}")
#         print(labels_one_hot.shape)
# loss = self.loss_fn(outputs, labels.unsqueeze(1))
# Images shape: torch.Size([1, 1, 256, 256])
# Labels shape: torch.Size([1, 256, 256])
# Outputs shape: torch.Size([1, 14, 256, 256])
# Labels with unsqueezed dim shape: torch.Size([1, 1, 256, 256])
# Predictions shape: torch.Size([1, 256, 256])
#         acc_micro = self.val_accuracy_micro(outputs, labels)
        accuracy = self.val_accuracy(outputs, labels)
        Dice = self.Dice(outputs, labels)
        F1 = self.F1(outputs, labels)
        Jaccard = self.Jaccard(outputs, labels)
        acc = self.val_accuracy_classwise(outputs, labels)
        
        self.log('val_loss', loss, on_step=True, on_epoch=False, logger=True, prog_bar=True)
#         self.log('val_accuracy_micro', acc_micro, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_accuracy', accuracy, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_F1', F1, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_Dice', Dice, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_Jaccard', Jaccard, on_step=True, on_epoch=False, logger=True, prog_bar=True)
        
        self.log('val_acc_4', acc[4], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_5', acc[5], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_10', acc[10], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_12', acc[12], on_step=True, on_epoch=False, logger=True, prog_bar=True)
        self.log('val_acc_13', acc[13], on_step=True, on_epoch=False, logger=True, prog_bar=True)

        # 返回loss和其他可能需要的信息，但不在这里记录loss
        return {"loss": loss, "accuracy": accuracy}
    # def on_validation_epoch_end(self, outputs):
    #     # outputs 是一个列表，包含每个验证批次的输出
    #     avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
    #     avg_accuracy = torch.stack([x['accuracy'] for x in outputs]).mean()
    #     avg_dice = torch.stack([x['Dice'] for x in outputs]).mean()
    #     avg_f1 = torch.stack([x['F1'] for x in outputs]).mean()
    #     avg_jaccard = torch.stack([x['Jaccard'] for x in outputs]).mean()

    #     # 将平均损失和平均精度记录到日志中
    #     self.log('avg_val_loss', avg_loss, on_epoch=True, prog_bar=True, logger=True)
    #     self.log('avg_val_accuracy', avg_accuracy, on_epoch=True, prog_bar=True, logger=True)
    #     self.log('avg_val_dice', avg_dice, on_epoch=True, prog_bar=True, logger=True)
    #     self.log('avg_val_f1', avg_f1, on_epoch=True, prog_bar=True, logger=True)
    #     self.log('avg_val_jaccard', avg_jaccard, on_epoch=True, prog_bar=True, logger=True)

    #     # 你可以选择记录其他指标或进行其他的后处理
    #     # 例如，你可能想要记录每个类别的平均精度
    #     classwise_accuracy = torch.stack([x['classwise_acc'] for x in outputs]).mean(dim=0)
    #     for i, acc in enumerate(classwise_accuracy):
    #         self.log(f'avg_val_acc_class_{i}', acc, on_epoch=True, prog_bar=True, logger=True)


    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure, **kwargs):
        if self.trainer.global_step < 50:
            lr_scale = min(1.0, float(self.trainer.global_step + 1) / 50)
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * self.hparams.learning_rate
        # 调用优化器的step方法来更新模型参数
        optimizer.step(closure=optimizer_closure)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
#         scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9, verbose=True)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=0.000001, last_epoch=-1)
#         print(self.hparams.learning_rate)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch', 
                'frequency': 1, 
            }
        }

# (Opt) Hyper Params Check

In [None]:
# from torchinfo import summary
model = UNetTestModel()
# summary(model, input_size=(1,1,256,256))
# summary(model1, input_size=(1,1,256,256), verbose=2)

# print(model.hparams)
# def print_model_details(model, indent=0):
#     for name, child in model.named_children():
#         print(" " * indent, name, child)
#         print_model_details(child, indent+4)

def print_model_details(model, indent=0, max_depth=1):
    for name, child in model.named_children():
        print(" " * indent + f"({name}): {child.__class__.__name__}")
        if indent < max_depth * 4:  # 增加缩进条件来控制输出深度
            print_model_details(child, indent + 4, max_depth)

print_model_details(model, max_depth=4)


In [13]:
from torchviz import make_dot
from graphviz import Digraph

def add_graph_node(dot, var, seen):
    if var not in seen:
        if torch.is_tensor(var):
            dot.node(str(id(var)), str(var.size()), fillcolor='orange')
        elif hasattr(var, 'variable'):
            u = var.variable
            name = str(id(u)) + ' ' + str(u.size())
            dot.node(name, name, fillcolor='lightblue')
        else:
            dot.node(str(id(var)), str(type(var).__name__))
        seen.add(var)
        if hasattr(var, 'next_functions'):
            for u in var.next_functions:
                if u[0] is not None:
                    dot.edge(str(id(u[0])), str(id(var)))
                    add_graph_node(dot, u[0], seen)
        if hasattr(var, 'saved_tensors'):
            for t in var.saved_tensors:
                dot.edge(str(id(t)), str(id(var)))
                add_graph_node(dot, t, seen)
                
x = torch.randn(1, 1, 256, 256)  # 生成一个随机输入
y = model(x)
dot = Digraph(comment='UNet Architecture', format='png')
add_graph_node(dot, y.grad_fn, set())

dot.render('unet_model_very_simplified')


dot: graph is too large for cairo-renderer bitmaps. Scaling by 0.224889 to fit


'unet_model_very_simplified.png'

# (Opt) Pre-Train Test

In [33]:
data_module = MOADataModule(data_dir=data_dir, batch_size=1)
# avail test
model = UNetTestModel()
lr_monitor = LearningRateMonitor(logging_interval='step')
############################################## fastrun ###################################################
trainer = pl.Trainer(max_epochs=20, 
                     fast_dev_run=True, 
#                      callbacks=[lr_monitor, ValidationCallback()],
#                      check_val_every_n_epoch=10, 
                    )
trainer.fit(model, datamodule=data_module)

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
data_module = MOADataModule(data_dir=data_dir, batch_size=1)
# learnability test
wandb_logger_test = WandbLogger()
# # 初始化Trainer，设置overfit_batches来过拟合一小部分数据
# lr_monitor = LearningRateMonitor(logging_interval='step')
test_trainer = pl.Trainer(overfit_batches=1, 
                          logger=wandb_logger_test, 
#                           callbacks=[lr_monitor, ValidationCallback()], 
                          check_val_every_n_epoch=1)

# # 或者，使用10个批次的数据来过拟合
# test_trainer = pl.Trainer(overfit_batches=10, logger=wandb_logger_test)

# 运行训练
test_trainer.fit(model, datamodule=data_module)
wandb.finish()

# Sweeps

In [None]:
sweep_config = {
  'method': 'random', 
  'metric': {
    'name': 'val_loss',
    'goal': 'minimize'}, 
  'parameters': {
    'learning_rate': {
      'min': 0.0001,
      'max': 0.1},
    'batch_size': {
      'values': [16, 32, 64]}
    }
}
wandb_logger = WandbLogger()
sweep_id = wandb.sweep(sweep_config, project="test sweeps2")
lr_monitor = LearningRateMonitor(logging_interval='step')

def sweep_train():
    with wandb.init() as run:
        config = wandb.config
        model = UNetTestModel(
            learning_rate=config.learning_rate,
#             batch_size=config.batch_size,
            # 其他参数...
        )
        trainer = pl.Trainer(
            max_epochs=150,
            callbacks=[lr_monitor, ValidationCallback()],
            logger=wandb_logger,
#             check_val_every_n_epoch=10,
            # 其他设置...
        )
        data_module = MOADataModule(data_dir=data_dir, batch_size=config.batch_size)
        trainer.fit(model, datamodule=data_module)

wandb.agent(sweep_id, function=sweep_train, count=10)


# Lightning Style Formal Test

In [6]:
lr_monitor = LearningRateMonitor(logging_interval='step')
# 假设你已经定义了 LiTSDataModule
data_module = MOADataModule(data_dir=data_dir, batch_size=32)
# 初始化模型和训练器
model = UNetTestModel()

# wandb_logger = WandbLogger(project="SMU MOA", name="ResUNetPP50_monaiDiceCELoss_Max150")
wandb_logger = WandbLogger(project="UNet Compare", name="MViTUNet_Max120_DiceCELoss_basicAUG")
# wandb_logger = WandbLogger()

# 设置ModelCheckpoint以每20轮保存一次模型
checkpoint_callback = ModelCheckpoint(
    dirpath='my_model/',  # 模型保存路径
    filename='model-{epoch:02d}',  # 文件名包含 epoch
    save_top_k=-1,  # 设置为 -1 以保存所有检查点
    every_n_epochs=20,  # 每20轮保存一次
    save_on_train_epoch_end=True  # 确保在训练轮结束时保存
)

trainer = Trainer(max_epochs=120, 
#                      fast_dev_run=True, 
                  logger=wandb_logger, 
                  callbacks=[lr_monitor, ValidationCallback(), checkpoint_callback],
#                   callbacks=[lr_monitor],
                  log_every_n_steps=1,
                  check_val_every_n_epoch=1,
#                   precision='16-mixed',
                  )


# tuner = Tuner(trainer)
# lr_finder = tuner.lr_find(model, datamodule=data_module)

# fig = lr_finder.plot(suggest=True)
# display(fig)

# new_lr = lr_finder.suggestion()
# model.hparams.learning_rate = new_lr


# tuner.scale_batch_size(model, datamodule=data_module, mode="power")
# print(f"Updated batch size: {data_module.batch_size}")

# 使用更新后的学习率继续训练
trainer.fit(model, datamodule=data_module)

# ##################################################### EVA ###################################################
# 确保模型处于评估模式
model.eval()
print("eval activated")

# 循环十次，每次都记录图像
for _ in range(10):
    fig = predict_and_log_images(num_samples=2)
    wandb.log({"Predicted Images": wandb.Image(fig)})

/opt/conda/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'loss_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss_fn'])`.
Downloading: "https://github.com/qubvel/segmentation_models.pytorch/releases/download/v0.0.2/mit_b2.pth" to /root/.cache/torch/hub/checkpoints/mit_b2.pth
100%|██████████| 94.3M/94.3M [00:01<00:00, 72.9MB/s]
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

wandb: Network error (ReadTimeout), entering retry loop.
wandb: Network error (ReadTimeout), entering retry loop.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

wandb: Network error (ReadTimeout), entering retry loop.


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

NameError: name '确保模型处于评估模式' is not defined

In [None]:
wandb.finish()

In [7]:
# def predict_and_log_images(num_samples=2):
    
#     # 假设 test_loader 和 model 已经定义好了，并且 model 已经移动到了适当的设备
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model.to(device)
#     test_loader = data_module.test_dataloader()
    
#     # 生成随机索引
#     indices = torch.randperm(len(test_loader.dataset))[:num_samples]
#     # 调整subplot的大小
#     fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))  # 每个样本显示3张图（原图、真实掩码、预测掩码）

#     for i, idx in enumerate(indices):
#         image, mask = test_loader.dataset[idx]
#         image = image.unsqueeze(0).to(device)  # 添加batch维度并移动到设备
#         mask = mask.squeeze()  # 移除batch维度（如果有的话）

#         with torch.no_grad():
#             pred = model(image)
#             prediction = torch.argmax(pred, dim=1).cpu()  # 获取预测类别并移回CPU

#         # 显示原始图像
#         axs[i, 0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
#         axs[i, 0].set_title(f'Original Image {i+1}')
#         axs[i, 0].axis('off')

#         # 显示Ground Truth
#         axs[i, 1].imshow(mask.cpu().numpy(), cmap='viridis')
#         axs[i, 1].set_title(f'True Mask {i+1}')
#         axs[i, 1].axis('off')

#         # 显示预测掩码
#         axs[i, 2].imshow(prediction[0].numpy(), cmap='viridis')
#         axs[i, 2].set_title(f'Predicted Mask {i+1}')
#         axs[i, 2].axis('off')

#     plt.tight_layout()
#     plt.close(fig)  # 防止在notebook中显示图像
#     return fig
# 循环十次，每次都记录图像
# 确保模型处于评估模式
model.eval()
print("eval activated")
for _ in range(10):
    fig = predict_and_log_images(num_samples=2)
    wandb.log({"Predicted Images 2": wandb.Image(fig)})
print("images logged")

eval activated
images logged


In [8]:
trainer.validate(model, datamodule=data_module)

Validation: |          | 0/? [00:00<?, ?it/s]

[{}]

# Following Process (Continual Training)

**if you just finished up training above**

In [9]:
import os

def list_files(startpath):
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print(f'{indent}{os.path.basename(root)}/')
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            print(f'{subindent}{f}')

# 指定你的目录路径
directory_path = '/kaggle/working/UNet Compare/'
list_files(directory_path)

/
ygf7rtvn/
    checkpoints/
        epoch=119-step=3960.ckpt


In [10]:
# 创建一个新的Artifact，指定其类型为'model'和Artifact的名称
artifact = wandb.Artifact('MViTUNet_Max120_DiceCELoss_basicAUG', type='model')
artifact.add_file('/kaggle/working/UNet Compare/ygf7rtvn/checkpoints/epoch=119-step=3960.ckpt')

# 保存Artifact到wandb
wandb.log_artifact(artifact)

<Artifact MViTUNet_Max120_DiceCELoss_basicAUG>

**if you want to continual training from wandb**

In [None]:
wandb.login()
artifact_dir = WandbLogger.download_artifact(artifact="southern/UNet Compare/ResUNet_Max120_DiceCELossWithKL_basicAug:v0")
model_ckpt = UNetTestModel.load_from_checkpoint('/kaggle/working/artifacts/ResUNet_Max120_DiceCELossWithKL_basicAug:v0/epoch=119-step=3960.ckpt')

In [None]:
# 定义数据增强
train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.2),
    A.VerticalFlip(p=0.2),
    A.RandomRotate90(),
#     A.RandomBrightnessContrast(p=0.2),
    A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.3, alpha_affine=120 * 0.2),
    A.RandomSizedCrop(min_max_height=(128, 256), height=256, width=256, p=0.5),
    # A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(256, 256),
    # A.Normalize(mean=(0.5,), std=(0.5,)),
    ToTensorV2(),
])

def adjust_window(image, window_center, window_width):
    """
    调整CT图像的窗宽窗位。
    :param image: 输入的图像数组。
    :param window_center: 窗位（WC）。
    :param window_width: 窗宽（WW）。
    :return: 调整窗宽窗位后的图像。
    """
    img_min = window_center - window_width // 2
    img_max = window_center + window_width // 2
    windowed_img = np.clip(image, img_min, img_max)
    # print(windowed_img.dtype) # NOW its float64
    return windowed_img

class MultipleImageDataset(Dataset):
    def __init__(self, image_paths, label_paths, transform=None):
        """
        image_paths: 图像文件路径列表
        label_paths: 标签文件路径列表
        transform: 应用于图像和标签的转换操作
        """
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.transform = transform

    def __len__(self):
        # 假设图像和标签列表长度相等
        return len(self.image_paths)

    def __getitem__(self, idx): # dataloader获取每个数据都会用到这个函数，所以你应当在这里实现你需要的
                                # 预处理等等步骤
        image = (np.load(self.image_paths[idx]))['arr_0'] 
#         print(image.shape)
        label = (np.load(self.label_paths[idx]))['arr_0']
#         print(label.shape)

        image = adjust_window(image, window_center=40, window_width=400)

        if self.transform:
#             image = image.astype(np.float32)
            augmented = self.transform(image=image, mask=label)
            image = augmented['image']
#             print("image aug")
            label = augmented['mask']
            image = image.float()
            label = label.long()

        label = label.long()
        image = (image - image.min()) / (image.max() - image.min())
        return image.float(), label.long()
######################################################################################################################
class MOADataModule(LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int = 16):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.train_transform = train_transform
        self.val_transform = val_transform

    def setup(self, stage=None):
        image_dir = os.path.join(self.data_dir, 'image_npz')  # 注意这里路径的更正
        label_dir = os.path.join(self.data_dir, 'mask_npz')

        # 读取文件路径
        image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.npz')])
        label_files = sorted([os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith('.npz')])

        # 划分训练集、验证集、测试集
        train_size = int(0.8 * len(image_files))
        val_size = int(0.1 * len(image_files))

        self.train_image_paths = image_files[:train_size]
        self.val_image_paths = image_files[train_size:train_size + val_size]
        self.test_image_paths = image_files[train_size + val_size:]

        self.train_label_paths = label_files[:train_size]
        self.val_label_paths = label_files[train_size:train_size + val_size]
        self.test_label_paths = label_files[train_size + val_size:]

    def train_dataloader(self):
        train_dataset = MultipleImageDataset(self.train_image_paths, self.train_label_paths, transform=self.train_transform)
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def val_dataloader(self):
        val_dataset = MultipleImageDataset(self.val_image_paths, self.val_label_paths, transform=self.val_transform)
        return DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, pin_memory=True)

    def test_dataloader(self):
        test_dataset = MultipleImageDataset(self.test_image_paths, self.test_label_paths, transform=self.val_transform)
        return DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2, pin_memory=True)

# 定义数据集和数据加载器
# WARNING: 下面的数据目录是针对特定Kaggle数据集的示例路径。
#          请根据您的实际数据集位置修改此路径。
data_dir = '/kaggle/input/aug-dataset-for-fine-tune/AUG_dataset'


def draw_bounding_boxes(ax, mask, class_ids):
    """ 在给定的轴上绘制边界框，突出显示特定类别 """
    for class_id in class_ids:
        positions = np.argwhere(mask == class_id)
        if positions.size > 0:
            xmin, xmax = positions[:, 0].min(), positions[:, 0].max()
            ymin, ymax = positions[:, 1].min(), positions[:, 1].max()
            rect = Rectangle((ymin, xmin), ymax - ymin, xmax - xmin, linewidth=2, edgecolor='red', facecolor='none')
            ax.add_patch(rect)
            ax.text(ymin, xmin, f'Class {class_id}', color='red', fontsize=12, va='top', ha='left')

def predict_and_log_images(num_samples=2, model=model_ckpt, data_module=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    test_loader = data_module.test_dataloader()

    # 生成随机索引
    indices = torch.randperm(len(test_loader.dataset))[:num_samples]
    special_classes = [4, 5, 10, 12, 13]

    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
    cmap = plt.get_cmap('tab20')
    for i, idx in enumerate(indices):
        image, mask = test_loader.dataset[idx]
        class_labels = np.unique(mask)
        colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
        image = image.unsqueeze(0).to(device)
        mask = mask.squeeze()

        with torch.no_grad():
            pred = model(image)
            prediction = torch.argmax(pred, dim=1).cpu().squeeze()

        axs[i, 0].imshow(image.squeeze().cpu().numpy(), cmap='gray')
        axs[i, 0].set_title(f'Original Image {i+1}')
        axs[i, 0].axis('off')

        axs[i, 1].imshow(mask.cpu().numpy(), cmap='tab20')
        axs[i, 1].set_title(f'True Mask {i+1}')
        axs[i, 1].axis('off')

        axs[i, 2].imshow(prediction.numpy(), cmap='tab20')
        axs[i, 2].set_title(f'Predicted Mask {i+1}')
        axs[i, 2].axis('off')

        # 在真实和预测掩码上绘制边界框
        draw_bounding_boxes(axs[i, 1], mask.cpu().numpy(), special_classes)
        draw_bounding_boxes(axs[i, 2], prediction.numpy(), special_classes)

    legend_elements = [Patch(facecolor=colors[i], label=f'Class {class_labels[i]}') for i in range(len(class_labels))]
    fig.legend(handles=legend_elements, loc='upper center', ncol=len(class_labels), title="Classes")
    plt.tight_layout()
    plt.close(fig)  # 防止在notebook中显示图像

    return fig


class ValidationCallback(Callback):
    def on_validation_epoch_end(self, trainer, pl_module):
        # 每10个epoch执行一次
        if (trainer.current_epoch + 1) % 1 == 0:
            fig = predict_and_log_images(num_samples=2,model=model_ckpt,data_module=data_module)
            # 用wandb记录图像，或进行其他操作
            wandb.log({"Validation Callback Predicted Images": wandb.Image(fig)})

In [None]:

#  初始化数据模块
data_module = MOADataModule(data_dir='/kaggle/input/aug-dataset-for-fine-tune/AUG_dataset', batch_size=16)
# 设置数据模块（准备数据）
data_module.setup()
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ckpt.to(device)
model_ckpt.eval()  # 设置为评估模式

# 获取测试数据
test_loader = data_module.test_dataloader()
special_classes = [4, 5, 10, 12, 13]
# 取一个批次的数据进行演示
for images, masks in test_loader:
    images, masks = images.to(device), masks.to(device)
    with torch.no_grad():
        predictions = model_ckpt(images)  # 进行预测
    predictions = torch.argmax(predictions, dim=1)  # 转换成类别标签
    break  # 这里我们只处理一个批次作为示例

# 将数据转移到CPU并转换为numpy
images = images.cpu().numpy()
masks = masks.cpu().numpy()
predictions = predictions.cpu().numpy()

# 显示图像、真实掩码和预测掩码
fig, axs = plt.subplots(len(images), 3, figsize=(20, 5 * len(images)))  # 根据批次大小设置子图
class_labels = np.unique(masks)  # 获取类别标签
cmap = plt.get_cmap('tab20')  # 获取颜色映射
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]

for i, (img, mask, pred) in enumerate(zip(images, masks, predictions)):
    axs[i, 0].imshow(img[0], cmap='gray')  # 假设图片是单通道的
    axs[i, 0].set_title('Original Image')
    axs[i, 0].axis('off')

    axs[i, 1].imshow(mask, cmap=cmap)  # 假设掩码是单通道的
    axs[i, 1].set_title('True Mask')
    axs[i, 1].axis('off')

    axs[i, 2].imshow(pred, cmap=cmap)
    axs[i, 2].set_title('Predicted Mask')
    axs[i, 2].axis('off')
    
    draw_bounding_boxes(axs[i, 1], mask, special_classes)
    draw_bounding_boxes(axs[i, 2], mask, special_classes)

# 创建图例
legend_elements = [Patch(facecolor=colors[i], label=f'Class {class_labels[i]}') for i in range(len(class_labels))]
fig.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, 1), ncol=len(class_labels), title="Classes")

plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # 调整子图布局，留出底部空间给图例

plt.show()


In [None]:

# wandb_logger = WandbLogger(project="SMU MOA", name="ResUNetPP50_monaiDiceCELoss_Max150")
wandb_logger = WandbLogger(project="UNet Compare", name="Continual_ResUNet_Max120+120_DiceCELosswithKL_HeavyAug_FreqAug")
# wandb_logger = WandbLogger()

# 设置ModelCheckpoint以每20轮保存一次模型
checkpoint_callback = ModelCheckpoint(
    dirpath='my_model/',  # 模型保存路径
    filename='model-{epoch:02d}',  # 文件名包含 epoch
    save_top_k=-1,  # 设置为 -1 以保存所有检查点
    every_n_epochs=20,  # 每20轮保存一次
    save_on_train_epoch_end=True  # 确保在训练轮结束时保存
)

lr_monitor = LearningRateMonitor(logging_interval='step')
# 假设你已经定义了 LiTSDataModule
data_module = MOADataModule(data_dir=data_dir, batch_size=16)
# 初始化模型和训练器
model = UNetTestModel()




trainer = Trainer(max_epochs=120, 
#                      fast_dev_run=True, 
                  logger=wandb_logger, 
                  callbacks=[lr_monitor, checkpoint_callback, ValidationCallback()],
#                   callbacks=[lr_monitor],
                  log_every_n_steps=1,
                  check_val_every_n_epoch=1,
#                   precision='16-mixed',
                  )

# 使用更新后的学习率继续训练
trainer.fit(model, datamodule=data_module)

#  初始化数据模块
data_module = MOADataModule(data_dir='/kaggle/input/aug-dataset-for-fine-tune/AUG_dataset', batch_size=16)
# 设置数据模块（准备数据）
data_module.setup()
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ckpt.to(device)
model_ckpt.eval()  # 设置为评估模式

# 获取测试数据
test_loader = data_module.test_dataloader()
special_classes = [4, 5, 10, 12, 13]
# 取一个批次的数据进行演示
for images, masks in test_loader:
    images, masks = images.to(device), masks.to(device)
    with torch.no_grad():
        predictions = model_ckpt(images)  # 进行预测
    predictions = torch.argmax(predictions, dim=1)  # 转换成类别标签
    break  # 这里我们只处理一个批次作为示例

# 将数据转移到CPU并转换为numpy
images = images.cpu().numpy()
masks = masks.cpu().numpy()
predictions = predictions.cpu().numpy()

# 显示图像、真实掩码和预测掩码
fig, axs = plt.subplots(len(images), 3, figsize=(20, 5 * len(images)))  # 根据批次大小设置子图
class_labels = np.unique(masks)  # 获取类别标签
cmap = plt.get_cmap('tab20')  # 获取颜色映射
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]

for i, (img, mask, pred) in enumerate(zip(images, masks, predictions)):
    axs[i, 0].imshow(img[0], cmap='gray')  # 假设图片是单通道的
    axs[i, 0].set_title('Original Image')
    axs[i, 0].axis('off')

    axs[i, 1].imshow(mask, cmap=cmap)  # 假设掩码是单通道的
    axs[i, 1].set_title('True Mask')
    axs[i, 1].axis('off')

    axs[i, 2].imshow(pred, cmap=cmap)
    axs[i, 2].set_title('Predicted Mask')
    axs[i, 2].axis('off')
    
    draw_bounding_boxes(axs[i, 1], mask, special_classes)
    draw_bounding_boxes(axs[i, 2], mask, special_classes)

# 创建图例
legend_elements = [Patch(facecolor=colors[i], label=f'Class {class_labels[i]}') for i in range(len(class_labels))]
fig.legend(handles=legend_elements, loc='upper center', bbox_to_anchor=(0.5, 1), ncol=len(class_labels), title="Classes")

plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # 调整子图布局，留出底部空间给图例
plt.show()
wandb.log({"Predicted Images After Continual Train": wandb.Image(fig)})


In [None]:
wandb.finish()