In [1]:
# Third-Party Libraries
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
prompt = """
# Inputs
# Entity Types <- {entities}
# Relationships <- {relations}
# Text <- {document}
# Output -> (subject > predicate > object)""".strip()

In [3]:
model = AutoModelForCausalLM.from_pretrained("sciphi/triplex", trust_remote_code=True).eval()

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
tokenizer = AutoTokenizer.from_pretrained("sciphi/triplex", trust_remote_code=True)

In [5]:
entity_types = ["LOCATION", "POSITION", "DATE", "CITY", "COUNTRY", "NUMBER"]
predicates = ["POPULATION", "AREA"]
text = """
San Francisco, officially the City and County of San Francisco, is a commercial, financial, and cultural center in Northern California.
With a population of 808,437 residents as of 2022, San Francisco is the fourth most populous city in the U.S. state of California behind Los Angeles, San Diego, and San Jose.
"""

In [8]:
def extract_triplets(doc) -> str:
    plain_entitites = ", ".join(entity_types)
    plain_predicates = ", ".join(predicates)
    x = tokenizer(prompt.format(entities=plain_entitites, relations=plain_predicates, document=doc), return_tensors="pt")
    return model.generate(**x, max_new_tokens=50)

In [9]:
extract_triplets(text)

tensor([[  396, 10567, 29879,    13, 29937, 14945, 28025,  3705, 11247, 29907,
          8098, 29892,   349,  3267, 22122, 29892, 20231, 29892,   315, 11937,
         29892,  4810,  3904,  5659, 29979, 29892, 28019, 13635,    13, 29937,
          6376,   800, 14587,  3705,   349,  4590, 13309,  8098, 29892,   319,
          1525, 29909,    13, 29937,  3992,  3705, 29871,    13, 22509,  8970,
         29892, 22444,   278,  4412,   322,  5127,   310,  3087,  8970, 29892,
           338,   263, 12128, 29892, 18161, 29892,   322, 16375,  4818,   297,
         14299,  8046, 29889,    13,  3047,   263,  4665,   310, 29871, 29947,
         29900, 29947, 29892, 29946, 29941, 29955, 24060,   408,   310, 29871,
         29906, 29900, 29906, 29906, 29892,  3087,  8970,   338,   278, 11582,
          1556, 14938,   681,  4272,   297,   278,   501, 29889, 29903, 29889,
          2106,   310,  8046,  5742,  4602, 10722, 29892,  3087, 16879, 29892,
           322,  3087,  5043, 29889,    13,    13, 2

In [None]:
tensor([[  396, 10567, 29879,    13, 29937, 14945, 28025,  3705, 11247, 29907,
          8098, 29892,   349,  3267, 22122, 29892, 20231, 29892,   315, 11937,
         29892,  4810,  3904,  5659, 29979, 29892, 28019, 13635,    13, 29937,
          6376,   800, 14587,  3705,   349,  4590, 13309,  8098, 29892,   319,
          1525, 29909,    13, 29937,  3992,  3705, 29871,    13, 22509,  8970,
         29892, 22444,   278,  4412,   322,  5127,   310,  3087,  8970, 29892,
           338,   263, 12128, 29892, 18161, 29892,   322, 16375,  4818,   297,
         14299,  8046, 29889,    13,  3047,   263,  4665,   310, 29871, 29947,
         29900, 29947, 29892, 29946, 29941, 29955, 24060,   408,   310, 29871,
         29906, 29900, 29906, 29906, 29892,  3087,  8970,   338,   278, 11582,
          1556, 14938,   681,  4272,   297,   278,   501, 29889, 29903, 29889,
          2106,   310,  8046,  5742,  4602, 10722, 29892,  3087, 16879, 29892,
           322,  3087,  5043, 29889,    13,    13, 29937, 10604,  1599,   313,
         16009,  1405, 24384,  1405,  1203, 29897,    13, 28956,  3126,    13,
         29912,    13,  1678,   376,   296,  1907, 29918,   392, 29918,  3626,
          2701,  1115,   518,    13,  4706, 14704, 29896,  1402,   315, 11937,
         29901, 22509,  8970,   613,    13,  4706, 14704, 29906,  1402,  4810,
          3904,  5659, 29979, 29901,  2525,  1573,  3900,   613,    13,  4706,
         14704, 29941,  1402, 11247, 29907,  8098]])