What is the aim? The aim is to take the output of the second model
1. Arrange it properly to form 100 dim
    1. We need to first get the grasper. 
    2. We need to get all the instruments for that image. 
    3. A max is needed
2. Pass it through a sigmoid activation function. 
3. Run the classification algorithm on it. 

## Logits of second model and classification results

Assuming confidence of the first model is one at all instances, Which Is exactly what I did, then we do not even need the first stage to calculate the second stage. We can just use the logits directly. 

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')
sys.path.append('../resnet_model')

import pickle
import os
import numpy as np
import json

from utils.general.dataset_variables import TripletSegmentationVariables 
from utils.general.read_files import read_from_json
from utils.general.save_files import save_to_json

INSTRUMENT_ID_TO_CLASS_DICT = TripletSegmentationVariables.categories['instrument']
INSTRUMENT_CLASS_TO_ID_DICT = {instrument_class: instrument_id for instrument_id, instrument_class in INSTRUMENT_ID_TO_CLASS_DICT.items()}
TRIPLET_ID_TO_CLASS_DICT = TripletSegmentationVariables.categories['triplet']
TRIPLET_NAME_TO_ID_DICT = {triplet_class: triplet_id for triplet_id, triplet_class in TRIPLET_ID_TO_CLASS_DICT.items()}
VERBTARGET_DICT = TripletSegmentationVariables.categories['verbtarget'] 

INSTRUMENT_TO_VERBTARGET_CLASSES = TripletSegmentationVariables.instrument_to_verbtarget_classes
INSTRUMENT_TO_TRIPLET_CLASSES = TripletSegmentationVariables.instrument_to_triplet_classes


## My three task model. 

In [2]:
second_stage_logits_file_path = '../resnet_model/work_dirs/threetask_resnet_fpn_parallel_decoders/results_logits.json'
second_stage_logits = read_from_json(second_stage_logits_file_path)

In [3]:
triplet_scores = {}

In [4]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [5]:
def generate_sigmoid_score_for_detected_objects_from_logit(logits_for_image_name_tool_instance,  tool_id):
    amount_of_verbtargets_for_instrument = len(INSTRUMENT_TO_VERBTARGET_CLASSES[tool_id])    
    logits_removing_inf_used_for_computation =  logits_for_image_name_tool_instance[0:amount_of_verbtargets_for_instrument]
    assert not np.isinf(logits_removing_inf_used_for_computation).any(), "Infinity values were not removed correctly!"
    
    # get the triplet indices
    triplet_indices_for_logits = INSTRUMENT_TO_TRIPLET_CLASSES[tool_id]
    # Ensure the number of logits matches the expected number of triplets
    assert len(logits_removing_inf_used_for_computation) == len(triplet_indices_for_logits), "Mismatch in logit length!"

    triplet_logits = np.full(100, -np.inf)
    # triplet_logits = np.full(100, -30.0)

    # Assign logits to the correct positions
    for triplet_idx, logit in zip(triplet_indices_for_logits, logits_removing_inf_used_for_computation):
        triplet_logits[triplet_idx] = logit  # Convert 1-based indexing to 0-based
    
    # print(f'when using inf triplet_logits = {triplet_logits} ')    
    triplet_sigmoid =  sigmoid(triplet_logits)
    
    # print(f'when using inf triplet_sigmoid = {triplet_sigmoid} ')  
    
    return triplet_sigmoid
    

In [6]:
def get_scores_from_my_second_stage_logits(second_stage_logits_file_path):
    
    second_stage_logits = read_from_json(second_stage_logits_file_path)        
    
    triplet_sigmoid_for_all_results = {} # This way to align with the way the results come from rendezvous as wel
    
    for image_name_toool_instance_verb_target_name,  logit_dict in second_stage_logits.items():
        image_name, tool_name, _ , _ , _  = image_name_toool_instance_verb_target_name.split(',')
        image_name = image_name.replace('t50_', '')
        tool_id = int(INSTRUMENT_CLASS_TO_ID_DICT[tool_name]) - 1
        
        logits_for_image_name_tool_instance = second_stage_logits[image_name_toool_instance_verb_target_name]['logits_verbtarget']
        
        triplet_sigmoid_for_detected_instrument_instance = generate_sigmoid_score_for_detected_objects_from_logit(logits_for_image_name_tool_instance,  tool_id)
        
        # Convert the ndarray to a list before adding it to the results
        triplet_sigmoid_for_detected_instrument_instance = triplet_sigmoid_for_detected_instrument_instance
        
        if image_name not in triplet_sigmoid_for_all_results:
            triplet_sigmoid_for_all_results[image_name] = {'triplet_prediction': triplet_sigmoid_for_detected_instrument_instance}  # use this format to ensure similarity with rendezvous paper. 
        else: 
            # print(f'{image_name}')
            # Take the element-wise max and convert to list
            triplet_sigmoid_for_all_results[image_name]['triplet_prediction'] = np.maximum(triplet_sigmoid_for_all_results[image_name]['triplet_prediction'], 
                                                                     triplet_sigmoid_for_detected_instrument_instance)  
    
    # convert to list        
    for triplet_name in triplet_sigmoid_for_all_results.keys():
        triplet_sigmoid_for_all_results[triplet_name]['triplet_prediction'] = triplet_sigmoid_for_all_results[triplet_name]['triplet_prediction'].tolist()
        

    return triplet_sigmoid_for_all_results
    

In [7]:
second_stage_logits_file_path = '../resnet_model/work_dirs/threetask_resnet_fpn_parallel_decoders/results_logits.json'
triplet_sigmoid_for_all_results = get_scores_from_my_second_stage_logits(second_stage_logits_file_path)

In [8]:
save_path_for_sigmoid_scores = '../resnet_model/work_dirs/threetask_resnet_fpn_parallel_decoders/results_triplet_sigmoid_scores.json'
save_to_json(data=triplet_sigmoid_for_all_results,
             json_file_path=save_path_for_sigmoid_scores)