# SDM-D: Segmentation-Description-Matching-Distilling

This notebook implements the SDM-D framework for fruit detection and segmentation without manual annotation.

**Framework Overview:**
- **SDM**: Segmentation-Description-Matching using SAM2 and OpenCLIP
- **SDM-D**: Complete framework including knowledge distillation to smaller models

## Requirements
Make sure you have installed all dependencies from `requirements.txt` and have SAM2 and OpenCLIP properly set up.

## 1. Import Required Libraries

In [1]:

import torch
import os
import open_clip

import sys
sys.path.insert(0, os.path.join(os.getcwd(), 'sam2'))

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

from archive.utils import load_descriptions, create_output_folders
from archive.utils import generate_all_sam_mask, label_assignment

## 2. Configuration Parameters

Set your parameters here instead of using command line arguments:

In [2]:
# Configuration parameters - Modify these according to your needs
class Config:
    def __init__(self):
        # Required parameters
        self.image_folder = './Images/strawberry'  # Path to the image segmentation folder
        self.out_folder = './output/strawberry'    # Path to save mask outputs
        self.des_file = '../description/straw_des.txt'  # Path to your prompt texts
        
        # Optional parameters
        self.sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"  # SAM2 model checkpoint path
        self.model_cfg = "sam2_hiera_l.yaml"  # SAM2 model config file
        self.enable_mask_nms = True  # Whether to apply NMS to masks
        self.mask_nms_thresh = 0.9  # Threshold for NMS mask overlap
        self.save_anns = True  # Whether to save mask annotations
        self.save_json = False  # Whether to save json
        self.box_visual = False  # Whether to visual results
        self.mask_color_visual = False  # Whether to visual mask results with color

# Create configuration instance
opt = Config()

# Display current configuration
print("Current Configuration:")
print(f"Image folder: {opt.image_folder}")
print(f"Output folder: {opt.out_folder}")
print(f"Description file: {opt.des_file}")
print(f"SAM2 checkpoint: {opt.sam2_checkpoint}")
print(f"Enable mask NMS: {opt.enable_mask_nms}")
print(f"Mask color visual: {opt.mask_color_visual}")
print(f"Box visual: {opt.box_visual}")

Current Configuration:
Image folder: ./Images/strawberry
Output folder: ./output/strawberry
Description file: ./description/straw_des.txt
SAM2 checkpoint: ./checkpoints/sam2_hiera_large.pt
Enable mask NMS: True
Mask color visual: False
Box visual: False


## 3. Setup Folder Structure

In [3]:
# Generate folder directories
image_folder = opt.image_folder
out_folder = opt.out_folder
enable_mask_nms = opt.enable_mask_nms
save_anns = opt.save_anns
save_json = opt.save_json
mask_color = opt.mask_color_visual
lable_box_visual = opt.box_visual
mask_nms_thresh = opt.mask_nms_thresh

# Create output directories
masks_segs_folder = os.path.join(out_folder, 'mask')
json_save_dir = os.path.join(out_folder, 'json')
label_output_dir = os.path.join(out_folder, 'labels')
mask_ids_visual_folder = os.path.join(out_folder, 'mask_idx_visual')
label_box_visual_dir = os.path.join(out_folder, 'label_box_visual')
mask_color_visual_dir = os.path.join(out_folder, 'mask_color_visual')

# Create all necessary folders
create_output_folders(out_folder)

print("✅ Output folders created successfully!")

Created folder: ./output/strawberry/mask
Created folder: ./output/strawberry/json
Created folder: ./output/strawberry/labels
Created folder: ./output/strawberry/mask_idx_visual
Created folder: ./output/strawberry/label_box_visual
Created folder: ./output/strawberry/mask_color_visual
✅ Output folders created successfully!


## 4. Load Descriptions and Labels

In [4]:
# Load descriptions from file
texts, labels, label_dict = load_descriptions(opt.des_file)

print("Loaded descriptions:")
for i, (text, label) in enumerate(zip(texts, labels)):
    print(f"{i}: '{text}' -> {label}")
    
print(f"\nLabel dictionary: {label_dict}")

Loaded descriptions:
0: 'a red strawberry with numerous points' -> ripe
1: 'a pale green strawberry with numerous points' -> unripe
2: 'a green veined strawberry leaf' -> leaf
3: 'a long and thin stem' -> stem
4: 'a white flower' -> flower
5: 'soil or background or something else' -> others

Label dictionary: {'ripe': 0, 'unripe': 1, 'leaf': 2, 'stem': 3, 'flower': 4, 'others': 5}


## 5. Initialize OpenCLIP Model

In [5]:
# Initialize OpenCLIP model
torch.cuda.set_device(0)
clip_model, _, clip_preprocessor = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')

clip_model = clip_model.to('cuda')

# Enable autocast for better performance
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

print("✅ OpenCLIP model initialized successfully!")
print(f"Model device: {next(clip_model.parameters()).device}")

✅ OpenCLIP model initialized successfully!
Model device: cuda:0


## 6. Initialize SAM2 Model

In [6]:
# Initialize SAM2 model
sam2 = build_sam2(opt.model_cfg, opt.sam2_checkpoint, device='cuda', apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_side=32, min_mask_region_area=50)

print(f"✅ SAM2 model initialized successfully!")
print(f"Your enable_mask_nms is {opt.enable_mask_nms}!")

✅ SAM2 model initialized successfully!
Your enable_mask_nms is True!


## 7. Generate All SAM Masks

This step processes all images and generates segmentation masks using SAM2:

In [7]:
# Generate all masks
print("🚀 Starting mask generation...")
generate_all_sam_mask(
    mask_generator, 
    image_folder, 
    masks_segs_folder, 
    json_save_dir, 
    mask_ids_visual_folder, 
    enable_mask_nms, 
    mask_nms_thresh, 
    save_anns, 
    save_json
)
print("✅ Mask generation completed!")

🚀 Starting mask generation...
Error with file label: OpenCV(4.11.0) /io/opencv/modules/imgproc/src/color.cpp:199: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

Error with file img: OpenCV(4.11.0) /io/opencv/modules/imgproc/src/color.cpp:199: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

Error with file label: OpenCV(4.11.0) /io/opencv/modules/imgproc/src/color.cpp:199: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

Error with file img: OpenCV(4.11.0) /io/opencv/modules/imgproc/src/color.cpp:199: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

Error with file label: OpenCV(4.11.0) /io/opencv/modules/imgproc/src/color.cpp:199: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

Error with file img: OpenCV(4.11.0) /io/opencv/modules/imgproc/src/color.cpp:199: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

✅ Mask generation completed!


## 8. Label Assignment

This step assigns labels to the generated masks using OpenCLIP:

In [8]:
# Label assignment
print("🏷️ Starting label assignment...")
label_assignment(
    clip_preprocessor, 
    image_folder, 
    masks_segs_folder, 
    label_output_dir, 
    label_box_visual_dir, 
    mask_color_visual_dir, 
    clip_model, 
    texts, 
    labels, 
    label_dict, 
    lable_box_visual, 
    mask_color
)
print("✅ Label assignment completed!")

🏷️ Starting label assignment...


IsADirectoryError: [Errno 21] Is a directory: './Images/strawberry/train/label'

## 9. Results Summary

Display information about the generated outputs:

In [None]:
# Summary of results
print("\n📊 Processing Summary:")
print("=" * 50)

# Count generated files
if os.path.exists(masks_segs_folder):
    mask_count = sum([len(files) for r, d, files in os.walk(masks_segs_folder)])
    print(f"🎭 Generated masks: {mask_count}")

if os.path.exists(label_output_dir):
    label_count = sum([len([f for f in files if f.endswith('.txt')]) for r, d, files in os.walk(label_output_dir)])
    print(f"🏷️ Generated label files: {label_count}")

print(f"\n📁 Output structure:")
print(f"├── mask/               # Instance segmentation masks")
print(f"├── labels/             # YOLO format labels")
if save_anns:
    print(f"├── mask_idx_visual/    # Mask visualization with indices")
if save_json:
    print(f"├── json/               # Mask metadata in JSON format")
if lable_box_visual:
    print(f"├── label_box_visual/   # Bounding box visualizations")
if mask_color:
    print(f"├── mask_color_visual/  # Colored mask visualizations")

print(f"\n✨ All outputs saved to: {out_folder}")
print("\n🎉 SDM processing completed successfully!")

## 10. Optional: Visualization and Analysis

Add some basic visualization of results if needed:

In [None]:
# Optional: Quick visualization of some results
import matplotlib.pyplot as plt
import cv2
import numpy as np

def show_sample_results(image_folder, output_folder, num_samples=2):
    """Display sample results for quick inspection"""
    
    # Get list of processed images
    processed_images = []
    for root, dirs, files in os.walk(image_folder):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                processed_images.append(os.path.join(root, file))
    
    # Show sample results
    num_samples = min(num_samples, len(processed_images))
    
    if num_samples > 0:
        fig, axes = plt.subplots(1, num_samples, figsize=(15, 5))
        if num_samples == 1:
            axes = [axes]
            
        for i in range(num_samples):
            img_path = processed_images[i]
            img = cv2.imread(img_path)
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            axes[i].imshow(img_rgb)
            axes[i].set_title(f"Sample {i+1}: {os.path.basename(img_path)}")
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        print(f"Displayed {num_samples} sample images from processing.")
    else:
        print("No processed images found to display.")

# Uncomment the line below to show sample results
# show_sample_results(image_folder, out_folder)

## 11. Next Steps

After running SDM, you can:

1. **Convert labels for different tasks:**
   - For object detection: Use `seg2label/seg2det.py`
   - For semantic segmentation: Use `seg2label/seg2seman.py`
   - For specific labels: Use `seg2label/extract_label_needed.py`

2. **Knowledge Distillation (SDM-D):**
   - Train smaller models using the generated pseudo-labels
   - Use any architecture (YOLOv8, EfficientDet, etc.) as student models

3. **Fine-tuning:**
   - Use few-shot learning with manual labels for better performance
   - Experiment with different prompt designs in the description file

For more details, refer to the original paper and documentation.