<a href="https://colab.research.google.com/github/moksima/MemCDTedit/blob/main/GD%2BSAM%2BMat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Cell 1: Install Correct PyTorch Version Compatible with Colab CUDA

!pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
!pip install transformers
!pip install --upgrade pillow
!pip install opencv-python
!pip install wget
!pip install addict numpy pycocotools supervision timm yapf albumentations tqdm
!pip install -U albumentations
!pip install segment-anything


Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113
Collecting wget
  Using cached wget-3.2.zip (10 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9656 sha256=5dbe5e7dc8a731497a28d3ff1c1460fdad7e7c1634a1f0c961dcba2ceffe5281
  Stored in directory: /root/.cache/pip/wheels/8b/f1/7f/5c94f0a7a505ca1c81cd1d9208ae2064675d97582078e6c769
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
Collecting albumentations
  Downloading albumentations-1.4.21-py3-none-any.whl.metadata (31 kB)
Collecting albucore==0.0.20 (from albumentations)
  Downloading albucore-0.0.20-py3-none-any.whl.metadata (5.3 kB)
Collecting simsimd>=5.9.2 (from albucore==0.0.20->albumentations)
  Downloading simsimd-5.9.9-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (54 kB)
[2K  

In [3]:
# Cell 2: Import Necessary Libraries

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import random
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np
from tqdm import tqdm
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import torchvision.models as models
from google.colab import drive
from segment_anything import sam_model_registry, SamPredictor
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from transformers import BertTokenizer
import seaborn as sns
import warnings

# Suppress warnings (optional)
warnings.filterwarnings("ignore", category=FutureWarning)


ModuleNotFoundError: No module named 'groundingdino'

In [None]:
# Cell 3: Import google Drive

drive.mount('/content/drive', force_remount=True)


In [None]:
# Cell 4: Install Hugging Face Hub and Autenticiate

!pip install huggingface_hub

from huggingface_hub import login, hf_hub_download
from getpass import getpass
import os

# Securely input your Hugging Face Token
hf_token = getpass("Enter your Hugging Face API Token: ")
login(token=hf_token)

# Optionally, set the token as an environment variable
os.environ['HF_TOKEN'] = hf_token



In [None]:
# New 4b: Verify CUDA Availability

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Device count: {torch.cuda.device_count()}")

# New Cell: Check GPU Availability

import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")


In [None]:
# Cell 5: Define Custom Dataset for Fine-Tuning

class MaterialSegmentationDataset(Dataset):
    def __init__(self, list_file, images_dirs, masks_dirs, transforms=None,
                 image_extensions=['.jpg', '.jpeg', '.png'],
                 mask_extensions=['.npy', '.png', '.jpg', '.jpeg']):
        """
        Args:
            list_file (str): Path to the txt file containing filenames for the split.
            images_dirs (list): List of directories containing images.
            masks_dirs (list): List of directories containing masks.
            transforms (albumentations.Compose, optional): Transformations to apply.
            image_extensions (list): List of acceptable image file extensions.
            mask_extensions (list): List of acceptable mask file extensions.
        """
        self.transforms = transforms
        self.images_dirs = images_dirs
        self.masks_dirs = masks_dirs
        self.image_extensions = image_extensions
        self.mask_extensions = mask_extensions

        # Read the list of filenames
        with open(list_file, 'r') as f:
            self.filenames = [line.strip() for line in f.readlines()]

        # Create a mapping from filename to image path
        self.image_paths = []
        missing_images = []
        for fname in self.filenames:
            found = False
            for img_dir in self.images_dirs:
                for ext in self.image_extensions:
                    potential_path = os.path.join(img_dir, fname + ext)
                    if os.path.isfile(potential_path):
                        self.image_paths.append(potential_path)
                        found = True
                        break  # Stop searching extensions after finding the image
                if found:
                    break  # Stop searching directories after finding the image
            if not found:
                missing_images.append(fname)

        if missing_images:
            print(f"Total missing images: {len(missing_images)}")
            for msg in missing_images[:5]:  # Print first 5 missing images
                print(f"Missing image: {msg}")

        # Create a mapping from filename to mask path
        self.mask_paths = []
        missing_masks = []
        for fname in self.filenames:
            found = False
            base_name = os.path.splitext(fname)[0]
            for mask_dir in self.masks_dirs:
                for ext in self.mask_extensions:
                    potential_mask_path = os.path.join(mask_dir, fname + ext)
                    if os.path.isfile(potential_mask_path):
                        self.mask_paths.append(potential_mask_path)
                        found = True
                        break  # Stop searching extensions after finding the mask
                if found:
                    break  # Stop searching directories after finding the mask
            if not found:
                missing_masks.append(fname)

        if missing_masks:
            print(f"Total missing masks: {len(missing_masks)}")
            for msg in missing_masks[:5]:  # Print first 5 missing masks
                print(f"Missing mask: {msg}")

        # Filter out any image-mask pairs where either is missing
        min_length = min(len(self.image_paths), len(self.mask_paths))
        if len(self.image_paths) != len(self.mask_paths):
            print(f"Mismatched image and mask counts. Using first {min_length} pairs.")
            self.image_paths = self.image_paths[:min_length]
            self.mask_paths = self.mask_paths[:min_length]

        assert len(self.image_paths) == len(self.mask_paths), "Mismatch between images and masks."

        # Define your grayscale to class index mapping
        self.grayscale_to_class = {
            0: 0,     # Background
            50: 1,    # Metal object
            100: 2,   # Wooden door
            150: 3,   # Concrete wall
            200: 4,   # Glass window
            250: 5,   # Soil
            255: 6,   # Asphalt (or appropriate class)
            # Add mappings for other classes as needed
        }

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)

        # Load mask
        mask_path = self.mask_paths[idx]
        if mask_path.endswith('.npy'):
            mask = np.load(mask_path)
            # Ensure mask is in the correct format (e.g., single channel)
            if mask.ndim == 3:
                mask = mask[:, :, 0]
        elif mask_path.endswith(('.png', '.jpg', '.jpeg')):
            mask = Image.open(mask_path).convert("L")  # Convert to grayscale
            mask = np.array(mask)
        else:
            raise ValueError(f"Unsupported mask format: {mask_path}")

        # Map mask values to class indices
        mask = self.map_mask(mask, self.grayscale_to_class)

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is of type torch.LongTensor
        if isinstance(mask, np.ndarray):
            mask = torch.from_numpy(mask).long()
        elif isinstance(mask, torch.Tensor):
            mask = mask.long()
        else:
            raise TypeError(f"Unsupported mask type: {type(mask)}")

        return image, mask

    def map_mask(self, mask, mapping):
        mapped_mask = np.full_like(mask, fill_value=-1, dtype=np.int64)  # Initialize with -1
        for grayscale_value, class_index in mapping.items():
            mapped_mask[mask == grayscale_value] = class_index

        # Check for unmapped values
        unmapped_pixels = (mapped_mask == -1)
        if np.any(unmapped_pixels):
            unique_unmapped_values = np.unique(mask[unmapped_pixels])
            print(f"Warning: Found unmapped grayscale values in mask: {unique_unmapped_values}")
            # Assign unmapped pixels to background class or any other class
            mapped_mask[unmapped_pixels] = 0  # Assign to background class

        return mapped_mask


In [None]:
# Cell 6: Define Data Transformations

# Define transformations for training
train_transforms = A.Compose([
    A.Resize(height=512, width=512),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

# Define transformations for validation and testing
val_transforms = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
], additional_targets={'mask': 'mask'})


In [None]:
# Cell 7: Define Directory Paths and List Files

# Define directories containing images
images_dirs = [
    "/content/drive/MyDrive/multimodal_dataset/GT",
    "/content/drive/MyDrive/multimodal_dataset/polL_color"
    # Add more image directories if applicable
]

# Define directories containing masks
masks_dirs = [
    "/content/drive/MyDrive/multimodal_dataset/NIR_warped_mask",
    "/content/drive/MyDrive/multimodal_dataset/polL_aolp_cos",
    "/content/drive/MyDrive/multimodal_dataset/polL_aolp_sin",
    "/content/drive/MyDrive/multimodal_dataset/polL_dolp"
    # Add more mask directories if applicable
]

# Paths to list files
train_list = "/content/drive/MyDrive/multimodal_dataset/list_folder/train.txt"
val_list = "/content/drive/MyDrive/multimodal_dataset/list_folder/val.txt"
test_list = "/content/drive/MyDrive/multimodal_dataset/list_folder/test.txt"


In [None]:
# Cell 8: Create Dataset Instances

# Use the transformations defined in Cell 6

# Create Dataset instances
train_dataset = MaterialSegmentationDataset(
    list_file=train_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=train_transforms
)

val_dataset = MaterialSegmentationDataset(
    list_file=val_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=val_transforms
)

test_dataset = MaterialSegmentationDataset(
    list_file=test_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=val_transforms  # Typically, no augmentation for test
)

# Verify Dataset Lengths
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")

# Visualize a Batch from the Training Loader
def visualize_batch(images, masks, batch_size=4):
    images = images.permute(0, 2, 3, 1).cpu().numpy()
    images = (images * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
    images = np.clip(images, 0, 1)

    masks = masks.cpu().numpy()

    plt.figure(figsize=(20, 10))
    for i in range(batch_size):
        plt.subplot(2, batch_size, i+1)
        plt.imshow(images[i])
        plt.title("Image")
        plt.axis('off')

        plt.subplot(2, batch_size, batch_size + i + 1)
        plt.imshow(masks[i], cmap='gray')
        plt.title("Mask")
        plt.axis('off')
    plt.show()

# Get a batch from the training loader
for images, masks in DataLoader(train_dataset, batch_size=4):
    visualize_batch(images, masks, batch_size=4)
    break  # Only visualize one batch


In [None]:
# Cell 9: Define DataLoaders

batch_size = 8
num_workers = 4  # Adjust based on your environment

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=False  # Temporarily set to False for debugging
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=False  # Temporarily set to False for debugging
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=False  # Temporarily set to False for debugging
)


In [None]:
# Cell 10: Verify Dataset and DataLoaders

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")

# Function to visualize image and mask
def visualize_batch(images, masks, batch_size=4):
    images = images.permute(0, 2, 3, 1).cpu().numpy()
    images = (images * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
    images = np.clip(images, 0, 1)

    masks = masks.cpu().numpy()

    plt.figure(figsize=(20, 10))
    for i in range(batch_size):
        plt.subplot(2, batch_size, i+1)
        plt.imshow(images[i])
        plt.title("Image")
        plt.axis('off')

        plt.subplot(2, batch_size, batch_size + i + 1)
        plt.imshow(masks[i], cmap='gray')
        plt.title("Mask")
        plt.axis('off')
    plt.show()

# Check mask values function
def check_mask_values(dataloader, num_classes):
    invalid_values = False
    for images, masks in dataloader:
        masks_np = masks.cpu().numpy()
        min_value = masks_np.min()
        max_value = masks_np.max()
        if min_value < 0 or max_value >= num_classes:
            print(f"Invalid mask values found. Min value: {min_value}, Max value: {max_value}")
            invalid_values = True
            break
    if not invalid_values:
        print("All mask values are within the valid range.")

# Define number of classes (update based on your mapping)
num_classes = 7  # Updated to match your grayscale_to_class mapping

# Check mask values in training and validation loaders
print("Checking mask values in training loader:")
check_mask_values(train_loader, num_classes)

print("Checking mask values in validation loader:")
check_mask_values(val_loader, num_classes)

# Get a batch from the training loader and visualize
for images, masks in train_loader:
    visualize_batch(images, masks, batch_size=4)
    break  # Only visualize one batch


In [None]:
# Cell 11: Define and Initialize the Segmentation Model

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Verify CUDA device
if device.type == 'cuda':
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

# Using DeepLabV3 with a ResNet backbone
from torchvision import models
model = models.segmentation.deeplabv3_resnet50(weights='DEFAULT')
num_classes = 7  # Updated to match your mapping

# Modify the classifier to match the number of classes
# Wrap in try-except to catch any errors
try:
    model.classifier = models.segmentation.deeplabv3.DeepLabHead(2048, num_classes)
    print("Model classifier modified successfully.")
except Exception as e:
    print(f"Error modifying model classifier: {e}")

# Move model to device
try:
    model = model.to(device)
    print("Model moved to device successfully.")
except Exception as e:
    print(f"Error moving model to device: {e}")


In [None]:
# Cell 12: Install Grounding DINO via pip (Replace 'Z' with an appropriate cell number)

!pip install git+https://github.com/IDEA-Research/GroundingDINO.git


In [None]:
# Cell 13a: Verify Mask Values Are Within Valid Range

def check_mask_values(dataloader, num_classes):
    invalid_values = False
    for images, masks in dataloader:
        masks = masks.cpu().numpy()
        min_value = masks.min()
        max_value = masks


In [None]:
# Cell 13b: Define Helper Functions to Load Models

import torch
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from segment_anything import sam_model_registry, SamPredictor
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from transformers import BertTokenizer

def load_grounding_dino(config_path, checkpoint_path, device='cuda'):
    """
    Load the Grounding DINO model from configuration and checkpoint files.
    """
    # Load configuration
    cfg = SLConfig.fromfile(config_path)

    # Build model
    model = build_model(cfg)

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)

    # Move model to device and set to evaluation mode
    model.to(device)
    model.eval()

    return model

def load_sam(sam_checkpoint_path, model_type="vit_b", device='cuda'):
    """
    Load the SAM model and initialize the predictor.
    """
    # Register and load SAM model
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path)
    sam.to(device)
    sam_predictor = SamPredictor(sam)

    return sam_predictor

def load_image(image_path):
    """
    Load an image from the specified path and preprocess it.
    """
    image_pil = Image.open(image_path).convert("RGB")

    preprocess = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])
    image_tensor = preprocess(image_pil)

    return image_pil, image_tensor

def get_grounding_output_updated(model, image, caption, box_threshold, text_threshold, device='cuda'):
    """
    Perform Grounding DINO inference to obtain bounding boxes and associated phrases.

    Args:
        model: The loaded Grounding DINO model.
        image (torch.Tensor): Preprocessed image tensor of shape (C, H, W).
        caption (str): Text prompt containing multiple categories separated by commas.
        box_threshold (float): Confidence threshold to filter bounding boxes.
        text_threshold (float): Confidence threshold for text recognition.
        device (str): Device to perform computation ('cuda' or 'cpu').

    Returns:
        boxes_filt (torch.Tensor): Filtered bounding boxes with shape (num_boxes, 4).
        pred_phrases (list): List of predicted phrases corresponding to each bounding box.
        scores (torch.Tensor): Confidence scores for each bounding box.
    """
    image = image.to(device)
    model = model.to(device)
    model.eval()

    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    with torch.no_grad():
        outputs = model(image[None], text=[caption])  # Changed 'captions' to 'text'

    # Debug: Print output keys
    print(f"Model outputs keys: {outputs.keys()}")

    # Check if 'pred_logits' and 'pred_boxes' are in outputs
    if 'pred_logits' not in outputs or 'pred_boxes' not in outputs:
        print("Model outputs do not contain 'pred_logits' or 'pred_boxes'.")
        return torch.empty((0, 4)).to(device), [], torch.empty((0,)).to(device)

    logits = outputs["pred_logits"].sigmoid()[0]  # Shape: (num_queries, vocab_size)
    boxes = outputs["pred_boxes"][0]              # Shape: (num_queries, 4)

    # Filter boxes with confidence threshold
    logits_max, _ = logits.max(dim=1)
    keep = logits_max > box_threshold
    if keep.sum() == 0:
        print("No boxes above box_threshold.")
        return torch.empty((0, 4)).to(device), [], torch.empty((0,)).to(device)
    logits_filt = logits[keep]
    boxes_filt = boxes[keep]

    # Tokenize caption
    tokenized = tokenizer(caption, return_tensors="pt").to(device)

    # Map logits to phrases
    pred_phrases = []
    scores = []
    for logit in logits_filt:
        # Find tokens with confidence above text_threshold
        token_indices = (logit > text_threshold).nonzero(as_tuple=True)[0]
        if len(token_indices) == 0:
            continue
        tokens = tokenized["input_ids"][0][token_indices]
        phrase = tokenizer.decode(tokens)
        pred_phrases.append(phrase)
        scores.append(logit[token_indices].mean().item())

    if len(scores) == 0:
        print("No scores above text_threshold.")
        return torch.empty((0, 4)).to(device), [], torch.empty((0,)).to(device)

    boxes_filt = boxes_filt[:len(pred_phrases)]
    scores = torch.tensor(scores).to(device)
    return boxes_filt, pred_phrases, scores

def segment_with_sam(image_pil, boxes, predictor):
    """
    Perform segmentation using the SAM model based on bounding boxes.

    Args:
        image_pil (PIL.Image.Image): The input image.
        boxes (torch.Tensor): Bounding boxes detected by Grounding DINO.
        predictor: The SAM predictor object.

    Returns:
        List of NumPy arrays representing segmentation masks.
    """
    if boxes.numel() == 0:
        print("No boxes provided for segmentation.")
        return []

    image_np = np.array(image_pil)
    predictor.set_image(image_np)
    masks = []

    transformed_boxes = predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2])

    for box in transformed_boxes:
        # Perform prediction
        masks_pred, _, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=box.unsqueeze(0).to(device),
            multimask_output=False,
        )

        # Convert mask to NumPy array
        mask_np = masks_pred[0].cpu().numpy()
        masks.append(mask_np)

    return masks

# **Add confirmation print statements**
print("Helper functions defined successfully.")
print(f"Function load_grounding_dino: {load_grounding_dino}")
print(f"Function load_sam: {load_sam}")
print(f"Function load_image: {load_image}")
print(f"Function get_grounding_output_updated: {get_grounding_output_updated}")
print(f"Function segment_with_sam: {segment_with_sam}")


In [None]:
# Cell 14: Update timm to Latest Version to Avoid FutureWarnings

!pip install --upgrade timm


In [None]:
# Cell 15: Verify Helper Functions Defined Correctly

try:
    print(load_grounding_dino)
    print(load_sam)
    print(load_image)
    print(get_grounding_output)
    print(segment_with_sam)
    print("All helper functions are defined correctly.")
except NameError as e:
    print(f"Error: {e}")


In [None]:
# Cell 16: Define Loss Function, Optimizer, and Scheduler

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
learning_rate = 1e-4
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


In [None]:
# Cell 17: Download SAM Checkpoint (Re-execute if necessary)

import os
import wget

sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
sam_checkpoint_dir = "/content/segment-anything/checkpoints"

# Create the directory if it doesn't exist
os.makedirs(sam_checkpoint_dir, exist_ok=True)

# Download the SAM checkpoint
!wget -O {sam_checkpoint_dir}/sam_vit_b_01ec64.pth {sam_checkpoint_url}

# Verify SAM Checkpoint
sam_checkpoint_path = "/content/segment-anything/checkpoints/sam_vit_b_01ec64.pth"
print(f"SAM Checkpoint Exists: {os.path.exists(sam_checkpoint_path)}")

# If not, list the contents of the directory to see available files
if not os.path.exists(sam_checkpoint_path):
    print("Contents of the SAM Checkpoint Directory:")
    print(os.listdir("/content/segment-anything/checkpoints"))


In [None]:
# Cell 18: Clone Grounding DINO Repository

!pip install git+https://github.com/IDEA-Research/GroundingDINO.git

!git clone https://github.com/IDEA-Research/GroundingDINO.git


In [None]:
# Cell 19: Verify Grounding DINO Checkpoint Path and Filename

import os

grounding_dino_checkpoint_dir = "/content/drive/MyDrive/GroundingDINO/checkpoint"
grounding_dino_checkpoint = "/content/drive/MyDrive/GroundingDINO/checkpoint/groundingdino_swint_ogc.pth"

# Check if the checkpoint file exists
checkpoint_exists = os.path.exists(grounding_dino_checkpoint)
print(f"Grounding DINO Checkpoint Exists: {checkpoint_exists}")

# List the contents of the checkpoint directory to verify filenames
if not checkpoint_exists:
    print(f"Contents of {grounding_dino_checkpoint_dir}:")
    print(os.listdir(grounding_dino_checkpoint_dir))


In [None]:
# Cell 32: Monitor GPU Utilization

!nvidia-smi


In [None]:
# Cell 20: Verify SAM Checkpoint Path

sam_checkpoint_path = "/content/segment-anything/checkpoints/sam_vit_b_01ec64.pth"

# Check if the SAM checkpoint exists
sam_checkpoint_exists = os.path.exists(sam_checkpoint_path)
print(f"SAM Checkpoint Exists: {sam_checkpoint_exists}")

# List the contents of the SAM checkpoint directory if missing
if not sam_checkpoint_exists:
    print(f"Contents of /content/segment-anything/checkpoints/:")
    print(os.listdir("/content/segment-anything/checkpoints/"))


In [None]:
# Cell 21: Load Grounding DINO and SAM Models

import os
import torch

# Define paths (update these paths based on your Google Drive structure)
grounding_dino_config = "/content/drive/MyDrive/GroundingDINO/config/GroundingDINO_SwinT_OGC.py"
grounding_dino_checkpoint = "/content/drive/MyDrive/GroundingDINO/checkpoint/groundingdino_swint_ogc.pth"
sam_checkpoint = "/content/segment-anything/checkpoints/sam_vit_b_01ec64.pth"  # Ensure SAM is downloaded here
sam_model_type = "vit_b"  # Options: "vit_b", "vit_l", "vit_h"

# Print the checkpoint paths to confirm
print(f"Grounding DINO Config Path: {grounding_dino_config}")
print(f"Grounding DINO Checkpoint Path: {grounding_dino_checkpoint}")
print(f"SAM Checkpoint Path: {sam_checkpoint}")

# Check if Grounding DINO checkpoint exists
if not os.path.exists(grounding_dino_checkpoint):
    print("❌ Grounding DINO checkpoint file not found. Please ensure it is downloaded correctly.")
else:
    # Determine device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load Grounding DINO model
    print("Loading Grounding DINO model...")
    grounding_dino_model = load_grounding_dino(
        config_path=grounding_dino_config,
        checkpoint_path=grounding_dino_checkpoint,
        device=device
    )
    print("Grounding DINO model loaded successfully.")

# Check if SAM checkpoint exists
if not os.path.exists(sam_checkpoint):
    print("❌ SAM checkpoint file not found. Please ensure it is downloaded correctly.")
else:
    # Load SAM predictor
    print("Loading SAM model and initializing predictor...")
    sam_predictor = load_sam(
        sam_checkpoint_path=sam_checkpoint,
        model_type=sam_model_type,
        device=device
    )
    print("SAM predictor initialized successfully.")


In [None]:
# Cell 21b


def get_grounding_output_updated(model, image, caption, box_threshold, text_threshold, device='cuda'):
    """
    Perform Grounding DINO inference to obtain bounding boxes and associated phrases.

    Args:
        model: The loaded Grounding DINO model.
        image (torch.Tensor): Preprocessed image tensor of shape (C, H, W).
        caption (str): Text prompt containing multiple categories separated by commas.
        box_threshold (float): Confidence threshold to filter bounding boxes.
        text_threshold (float): Confidence threshold for text recognition.
        device (str): Device to perform computation ('cuda' or 'cpu').

    Returns:
        boxes_filt (torch.Tensor): Filtered bounding boxes with shape (num_boxes, 4).
        pred_phrases (list): List of predicted phrases corresponding to each bounding box.
        scores (torch.Tensor): Confidence scores for each bounding box.
    """
    image = image.to(device)
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        outputs = model(image[None], captions=[caption])

    logits = outputs["pred_logits"].sigmoid()[0]  # Shape: (num_queries, vocab_size)
    boxes = outputs["pred_boxes"][0]              # Shape: (num_queries, 4)

    # Filter boxes with confidence threshold
    logits_max, _ = logits.max(dim=1)
    keep = logits_max > box_threshold
    logits_filt = logits[keep]
    boxes_filt = boxes[keep]

    # Tokenize caption
    tokenizer = model.tokenizer
    tokenized = tokenizer(caption, return_tensors="pt").to(device)

    # Map logits to phrases
    pred_phrases = []
    scores = []
    for logit in logits_filt:
        # Find tokens with confidence above text_threshold
        token_indices = (logit > text_threshold).nonzero(as_tuple=True)[0]
        if len(token_indices) == 0:
            continue
        tokens = tokenized["input_ids"][0][token_indices]
        phrase = tokenizer.decode(tokens)
        pred_phrases.append(phrase)
        scores.append(logit[token_indices].mean().item())

    if len(scores) == 0:
        return torch.empty((0, 4)).to(device), [], torch.empty((0,)).to(device)

    boxes_filt = boxes_filt[:len(pred_phrases)]
    scores = torch.tensor(scores).to(device)
    return boxes_filt, pred_phrases, scores


In [None]:
# Cell 22: Process Image with Grounding DINO and SAM

from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import seaborn as sns
import numpy as np
import os
import warnings
import traceback

# Suppress FutureWarnings temporarily
warnings.filterwarnings("ignore", category=FutureWarning)

# Define the image path
image_path = "/content/drive/MyDrive/multimodal_dataset/GT/outscene1208_10_0000000000.png"

# Verify if the image exists
if not os.path.exists(image_path):
    print(f"❌ Image file not found at {image_path}. Please check the path.")
else:
    # Load the image
    image_pil, image_tensor = load_image(image_path)

    # Move image tensor to the correct device
    image_tensor = image_tensor.to(device)

    # Display the image
    plt.figure(figsize=(8,8))
    plt.imshow(image_pil)
    plt.axis('off')
    plt.title("Input Image")
    plt.show()

    # Updated prompts without 'a' and unnecessary words
    prompts = [
        "concrete building",
        "metal door",
        "brick house",
        "glass window",
        "gravel pavement",
        "wooden tree trunk",
        "rubber motorcycle tire",
        "cobblestone pathway",
        "plastic bicycle frame",
        "road marking lines"
    ]

    # Combine prompts into a single caption
    caption = ', '.join(prompts)

    # Adjust thresholds as needed
    box_threshold = 0.2
    text_threshold = 0.1

    print(f"\nProcessing caption: '{caption}'")

    try:
        boxes_filt, pred_phrases, scores_filt = get_grounding_output_updated(
            model=grounding_dino_model,
            image=image_tensor,
            caption=caption,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            device=device
        )

        print(f"Detected {len(boxes_filt)} objects")
        for idx, (box, phrase) in enumerate(zip(boxes_filt, pred_phrases), 1):
            print(f"{idx}. {phrase} - Box coordinates: {box.cpu().numpy()}")

        # Use these boxes with SAM for segmentation
        if boxes_filt.shape[0] > 0:
            masks = segment_with_sam(
                image_pil=image_pil,
                boxes=boxes_filt,
                predictor=sam_predictor
            )

            # Ensure output directory exists
            output_dir = "/content/drive/MyDrive/multimodal_dataset/output"
            os.makedirs(output_dir, exist_ok=True)

            # Visualization of the masks and save them
            for idx, (mask, phrase) in enumerate(zip(masks, pred_phrases), 1):
                plt.figure(figsize=(6,6))
                plt.imshow(image_pil)
                plt.imshow(mask, alpha=0.5, cmap='jet')
                plt.title(f"Mask {idx}: {phrase}")
                plt.axis('off')

                # Save the figure
                mask_path = os.path.join(output_dir, f"mask_{idx}.png")
                plt.savefig(mask_path)
                print(f"Saved Mask {idx} to {mask_path}")

                plt.show()
        else:
            print("No objects detected for segmentation.")

        # **Visualize Raw Scores**
        if scores_filt.numel() > 0:
            sns.set(style="whitegrid")
            plt.figure(figsize=(10,6))
            sns.histplot(scores_filt.cpu().numpy().flatten(), bins=50, kde=True)
            plt.title("Distribution of Filtered Scores")
            plt.xlabel("Score")
            plt.ylabel("Frequency")
            plt.show()
        else:
            print("No raw scores available to visualize.")

    except Exception as e:
        print(f"Error processing caption: {e}")
        traceback.print_exc()


In [None]:
# Cell 23: Example Usage During Inference

from PIL import Image
import torchvision.transforms as transforms
import traceback

# Define the image path (update based on your Google Drive structure)
image_path = "/content/drive/MyDrive/multimodal_dataset/GT/outscene1208_10_0000000000.png"  # Update as needed

# Load the image
image_pil = Image.open(image_path).convert("RGB")

# Preprocess the image (ensure it matches the model's expected input)
preprocess = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))
])
image_tensor = preprocess(image_pil).to(device)

# Define your prompt (use multiple categories)
prompt = "sky, ground, building, tree, person"  # Update based on your detection needs

# Get grounding outputs
try:
    boxes_filt, pred_phrases, scores_filt = get_grounding_output_updated(
        model=grounding_dino_model,  # Ensure 'grounding_dino_model' is correctly defined
        image=image_tensor,
        caption=prompt,  # Use 'caption' instead of 'prompt'
        box_threshold=0.2,
        text_threshold=0.1,
        device=device
    )

    print(f"Detected {len(boxes_filt)} objects for prompt '{prompt}':")
    for idx, (box, phrase) in enumerate(zip(boxes_filt, pred_phrases), 1):
        print(f"{idx}. {phrase} - Box coordinates: {box.cpu().numpy()}")

    # Now you can use these boxes with SAM or for further processing

except Exception as e:
    print(f"Error processing caption: {e}")
    traceback.print_exc()


In [None]:
# Cell 24: Initialize TensorBoard (Optional)

%load_ext tensorboard

# Define training output directory
training_output_dir = "/content/drive/MyDrive/MaterialSegmentationOutput"
model_checkpoint_dir = os.path.join(training_output_dir, "checkpoints")
os.makedirs(model_checkpoint_dir, exist_ok=True)

# Initialize TensorBoard writer
writer = SummaryWriter(log_dir=os.path.join(training_output_dir, "tensorboard_logs"))


In [None]:
# Cell 25: Training and Validation Loop

from tqdm import tqdm

num_epochs = 25  # Adjust based on your requirements
patience = 5  # For early stopping

# Initialize lists to store metrics
train_losses = []
val_losses = []
val_accuracies = []

best_val_loss = float('inf')
counter = 0

# Initialize GradScaler for mixed precision training
scaler = GradScaler()

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    # Training Phase
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(train_loader, desc="Training", leave=False):
        images = images.to(device)
        masks = masks.to(device).long()

        optimizer.zero_grad()

        with autocast():
            outputs = model(images)['out']
            loss = criterion(outputs, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)

    epoch_train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)
    print(f"Train Loss: {epoch_train_loss:.4f}")

    # Validation Phase
    model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            masks = masks.to(device).long()

            with autocast():
                outputs = model(images)['out']
                loss = criterion(outputs, masks)

            val_running_loss += loss.item() * images.size(0)

            # Calculate accuracy
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == masks).sum().item()
            total += masks.numel()

    epoch_val_loss = val_running_loss / len(val_loader.dataset)
    epoch_val_accuracy = correct / total
    val_losses.append(epoch_val_loss)
    val_accuracies.append(epoch_val_accuracy)

    print(f"Validation Loss: {epoch_val_loss:.4f}")
    print(f"Validation Accuracy: {epoch_val_accuracy:.4f}")

    # Log metrics to TensorBoard
    writer.add_scalar('Train/Loss', epoch_train_loss, epoch+1)
    writer.add_scalar('Validation/Loss', epoch_val_loss, epoch+1)
    writer.add_scalar('Validation/Accuracy', epoch_val_accuracy, epoch+1)

    # Step the scheduler
    scheduler.step()

    # Early Stopping Check
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        counter = 0
        # Save the best model
        best_model_path = os.path.join(model_checkpoint_dir, "best_model.pth")
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path}")
    else:
        counter += 1
        print(f"No improvement in validation loss for {counter} epoch(s).")
        if counter >= patience:
            print("Early stopping triggered.")
            break

    # Save model checkpoint
    checkpoint_path = os.path.join(model_checkpoint_dir, f"model_epoch_{epoch + 1}.pth")
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Saved model checkpoint to {checkpoint_path}\n")


In [None]:
# Cell 26: Save and Visualize Training Metrics

# Save metrics to CSV
metrics = {
    "Epoch": list(range(1, len(train_losses) + 1)),
    "Train_Loss": train_losses,
    "Validation_Loss": val_losses,
    "Validation_Accuracy": val_accuracies
}

df_metrics = pd.DataFrame(metrics)
metrics_csv_path = os.path.join(training_output_dir, "training_metrics.csv")
df_metrics.to_csv(metrics_csv_path, index=False)
print(f"Training metrics saved to {metrics_csv_path}")

# Plot Losses
plt.figure(figsize=(10,5))
plt.plot(df_metrics['Epoch'], df_metrics['Train_Loss'], label='Train Loss')
plt.plot(df_metrics['Epoch'], df_metrics['Validation_Loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Plot Accuracy
plt.figure(figsize=(10,5))
plt.plot(df_metrics['Epoch'], df_metrics['Validation_Accuracy'], label='Validation Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# Cell 27: Define CATEGORY_COLORS for Visualization

# Define CATEGORY_COLORS
CATEGORY_COLORS = {
    "background": (0, 0, 0),
    "metal_object": (255, 0, 0),
    "wooden_door": (0, 255, 0),
    "concrete_wall": (0, 0, 255),
    "glass_window": (255, 255, 0),
    "soil": (255, 165, 0),
    "bricks": (128, 0, 128),
    "plastic": (0, 255, 255),
    "tiles": (255, 192, 203)
}


In [None]:
# Cell 28: Visualize Segmentation Results

def visualize_predictions(model, dataset, device, num_samples=5):
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)

    for idx in indices:
        image, mask = dataset[idx]
        image_input = image.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(image_input)['out'][0]
            pred_mask = torch.argmax(output, dim=0).cpu().numpy()

        # Convert tensors to NumPy arrays for visualization
        image_np = image.cpu().numpy().transpose(1, 2, 0)
        image_np = (image_np * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
        image_np = np.clip(image_np, 0, 1)

        mask_np = mask.cpu().numpy()

        # Create color masks
        color_mask = np.zeros_like(image_np)
        pred_color_mask = np.zeros_like(image_np)

        for class_name, color in CATEGORY_COLORS.items():
            class_num = list(CATEGORY_COLORS.keys()).index(class_name)  # Assuming background is 0
            color_mask[mask_np == class_num] = np.array(color) / 255.0
            pred_color_mask[pred_mask == class_num] = np.array(color) / 255.0

        # Overlay masks on the image
        overlay_true = (0.5 * image_np + 0.5 * color_mask)
        overlay_pred = (0.5 * image_np + 0.5 * pred_color_mask)

        # Plotting
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis('off')

        axs[1].imshow(mask_np, cmap='jet', alpha=0.5)
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis('off')

        axs[2].imshow(pred_mask, cmap='jet', alpha=0.5)
        axs[2].set_title("Predicted Mask")
        axs[2].axis('off')

        axs[3].imshow(overlay_pred)
        axs[3].set_title("Overlay Predicted Mask")
        axs[3].axis('off')

        plt.show()

# Visualize predictions on validation dataset
visualize_predictions(model, val_dataset, device, num_samples=5)


In [None]:
# Cell 29: Example Usage During Inference

from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Define the image path (update based on your Google Drive structure)
image_path = "/content/drive/MyDrive/multimodal_dataset/GT/outscene1208_10_0000000000.png"  # Update as needed

# Verify if the image exists
if not os.path.exists(image_path):
    print(f"❌ Image file not found at {image_path}. Please check the path.")
else:
    # Load the image
    image_pil, image_tensor = load_image(image_path)

    # Define your prompt
    prompt = "background"  # Update based on your detection needs

    # Get grounding outputs
    boxes_filt, pred_phrases = get_grounding_output(
        model=grounding_dino_model,  # Ensure 'grounding_dino_model' is correctly defined
        image=image_tensor,
        prompt=prompt,
        box_threshold=0.2,
        text_threshold=0.1,
        cpu_only=False
    )

    print(f"Detected {len(boxes_filt)} objects for prompt '{prompt}':")
    for idx, (box, phrase) in enumerate(zip(boxes_filt, pred_phrases), 1):
        print(f"{idx}. {phrase} - Box coordinates: {box.cpu().numpy()}")

    # Use these boxes with SAM for segmentation
    if len(boxes_filt) > 0:
        masks = segment_with_sam(
            image_pil=image_pil,
            boxes=boxes_filt,
            predictor=sam_predictor
        )

        # Visualization of the masks
        for idx, mask in enumerate(masks, 1):
            plt.figure(figsize=(6,6))
            plt.imshow(image_pil)
            plt.imshow(mask, alpha=0.5, cmap='jet')
            plt.title(f"Mask {idx}")
            plt.axis('off')
            plt.show()
    else:
        print("No objects detected for segmentation.")


In [None]:
# Cell 30: Evaluate Model on Test Set

def evaluate_model(model, dataloader, device, num_classes):
    model.eval()
    total_correct = 0
    total_pixels = 0
    iou_scores = np.zeros(num_classes)
    intersection = np.zeros(num_classes)
    union = np.zeros(num_classes)

    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Evaluating", leave=False):
            images = images.to(device)
            masks = masks.to(device).long()

            outputs = model(images)['out']
            preds = torch.argmax(outputs, dim=1)

            # Calculate pixel-wise accuracy
            total_correct += (preds == masks).sum().item()
            total_pixels += masks.numel()

            # Calculate IoU for each class
            for cls in range(num_classes):
                pred_inds = (preds == cls)
                target_inds = (masks == cls)
                intersection[cls] += (pred_inds & target_inds).sum().item()
                union[cls] += (pred_inds | target_inds).sum().item()

    accuracy = total_correct / total_pixels
    iou = intersection / union
    mean_iou = np.nanmean(iou)

    print(f"Test Accuracy: {accuracy:.4f}")
    for cls in range(num_classes):
        print(f"Class {cls} IoU: {iou[cls]:.4f}")
    print(f"Mean IoU: {mean_iou:.4f}")

# Define number of classes
num_classes = 9  # Adjust based on your dataset

# Evaluate on test set
evaluate_model(model, test_loader, device, num_classes)


In [None]:
# Cell 31: Visualize Model Predictions on Test Set

def visualize_test_predictions(model, dataset, device, num_samples=5):
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)

    for idx in indices:
        image, mask = dataset[idx]
        image_input = image.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(image_input)['out'][0]
            pred_mask = torch.argmax(output, dim=0).cpu().numpy()

        # Convert tensors to NumPy arrays for visualization
        image_np = image.cpu().numpy().transpose(1, 2, 0)
        image_np = (image_np * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
        image_np = np.clip(image_np, 0, 1)

        mask_np = mask.cpu().numpy()

        # Create color masks
        color_mask = np.zeros_like(image_np)
        pred_color_mask = np.zeros_like(image_np)

        for class_name, color in CATEGORY_COLORS.items():
            class_num = list(CATEGORY_COLORS.keys()).index(class_name)  # Assuming background is 0
            color_mask[mask_np == class_num] = np.array(color) / 255.0
            pred_color_mask[pred_mask == class_num] = np.array(color) / 255.0

        # Overlay masks on the image
        overlay_true = (0.5 * image_np + 0.5 * color_mask)
        overlay_pred = (0.5 * image_np + 0.5 * pred_color_mask)

        # Plotting
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis('off')

        axs[1].imshow(mask_np, cmap='jet', alpha=0.5)
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis('off')

        axs[2].imshow(pred_mask, cmap='jet', alpha=0.5)
        axs[2].set_title("Predicted Mask")
        axs[2].axis('off')

        axs[3].imshow(overlay_pred)
        axs[3].set_title("Overlay Predicted Mask")
        axs[3].axis('off')

        plt.show()

# Visualize predictions on test dataset
visualize_test_predictions(model, test_dataset, device, num_samples=5)


In [None]:
# Cell 32: Monitor GPU Utilization

!nvidia-smi


In [None]:
# Cell 33: Save Confidence Scores (Optional)

# Initialize a list to store confidence data before the training loop
confidence_records = []

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    # Training Phase
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(train_loader, desc="Training", leave=False):
        images = images.to(device)
        masks = masks.to(device).long()

        optimizer.zero_grad()

        with autocast():
            outputs = model(images)['out']
            loss = criterion(outputs, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)

    epoch_train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)
    print(f"Train Loss: {epoch_train_loss:.4f}")

    # Validation Phase
    model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            masks = masks.to(device).long()

            with autocast():
                outputs = model(images)['out']
                loss = criterion(outputs, masks)

            val_running_loss += loss.item() * images.size(0)

            # Calculate accuracy
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == masks).sum().item()
            total += masks.numel()

    epoch_val_loss = val_running_loss / len(val_loader.dataset)
    epoch_val_accuracy = correct / total
    val_losses.append(epoch_val_loss)
    val_accuracies.append(epoch_val_accuracy)

    print(f"Validation Loss: {epoch_val_loss:.4f}")
    print(f"Validation Accuracy: {epoch_val_accuracy:.4f}")

    # Record confidence scores (example: maximum predicted probability per sample)
    confidence_scores = torch.max(torch.softmax(outputs, dim=1), dim=1)[0].cpu().numpy()
    for idx, conf in enumerate(confidence_scores):
        confidence_records.append({
            "Epoch": epoch + 1,
            "Sample": idx,
            "Confidence": float(conf)
        })

    # Log metrics to TensorBoard
    writer.add_scalar('Train/Loss', epoch_train_loss, epoch+1)
    writer.add_scalar('Validation/Loss', epoch_val_loss, epoch+1)
    writer.add_scalar('Validation/Accuracy', epoch_val_accuracy, epoch+1)

    # Step the scheduler
    scheduler.step()

    # Early Stopping Check
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        counter = 0
        # Save the best model
        best_model_path = os.path.join(model_checkpoint_dir, "best_model.pth")
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path}")
    else:
        counter += 1
        print(f"No improvement in validation loss for {counter} epoch(s).")
        if counter >= patience:
            print("Early stopping triggered.")
            break

    # Save model checkpoint
    checkpoint_path = os.path.join(model_checkpoint_dir, f"model_epoch_{epoch + 1}.pth")
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Saved model checkpoint to {checkpoint_path}\n")

# After Training, Save Confidence Records to CSV
df_confidence = pd.DataFrame(confidence_records)
confidence_csv_path = os.path.join(training_output_dir, "confidence_scores.csv")
df_confidence.to_csv(confidence_csv_path, index=False)
print(f"Confidence scores saved to {confidence_csv_path}")


In [None]:
# Cell 34: Visualize Confidence Scores (Optional)

import pandas as pd
import matplotlib.pyplot as plt

# Load confidence scores
confidence_csv_path = os.path.join(training_output_dir, "confidence_scores.csv")
df_confidence = pd.read_csv(confidence_csv_path)

# Plot Confidence Scores Over Epochs
plt.figure(figsize=(10,5))
for sample in df_confidence['Sample'].unique():
    sample_data = df_confidence[df_confidence['Sample'] == sample]
    plt.plot(sample_data['Epoch'], sample_data['Confidence'], label=f'Sample {sample}')
plt.xlabel('Epoch')
plt.ylabel('Confidence')
plt.title('Confidence Scores Over Epochs')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True)
plt.show()


In [None]:
# Cell 35: Setup TensorBoard

# Launch TensorBoard
%tensorboard --logdir /content/drive/MyDrive/MaterialSegmentationOutput/tensorboard_logs


In [None]:
# Cell 36: Integrate SAM with Fine-Tuned Grounding DINO

# Define SAM checkpoint and model type
sam_checkpoint = "/content/segment-anything/checkpoints/sam_vit_b_01ec64.pth"  # Update path if different
sam_model_type = "vit_b"  # Options: "vit_b", "vit_l", "vit_h"

# Load SAM model
sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint)
sam.to(device)
sam_predictor = SamPredictor(sam)

def segment_with_sam(image_pil, boxes, predictor):
    """
    Perform segmentation using the SAM model based on bounding boxes.

    Args:
        image_pil (PIL.Image.Image): The input image.
        boxes (torch.Tensor): Bounding boxes detected by Grounding DINO.
        predictor: The SAM predictor object.

    Returns:
        List of NumPy arrays representing segmentation masks.
    """
    image_np = np.array(image_pil)
    predictor.set_image(image_np)
    masks = []

    for box in boxes:
        # Convert box to XYWH format
        x_min, y_min, x_max, y_max = box.cpu().numpy()
        width = x_max - x_min
        height = y_max - y_min
        sam_box = np.array([x_min, y_min, width, height])

        # Perform prediction
        masks_pred, _, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=sam_box,
            multimask_output=False,
        )

        # Convert mask to NumPy array
        mask_np = masks_pred[0].cpu().numpy()
        masks.append(mask_np)

    return masks


In [None]:
# Cell 37: Evaluate Model with IoU and Accuracy Metrics

def calculate_iou(pred, target, num_classes):
    ious = []
    pred = pred.flatten()
    target = target.flatten()
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum()
        union = (pred_inds | target_inds).sum()
        if union == 0:
            ious.append(float('nan'))  # If no ground truth, do not include in evaluation
        else:
            ious.append(intersection / union)
    return ious

def evaluate(model, dataloader, device, num_classes):
    model.eval()
    iou_scores = []
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device).long()

            outputs = model(images)['out']
            preds = torch.argmax(outputs, dim=1)
            for pred, mask in zip(preds, masks):
                iou = calculate_iou(pred.cpu().numpy(), mask.cpu().numpy(), num_classes)
                iou_scores.append(iou)
    # Calculate mean IoU for each class
    iou_scores = np.array(iou_scores)
    mean_ious = np.nanmean(iou_scores, axis=0)
    for cls_idx, mean_iou in enumerate(mean_ious):
        print(f"Class {cls_idx}: Mean IoU = {mean_iou:.4f}")
    # Overall Mean IoU
    overall_mean_iou = np.nanmean(mean_ious)
    print(f"Overall Mean IoU: {overall_mean_iou:.4f}")

# Define number of classes
num_classes = 9  # Adjust based on your dataset

# Evaluate on validation set
evaluate(model, val_loader, device, num_classes)

# Evaluate on test set
evaluate(model, test_loader, device, num_classes)


In [None]:
# Cell 38: Implement Early Stopping and Save Best Model

patience = 5
best_val_loss = float('inf')
counter = 0

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    # Training Phase
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(train_loader, desc="Training", leave=False):
        images = images.to(device)
        masks = masks.to(device).long()

        optimizer.zero_grad()

        with autocast():
            outputs = model(images)['out']
            loss = criterion(outputs, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)

    epoch_train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)
    print(f"Train Loss: {epoch_train_loss:.4f}")

    # Validation Phase
    model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            masks = masks.to(device).long()

            with autocast():
                outputs = model(images)['out']
                loss = criterion(outputs, masks)

            val_running_loss += loss.item() * images.size(0)

            # Calculate accuracy
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == masks).sum().item()
            total += masks.numel()

    epoch_val_loss = val_running_loss / len(val_loader.dataset)
    epoch_val_accuracy = correct / total
    val_losses.append(epoch_val_loss)
    val_accuracies.append(epoch_val_accuracy)

    print(f"Validation Loss: {epoch_val_loss:.4f}")
    print(f"Validation Accuracy: {epoch_val_accuracy:.4f}")

    # Log metrics to TensorBoard
    writer.add_scalar('Train/Loss', epoch_train_loss, epoch+1)
    writer.add_scalar('Validation/Loss', epoch_val_loss, epoch+1)
    writer.add_scalar('Validation/Accuracy', epoch_val_accuracy, epoch+1)

    # Step the scheduler
    scheduler.step()

    # Early Stopping Check
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        counter = 0
        # Save the best model
        best_model_path = os.path.join(model_checkpoint_dir, "best_model.pth")
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path}")
    else:
        counter += 1
        print(f"No improvement in validation loss for {counter} epoch(s).")
        if counter >= patience:
            print("Early stopping triggered.")
            break

    # Save model checkpoint
    checkpoint_path = os.path.join(model_checkpoint_dir, f"model_epoch_{epoch + 1}.pth")
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Saved model checkpoint to {checkpoint_path}\n")


In [None]:
# Cell 39: Integrate TensorBoard Logs

# Launch TensorBoard
%tensorboard --logdir /content/drive/MyDrive/MaterialSegmentationOutput/tensorboard_logs


In [None]:
# Cell 40: Final Evaluation on Test Set

# Load the best model
best_model_path = os.path.join(model_checkpoint_dir, "best_model.pth")
model.load_state_dict(torch.load(best_model_path))
model.to(device)
model.eval()
print(f"Loaded best model from {best_model_path}")

# Evaluate on test set
evaluate(model, test_loader, device, num_classes)

# Visualize some test predictions
visualize_test_predictions(model, test_dataset, device, num_samples=5)


In [None]:
# Cell 41: Save the Trained Model

final_model_path = os.path.join(model_checkpoint_dir, "final_model.pth")
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved to {final_model_path}")


In [None]:
# Cell 42: Shutdown TensorBoard (Optional)

# To stop TensorBoard, interrupt the cell execution or run the following:
import os
os.kill(os.getpid(), 9)
