In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')
sys.path.append('../resnet_model')

import pickle
import os
import numpy as np
import torch
import torch.nn.functional as F
import pickle
import json
import os
from PIL import Image
from pycocotools import mask as coco_mask

from utils.general.dataset_variables import TripletSegmentationVariables 

INSTRUMENT_ID_TO_INSTRUMENT_CLASS_DICT = TripletSegmentationVariables.categories['instrument_direct_pred']

num_used_classes = 6

In [2]:
pickle_file_path = '../results/mask2former_cholec_final_predict_on_triplet_segmentation/mask2former_test_triplet_segmentation_v2_results/mask2former_test_triplet_segmentation_v2_dataset.pkl'
output_json_path  = '../../datasets/my_triplet_seg_datasets/triplet_segmentation_dataset_v2_second_stage/test_soft_labels/logits_from_first_stage.json'
output_mask_dir = '../../datasets/my_triplet_seg_datasets/triplet_segmentation_dataset_v2_second_stage/test_soft_labels/predicted_instance_masks'

temperature = 2

In [3]:
os.makedirs(output_mask_dir, exist_ok=True)

In [4]:
with open(pickle_file_path, 'rb') as file:
        # Load the object from the file
        mmdet_results = pickle.load(file)

In [5]:
# Initialize output dictionary
final_json = {}

for result  in mmdet_results:
    img_filename = os.path.basename(result ['img_path'])
    img_id = os.path.splitext(img_filename)[0]

    pred_instances = result['pred_instances']
    masks = pred_instances['masks']
    scores = pred_instances['scores']
    logits = pred_instances['logits']
    labels = pred_instances['labels']
    
    # Convert to numpy if needed
    if isinstance(masks, torch.Tensor):
        masks = masks.cpu().numpy()
    if isinstance(logits, torch.Tensor):
        logits = logits.cpu().numpy()
    if isinstance(scores, torch.Tensor):
        scores = scores.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    image_instances = []
    
    if isinstance(masks, torch.Tensor):
        masks = masks.cpu().numpy()

    for i, (score, logit, label) in enumerate(zip(scores, logits, labels)):
        if score < 0.5:
            continue
        
        rle = masks[i]
        instance_mask = coco_mask.decode(rle).astype(np.uint8) * 255  # Binary mask
        
        # Remove snare class from logits
        logit = logit[:num_used_classes]
        softmax = torch.nn.functional.softmax(torch.tensor(logit)/temperature, dim=0).numpy()
        predicted_class = np.argmax(softmax)
        predicted_class_name = INSTRUMENT_ID_TO_INSTRUMENT_CLASS_DICT[str(predicted_class + 1)]

        
        # Save binary mask
        mask_filename = f"{img_id}_instance_{i}.png"
        mask_path = os.path.join(output_mask_dir, mask_filename)
        Image.fromarray(instance_mask).save(mask_path)
        
        image_instances.append({
            "instance_id": i,
            "predicted_class_name": predicted_class_name,
            "softmax": softmax.tolist(),
            "mask_path": mask_filename
        })
    # Save None if no valid detections
    final_json[img_id] = image_instances if image_instances else None
        


# Save
with open(output_json_path, 'w') as f:
    json.dump(final_json, f, indent=4)

print(f"✅ Saved softmax JSON to: {output_json_path}")
print(f"✅ Masks saved to: {output_mask_dir}")
            
        

✅ Saved softmax JSON to: ../../datasets/my_triplet_seg_datasets/triplet_segmentation_dataset_v2_second_stage/test_soft_labels/logits_from_first_stage.json
✅ Masks saved to: ../../datasets/my_triplet_seg_datasets/triplet_segmentation_dataset_v2_second_stage/test_soft_labels/predicted_instance_masks
