In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from coco_dataset import COCOPanopticDataset
from load_data import train_loader
from pixeldecoder import PixelDecoder
from backbone import BackboneWithMultiScaleFeatures
from tokenizer import TaskTokenizer
from mlp import TaskMLP
from text_mapper import TextMapper
from contrastive_loss import ContrastiveLoss
from query_formulation import TaskConditionedQueryFormulator
from compute_loss import SetCriterion
from hungarian_matcher import HungarianMatcher
from transformer_decoder import TransformerDecoder
from predict import MaskClassPredictor

# Define hyperparameters
vocab_size = 30000
embed_dim = 256
max_seq_len = 128
num_queries = 100
temperature = 0.2
num_heads = 8
num_layers = 6
num_classes = 80
contrastive_weight = 0.5
primary_loss_weight = 1.0

# COCO Dataset Paths
train_image_dir = "datasets/coco/train2017"
train_instance_file = "datasets/coco/annotations/instances_train2017.json"
train_panoptic_file = "datasets/coco/annotations/panoptic_train2017.json"
train_panoptic_mask_dir = "datasets/coco/panoptic_train2017"

# Define transformation for images and masks
data_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

# Initialize the dataset and DataLoader
full_train_dataset = COCOPanopticDataset(
    image_dir=train_image_dir,
    instance_file=train_instance_file,
    panoptic_file=train_panoptic_file,
    panoptic_mask_dir=train_panoptic_mask_dir,
    transform=data_transform
)

train_dataset = torch.utils.data.Subset(full_train_dataset, range(5000))
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)





In [2]:
# Initialize Model Components
backbone = BackboneWithMultiScaleFeatures()
pixel_decoder = PixelDecoder(input_channels=[256, 512, 1024, 2048])
tokenizer = TaskTokenizer(vocab_size, embed_dim, max_seq_len)
mlp = TaskMLP(input_dim=embed_dim, hidden_dim=embed_dim, output_dim=embed_dim)
text_mapper = TextMapper(vocab_size=vocab_size, embed_dim=embed_dim)
contrastive_loss_fn = ContrastiveLoss(temperature)
task_query_formulator = TaskConditionedQueryFormulator(num_queries=num_queries, embed_dim=embed_dim)
matcher = HungarianMatcher(cost_class=1, cost_mask=1, cost_dice=1)
criterion = SetCriterion(matcher=matcher, num_classes=num_classes, weight_dict={'loss_ce': 1, 'loss_mask': 1, 'loss_dice': 1}, eos_coef=0.1, losses=['labels', 'masks'])
transformer_decoder = TransformerDecoder(
    embed_dim=embed_dim,
    num_queries=num_queries,
    num_classes=num_classes,
    num_heads=num_heads,
    num_layers=num_layers
)

mask_class_predictor = MaskClassPredictor(embed_dim, num_queries, num_classes)

optimizer = torch.optim.Adam([
    {"params": backbone.parameters()},
    {"params": pixel_decoder.parameters()},
    {"params": transformer_decoder.parameters()},
    {"params": mask_class_predictor.parameters()},
    {"params": mlp.parameters()},
    #{"params": tokenizer.parameters()},
    {"params": text_mapper.parameters()},
    {"params": task_query_formulator.parameters()},
], lr=1e-4)

# Main Training Loop (Single Batch for Debugging)
for image_batch, mask_batch in train_loader:
    optimizer.zero_grad()

    # Step 1: Extract Multi-Scale Features
    multi_scale_features = backbone(image_batch)
    decoded_features = pixel_decoder(multi_scale_features)
    image_features_1_4 = decoded_features[0]
    
    # Step 2: Tokenize Task Texts
    task_texts = ["panoptic", "instance", "semantic"]
    task_embeddings = tokenizer.forward(task_texts)  # [3, max_seq_len, embed_dim]
    task_embeddings = mlp(task_embeddings.mean(dim=1).unsqueeze(1)).squeeze(1)  # [3, embed_dim]
    
    # Step 3: Map Task Embeddings to Q_text
    q_text = text_mapper(
        panoptic_text=task_embeddings[0].unsqueeze(0).long(),
        instance_text=task_embeddings[1].unsqueeze(0).long(),
        semantic_text=task_embeddings[2].unsqueeze(0).long()
    )

    # Step 4: Generate Q_task
    batch_size = image_batch.size(0)
    q_task = task_query_formulator(task_embeddings.unsqueeze(1), batch_size).permute(1, 0, 2)
    print(f"------------------------------------Shape of q_task: {q_task.shape}")
    print(f"------------------------------------Shape of q_text: {q_text.shape}")

    # Step 5: Calculate Contrastive Loss between Q_text and Q_task
    contrastive_loss = contrastive_loss_fn(q_text, q_task)
    print(f"Contrastive Loss: {contrastive_loss.item()}")
    
    ############# THIS IS FROM CONTRASTIVE_LOSS #####################
    # Expand q_text to match q_task's batch size if necessary
    if q_text.size(0) == 1:
        q_text = q_text.expand(q_task.size(0), -1, -1)  # Adjust q_text to [batch_size, num_tasks, embed_dim]

    # Normalize embeddings
    q_text = F.normalize(q_text, dim=-1)
    q_task = F.normalize(q_task, dim=-1)

    batch_size, num_tasks, embed_dim = q_text.size()
    _, num_queries, _ = q_task.size()
    
    # Reshape for pairwise comparison
    q_text = q_text.reshape(batch_size * num_tasks, embed_dim)
    q_task = q_task.reshape(batch_size * num_queries, embed_dim)
    #################################################################

    # Step 6: Integrate Image Features
    decoder_output = transformer_decoder(q_task, multi_scale_features)

    # Print shapes before concatenation
    print(f"Shape of decoder_output: {decoder_output.shape}")
    print(f"Shape of image_features_1_4 before view: {image_features_1_4.shape}")

    # Reshape image_features_1_4 and print again
    # Flatten the spatial dimensions (128 x 128) into a single dimension
    # Flatten the spatial dimensions (128 x 128) into a single dimension of 16384
    flattened_image_features_1_4 = image_features_1_4.view(1, embed_dim, 128 * 128).permute(0, 2, 1)
    print(f"Shape of flattened_image_features_1_4 after view and permute: {flattened_image_features_1_4.shape}")

    # Concatenate along the sequence dimension
    combined_input = torch.cat([decoder_output, flattened_image_features_1_4], dim=1)
    print(f"Shape of combined_input after concatenation: {combined_input.shape}")
    
    # Step 8: Mask and Class Prediction
    mask_pred, class_pred = mask_class_predictor(combined_input)
    
    # Step 9: Calculate Primary Loss
    outputs = {'pred_logits': class_pred, 'pred_masks': mask_pred}
    targets = [{'labels': mask_batch[0]}]
    primary_loss = criterion(outputs, targets)
    
    # Combined Loss
    total_loss = contrastive_weight * contrastive_loss + primary_loss_weight * sum(primary_loss.values())
    
    # Step 10: Backpropagation
    total_loss.backward()
    optimizer.step()
    
    print(f"Contrastive Loss: {contrastive_loss.item()}, Primary Loss: {sum(primary_loss.values()).item()}, Total Loss: {total_loss.item()}")
    
    # Break after one batch for debugging
    break



Panoptic mask 286903.png not found in datasets/coco/panoptic_train2017
Panoptic mask 137451.png not found in datasets/coco/panoptic_train2017
Panoptic mask 413734.png not found in datasets/coco/panoptic_train2017
Panoptic mask 251920.png not found in datasets/coco/panoptic_train2017
Panoptic mask 243134.png not found in datasets/coco/panoptic_train2017
Panoptic mask 289899.png not found in datasets/coco/panoptic_train2017
Panoptic mask 380140.png not found in datasets/coco/panoptic_train2017
Panoptic mask 277440.png not found in datasets/coco/panoptic_train2017
Panoptic mask 361351.png not found in datasets/coco/panoptic_train2017
Input shape before reshaping: torch.Size([3, 1, 256])
Shape after flattening for MLP: torch.Size([3, 256])
Shape after MLP processing: torch.Size([3, 256])
Shape after reshaping back to [batch_size, seq_len, output_dim]: torch.Size([3, 1, 256])
Input shape before reshaping: torch.Size([3, 1, 256])
Shape after flattening for MLP: torch.Size([3, 256])
Shape aft

  print(f"src_logits mean: {src_logits.mean()}, std: {src_logits.std()}")


Contrastive Loss: 5.650376319885254, Primary Loss: 0.0, Total Loss: 2.825188159942627
