<a href="https://colab.research.google.com/github/moksima/MemCDTedit/blob/main/GD%2BSAM%2BMat_New.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
# Cell 1: Clone the Repository Safely

import os

repo_url = "https://github.com/kyotovision-public/multimodal-material-segmentation.git"
repo_dir = "multimodal-material-segmentation"

if not os.path.exists(repo_dir):
    !git clone {repo_url}
    print(f"Cloned repository '{repo_dir}'.")
else:
    print(f"Repository '{repo_dir}' already exists. Pulling latest changes.")
    !cd {repo_dir} && git pull


Cloning into 'multimodal-material-segmentation'...
remote: Enumerating objects: 164, done.[K
remote: Counting objects: 100% (164/164), done.[K
remote: Compressing objects: 100% (120/120), done.[K
remote: Total 164 (delta 42), reused 148 (delta 33), pack-reused 0 (from 0)[K
Receiving objects: 100% (164/164), 2.75 MiB | 24.89 MiB/s, done.
Resolving deltas: 100% (42/42), done.
Cloned repository 'multimodal-material-segmentation'.


In [8]:
# Cell 2: Install Dependencies with Compatible Versions

%cd multimodal-material-segmentation

# Upgrade pip, setuptools, and wheel
!pip install --upgrade pip setuptools wheel

# Install dependencies without strict version pinning to ensure compatibility with Python 3.10
# You can customize this list based on your project's actual dependencies
!pip install absl-py aiohttp astroid async-timeout attrs cachetools certifi chardet cycler Cython future \
    google-auth google-auth-oauthlib grpcio idna idna-ssl importlib-metadata isort kiwisolver \
    lazy-object-proxy Markdown matplotlib mccabe multidict numpy transformers filelock huggingface-hub \
    safetensors tokenizers tqdm opencv-python addict pycocotools supervision timm yapf

# Additionally, install other necessary packages
!pip install torch torchvision albumentations numpy matplotlib scikit-learn opencv-python


/content/multimodal-material-segmentation/multimodal-material-segmentation


In [9]:
# Cell 3a: Install Grounding DINO and Handle Existing Clones

import os

grounding_dino_repo = "GroundingDINO"
grounding_dino_url = "https://github.com/IDEA-Research/GroundingDINO.git"

if not os.path.exists(grounding_dino_repo):
    !git clone {grounding_dino_url}
    print(f"Cloned GroundingDINO repository.")
else:
    print(f"GroundingDINO repository already exists. Pulling latest changes.")
    !cd {grounding_dino_repo} && git pull

%cd GroundingDINO

# Upgrade pip and install requirements
!pip install --upgrade pip setuptools wheel

# Install GroundingDINO in editable mode
!pip install -e .

%cd ..


Cloning into 'GroundingDINO'...
remote: Enumerating objects: 463, done.[K
remote: Counting objects:   0% (1/240)[Kremote: Counting objects:   1% (3/240)[Kremote: Counting objects:   2% (5/240)[Kremote: Counting objects:   3% (8/240)[Kremote: Counting objects:   4% (10/240)[Kremote: Counting objects:   5% (12/240)[Kremote: Counting objects:   6% (15/240)[Kremote: Counting objects:   7% (17/240)[Kremote: Counting objects:   8% (20/240)[Kremote: Counting objects:   9% (22/240)[Kremote: Counting objects:  10% (24/240)[Kremote: Counting objects:  11% (27/240)[Kremote: Counting objects:  12% (29/240)[Kremote: Counting objects:  13% (32/240)[Kremote: Counting objects:  14% (34/240)[Kremote: Counting objects:  15% (36/240)[Kremote: Counting objects:  16% (39/240)[Kremote: Counting objects:  17% (41/240)[Kremote: Counting objects:  18% (44/240)[Kremote: Counting objects:  19% (46/240)[Kremote: Counting objects:  20% (48/240)[Kremote: Counting objects: 

In [10]:
# Cell 3b: Install 'segment-anything' and Upgrade 'albumentations'

# Install segment-anything via pip
!pip install segment-anything

# Upgrade albumentations to the latest version to resolve the warning
!pip install --upgrade albumentations

print("Installed 'segment-anything' and upgraded 'albumentations' successfully.")


Installed 'segment-anything' and upgraded 'albumentations' successfully.


In [11]:
# Cell 4: Install Hugging Face Hub and Authenticate

!pip install huggingface_hub

from huggingface_hub import login
from getpass import getpass
import os

# Securely input your Hugging Face Token
hf_token = getpass("Enter your Hugging Face API Token: ")
login(token=hf_token)

# Optionally, set the token as an environment variable
os.environ['HF_TOKEN'] = hf_token

print("Hugging Face login successful.")


Enter your Hugging Face API Token: ··········
The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful
Hugging Face login successful.


In [13]:
# Cell 5: Import Necessary Libraries

import os
import random
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont
import cv2
import numpy as np
from tqdm import tqdm
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import torchvision.models as models
from google.colab import drive
from segment_anything import sam_model_registry, SamPredictor
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from transformers import BertTokenizer
import seaborn as sns
import warnings

# Suppress warnings (optional)
warnings.filterwarnings("ignore", category=FutureWarning)

print("All libraries imported successfully.")


ModuleNotFoundError: No module named 'groundingdino'

In [None]:
# Cell 6: Mount Google Drive

from google.colab import drive

drive.mount('/content/drive', force_remount=True)

print("Google Drive mounted successfully.")


In [None]:
# Cell 7: Define Custom Dataset for MCubeS

from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import torch
import os

class MaterialSegmentationDataset(Dataset):
    def __init__(self, list_file, images_dirs, masks_dirs, transforms=None,
                 image_extensions=['.jpg', '.jpeg', '.png'],
                 mask_extensions=['.npy', '.png', '.jpg', '.jpeg'],
                 grayscale_to_class=None):
        """
        Args:
            list_file (str): Path to the txt file containing filenames for the split.
            images_dirs (list): List of directories containing images.
            masks_dirs (list): List of directories containing masks.
            transforms (albumentations.Compose, optional): Transformations to apply.
            image_extensions (list): List of acceptable image file extensions.
            mask_extensions (list): List of acceptable mask file extensions.
            grayscale_to_class (dict, optional): Mapping from grayscale values to class indices.
        """
        self.transforms = transforms
        self.images_dirs = images_dirs
        self.masks_dirs = masks_dirs
        self.image_extensions = image_extensions
        self.mask_extensions = mask_extensions
        self.grayscale_to_class = grayscale_to_class if grayscale_to_class else {}

        # Read the list of filenames
        with open(list_file, 'r') as f:
            self.filenames = [line.strip() for line in f.readlines()]

        # Create a mapping from filename to image path
        self.image_paths = []
        missing_images = []
        for fname in self.filenames:
            found = False
            for img_dir in self.images_dirs:
                for ext in self.image_extensions:
                    potential_path = os.path.join(img_dir, fname + ext)
                    if os.path.isfile(potential_path):
                        self.image_paths.append(potential_path)
                        found = True
                        break  # Stop searching extensions after finding the image
                if found:
                    break  # Stop searching directories after finding the image
            if not found:
                missing_images.append(fname)

        if missing_images:
            print(f"Total missing images: {len(missing_images)}")
            for msg in missing_images[:5]:  # Print first 5 missing images
                print(f"Missing image: {msg}")

        # Create a mapping from filename to mask path
        self.mask_paths = []
        missing_masks = []
        for fname in self.filenames:
            found = False
            for mask_dir in self.masks_dirs:
                for ext in self.mask_extensions:
                    potential_mask_path = os.path.join(mask_dir, fname + ext)
                    if os.path.isfile(potential_mask_path):
                        self.mask_paths.append(potential_mask_path)
                        found = True
                        break  # Stop searching extensions after finding the mask
            if not found:
                missing_masks.append(fname)

        if missing_masks:
            print(f"Total missing masks: {len(missing_masks)}")
            for msg in missing_masks[:5]:  # Print first 5 missing masks
                print(f"Missing mask: {msg}")

        # Ensure that each image has a corresponding mask
        assert len(self.image_paths) == len(self.mask_paths), "Mismatch between images and masks."

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)

        # Load mask
        mask_path = self.mask_paths[idx]
        if mask_path.endswith('.npy'):
            mask = np.load(mask_path)
            # Ensure mask is in the correct format (e.g., single channel)
            if mask.ndim == 3:
                mask = mask[:, :, 0]
        elif mask_path.endswith(('.png', '.jpg', '.jpeg')):
            mask = Image.open(mask_path).convert("L")  # Convert to grayscale
            mask = np.array(mask)
        else:
            raise ValueError(f"Unsupported mask format: {mask_path}")

        # Map mask values to class indices
        mask = self.map_mask(mask, self.grayscale_to_class)

        # Apply transformations
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Ensure mask is of type torch.LongTensor
        if isinstance(mask, np.ndarray):
            mask = torch.from_numpy(mask).long()
        elif isinstance(mask, torch.Tensor):
            mask = mask.long()
        else:
            raise TypeError(f"Unsupported mask type: {type(mask)}")

        return image, mask

    def map_mask(self, mask, mapping):
        mapped_mask = np.full_like(mask, fill_value=-1, dtype=np.int64)  # Initialize with -1
        for grayscale_value, class_index in mapping.items():
            mapped_mask[mask == grayscale_value] = class_index

        # Check for unmapped values
        unmapped_pixels = (mapped_mask == -1)
        if np.any(unmapped_pixels):
            unique_unmapped_values = np.unique(mask[unmapped_pixels])
            print(f"Warning: Found unmapped grayscale values in mask: {unique_unmapped_values}")
            # Assign unmapped pixels to background class or any other class
            mapped_mask[unmapped_pixels] = 0  # Assign to background class

        return mapped_mask


In [None]:
# Cell 8: Collect All Unique Grayscale Values from Train, Validation, and Test Sets

from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm

# Define directories containing images and masks
images_dirs = [
    "/content/drive/MyDrive/multimodal_dataset/polL_color"
    # Add more image directories if applicable
]

masks_dirs = [
    "/content/drive/MyDrive/multimodal_dataset/GT"
    # Add more mask directories if applicable
]

# Define list files
train_list = "/content/drive/MyDrive/multimodal_dataset/list_folder/train.txt"
val_list = "/content/drive/MyDrive/multimodal_dataset/list_folder/val.txt"
test_list = "/content/drive/MyDrive/multimodal_dataset/list_folder/test.txt"

# Initialize datasets for all splits to collect unique grayscale values
train_dataset = MaterialSegmentationDataset(
    list_file=train_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=None  # No transforms needed
)

val_dataset = MaterialSegmentationDataset(
    list_file=val_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=None  # No transforms needed
)

test_dataset = MaterialSegmentationDataset(
    list_file=test_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=None  # No transforms needed
)

# Collect unique grayscale values from all splits
unique_grayscale = set()

print("Collecting unique grayscale values from all splits...")
for split_name, dataset in zip(['Train', 'Validation', 'Test'], [train_dataset, val_dataset, test_dataset]):
    split_unique = set()
    for _, mask in tqdm(dataset, desc=f"Processing {split_name} Set", leave=False):
        split_unique.update(np.unique(mask.numpy()))
    unique_grayscale.update(split_unique)
    print(f"Unique grayscale values in {split_name} set: {sorted(split_unique)}")

print(f"\nTotal unique grayscale values across all splits: {sorted(unique_grayscale)}")

# Define the comprehensive grayscale_to_class mapping
grayscale_to_class = {
    0: 0,     # Background
    1: 1,     # Asphalt
    2: 2,     # Concrete
    3: 3,     # Metal
    4: 4,     # Road Marking
    5: 5,     # Gravel
    6: 6,     # Fabric
    7: 7,     # Glass
    8: 8,     # Plaster
    9: 9,     # Plastic
    10: 10,   # Rubber
    11: 11,   # Sand
    12: 12,   # Ceramic
    13: 13,   # Cobblestone
    14: 14,   # Brick
    15: 15,   # Grass
    16: 16,   # Wood
    17: 17,   # Leaf
    18: 18,   # Water
    19: 19,   # Human Body
    20: 20,   # Sky
    255: 0     # Optional: Map 255 to Background
}

# Check if all grayscale values are mapped
unmapped_grayscale = unique_grayscale - set(grayscale_to_class.keys())
if len(unmapped_grayscale) > 0:
    print(f"\nError: The following grayscale values are not mapped to any class: {sorted(unmapped_grayscale)}")
    print("Please update 'grayscale_to_class' to include these values.")
else:
    print("\nAll grayscale values are successfully mapped to class indices.")

# Optionally, save the mapping for future reference
mapping_df = pd.DataFrame(list(grayscale_to_class.items()), columns=['Grayscale', 'Class_Index'])
mapping_csv_path = "/content/drive/MyDrive/multimodal_dataset/grayscale_to_class_mapping.csv"
mapping_df.to_csv(mapping_csv_path, index=False)
print(f"Grayscale to Class mapping saved to {mapping_csv_path}")


In [None]:
# Cell 9: Update Dataset Instances and DataLoaders with Complete Mapping

import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Assuming 'grayscale_to_class' has been defined in Cell 8 and includes all necessary mappings

# Define data transformations
# Define transformations for training
train_transforms = A.Compose([
    A.Resize(height=512, width=512),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

# Define transformations for validation and testing
val_transforms = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

# Update the grayscale_to_class mapping based on Cell 8
# Ensure that the mapping correctly reflects your dataset's class definitions

# Example:
# grayscale_to_class = {
#     0: 0,     # Background
#     1: 1,     # Class 1
#     2: 2,     # Class 2
#     ...
#     255: 20,  # Optional: Special label
# }

# Create Dataset instances with the updated mapping
train_dataset = MaterialSegmentationDataset(
    list_file=train_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=train_transforms,
    grayscale_to_class=grayscale_to_class
)

val_dataset = MaterialSegmentationDataset(
    list_file=val_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=val_transforms,
    grayscale_to_class=grayscale_to_class
)

test_dataset = MaterialSegmentationDataset(
    list_file=test_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=val_transforms,  # Typically, no augmentation for test
    grayscale_to_class=grayscale_to_class
)

# Define DataLoaders
batch_size = 8
num_workers = 4  # Adjust based on your environment

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

# Verify Dataset Lengths
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")

# Visualize a Batch from the Training Loader

def visualize_batch(images, masks, batch_size=4):
    images = images.permute(0, 2, 3, 1).cpu().numpy()
    images = (images * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
    images = np.clip(images, 0, 1)

    masks = masks.cpu().numpy()

    plt.figure(figsize=(20, 10))
    for i in range(batch_size):
        plt.subplot(2, batch_size, i+1)
        plt.imshow(images[i])
        plt.title("Image")
        plt.axis('off')

        plt.subplot(2, batch_size, batch_size + i + 1)
        plt.imshow(masks[i], cmap='jet')
        plt.title("Mask")
        plt.axis('off')
    plt.show()

# Get a batch from the training loader
for images, masks in train_loader:
    visualize_batch(images, masks, batch_size=4)
    break  # Only visualize one batch


In [None]:
# Cell 10: Update Dataset Instances and DataLoaders with Complete Mapping

import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Define data transformations
# Define transformations for training
train_transforms = A.Compose([
    A.Resize(height=512, width=512),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

# Define transformations for validation and testing
val_transforms = A.Compose([
    A.Resize(height=512, width=512),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2()
], additional_targets={'mask': 'mask'})

# Ensure that 'grayscale_to_class' has been defined and includes all necessary mappings
# Example:
# grayscale_to_class = {
#     0: 0,    # Background
#     1: 1,    # Concrete
#     2: 2,    # Asphalt
#     3: 3,    # Grass
#     4: 4,    # Water
#     # Add other mappings based on your observations
#     255: 20  # Optional: Special label
# }

# Create Dataset instances with the updated mapping
train_dataset = MaterialSegmentationDataset(
    list_file=train_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=train_transforms,
    grayscale_to_class=grayscale_to_class
)

val_dataset = MaterialSegmentationDataset(
    list_file=val_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=val_transforms,
    grayscale_to_class=grayscale_to_class
)

test_dataset = MaterialSegmentationDataset(
    list_file=test_list,
    images_dirs=images_dirs,
    masks_dirs=masks_dirs,
    transforms=val_transforms,  # Typically, no augmentation for test
    grayscale_to_class=grayscale_to_class
)

# Define DataLoaders
batch_size = 8
num_workers = 4  # Adjust based on your environment

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)

# Verify Dataset Lengths
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")

# Visualize a Batch from the Training Loader

def visualize_batch(images, masks, batch_size=4):
    images = images.permute(0, 2, 3, 1).cpu().numpy()
    images = (images * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
    images = np.clip(images, 0, 1)

    masks = masks.cpu().numpy()

    plt.figure(figsize=(20, 10))
    for i in range(batch_size):
        plt.subplot(2, batch_size, i+1)
        plt.imshow(images[i])
        plt.title("Image")
        plt.axis('off')

        plt.subplot(2, batch_size, batch_size + i + 1)
        plt.imshow(masks[i], cmap='jet')
        plt.title("Mask")
        plt.axis('off')
    plt.show()

# Get a batch from the training loader
for images, masks in train_loader:
    visualize_batch(images, masks, batch_size=4)
    break  # Only visualize one batch


In [None]:
# Cell 11: Verify Dataset and DataLoaders

num_classes = 13  # Updated to match the MCubeS dataset

# Function to collect unique labels in masks
def collect_unique_labels(dataset):
    unique_values = set()
    for idx in range(len(dataset)):
        _, mask = dataset[idx]
        unique_values.update(np.unique(mask.numpy()))
    return unique_values

# Collect unique labels from training dataset
unique_values_train = collect_unique_labels(train_dataset)
print(f"Unique labels in training masks: {unique_values_train}")

# Collect unique labels from validation dataset
unique_values_val = collect_unique_labels(val_dataset)
print(f"Unique labels in validation masks: {unique_values_val}")

# Collect unique labels from test dataset
unique_values_test = collect_unique_labels(test_dataset)
print(f"Unique labels in test masks: {unique_values_test}")

# Check class distribution in training dataset
def compute_class_distribution(dataset, num_classes):
    class_counts = np.zeros(num_classes, dtype=np.int64)
    for idx in range(len(dataset)):
        _, mask = dataset[idx]
        mask_np = mask.numpy()
        for cls in range(num_classes):
            class_counts[cls] += np.sum(mask_np == cls)
    total_pixels = np.sum(class_counts)
    for cls in range(num_classes):
        percentage = (class_counts[cls] / total_pixels) * 100
        print(f"Class {cls}: {class_counts[cls]} pixels ({percentage:.2f}%)")
    return class_counts

print("\nClass distribution in training dataset:")
class_counts_train = compute_class_distribution(train_dataset, num_classes)

# Optional: Visualize class distribution
plt.figure(figsize=(10,6))
sns.barplot(x=list(range(num_classes)), y=class_counts_train)
plt.xlabel('Class')
plt.ylabel('Pixel Count')
plt.title('Class Distribution in Training Dataset')
plt.show()


In [None]:
# Cell 12: Define and Initialize the Segmentation Model

import torch
from torchvision import models

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize DeepLabV3 with ResNet-50 backbone
model = models.segmentation.deeplabv3_resnet50(weights='DEFAULT')

# Modify the classifier to match the number of classes
num_classes = 13  # Updated to match the MCubeS dataset

# Replace the classifier with a new one (DeepLabHead)
model.classifier = models.segmentation.deeplabv3.DeepLabHead(2048, num_classes)
print("Model classifier modified successfully.")

# Move model to device
model = model.to(device)
print("Model moved to device successfully.")


In [None]:
# Cell 13: Load Grounding DINO and SAM Models

import torch
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from segment_anything import sam_model_registry, SamPredictor
from transformers import BertTokenizer
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import os

def load_grounding_dino(config_path, checkpoint_path, device='cuda'):
    """
    Load the Grounding DINO model from configuration and checkpoint files.
    """
    # Load configuration
    cfg = SLConfig.fromfile(config_path)

    # Build model
    model = build_model(cfg)

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)

    # Move model to device and set to evaluation mode
    model.to(device)
    model.eval()

    return model

def load_sam(sam_checkpoint_path, model_type="vit_b", device='cuda'):
    """
    Load the SAM model and initialize the predictor.
    """
    # Register and load SAM model
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path)
    sam.to(device)
    sam_predictor = SamPredictor(sam)

    return sam_predictor

def load_image(image_path):
    """
    Load an image from the specified path and preprocess it.
    """
    image_pil = Image.open(image_path).convert("RGB")

    preprocess = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])
    image_tensor = preprocess(image_pil)

    return image_pil, image_tensor

def get_grounding_output(model, image, caption, box_threshold, text_threshold, device='cuda'):
    """
    Perform Grounding DINO inference to obtain bounding boxes and associated phrases.

    Args:
        model: The loaded Grounding DINO model.
        image (torch.Tensor): Preprocessed image tensor of shape (C, H, W).
        caption (str): Text prompt containing multiple categories separated by commas.
        box_threshold (float): Confidence threshold to filter bounding boxes.
        text_threshold (float): Confidence threshold for text recognition.
        device (str): Device to perform computation ('cuda' or 'cpu').

    Returns:
        boxes_filt (torch.Tensor): Filtered bounding boxes with shape (num_boxes, 4).
        pred_phrases (list): List of predicted phrases corresponding to each bounding box.
        scores (torch.Tensor): Confidence scores for each bounding box.
    """
    image = image.to(device)
    model = model.to(device)
    model.eval()

    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    with torch.no_grad():
        outputs = model(image[None], text=[caption])

    # Check if 'pred_logits' and 'pred_boxes' are in outputs
    if 'pred_logits' not in outputs or 'pred_boxes' not in outputs:
        print("Model outputs do not contain 'pred_logits' or 'pred_boxes'.")
        return torch.empty((0, 4)).to(device), [], torch.empty((0,)).to(device)

    logits = outputs["pred_logits"].sigmoid()[0]  # Shape: (num_queries, vocab_size)
    boxes = outputs["pred_boxes"][0]              # Shape: (num_queries, 4)

    # Filter boxes with confidence threshold
    logits_max, _ = logits.max(dim=1)
    keep = logits_max > box_threshold
    if keep.sum() == 0:
        print("No boxes above box_threshold.")
        return torch.empty((0, 4)).to(device), [], torch.empty((0,)).to(device)
    logits_filt = logits[keep]
    boxes_filt = boxes[keep]

    # Tokenize caption
    tokenized = tokenizer(caption, return_tensors="pt").to(device)

    # Map logits to phrases
    pred_phrases = []
    scores = []
    for logit in logits_filt:
        # Find tokens with confidence above text_threshold
        token_indices = (logit > text_threshold).nonzero(as_tuple=True)[0]
        if len(token_indices) == 0:
            continue
        tokens = tokenized["input_ids"][0][token_indices]
        phrase = tokenizer.decode(tokens)
        pred_phrases.append(phrase)
        scores.append(logit[token_indices].mean().item())

    if len(scores) == 0:
        print("No scores above text_threshold.")
        return torch.empty((0, 4)).to(device), [], torch.empty((0,)).to(device)

    boxes_filt = boxes_filt[:len(pred_phrases)]
    scores = torch.tensor(scores).to(device)
    return boxes_filt, pred_phrases, scores

def segment_with_sam(image_pil, boxes, predictor):
    """
    Perform segmentation using the SAM model based on bounding boxes.

    Args:
        image_pil (PIL.Image.Image): The input image.
        boxes (torch.Tensor): Bounding boxes detected by Grounding DINO.
        predictor: The SAM predictor object.

    Returns:
        List of NumPy arrays representing segmentation masks.
    """
    if boxes.numel() == 0:
        print("No boxes provided for segmentation.")
        return []

    image_np = np.array(image_pil)
    predictor.set_image(image_np)
    masks = []

    transformed_boxes = predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2])

    for box in transformed_boxes:
        # Perform prediction
        masks_pred, _, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=box.unsqueeze(0).to(device),
            multimask_output=False,
        )

        # Convert mask to NumPy array
        mask_np = masks_pred[0].cpu().numpy()
        masks.append(mask_np)

    return masks

print("Helper functions loaded successfully.")


In [None]:
# Cell 14: Upgrade `timm` to the Latest Version

!pip install --upgrade timm

print("Upgraded timm to the latest version.")


In [None]:
# Cell 15: Verify Helper Functions Defined Correctly

try:
    print(load_grounding_dino)
    print(load_sam)
    print(load_image)
    print(get_grounding_output)
    print(segment_with_sam)
    print("All helper functions are defined correctly.")
except NameError as e:
    print(f"Error: {e}")


In [None]:
# Cell 16: Define Loss Function, Optimizer, and Scheduler with Class Weights

import torch.nn as nn
import torch.optim as optim
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
import torch

# Compute class weights based on class counts in the training dataset
all_masks = []
for _, mask in train_dataset:
    all_masks.extend(mask.numpy().flatten())
all_masks = np.array(all_masks)

# Get unique labels in all_masks
unique_labels = np.unique(all_masks)
print(f"Unique labels in training masks: {unique_labels}")

# Define number of classes
num_classes = 13  # Ensure this matches your dataset

# Compute class weights using only the classes present in the training data
class_weights_raw = compute_class_weight(class_weight='balanced',
                                         classes=unique_labels,
                                         y=all_masks)

# Initialize class_weights array with ones
class_weights = np.ones(num_classes, dtype=np.float32)

# Assign the computed weights to the corresponding indices
for i, cls in enumerate(unique_labels):
    class_weights[int(cls)] = class_weights_raw[i]

# Convert class_weights to a tensor and move to the appropriate device
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
print(f"Class Weights: {class_weights}")

# Loss function with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer
learning_rate = 1e-4
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


In [None]:
# Cell 17: Verify Mask Classes in All Datasets

def verify_mask_classes(dataset, num_classes):
    unique_classes = set()
    for idx in range(len(dataset)):
        _, mask = dataset[idx]
        unique_classes.update(mask.numpy().flatten())
    print(f"Unique classes in the dataset: {sorted(unique_classes)}")
    missing_classes = set(range(num_classes)) - unique_classes
    if missing_classes:
        print(f"Warning: The following classes are missing in the dataset: {missing_classes}")
    else:
        print("All classes are present in the dataset.")

# Verify training dataset
print("Verifying training dataset:")
verify_mask_classes(train_dataset, num_classes=13)

# Verify validation dataset
print("\nVerifying validation dataset:")
verify_mask_classes(val_dataset, num_classes=13)

# Verify test dataset
print("\nVerifying test dataset:")
verify_mask_classes(test_dataset, num_classes=13)


In [None]:
# Cell 18: Visualize Sample Images and Masks

import matplotlib.pyplot as plt

def visualize_sample(dataset, idx):
    image, mask = dataset[idx]
    image_np = image.cpu().numpy().transpose(1, 2, 0)
    image_np = (image_np * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
    image_np = np.clip(image_np, 0, 1)

    mask_np = mask.cpu().numpy()

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image_np)
    plt.title("Image")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(mask_np, cmap='jet')
    plt.title("Mask")
    plt.axis('off')

    plt.show()

# Visualize first 3 samples from training dataset
for i in range(3):
    visualize_sample(train_dataset, i)


In [None]:
# Cell 19: Initialize the Segmentation Model

import torch
from torchvision import models

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize DeepLabV3 with ResNet-50 backbone
model = models.segmentation.deeplabv3_resnet50(weights='DEFAULT')

# Modify the classifier to match the number of classes
num_classes = 13  # Updated to match the MCubeS dataset

# Replace the classifier with a new one (DeepLabHead)
model.classifier = models.segmentation.deeplabv3.DeepLabHead(2048, num_classes)
print("Model classifier modified successfully.")

# Move model to device
model = model.to(device)
print("Model moved to device successfully.")


In [None]:
# Cell 20: Define and Initialize the SAM Predictor

import torch
from segment_anything import sam_model_registry, SamPredictor
import os

def load_sam_predictor(sam_checkpoint_path, model_type="vit_b", device='cuda'):
    """
    Load the SAM model and initialize the predictor.
    """
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint_path)
    sam.to(device)
    sam_predictor = SamPredictor(sam)
    return sam_predictor

# Define SAM checkpoint path
sam_checkpoint_path = "/content/segment-anything/checkpoints/sam_vit_b_01ec64.pth"

# Verify if the SAM checkpoint exists
if not os.path.exists(sam_checkpoint_path):
    print(f"❌ SAM checkpoint file not found at {sam_checkpoint_path}. Please download it and place it in the specified directory.")
else:
    # Load SAM predictor
    sam_predictor = load_sam_predictor(
        sam_checkpoint_path=sam_checkpoint_path,
        model_type="vit_b",
        device=device
    )
    print("SAM predictor initialized successfully.")


In [None]:
# Cell 21: Define the Training Loop

from torch.utils.tensorboard import SummaryWriter
import torch

# Initialize TensorBoard writer
writer = SummaryWriter("/content/drive/MyDrive/MaterialSegmentationOutput/tensorboard_logs")

num_epochs = 25
patience = 5  # For early stopping

# Initialize lists to store metrics
train_losses = []
val_losses = []
val_accuracies = []

best_val_loss = float('inf')
counter = 0

# Initialize GradScaler for mixed precision training
scaler = GradScaler()

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    # Training Phase
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(train_loader, desc="Training", leave=False):
        images = images.to(device)
        masks = masks.to(device).long()

        optimizer.zero_grad()

        with autocast():
            outputs = model(images)['out']
            loss = criterion(outputs, masks)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)

    epoch_train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)
    print(f"Train Loss: {epoch_train_loss:.4f}")

    # Validation Phase
    model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc="Validating", leave=False):
            images = images.to(device)
            masks = masks.to(device).long()

            with autocast():
                outputs = model(images)['out']
                loss = criterion(outputs, masks)

            val_running_loss += loss.item() * images.size(0)

            # Calculate accuracy
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == masks).sum().item()
            total += masks.numel()

    epoch_val_loss = val_running_loss / len(val_loader.dataset)
    epoch_val_accuracy = correct / total
    val_losses.append(epoch_val_loss)
    val_accuracies.append(epoch_val_accuracy)

    print(f"Validation Loss: {epoch_val_loss:.4f}")
    print(f"Validation Accuracy: {epoch_val_accuracy:.4f}")

    # Log metrics to TensorBoard
    writer.add_scalar('Train/Loss', epoch_train_loss, epoch+1)
    writer.add_scalar('Validation/Loss', epoch_val_loss, epoch+1)
    writer.add_scalar('Validation/Accuracy', epoch_val_accuracy, epoch+1)

    # Step the scheduler
    scheduler.step()

    # Early Stopping Check
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        counter = 0
        # Save the best model
        best_model_path = "/content/drive/MyDrive/MaterialSegmentationOutput/checkpoints/best_model.pth"
        os.makedirs(os.path.dirname(best_model_path), exist_ok=True)
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path}")
    else:
        counter += 1
        print(f"No improvement in validation loss for {counter} epoch(s).")
        if counter >= patience:
            print("Early stopping triggered.")
            break

    # Save model checkpoint
    checkpoint_path = f"/content/drive/MyDrive/MaterialSegmentationOutput/checkpoints/model_epoch_{epoch + 1}.pth"
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Saved model checkpoint to {checkpoint_path}\n")

print("Training completed.")
writer.close()


In [None]:
# Cell 22: Save and Visualize Training Metrics

import pandas as pd
import matplotlib.pyplot as plt

# Save metrics to CSV
metrics = {
    "Epoch": list(range(1, len(train_losses) + 1)),
    "Train_Loss": train_losses,
    "Validation_Loss": val_losses,
    "Validation_Accuracy": val_accuracies
}

df_metrics = pd.DataFrame(metrics)
metrics_csv_path = "/content/drive/MyDrive/MaterialSegmentationOutput/training_metrics.csv"
df_metrics.to_csv(metrics_csv_path, index=False)
print(f"Training metrics saved to {metrics_csv_path}")

# Plot Losses
plt.figure(figsize=(10,5))
plt.plot(df_metrics['Epoch'], df_metrics['Train_Loss'], label='Train Loss')
plt.plot(df_metrics['Epoch'], df_metrics['Validation_Loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Plot Accuracy
plt.figure(figsize=(10,5))
plt.plot(df_metrics['Epoch'], df_metrics['Validation_Accuracy'], label='Validation Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# Cell 23: Launch TensorBoard

%load_ext tensorboard

# Launch TensorBoard
%tensorboard --logdir /content/drive/MyDrive/MaterialSegmentationOutput/tensorboard_logs


In [None]:
# Cell 24: Evaluate the Model on the Test Set

def calculate_iou(pred, target, num_classes):
    ious = []
    pred = pred.flatten()
    target = target.flatten()
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = (pred_inds & target_inds).sum()
        union = (pred_inds | target_inds).sum()
        if union == 0:
            ious.append(float('nan'))  # If no ground truth, do not include in evaluation
        else:
            ious.append(intersection / union)
    return ious

def evaluate(model, dataloader, device, num_classes):
    model.eval()
    iou_scores = []
    accuracy_total = 0
    total_pixels = 0
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc="Evaluating", leave=False):
            images = images.to(device)
            masks = masks.to(device).long()

            outputs = model(images)['out']
            preds = torch.argmax(outputs, dim=1)

            # Calculate pixel-wise accuracy
            correct = (preds == masks).sum().item()
            total = masks.numel()
            accuracy_total += correct
            total_pixels += total

            for pred, mask in zip(preds, masks):
                iou = calculate_iou(pred.cpu().numpy(), mask.cpu().numpy(), num_classes)
                iou_scores.append(iou)

    accuracy = accuracy_total / total_pixels
    # Calculate mean IoU for each class
    iou_scores = np.array(iou_scores)
    mean_ious = np.nanmean(iou_scores, axis=0)
    for cls_idx, mean_iou in enumerate(mean_ious):
        print(f"Class {cls_idx}: Mean IoU = {mean_iou:.4f}")
    # Overall Mean IoU
    overall_mean_iou = np.nanmean(mean_ious)
    print(f"Overall Mean IoU: {overall_mean_iou:.4f}")
    print(f"Pixel-wise Accuracy: {accuracy:.4f}")

# Define number of classes
num_classes = 13  # Ensure this matches your dataset

# Evaluate on test set
evaluate(model, test_loader, device, num_classes)


In [None]:
# Cell 25: Visualize Model Predictions on Test Set

def visualize_test_predictions(model, dataset, device, num_samples=5):
    model.eval()
    indices = random.sample(range(len(dataset)), num_samples)

    for idx in indices:
        image, mask = dataset[idx]
        image_input = image.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(image_input)['out'][0]
            pred_mask = torch.argmax(output, dim=0).cpu().numpy()

        # Convert tensors to NumPy arrays for visualization
        image_np = image.cpu().numpy().transpose(1, 2, 0)
        image_np = (image_np * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
        image_np = np.clip(image_np, 0, 1)

        mask_np = mask.cpu().numpy()

        # Create color masks
        color_mask = np.zeros_like(image_np)
        pred_color_mask = np.zeros_like(image_np)

        for cls in range(num_classes):
            if cls in CATEGORY_COLORS:
                color = CATEGORY_COLORS[cls]
                color_mask[mask_np == cls] = np.array(color) / 255.0
                pred_color_mask[pred_mask == cls] = np.array(color) / 255.0

        # Overlay masks on the image
        overlay_true = (0.5 * image_np + 0.5 * color_mask)
        overlay_pred = (0.5 * image_np + 0.5 * pred_color_mask)

        # Plotting
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        axs[0].imshow(image_np)
        axs[0].set_title("Original Image")
        axs[0].axis('off')

        axs[1].imshow(mask_np, cmap='jet', alpha=0.5)
        axs[1].set_title("Ground Truth Mask")
        axs[1].axis('off')

        axs[2].imshow(pred_mask, cmap='jet', alpha=0.5)
        axs[2].set_title("Predicted Mask")
        axs[2].axis('off')

        axs[3].imshow(overlay_pred)
        axs[3].set_title("Overlay Predicted Mask")
        axs[3].axis('off')

        plt.show()

# Define CATEGORY_COLORS for visualization
CATEGORY_COLORS = {
    0: (0, 0, 0),         # Background
    1: (128, 0, 0),       # Class 1
    2: (0, 128, 0),       # Class 2
    3: (128, 128, 0),     # Class 3
    4: (0, 0, 128),       # Class 4
    5: (128, 0, 128),     # Class 5
    6: (0, 128, 128),     # Class 6
    7: (128, 128, 128),   # Class 7
    8: (64, 0, 0),        # Class 8
    9: (192, 0, 0),       # Class 9
    10: (64, 128, 0),     # Class 10
    11: (192, 128, 0),    # Class 11
    12: (64, 0, 128),     # Class 12
}

# Visualize predictions on test dataset
visualize_test_predictions(model, test_dataset, device, num_samples=5)


In [None]:
# Cell 26: Shutdown TensorBoard (Optional)

import os
os.kill(os.getpid(), 9)
