In [3]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
from utils.general.dataset_variables import TripletSegmentationVariables

INSTRUMENT_ID_TO_CLASS_DICT = TripletSegmentationVariables.categories['instrument']
VERB_ID_TO_CLASS_DICT = TripletSegmentationVariables.categories['verb']
TARGET_ID_TO_CLASS_DICT = TripletSegmentationVariables.categories['target']
VERBTARGET_ID_TO_CLASS_DICT = TripletSegmentationVariables.categories['verbtarget']
TRIPLET_ID_TO_CLASS_DICT = TripletSegmentationVariables.categories['triplet']

INSTRUMENT_CLASS_TO_ID_DICT = {instrument_class: instrument_id for instrument_id, instrument_class in INSTRUMENT_ID_TO_CLASS_DICT.items()}
VERB_CLASS_TO_ID_DICT = {verb_class: verb_id for verb_id, verb_class in VERB_ID_TO_CLASS_DICT.items()}
TARGET_CLASS_TO_ID_DICT = {target_class: target_id for target_id, target_class in TARGET_ID_TO_CLASS_DICT.items()}
VERBTARGET_CLASS_TO_ID_DICT = {verbtarget_class: verbtarget_id for verbtarget_id, verbtarget_class in VERBTARGET_ID_TO_CLASS_DICT.items()}
TRIPLET_CLASS_TO_ID_DICT = {triplet_class: triplet_id for triplet_id, triplet_class in TRIPLET_ID_TO_CLASS_DICT.items()}

In [23]:
from collections import defaultdict

def generate_instrument_to_task_classes(instrument_dict, verb_dict, target_dict, verbtarget_dict, triplet_dict):
    """
    Generates instrument-to-task mappings for verbs, targets, and verbtargets.

    Args:
        instrument_dict (dict): Mapping of instrument names to IDs.
        verb_dict (dict): Mapping of verb names to IDs.
        target_dict (dict): Mapping of target names to IDs.
        verbtarget_dict (dict): Mapping of verbtarget names to IDs.
        triplet_dict (dict): Mapping of triplets (instrument, verb, target) to IDs.

    Returns:
        instrument_to_verb_classes (dict): Mapping of instruments to valid verb class IDs.
        instrument_to_target_classes (dict): Mapping of instruments to valid target class IDs.
        instrument_to_verbtarget_classes (dict): Mapping of instruments to valid verbtarget class IDs.
    """

    # Initialize dictionaries using defaultdict(set) to prevent duplicates
    instrument_to_verb_classes = defaultdict(set)
    instrument_to_target_classes = defaultdict(set)
    instrument_to_verbtarget_classes = defaultdict(set)

    # Iterate over all known triplets
    for triplet_str, triplet_id in triplet_dict.items():
        instrument_name, verb_name, target_name = triplet_str.split(',')

        # Get class IDs
        instrument_id = int(instrument_dict[instrument_name])  -1
        verb_id = int(verb_dict[verb_name]) -1
        target_id = int(target_dict[target_name]) -1
        verbtarget_id = int(verbtarget_dict[f"{verb_name},{target_name}"]) - 1
        

        # Only add valid IDs (skip if -1)
        if instrument_id >= 0:
            if verb_id >= 0:
                instrument_to_verb_classes[instrument_id].add(verb_id)
            if target_id >= 0:
                instrument_to_target_classes[instrument_id].add(target_id)
            if verbtarget_id >= 0:
                instrument_to_verbtarget_classes[instrument_id].add(verbtarget_id)

    # Convert sets to lists (for JSON serialization)
    instrument_to_verb_classes = {k: sorted(list(v)) for k, v in instrument_to_verb_classes.items()}
    instrument_to_target_classes = {k: sorted(list(v)) for k, v in instrument_to_target_classes.items()}
    instrument_to_verbtarget_classes = {k: sorted(list(v)) for k, v in instrument_to_verbtarget_classes.items()}

    return instrument_to_verb_classes, instrument_to_target_classes, instrument_to_verbtarget_classes



In [29]:
from IPython.display import display


In [32]:
# Generate mappings
inst_to_verb, inst_to_target, inst_to_verbtarget = generate_instrument_to_task_classes(
    INSTRUMENT_CLASS_TO_ID_DICT,
    VERB_CLASS_TO_ID_DICT,
    TARGET_CLASS_TO_ID_DICT,
    VERBTARGET_CLASS_TO_ID_DICT,
    TRIPLET_CLASS_TO_ID_DICT
)

display("Instrument to Verb:", inst_to_verb)
display("Instrument to Target:", inst_to_target)
display("Instrument to VerbTarget:", inst_to_verbtarget)


'Instrument to Verb:'

{0: [0, 1, 2, 8, 9],
 1: [0, 1, 2, 3, 9],
 2: [1, 2, 3, 5, 9],
 3: [2, 3, 5, 9],
 4: [4, 9],
 5: [1, 2, 6, 7, 9]}

'Instrument to Target:'

{0: [0, 1, 2, 3, 4, 8, 10, 11, 12, 13, 14],
 1: [0, 1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 13, 14],
 2: [0, 1, 2, 3, 4, 5, 8, 10, 11, 14],
 3: [0, 1, 2, 3, 5, 8, 9, 10, 11, 14],
 4: [1, 2, 3, 4, 5, 14],
 5: [0, 1, 2, 4, 6, 7, 8, 10, 14]}

'Instrument to VerbTarget:'

{0: [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  55],
 1: [0,
  1,
  2,
  6,
  9,
  12,
  14,
  15,
  17,
  19,
  20,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  55],
 2: [0,
  1,
  2,
  17,
  19,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  33,
  34,
  35,
  36,
  37,
  38,
  55],
 3: [0, 1, 2, 30, 35, 36, 39, 40, 41, 42, 43, 44, 55],
 4: [45, 46, 47, 48, 49, 55],
 5: [0, 1, 2, 17, 19, 20, 34, 50, 51, 52, 53, 54, 55]}