In [1]:
import os 
import timm
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

In [2]:
from multissl.models import MSRGBConvNeXtUPerNetMixed
from multissl.data.seg_transforms import JointTransform, ValidationJointTransform
from multissl.data.semantic_partial_dataset import MixedSupervisionSegmentationDataset, mixed_supervision_collate_fn

In [8]:

args = {"checkpoint_path":"../checkpoints_convnext_tiny/last.ckpt",
        "num_classes": 62, #61+ 0 background
        "freeze_backbone": True,
        "batch_size": 4,
        "img_size": 512,
        "model_size": "tiny",
        "rgb_in_channels": 3, #pretrained backbone setting
        "ms_in_channels": 5,#pretrained backbone setting
        "model_size":"tiny", # Can be 'tiny', 'small', 'base', 'large' #pretrained backbone setting
        "fusion_strategy": "hierarchical", # 'early', 'late', 'hierarchical', 'progressive' #pretrained backbone setting
        "fusion_type": 'attention',  # 'concat', 'add', 'attention' #pretrained backbone setting
        "learning_rate": 1e3,
        "weight_decay": 1e4,

        "fully_labeled_dir_train" : "../dataset/TreeAI/12_RGB_SemSegm_640_fL/train",
        "partially_labeled_dir_train": "../dataset/TreeAI/34_RGB_SemSegm_640_pL/train",
        "fully_labeled_dir_val" : "../dataset/TreeAI/12_RGB_SemSegm_640_fL/val",
        "partially_labeled_dir_val" :  "../dataset/TreeAI/34_RGB_SemSegm_640_pL/val",
        "ignore_index" : 255, # probably just background class too :(
        "balance_supervision" :  True,
        "partial_label_ratio":  0.5  # W"
       }
# pretrained tiny has hierarchical fusion: at every layer MS +RGB is fused with attention

pl_model =  MSRGBConvNeXtUPerNetMixed(
        num_classes=args["num_classes"],  # Binary segmentation (background, foreground)
        rgb_in_channels=args["rgb_in_channels"],
        ms_in_channels=args["ms_in_channels"],
        model_size=args['model_size'], 
        fusion_strategy=args['fusion_strategy'], 
        fusion_type=args['fusion_type'],  # 'concat', 'add', 'attention'
        learning_rate=args["learning_rate"],
        weight_decay=args["weight_decay"],
        pretrained_backbone=args["checkpoint_path"],  # Path to pretrained weights if available
        freeze_backbone = args["freeze_backbone"]
    )

Loading checkpoint from ../checkpoints_convnext_tiny/last.ckpt
Unexpected keys: ['projection_head.layers.0.weight', 'projection_head.layers.1.weight', 'projection_head.layers.1.bias', 'projection_head.layers.1.running_mean', 'projection_head.layers.1.running_var', 'projection_head.layers.1.num_batches_tracked', 'projection_head.layers.3.weight', 'projection_head.layers.4.weight', 'projection_head.layers.4.bias', 'projection_head.layers.4.running_mean', 'projection_head.layers.4.running_var', 'projection_head.layers.4.num_batches_tracked', 'projection_head.layers.6.weight', 'projection_head.layers.7.running_mean', 'projection_head.layers.7.running_var', 'projection_head.layers.7.num_batches_tracked', 'prediction_head.layers.0.weight', 'prediction_head.layers.1.weight', 'prediction_head.layers.1.bias', 'prediction_head.layers.1.running_mean', 'prediction_head.layers.1.running_var', 'prediction_head.layers.1.num_batches_tracked', 'prediction_head.layers.3.weight', 'prediction_head.layers.

In [9]:
train_transform = JointTransform(img_size = args["img_size"], strong=True)
val_transform = ValidationJointTransform(img_size = 640)

treeai_dataset_train = MixedSupervisionSegmentationDataset(
        fully_labeled_dir = args["fully_labeled_dir_train"],
        partially_labeled_dir = args["partially_labeled_dir_train"],
        img_size =  args["img_size"],
        transform = train_transform,
        ignore_index= args["ignore_index"],
        num_classes =  args["num_classes"],
        balance_supervision = args["balance_supervision"],
        partial_label_ratio = args["partial_label_ratio"] # W
    
)

treeai_dataset_val = MixedSupervisionSegmentationDataset(
        fully_labeled_dir = args["fully_labeled_dir_val"],
        partially_labeled_dir = args["partially_labeled_dir_val"],
        img_size = 640,
        transform = val_transform,
        ignore_index= args["ignore_index"],
        num_classes =  args["num_classes"],
        balance_supervision = args["balance_supervision"],
        partial_label_ratio = args["partial_label_ratio"]
)
train_loader = DataLoader(treeai_dataset_train,
                         collate_fn =mixed_supervision_collate_fn )
val_loader = DataLoader(treeai_dataset_val,
                       collate_fn =mixed_supervision_collate_fn )

Found 2218 fully labeled samples
Found 3126 partially labeled samples
Final dataset size: 5344
Dataset composition:
  - Fully supervised: 2672 (50.0%)
  - Partially supervised: 2672 (50.0%)
Found 634 fully labeled samples
Found 892 partially labeled samples
Final dataset size: 1526
Dataset composition:
  - Fully supervised: 763 (50.0%)
  - Partially supervised: 763 (50.0%)


In [None]:
# Create callbacks
from pytorch_lightning.callbacks import RichProgressBar, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger


checkpoint_callback = ModelCheckpoint(
    dirpath="treeai_checkpoints",
    filename="",
    save_top_k=3,
    monitor"val_loss",
    mode="min",
    save_last=True
)

lr_monitor = LearningRateMonitor(logging_interval="epoch")
# Create logger (you can use either WandbLogger or TensorBoardLogger)
# Comment out if you don't want to use wandb

wandb_logger = WandbLogger(project="TreeAI-Segmentation", log_model=False)

# Create trainer
trainer = pl.Trainer(
    max_epochs=args["epochs"],
    accelerator="cuda",  # Uses GPU if available
    devices=1,
    callbacks=[checkpoint_callback, lr_monitor, RichProgressBar()],
    logger=wandb_logger,  # Comment out if not using wandb
    check_val_every_n_epoch=5  # Only validate every 5 epochs
)
