Reimplementing OneFormer

Requirements

In [1]:
import os
import json
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from torchvision import transforms
from coco_dataset import COCOPanopticDataset

In [2]:
data_transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize images to 512x512
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize based on ImageNet means and std
])


In [3]:
from torch.utils.data import DataLoader

# Paths to COCO data
train_image_dir = "datasets/coco/train2017"
train_instance_file = "datasets/coco/annotations/instances_train2017.json"
train_panoptic_file = "datasets/coco/annotations/panoptic_train2017.json"
train_panoptic_mask_dir = "datasets/coco/panoptic_train2017"

# Initialize dataset
train_dataset = COCOPanopticDataset(
    image_dir=train_image_dir,
    instance_file=train_instance_file,
    panoptic_file=train_panoptic_file,
    panoptic_mask_dir=train_panoptic_mask_dir,
    transform=data_transform
)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)


In [4]:
for images, masks in train_loader:
    print("Batch of images shape:", images.shape)  # Expected shape: [batch_size, 3, 512, 512]
    print("Batch of masks shape:", masks.shape)    # Expected shape: [batch_size, 512, 512]
    break


Panoptic mask 470036.png not found in datasets/coco/panoptic_train2017
Panoptic mask 10407.png not found in datasets/coco/panoptic_train2017
Panoptic mask 552054.png not found in datasets/coco/panoptic_train2017
Panoptic mask 324937.png not found in datasets/coco/panoptic_train2017
Panoptic mask 192817.png not found in datasets/coco/panoptic_train2017
Panoptic mask 162087.png not found in datasets/coco/panoptic_train2017
Panoptic mask 354644.png not found in datasets/coco/panoptic_train2017
Panoptic mask 326063.png not found in datasets/coco/panoptic_train2017
Batch of images shape: torch.Size([2, 3, 512, 512])
Batch of masks shape: torch.Size([2, 512, 512])
Panoptic mask 464265.png not found in datasets/coco/panoptic_train2017
Panoptic mask 290314.png not found in datasets/coco/panoptic_train2017
Panoptic mask 286925.png not found in datasets/coco/panoptic_train2017
Panoptic mask 425158.png not found in datasets/coco/panoptic_train2017
Panoptic mask 14108.png not found in datasets/coc

In [5]:
for i, (images, masks) in enumerate(train_loader):
    print(f"Batch {i+1}")
    print("Batch of images shape:", images.shape)  # Expected shape: [batch_size, 3, 512, 512]
    print("Batch of masks shape:", masks.shape)    # Expected shape: [batch_size, 512, 512]

    # Stop after 3 batches
    if i == 2:
        break


Panoptic mask 131856.png not found in datasets/coco/panoptic_train2017
Panoptic mask 118690.png not found in datasets/coco/panoptic_train2017
Panoptic mask 181906.png not found in datasets/coco/panoptic_train2017
Panoptic mask 433968.png not found in datasets/coco/panoptic_train2017
Panoptic mask 479528.png not found in datasets/coco/panoptic_train2017
Panoptic mask 153671.png not found in datasets/coco/panoptic_train2017
Panoptic mask 9813.png not found in datasets/coco/panoptic_train2017
Panoptic mask 342624.png not found in datasets/coco/panoptic_train2017
Batch 1
Batch of images shape: torch.Size([2, 3, 512, 512])
Batch of masks shape: torch.Size([2, 512, 512])
Panoptic mask 473102.png not found in datasets/coco/panoptic_train2017
Batch 2
Batch of images shape: torch.Size([2, 3, 512, 512])
Batch of masks shape: torch.Size([2, 512, 512])
Panoptic mask 202931.png not found in datasets/coco/panoptic_train2017
Batch 3
Batch of images shape: torch.Size([2, 3, 512, 512])
Batch of masks s

In [68]:
import torch.nn as nn

class PixelDecoder(nn.Module):
    def __init__(self, embed_dim=256):
        super(PixelDecoder, self).__init__()
        # Define layers for each scale with matching input channels
        self.layer1 = nn.Conv2d(2048, embed_dim, kernel_size=3, padding=1)  # For Layer4
        self.layer2 = nn.Conv2d(1024, embed_dim, kernel_size=3, padding=1)  # For Layer3
        self.layer3 = nn.Conv2d(512, embed_dim, kernel_size=3, padding=1)   # For Layer2
        self.layer4 = nn.Conv2d(256, embed_dim, kernel_size=3, padding=1)   # For Layer1

    def forward(self, features):
        # Assuming `features` is a list of feature maps from ResNet backbone
        multi_scale_features = []
        for i, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]):
            # Interpolate each feature map to the same size as the last feature map (smallest resolution)
            scaled_feature = F.interpolate(features[i], size=(features[0].shape[2], features[0].shape[3]), mode='bilinear')
            multi_scale_features.append(layer(scaled_feature))
        
        return multi_scale_features


In [69]:
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights

from torch.nn import TransformerEncoder, TransformerEncoderLayer

from detectron2.modeling import build_backbone
from detectron2.config import get_cfg, CfgNode
import torch.nn.functional as F

# Initialize configuration and allow setting new attributes
cfg = get_cfg()
cfg.MODEL.ONE_FORMER = CfgNode()
cfg.MODEL.ONE_FORMER.set_new_allowed(True)

# General settings for ONE_FORMER
cfg.MODEL.ONE_FORMER.NUM_OBJECT_QUERIES = 100  # Number of queries
cfg.MODEL.ONE_FORMER.DEEP_SUPERVISION = True   # Enable/disable deep supervision
cfg.MODEL.ONE_FORMER.NO_OBJECT_WEIGHT = 0.1    # Weight for no-object class
cfg.MODEL.ONE_FORMER.CLASS_WEIGHT = 1.0        # Class weight
cfg.MODEL.ONE_FORMER.DICE_WEIGHT = 1.0         # Dice loss weight
cfg.MODEL.ONE_FORMER.MASK_WEIGHT = 1.0         # Mask weight
cfg.MODEL.ONE_FORMER.CONTRASTIVE_WEIGHT = 0.07 # Contrastive weight

# Training settings
cfg.MODEL.ONE_FORMER.TRAIN_NUM_POINTS = 12544  # Number of training points
cfg.MODEL.ONE_FORMER.OVERSAMPLE_RATIO = 3.0    # Oversample ratio
cfg.MODEL.ONE_FORMER.IMPORTANCE_SAMPLE_RATIO = 0.75  # Importance sample ratio
cfg.MODEL.ONE_FORMER.SIZE_DIVISIBILITY = 32    # Size divisibility for input

# TEXT_ENCODER settings
cfg.MODEL.TEXT_ENCODER = CfgNode()
cfg.MODEL.TEXT_ENCODER.set_new_allowed(True)
cfg.MODEL.TEXT_ENCODER.VOCAB_SIZE = 30000      # Vocabulary size
cfg.MODEL.TEXT_ENCODER.WIDTH = 256             # Embedding dimension
cfg.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 128    # Context window length
cfg.MODEL.TEXT_ENCODER.PROJ_NUM_LAYERS = 2     # Projection layers
cfg.MODEL.TEXT_ENCODER.N_CTX = 32              # Context size

# Input configuration
cfg.INPUT = CfgNode()
cfg.INPUT.set_new_allowed(True)
cfg.INPUT.TASK_SEQ_LEN = 128                   # Task sequence length
cfg.INPUT.MAX_SEQ_LEN = 512                    # Max sequence length

# Pixel mean and std for normalization
cfg.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675]
cfg.MODEL.PIXEL_STD = [1.0, 1.0, 1.0]

# Post-processing and inference settings
cfg.MODEL.TEST = CfgNode()
cfg.MODEL.TEST.set_new_allowed(True)
cfg.MODEL.TEST.OBJECT_MASK_THRESHOLD = 0.5
cfg.MODEL.TEST.OVERLAP_THRESHOLD = 0.7
cfg.MODEL.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
cfg.MODEL.TEST.PANOPTIC_ON = True
cfg.MODEL.TEST.INSTANCE_ON = True
cfg.MODEL.TEST.DETECTION_ON = False
cfg.TEST.DETECTIONS_PER_IMAGE = 100


class Encoder(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8, num_layers=6):
        super(Encoder, self).__init__()
        # Build backbone with Detectron2 configuration
        self.backbone = build_backbone(cfg)
        

        # Pixel Decoder setup
        self.pixel_decoder = PixelDecoder(embed_dim=embed_dim)
        final_layer_channels = self.backbone.output_shape()['res4'].channels
        self.conv1x1 = nn.Conv2d(final_layer_channels, embed_dim, kernel_size=1)

        # Transformer Encoder Layer
        encoder_layer = TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, x):
        # Extract features using the Detectron2 backbone
        features = self.backbone(x)
        
        # Check and transform final backbone feature layer for transformer encoder
        res4_feature = features['res4']
        res4_feature_reduced = self.conv1x1(res4_feature)
        
        # Transformer expects flattened features
        B, C, H, W = res4_feature_reduced.shape
        features_flattened = res4_feature_reduced.flatten(2).permute(2, 0, 1)  # [H*W, B, embed_dim]

        # Pass through Transformer Encoder
        encoded_features = self.transformer_encoder(features_flattened)
        encoded_features = encoded_features.permute(1, 2, 0).view(B, C, H, W)
        
        return res4_feature_reduced, encoded_features






In [70]:
class Decoder(nn.Module):
    def __init__(self, embed_dim=256, num_classes=21, num_queries=100):
        super(Decoder, self).__init__()
        self.cross_attention = nn.MultiheadAttention(embed_dim, num_heads=8)
        
        # Linear layer to map queries to class logits
        self.query_to_mask = nn.Linear(embed_dim, num_classes)
        
        # Final convolution layer to map back to image dimensions
        self.conv1 = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(embed_dim, num_classes, kernel_size=1)

    def forward(self, features, queries):
        B, C, H, W = features.shape
        features_flattened = features.flatten(2).permute(2, 0, 1)  # Shape: [H*W, batch, embed_dim]
        
        # Cross-attention between queries and flattened features
        attended_queries, _ = self.cross_attention(queries, features_flattened, features_flattened)
        
        # Map attended queries to class logits
        mask_logits = self.query_to_mask(attended_queries)  # Shape: [num_queries, batch, num_classes]
        
        # Reshape features for final convolution
        attended_features = features + features.mean(dim=[2, 3], keepdim=True)  # Add query information to features
        x = self.conv1(attended_features)
        x = self.conv2(x)
        
        return x




In [71]:
class MultiStageDecoder(nn.Module):
    def __init__(self, embed_dim=256, num_classes=21, num_queries=100, num_stages=3):
        super(MultiStageDecoder, self).__init__()
        self.num_stages = num_stages
        self.stages = nn.ModuleList([Decoder(embed_dim, num_classes) for _ in range(num_stages)])

    def forward(self, features, queries):
        output = None
        for stage in self.stages:
            output = stage(features, queries)
        return output


In [72]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(MLP, self).__init__()
        layers = []
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.ReLU())
            input_dim = hidden_dim
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)


In [77]:
import torch.nn.functional as F

def compute_loss(pred_masks, gt_masks, gt_labels):
    # Cross-entropy for semantic segmentation
    ce_loss = F.cross_entropy(pred_masks, gt_labels, ignore_index=-1)  # Add ignore_index if needed

    # Dice loss for mask overlap
    dice = dice_loss(pred_masks, gt_masks)
    
    # Query-to-mask matching loss (if implemented in your OneFormer version)
    query_loss = query_to_mask_loss(pred_masks, gt_masks, num_queries=pred_masks.shape[1])
    
    return ce_loss + dice + query_loss

def dice_loss(pred, target, smooth=1):
    # Flatten tensors to calculate overlap
    pred = pred.contiguous().view(-1)
    target = target.contiguous().view(-1)
    
    intersection = (pred * target).sum()
    dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
    
    return 1 - dice

def query_to_mask_loss(pred_masks, gt_masks, num_queries):
    # Placeholder for query-to-mask matching loss
    # Implement Hungarian matching if required
    return torch.tensor(0.0)  # Temporary placeholder


In [88]:
class OneFormer(nn.Module):
    def __init__(self, num_classes=21, embed_dim=256, num_heads=8, num_layers=6, num_queries=100):
        super(OneFormer, self).__init__()
        self.encoder = Encoder(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers)
        self.num_queries = num_queries
        self.query_embed = nn.Embedding(num_queries, embed_dim)
        self.task_embeddings = nn.Embedding(3, embed_dim)
        self.task_mlp = MLP(cfg.INPUT.TASK_SEQ_LEN, embed_dim, embed_dim, 2)
        self.text_encoder = nn.Embedding(cfg.MODEL.TEXT_ENCODER.VOCAB_SIZE, embed_dim)
        self.text_projector = nn.Linear(embed_dim, embed_dim)
        self.decoder = MultiStageDecoder(embed_dim=embed_dim, num_classes=num_classes)

    def forward(self, x, task_type, masks=None):
        task_embed = self.task_embeddings(torch.tensor([task_type], device=x.device)).unsqueeze(1)
        multi_scale_features, encoded_features = self.encoder(x)
        B = x.shape[0]
        queries = self.query_embed.weight.unsqueeze(1).repeat(1, B, 1) + task_embed
        segmentation_output = self.decoder(encoded_features, queries)

        print("Segmentation Output Shape:", segmentation_output.shape)  # Debugging line
        
        if self.training:
            # Compute and return training losses
            if masks is not None:
                losses = compute_loss(segmentation_output, masks['masks'], masks['labels'])
                return losses
            else:
                raise ValueError("Missing ground truth masks for training.")
        else:
            # Perform inference (use semantic, panoptic, or instance inference as needed)
            return segmentation_output





In [89]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context


In [92]:
# Set up the model
num_classes = 21  # Adjust this based on your specific dataset (e.g., COCO has 80 classes)
model = OneFormer(num_classes=num_classes)

# Test with a sample batch
sample_batch = torch.randn(2, 3, 512, 512)  # [batch_size, channels, height, width]
for images, targets in train_loader:
    output = model(sample_batch, task_type=0)
    if model.training:
        loss = model.compute_loss(output, targets['masks'], targets['labels'])

print("Output shape:", output.shape)  # Expected shape: [batch_size, num_classes, height, width]




Panoptic mask 251439.png not found in datasets/coco/panoptic_train2017
Panoptic mask 507686.png not found in datasets/coco/panoptic_train2017
Panoptic mask 539397.png not found in datasets/coco/panoptic_train2017
Panoptic mask 473706.png not found in datasets/coco/panoptic_train2017
Panoptic mask 84592.png not found in datasets/coco/panoptic_train2017
Panoptic mask 553192.png not found in datasets/coco/panoptic_train2017
Panoptic mask 115752.png not found in datasets/coco/panoptic_train2017
Panoptic mask 445933.png not found in datasets/coco/panoptic_train2017
Panoptic mask 564938.png not found in datasets/coco/panoptic_train2017
Panoptic mask 441619.png not found in datasets/coco/panoptic_train2017
Panoptic mask 136299.png not found in datasets/coco/panoptic_train2017
Panoptic mask 503293.png not found in datasets/coco/panoptic_train2017
Panoptic mask 123382.png not found in datasets/coco/panoptic_train2017
Panoptic mask 116819.png not found in datasets/coco/panoptic_train2017
Panopti

ValueError: Missing ground truth masks for training.