In [None]:
from transformers import pipeline

triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
# We need to use the tokenizer manually since we need special tokens.

# Function to parse the generated text and extract the triplets
def extract_triplets(text):
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
    return triplets

config.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.23k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/123 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/344 [00:00<?, ?B/s]

Device set to use cuda:0


<s><triplet> Punta Cana <subj> La Altagracia Province <obj> located in the administrative territorial entity <subj> Dominican Republic <obj> country <triplet> Higuey <subj> La Altagracia Province <obj> located in the administrative territorial entity <subj> Dominican Republic <obj> country <triplet> La Altagracia Province <subj> Dominican Republic <obj> country <triplet> Dominican Republic <subj> La Altagracia Province <obj> contains administrative territorial entity</s>


In [7]:
text = """
You stand in the dusty archives of Elderwood Library. The librarian, Elara, pushes her spectacles up nervously. "Captain Rylan stole the Moonlight Amulet from our vault last night!" She points to a faded map on the table. "He fled to the Cursed Catacombs beneath the city. I suspect he's working for Lady Morana, the vampire lord."

As you examine the map, a hooded figure ( Brother Thaddeus ) emerges from the shadows. "Take this blessed dagger," he whispers, pressing the cold steel into your hand. "Rylan fears silver. But beware – the amulet corrupts its bearer. It once belonged to King Aldric, who forged it with dragon's blood."

Near the library entrance, you spot Rylan's lieutenant arguing with a suspicious merchant. A torn letter at their feet reveals: "...deliver the amulet to Morana before the blood moon..."
"""

extracted_text = triplet_extractor.tokenizer.batch_decode([triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]])
print(extracted_text[0])

<s><triplet> Elara <subj> Elderwood Library <obj> employer</s>


In [5]:
extracted_triplets = extract_triplets(extracted_text[0])
print(extracted_triplets)

[{'head': 'Elara', 'type': 'employer', 'tail': 'Elderwood Library'}]


In [12]:
!python -m spacy download en_core_web_sm

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m25.7 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hInstalling collected packages: en-core-web-sm
Successfully installed en-core-web-sm-3.8.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [13]:
import spacy
from collections import defaultdict

nlp = spacy.load("en_core_web_sm")

def extract_triplets(text):
    doc = nlp(text)
    triplets = []
    current_loc = None
    
    for sent in doc.sents:
        # Определение локации
        loc_entities = [ent for ent in sent.ents if ent.label_ == "LOC"]
        if loc_entities: 
            current_loc = loc_entities[0].text
        
        # Извлечение отношений
        for token in sent:
            if token.dep_ in ("nsubj", "nsubj:pass"):
                subject = token.text
                relation = token.head.text
                object_ = next((child for child in token.head.children 
                               if child.dep_ in ("dobj", "attr")), None)
                if object_ and current_loc:
                    triplets.append((current_loc, relation, object_.text))
    
    return triplets

In [29]:
doc = nlp("Captain Rylan stole the Moonlight Amulet from our vault last night!")
triplets = []
current_loc = None
for sent in doc.sents:
    loc_entities = [ent for ent in sent.ents if ent.label_ == "LOC"]
    if loc_entities: 
            current_loc = loc_entities[0].text
    for token in sent:
        if token.dep_ in ("nsubj", "nsubj:pass"):
            subject = token.text
            relation = token.head.text
            object_ = next((child for child in token.head.children 
                           if child.dep_ in ("dobj", "attr")), None)
            if object_ and current_loc:
                triplets.append((current_loc, relation, object_.text))

current_loc, triplets

(None, [])