In [1]:
import os 
import timm
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from lightly.data.dataset import LightlyDataset



In [2]:
from multissl.models import MSRGBInstanceModule
from multissl.plotting.rgb_batch import rgb_visualize_batch
from multissl.data.instance_segmentation_dataset import COCOInstanceSegmentationDataset, get_instance_transforms,instance_segmentation_collate_fn

# Segmentation Head Loading

In [3]:

args = {"checkpoint_path":"../checkpoints_convnext_tiny/last.ckpt",
    "num_classes": 2,
    "class_names": ["Background", "Solarpanel"],
    "freeze_backbone": False,
    "batch_size": 8,
    "img_size": 448,

       }
# pretrained tiny has hierarchical fusion: at every layer MS +RGB is fused with attention

pl_model =  MSRGBInstanceModule(
        num_classes=args["num_classes"],  # Binary segmentation (background, foreground)
        rgb_in_channels=3,
        ms_in_channels=5,  # Adjust based on your MS data
        model_size='tiny',  # Can be 'tiny', 'small', 'base', 'large'
        fusion_strategy='hierarchical',  # 'early', 'late', 'hierarchical', 'progressive'
        fusion_type='attention',  # 'concat', 'add', 'attention'
        lr=1e-4,
        weight_decay=1e-4,
        pretrained_backbone=args["checkpoint_path"],  # Path to pretrained weights if available
        freeze_backbone = args["freeze_backbone"]
    )

Loading checkpoint from ../checkpoints_convnext_tiny/last.ckpt
Unexpected keys: ['projection_head.layers.0.weight', 'projection_head.layers.1.weight', 'projection_head.layers.1.bias', 'projection_head.layers.1.running_mean', 'projection_head.layers.1.running_var', 'projection_head.layers.1.num_batches_tracked', 'projection_head.layers.3.weight', 'projection_head.layers.4.weight', 'projection_head.layers.4.bias', 'projection_head.layers.4.running_mean', 'projection_head.layers.4.running_var', 'projection_head.layers.4.num_batches_tracked', 'projection_head.layers.6.weight', 'projection_head.layers.7.running_mean', 'projection_head.layers.7.running_var', 'projection_head.layers.7.num_batches_tracked', 'prediction_head.layers.0.weight', 'prediction_head.layers.1.weight', 'prediction_head.layers.1.bias', 'prediction_head.layers.1.running_mean', 'prediction_head.layers.1.running_var', 'prediction_head.layers.1.num_batches_tracked', 'prediction_head.layers.3.weight', 'prediction_head.layers.

In [4]:
image_path = "../dataset/solarcoco/imgs"
mask_path = "../dataset/solarcoco/annotations/frame_000003.JSON"
instance_path = "../dataset/solarcoco/instance"


# Create dataset
dataset = COCOInstanceSegmentationDataset(
   coco_json_path = mask_path,
    img_dir = image_path,
    instance_dir = instance_path,
    transform=get_instance_transforms(img_size=args["img_size"], augment=True)

)

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


In [5]:

import torch
from torch.utils.data import Sampler
import numpy as np


class RepeatingBatchSampler(Sampler):
    """
    Sampler that repeats dataset indices to ensure each batch contains batch_size items.
    Useful when working with small datasets or when you want to apply heavy augmentation
    to a small set of samples.
    """
    def __init__(self, dataset_size, batch_size, shuffle=True):
        """
        Args:
            dataset_size: Number of samples in the original dataset
            batch_size: Desired batch size 
            shuffle: Whether to shuffle the data or access sequentially
        """
        self.dataset_size = dataset_size
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        # Number of copies needed for each dataset item
        self.copies_per_item = max(1, batch_size // dataset_size)
        
        # Extra samples needed beyond perfect division
        self.extra_samples = batch_size - (dataset_size * self.copies_per_item)
        if self.extra_samples < 0:
            self.extra_samples = 0
        
        # Total number of indices we'll generate
        self.total_samples = dataset_size * self.copies_per_item + self.extra_samples
        
    def __iter__(self):
        # Create base indices
        if self.shuffle:
            # For each "epoch", we shuffle the dataset order
            base_indices = torch.randperm(self.dataset_size).tolist()
        else:
            base_indices = list(range(self.dataset_size))
        
        # Repeat each index the required number of times
        repeated_indices = []
        for idx in base_indices:
            repeated_indices.extend([idx] * self.copies_per_item)
        
        # Add extra samples if needed to exactly reach batch_size
        if self.extra_samples > 0:
            extra_indices = base_indices[:self.extra_samples]
            repeated_indices.extend(extra_indices)
        
        # Shuffle the final indices to mix different samples
        if self.shuffle:
            np.random.shuffle(repeated_indices)
        
        return iter(repeated_indices)
    
    def __len__(self):
        return self.total_samples

In [6]:


import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from matplotlib.colors import ListedColormap
import random

def plot_batch_item(batch_item, figsize=(15, 5), show_masks=True, show_boxes=True, 
                   mask_alpha=0.3, box_linewidth=4):
    """
    Plot a single item from a batch containing:
    - rgb: RGB image tensor (3, H, W)
    - boxes: List of bounding boxes in format [confidence, cx, cy, w, h] (normalized)
    - instance_masks: List of instance segmentation masks
    - instance_classes: Class labels for each instance
    - mask: Semantic segmentation mask
    
    Args:
        batch_item: Dictionary with keys 'rgb', 'boxes', 'instance_masks', 'instance_classes', 'mask'
        figsize: Figure size tuple
        show_masks: Whether to show instance masks
        show_boxes: Whether to show bounding boxes
        mask_alpha: Transparency for mask overlay
        box_linewidth: Line width for bounding boxes
    """
    
    # Extract data
    rgb_tensor = batch_item['rgb']
    boxes = batch_item['boxes']
    instance_masks = batch_item['instance_masks']
    instance_classes = batch_item['instance_classes']
    semantic_mask = batch_item['mask']
    
    # Convert RGB tensor to numpy (H, W, 3)
    if isinstance(rgb_tensor, torch.Tensor):
        rgb_img = rgb_tensor.permute(1, 2, 0).numpy()
    else:
        rgb_img = rgb_tensor
    
    # Ensure RGB values are in [0, 1] range
    rgb_img = np.clip(rgb_img, 0, 1)
    
    # Get image dimensions
    height, width = rgb_img.shape[:2]
    
    # Create subplots
    n_plots = 3 if show_masks else 2
    fig, axes = plt.subplots(1, n_plots, figsize=figsize)
    if n_plots == 1:
        axes = [axes]
    
    # Plot 1: Original RGB image with bounding boxes
    axes[0].imshow(rgb_img)
    axes[0].set_title('RGB Image with Bounding Boxes')
    axes[0].axis('off')
    
    if show_boxes:
        # Generate colors for different instances
        colors = plt.cm.Set3(np.linspace(0, 1, len(boxes)))
        
        for i, (box, class_id) in enumerate(zip(boxes, instance_classes)):
            if isinstance(box, torch.Tensor):
                box = box.numpy()
            if isinstance(class_id, torch.Tensor):
                class_id = class_id.item()
            
            # Convert from (confidence, cx, cy, w, h) to pixel coordinates
            confidence = box[0]
            cx, cy, w, h = box[1:]
            
            # Convert normalized coordinates to pixel coordinates
            x_center = cx * width
            y_center = cy * height
            box_width = w * width
            box_height = h * height
            
            # Convert center coordinates to top-left coordinates
            x_left = x_center - box_width / 2
            y_top = y_center - box_height / 2
            
            # Create rectangle patch
            rect = patches.Rectangle(
                (x_left, y_top), box_width, box_height,
                linewidth=box_linewidth, edgecolor=colors[i], 
                facecolor='none', alpha=0.8
            )
            axes[0].add_patch(rect)
            
            # Add label
            axes[0].text(x_left, y_top - 5, f'Class {class_id} ({confidence:.2f})',
                        color=colors[i], fontsize=10, fontweight='bold',
                        bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7))
    
    # Plot 2: Semantic segmentation mask
    axes[1].imshow(semantic_mask, cmap='viridis')
    axes[1].set_title('Semantic Segmentation Mask')
    axes[1].axis('off')
    
    # Plot 3: Instance masks overlay (if requested)
    if show_masks and len(axes) > 2:
        # Create overlay image
        overlay_img = rgb_img.copy()
        
        # Generate colors for different instances
        colors = plt.cm.Set3(np.linspace(0, 1, len(instance_masks)))
        
        for i, mask in enumerate(instance_masks):
            mask = mask==1
            if isinstance(mask, torch.Tensor):
                mask = mask.numpy()
            
            # Create colored mask
            color_mask = np.zeros_like(rgb_img)
            color_mask[mask] = colors[i][:3]  # Use RGB components
            
            # Blend with original image
            mask_area = mask[..., np.newaxis]
            overlay_img = overlay_img * (1 - mask_area * mask_alpha) + color_mask * mask_area * mask_alpha
        
        axes[2].imshow(np.clip(overlay_img, 0, 1))
        axes[2].set_title('RGB Image with Instance Masks')
        axes[2].axis('off')
    
    plt.tight_layout()
    return fig, axes

def plot_batch(batch, max_items=4, **kwargs):
    """
    Plot multiple items from a batch.
    
    Args:
        batch: List of batch items or single batch item
        max_items: Maximum number of items to plot
        **kwargs: Additional arguments passed to plot_batch_item
    """
    if isinstance(batch, dict):
        # Single item
        return plot_batch_item(batch, **kwargs)
    
    # Multiple items
    n_items = min(len(batch), max_items)
    
    for i in range(n_items):
        print(f"\nPlotting batch item {i}:")
        plot_batch_item(batch[i], **kwargs)
        plt.show()


In [7]:
sampler = RepeatingBatchSampler(dataset_size=len(dataset),batch_size=args["batch_size"],
    shuffle=True
)
collate_fn = instance_segmentation_collate_fn
# Create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=args["batch_size"], 
    num_workers=0,
    # Use 0 for single image to avoid overhead,
    sampler = sampler,
    collate_fn = collate_fn,
    drop_last= True


)

In [8]:
# Create progress bar callback
from pytorch_lightning.callbacks import RichProgressBar
class LossProgressBar(RichProgressBar):
    def __init__(self):
        super().__init__()
        self.losses = []
        
    def on_train_epoch_end(self, trainer, pl_module):
        super().on_train_epoch_end(trainer, pl_module)
        loss = float(trainer.callback_metrics.get('train_total_loss', 0))
        self.losses.append(loss)
        
progress_bar = LossProgressBar()

# Create model checkpoint callback
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='train_total_loss',
    filename='pasiphae-upernet-{epoch:02d}-{train_total_loss:.4f}',
    save_top_k=1,
    mode='min'
)

# Early stopping callback
early_stop_callback = pl.callbacks.EarlyStopping(
    monitor='train_total_loss',
    patience=50,
    mode='min'
)

# Create trainer
trainer = pl.Trainer(
    max_epochs=10,
    devices=1,
    callbacks=[progress_bar, checkpoint_callback, early_stop_callback],
    logger=True,
    log_every_n_steps=1,
    accelerator = "cuda"
)

# Train model
print("Starting training...")
trainer.fit(pl_model, dataloader)

print("Training complete!")




GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
C:\Users\judoj\mambaforge\envs\lightly\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
C:\Users\judoj\mambaforge\envs\lightly\lib\site-packages\pytorch_lightning\trainer\configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Starting training...


Output()

C:\Users\judoj\mambaforge\envs\lightly\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


AttributeError: 'list' object has no attribute 'device'

In [None]:
# Plot loss curve
losses = progress_bar.losses
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()


In [None]:
# Apply to dataset:
# Create train/val/test datasets
rgb_transform = transforms.Compose([
    transforms.Resize((args["img_size"],args["img_size"] ), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
])


solar_folder = "../dataset/pvhawk" 

non_labeled = LightlyDataset(input_dir = "../dataset/pvhawk",transform=rgb_transform)
non_labeled = torch.utils.data.DataLoader(non_labeled)

In [None]:
# Process all images
from tqdm import tqdm

device='cuda'
pl_model.to(device)
all_results = []

pl_model.eval()

for rgb, idx, img_name in tqdm(iter(non_labeled)):
    rgb = rgb.to(device)
    
    with torch.no_grad():
        outputs = pl_model(rgb  = rgb)
        # Get semantic segmentation prediction
    pred_mask = torch.argmax(outputs["sem_logits"], dim=1).cpu().numpy()
    outputs =   pl_model.predict_instances(
                outputs, 
                confidence_threshold=0.5
            )

    # Get instance masks if available
    if 'instance_masks' in outputs:

        # If mask_coeffs are already in mask format
        instance_masks = [output.cpu().numpy() for output in outputs['instance_masks']]
        
    else:
        instance_masks = None
    
    # Get instance boxes if available
    instance_boxes = outputs.get('boxes', None)
    if instance_boxes is not None:
        instance_boxes = instance_boxes.cpu().numpy()
    
    # Get instance classes if available
    instance_classes = outputs.get('cls_scores', None)
    if instance_classes is not None:
        # Get the predicted class for each instance
        if instance_classes.dim() > 2:  # If shape is [B, N, C]
            instance_class_ids = torch.argmax(instance_classes, dim=2).cpu().numpy()
        else:
            instance_class_ids = torch.argmax(instance_classes, dim=1).cpu().numpy()
    else:
        instance_class_ids = None
    
    # Get center heatmap if available
    if 'center_heatmap' in outputs:
        center_heatmap = outputs['center_heatmap'].cpu().numpy()
    else:
        center_heatmap = None
    
    # Convert tensor back to PIL image for visualization
    orig_img = rgb.cpu().squeeze(0)
    # Denormalize if needed
    orig_img = torch.clamp(orig_img, 0, 1)
    
    # Store all results together
    all_results.append({
        'image': orig_img.permute(1, 2, 0).numpy(),  # Convert to HWC format for visualization
        'mask': pred_mask[0],  # Semantic segmentation mask
        'instance_masks': instance_masks[0] if instance_masks is not None else None,  # Instance masks
        'instance_boxes': instance_boxes[0] if instance_boxes is not None else None,  # Bounding boxes
        'instance_classes': instance_class_ids[0] if instance_class_ids is not None else None,  # Class IDs
        'center_heatmap': center_heatmap[0] if center_heatmap is not None else None,  # Center points
        'filename': img_name[0] if isinstance(img_name, list) else img_name
    })

In [None]:
num_samples=10
import random
# Sample random images
if len(all_results) > num_samples:
    random_samples = random.sample(all_results, num_samples)
else:
    random_samples = all_results

In [None]:
# Visualize the results
def visualize_segmentation(samples, save_path='segmentation_results.png'):
    """
    Visualize segmentation results including original image, semantic segmentation mask,
    instance masks, and bounding boxes.
    
    Args:
        samples: List of dictionaries containing results for each sample
        save_path: Path to save the visualization
    """
    # Determine the number of columns - add columns for instance segmentation if available
    num_cols = 2  # Default: Original image + semantic mask
    
    # Check if we have instance masks
    has_instances = any(sample.get('instance_masks') is not None for sample in samples)
    has_boxes = any(sample.get('instance_boxes') is not None for sample in samples)
    has_centers = any(sample.get('center_heatmap') is not None for sample in samples)
    
    if has_instances:
        num_cols += 1  # Add column for instance masks
    if has_boxes:
        num_cols += 1  # Add column for box visualization
    if has_centers:
        num_cols += 1  # Add column for center heatmap
    
    # Create figure
    fig, axes = plt.subplots(len(samples), num_cols, figsize=(4*num_cols, 4*len(samples)))
    
    # If only one sample, wrap axes in a list
    if len(samples) == 1:
        axes = [axes] if num_cols == 1 else [axes]
    
    # Create colormap for instances
    from matplotlib.colors import ListedColormap
    import matplotlib as mpl
    
    # Generate distinct colors for instance visualization
    num_colors = 20  # Arbitrary, could be increased
    cmap = plt.cm.get_cmap('tab20', num_colors)
    colors = [cmap(i) for i in range(num_colors)]
    
    # Process each sample
    for i, sample in enumerate(samples):
        col_idx = 0
        
        # Original image
        axes[i][col_idx].imshow(sample['image'])
        #axes[i][col_idx].set_title(f"Original: {os.path.basename(sample['filename'])}")
        axes[i][col_idx].axis('off')
        col_idx += 1
        
        # Semantic segmentation mask
        sem_mask = sample['mask']
        num_classes = len(np.unique(sem_mask))
        
        # Create a more visually distinct colormap for semantic classes
        if num_classes <= 2:  # Binary mask
            axes[i][col_idx].imshow(sem_mask, cmap='viridis')
        else:  # Multi-class mask
            axes[i][col_idx].imshow(sem_mask, cmap='nipy_spectral')
            
        axes[i][col_idx].set_title(f"Semantic Mask")
        axes[i][col_idx].axis('off')
        col_idx += 1
        
        # Instance segmentation visualization
        if has_instances and sample.get('instance_masks') is not None:
            instance_masks = sample['instance_masks']
            num_instances = instance_masks.shape[0]
            
            # Create a visualization with each instance in a different color
            instance_vis = np.zeros((*sem_mask.shape, 3), dtype=np.float32)
            
            for inst_idx in range(num_instances):
                # Get color for this instance
                color = colors[inst_idx % num_colors][:3]
                
                # Create mask for this instance (threshold at 0.5 if not binary)
                mask = instance_masks[inst_idx] > 0.5 if isinstance(instance_masks[inst_idx].max(), (int, float)) else instance_masks[inst_idx] > 0
                
                # Skip empty masks
                if not np.any(mask):
                    continue
                
                # Add this instance with its color
                for c in range(3):
                    instance_vis[:, :, c] = np.where(mask, 
                                                   color[c], 
                                                   instance_vis[:, :, c])
            
            # Overlay instances on the image with alpha blending
            overlay = sample['image'].copy()
            alpha = 0.7  # Transparency for the instance masks
            
            # Only apply alpha blending where masks exist
            any_mask = np.any(instance_vis > 0, axis=2)
            for c in range(3):
                overlay[:, :, c] = np.where(any_mask, 
                                          overlay[:, :, c] * (1 - alpha) + instance_vis[:, :, c] * alpha,
                                          overlay[:, :, c])
            
            axes[i][col_idx].imshow(overlay)
            axes[i][col_idx].set_title(f"Instance Masks ({num_instances} instances)")
            axes[i][col_idx].axis('off')
            col_idx += 1
        
        # Bounding box visualization
        if has_boxes and sample.get('instance_boxes') is not None:
            boxes = sample['instance_boxes']
            
            # Create a copy of the original image for box visualization
            box_vis = sample['image'].copy()
            
            # Draw rectangles
            from matplotlib.patches import Rectangle
            
            # Display the image
            axes[i][col_idx].imshow(box_vis)
            
            # Draw each box
            for box_idx, box in enumerate(boxes):
                x1, y1, x2, y2 = box
                
                # If boxes are normalized to 0-1, scale to image dimensions
                if x1 <= 1.0 and y1 <= 1.0 and x2 <= 1.0 and y2 <= 1.0:
                    h, w = box_vis.shape[:2]
                    x1, y1, x2, y2 = x1 * w, y1 * h, x2 * w, y2 * h
                
                # Get color for this box
                color = colors[box_idx % num_colors]
                
                # Create rectangle
                rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                                linewidth=2, edgecolor=color, facecolor='none')
                axes[i][col_idx].add_patch(rect)
                
                # Add class label if available
                if sample.get('instance_classes') is not None:
                    class_id = sample['instance_classes'][box_idx]
                    axes[i][col_idx].text(x1, y1-5, f"Class {class_id}", 
                                        color=color, fontsize=10, weight='bold')
            
            axes[i][col_idx].set_title(f"Bounding Boxes")
            axes[i][col_idx].axis('off')
            col_idx += 1
        
        # Center heatmap visualization
        if has_centers and sample.get('center_heatmap') is not None:
            heatmap = sample['center_heatmap']
            
            # Handle different heatmap formats
            if heatmap.ndim > 2:
                if heatmap.shape[0] == 1:  # [1, H, W]
                    heatmap = heatmap[0]
                elif heatmap.ndim == 4 and heatmap.shape[0] == 1 and heatmap.shape[1] == 1:  # [1, 1, H, W]
                    heatmap = heatmap[0, 0]
            
            # Normalize heatmap for visualization
            if heatmap.max() > 0:
                heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
            else:
                heatmap_norm = heatmap
            
            # Show heatmap with a red-yellow colormap
            axes[i][col_idx].imshow(heatmap_norm, cmap='hot')
            axes[i][col_idx].set_title("Center Heatmap")
            axes[i][col_idx].axis('off')
            col_idx += 1
    
    plt.tight_layout()
    ##plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Visualization saved to {save_path}")

visualize_segmentation(random_samples)

In [None]:


output_path = "model_trained_on_single_image.pth"
torch.save(model.state_dict(), output_path)
print(f"Model saved to {output_path}")
