In [1]:
from transformers import pipeline

In [2]:
triplet_extractor = pipeline('translation_xx_to_yy', model='Babelscape/mrebel-large', tokenizer='Babelscape/mrebel-large')

In [9]:
gen_kwargs = {
    "max_length": 1024,
    "length_penalty": 0,
    "num_beams": 5,
    "num_return_sequences": 2,
    #"forced_bos_token_id": None,
}
#triplet_extractor.model.config.update(gen_kwargs)

In [10]:
text = """Chief executive Carol Tomé said 2023 was a "difficult and disappointing year", and UPS was investing in artificial intelligence (AI) as it pushes to become more efficient. """
tripled_extracted = triplet_extractor(
        text, 
        decoder_start_token_id=250058, 
        src_lang="en_XX", 
        tgt_lang="<triplet>", 
        return_tensors=True, 
        return_text=False,
    **gen_kwargs,
)

In [11]:
tripled_extracted

[{'translation_token_ids': tensor([250058, 250054,  61171,   8352,    446,      6, 250061, 127623,      6,
          250064, 143889,      2,      1,      1])},
 {'translation_token_ids': tensor([250058, 250054,  61171,   8352,    446,      6, 250061, 127873, 159354,
               6, 250070,  19069,  34658,      2])}]

In [13]:
tripled_extracted_decoded = triplet_extractor.tokenizer.batch_decode([tripled_extracted[k]["translation_token_ids"] for k in [0,1]])

In [14]:
tripled_extracted_decoded

['tp_XX<triplet> Carol Tomé <per> UPS <org> employer</s><pad><pad>',
 'tp_XX<triplet> Carol Tomé <per> Chief executive <concept> position held</s>']

In [15]:
def extract_triplets_typed(text):
    triplets = []
    relation = ''
    text = text.strip()
    current = 'x'
    subject, relation, object_, object_type, subject_type = '','','','',''

    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").replace("tp_XX", "").replace("__en__", "").split():
        if token == "<triplet>" or token == "<relation>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
                relation = ''
            subject = ''
        elif token.startswith("<") and token.endswith(">"):
            if current == 't' or current == 'o':
                current = 's'
                if relation != '':
                    triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
                object_ = ''
                subject_type = token[1:-1]
            else:
                current = 'o'
                object_type = token[1:-1]
                relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '' and object_type != '' and subject_type != '':
        triplets.append({'head': subject.strip(), 'head_type': subject_type, 'type': relation.strip(),'tail': object_.strip(), 'tail_type': object_type})
    return triplets

In [16]:
extract_triplets_typed(tripled_extracted_decoded[0])

[{'head': 'Carol Tomé',
  'head_type': 'per',
  'type': 'employer',
  'tail': 'UPS',
  'tail_type': 'org'}]