In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from coco_dataset import COCOPanopticDataset
from load_data import train_loader
from pixeldecoder import PixelDecoder
from backbone import BackboneWithMultiScaleFeatures
from tokenizer import TaskTokenizer
from mlp import TaskMLP
from text_mapper import TextMapper
from contrastive_loss import ContrastiveLoss
from query_formulation import TaskConditionedQueryFormulator
from compute_loss import SetCriterion
from hungarian_matcher import HungarianMatcher
from transformer_decoder import TransformerDecoder
from predict import MaskClassPredictor

# Define hyperparameters
vocab_size = 30000
embed_dim = 256
max_seq_len = 128
num_queries = 100
temperature = 0.07
num_heads = 8
num_layers = 6
num_classes = 80

# COCO Dataset Paths
train_image_dir = "datasets/coco/train2017"
train_instance_file = "datasets/coco/annotations/instances_train2017.json"
train_panoptic_file = "datasets/coco/annotations/panoptic_train2017.json"
train_panoptic_mask_dir = "datasets/coco/panoptic_train2017"

# Define transformation for images and masks
data_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

# Initialize the dataset and DataLoader
full_train_dataset = COCOPanopticDataset(
    image_dir=train_image_dir,
    instance_file=train_instance_file,
    panoptic_file=train_panoptic_file,
    panoptic_mask_dir=train_panoptic_mask_dir,
    transform=data_transform
)

train_dataset = torch.utils.data.Subset(full_train_dataset, range(5000))
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)





In [2]:
# Initialize Model Components
backbone = BackboneWithMultiScaleFeatures()
pixel_decoder = PixelDecoder(input_channels=[256, 512, 1024, 2048])
tokenizer = TaskTokenizer(vocab_size, embed_dim, max_seq_len)
mlp = TaskMLP(input_dim=embed_dim, hidden_dim=embed_dim, output_dim=embed_dim)
text_mapper = TextMapper(vocab_size=vocab_size, embed_dim=embed_dim)
contrastive_loss_fn = ContrastiveLoss(temperature)
task_query_formulator = TaskConditionedQueryFormulator(num_queries=num_queries, embed_dim=embed_dim)
matcher = HungarianMatcher(cost_class=1, cost_mask=1, cost_dice=1)
criterion = SetCriterion(matcher=matcher, num_classes=num_classes, weight_dict={'loss_ce': 1, 'loss_mask': 1, 'loss_dice': 1}, eos_coef=0.1, losses=['labels', 'masks'])
transformer_decoder = TransformerDecoder(
    embed_dim=embed_dim,
    num_queries=num_queries,
    num_classes=num_classes,
    num_heads=num_heads,
    num_layers=num_layers
)

mask_class_predictor = MaskClassPredictor(embed_dim, num_queries, num_classes)

# Main Training Loop (Single Batch for Debugging)
for image_batch, mask_batch in train_loader:
    # Step 1: Extract Multi-Scale Features
    multi_scale_features = backbone(image_batch)
    decoded_features = pixel_decoder(multi_scale_features)
    image_features_1_4 = decoded_features[0]  # Select 1/4 resolution features
    
    # Step 2: Tokenize Task Texts
    task_texts = ["panoptic", "instance", "semantic"]
    task_embeddings = tokenizer.forward(task_texts)  # [3, max_seq_len, embed_dim]
    task_embeddings = mlp(task_embeddings.mean(dim=1).unsqueeze(1)).squeeze(1)  # [3, embed_dim]
    
    # Step 3: Map Task Embeddings to Q_text
    q_text = text_mapper(
        panoptic_text=task_embeddings[0].unsqueeze(0).long(),
        instance_text=task_embeddings[1].unsqueeze(0).long(),
        semantic_text=task_embeddings[2].unsqueeze(0).long()
    )
    
    # Step 4: Generate Q_task
    batch_size = image_batch.size(0)
    q_task = task_query_formulator(task_embeddings.unsqueeze(1), batch_size).permute(1, 0, 2)
    
    print(q_task.dim())
    # Step 5: Calculate Contrastive Loss between Q_text and Q_task
    contrastive_loss = contrastive_loss_fn(q_text)
    print(f"Contrastive Loss: {contrastive_loss.item()}")
    
    # Step 6: Flatten and Integrate Image Features
    flattened_features = F.interpolate(image_features_1_4, scale_factor=4, mode="nearest")
    flattened_features = flattened_features.view(batch_size, embed_dim, -1).permute(0, 2, 1)
    flattened_features = flattened_features.repeat(q_task.size(0), 1, 1)  # Corrected here

    combined_input = torch.cat([q_task, flattened_features], dim=1)  # No permutation here
        # Debugging: Print shapes to ensure correctness
    print(f"Shape of q_task: {q_task.shape}")
    print(f"Shape of flattened_features: {flattened_features.shape}")
    print(f"Shape of combined_input before TransformerDecoder: {combined_input.shape}")

    # Make sure q_task and combined_input have compatible dimensions for TransformerDecoder
    q_task = q_task.unsqueeze(1) if q_task.dim() == 3 else q_task
    print(f"Shape of q_task: {q_task.shape}")
    # Step 7: Pass through Transformer Decoder
    decoder_output = transformer_decoder(combined_input, task_queries=q_task)
    
    # Step 8: Mask and Class Prediction
    mask_pred, class_pred = mask_class_predictor(decoder_output)
    
    # Step 9: Calculate SetCriterion Loss
    outputs = {'pred_logits': class_pred, 'pred_masks': mask_pred}
    targets = [{'labels': mask_batch[i]} for i in range(batch_size)]
    indices = matcher(outputs, targets)
    losses = criterion(outputs, targets)
    
    print(f"Losses: {losses}")
    break



Panoptic mask 175611.png not found in datasets/coco/panoptic_train2017
Panoptic mask 150235.png not found in datasets/coco/panoptic_train2017
Panoptic mask 30156.png not found in datasets/coco/panoptic_train2017
Panoptic mask 359959.png not found in datasets/coco/panoptic_train2017
Panoptic mask 325027.png not found in datasets/coco/panoptic_train2017
Panoptic mask 180800.png not found in datasets/coco/panoptic_train2017
Panoptic mask 469982.png not found in datasets/coco/panoptic_train2017
Panoptic mask 548331.png not found in datasets/coco/panoptic_train2017
Panoptic mask 482829.png not found in datasets/coco/panoptic_train2017
Input shape before reshaping: torch.Size([3, 1, 256])
Shape after flattening for MLP: torch.Size([3, 256])
Shape after MLP processing: torch.Size([3, 256])
Shape after reshaping back to [batch_size, seq_len, output_dim]: torch.Size([3, 1, 256])
Input shape before reshaping: torch.Size([3, 1, 256])
Shape after flattening for MLP: torch.Size([3, 256])
Shape afte

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)