In [5]:
from transformers import DebertaV2ForSequenceClassification, DebertaV2Tokenizer

NARRATIVE_SEP_TOKEN = "<n_sep>"
FACT_SEP_TOKEN = "<f_sep>"

RELATION_VERBALIZER = {"AtLocation": "located or found at/in/on",
                       "CapableOf": "is/are capable of",
                       "Causes": "causes",
                       "CausesDesire": "makes someone want",
                       "CreatedBy": "is created by",
                       "Desires": "desires",
                       "HasA": "has, possesses or contains",
                       "HasFirstSubevent": "begins with the event/action",
                       "HasLastSubevent": "ends with the event/action",
                       "HasPrerequisite": "to do this, one requires",
                       "HasProperty": "can be characterized by being/having",
                       "HasSubEvent": "includes the event/action",
                       "HinderedBy": "can be hindered by",
                       "InstanceOf": "is an example/instance of",
                       "isAfter": "happens after",
                       "isBefore": "happens before",
                       "isFilledBy": "___ can be filled by",
                       "MadeOf": "is made of",
                       "MadeUpOf": "made (up) of",
                       "MotivatedByGoal": "is a step towards accomplishing the goal",
                       "NotDesires": "do(es) not desire",
                       "ObjectUse": "used for",
                       "UsedFor": "used for",
                       "oEffect": "as a result, PersonY or others will",
                       "oReact": "as a result, PersonY or others feels",
                       "oWant": "as a result, PersonY or others wants",
                       "PartOf": "is a part of",
                       "ReceivesAction": "can receive or be affected by the action",
                       "xAttr": "PersonX is seen as",
                       "xEffect": "as a result, PersonX will",
                       "xIntent": "because PersonX wants",
                       "xNeed": "but before, PersonX needs",
                       "xReact": "as a result, PersonX feels",
                       "xReason": "because",
                       "xWant": "as a result, PersonX wants"}

tokenizer_path = model_path = "ComFact_DeBERTa/deberta-large-nlu-fact_full/checkpoint-236560"
tokenizer = DebertaV2Tokenizer.from_pretrained(tokenizer_path)
model = DebertaV2ForSequenceClassification.from_pretrained(model_path)

narrative_sep_id = tokenizer.convert_tokens_to_ids(NARRATIVE_SEP_TOKEN)
fact_sep_id = tokenizer.convert_tokens_to_ids(FACT_SEP_TOKEN)

context = [
      "hey , i am in a lady motorcycle club and i love to drive fast",
      "i am married to a wife beater and have two kids",
      "well do you want me to come beat him ? i have never lost a fight",
      "then we can go shopping ! i love shopping . i am a lifestyle shop blogger .",
      "well there you go lol and your kids would enjoy checking my tatts i have got 12",
      "i am very attractive . i was a cheerleader in high school . maybe we can go on a date",
      "u like women too ? did not know that",
      "got to get away from my husband i live in florida . celebration florida come meet me",
      "i just drove 20 mins this morning at 208 mph i can get there fast",
      "i will leave my kids never liked them lets do this !",
      "sounds like a plan i will be there soon you can hop on my bike",
      "the we can ride off into the sunset just like lovers in a novel",
      "well then pack your bags",
      "yay i am so excited i think i will burn the house down before i leave ."
    ]
# fact = {"head": "PersonX drives ___ fast", "relation": "xIntent", "tail": "to get a thrill"}
fact = {"head": "PersonX drives ___ fast", "relation": "oWant", "tail": "to call the police"}
fact["relation"] = RELATION_VERBALIZER[fact["relation"]]
fact["head"] = fact["head"].lower()
fact["tail"] = fact["tail"].lower()

context_ids = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sent)) for sent in context]
fact_ids = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(fact[key])) for key in ["head", "relation", "tail"]]

In [2]:
narrative_sep_id, fact_sep_id

(128001, 128002)

In [6]:
fact_ids

[[604, 982, 5328, 5179, 616, 616, 1274],
 [401, 12590, 2087, 1654],
 [264, 350, 266, 14713]]

In [14]:
from itertools import chain

context_ids_with_sep = list(chain(*[ids+[narrative_sep_id] for ids in context_ids[:-1]], context_ids[-1]))
fact_ids_with_sep = list(chain(*[ids+[fact_sep_id] for ids in fact_ids[:-1]], fact_ids[-1]))

In [15]:
input_ids = tokenizer.build_inputs_with_special_tokens(context_ids_with_sep, fact_ids_with_sep)

In [16]:
input_ids

[1,
 11187,
 366,
 584,
 481,
 267,
 266,
 4396,
 8209,
 1788,
 263,
 584,
 472,
 264,
 1168,
 1274,
 128001,
 584,
 481,
 2410,
 264,
 266,
 1553,
 56596,
 263,
 286,
 375,
 978,
 128001,
 371,
 333,
 274,
 409,
 351,
 264,
 488,
 2584,
 417,
 1102,
 584,
 286,
 518,
 1125,
 266,
 1801,
 128001,
 393,
 301,
 295,
 424,
 2017,
 1084,
 584,
 472,
 2017,
 323,
 584,
 481,
 266,
 3444,
 1638,
 8874,
 323,
 128001,
 371,
 343,
 274,
 424,
 8878,
 263,
 290,
 978,
 338,
 929,
 4155,
 312,
 33276,
 297,
 268,
 584,
 286,
 519,
 621,
 128001,
 584,
 481,
 379,
 3851,
 323,
 584,
 284,
 266,
 43701,
 267,
 459,
 563,
 323,
 1461,
 301,
 295,
 424,
 277,
 266,
 1043,
 128001,
 3636,
 334,
 694,
 461,
 1102,
 464,
 298,
 391,
 272,
 128001,
 519,
 264,
 350,
 557,
 292,
 312,
 1745,
 584,
 685,
 267,
 29983,
 323,
 4630,
 29983,
 488,
 957,
 351,
 128001,
 584,
 348,
 5066,
 602,
 12178,
 291,
 1066,
 288,
 23299,
 9353,
 584,
 295,
 350,
 343,
 1274,
 128001,
 584,
 296,
 1021,
 312,
 978,
 518

In [19]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

output = model(torch.tensor(input_ids).unsqueeze(0).to(device))

In [20]:
output

SequenceClassifierOutput(loss=None, logits=tensor([[-0.9805,  1.0133]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [22]:
torch.softmax(output.logits, dim=1)

tensor([[0.1199, 0.8801]], device='cuda:0', grad_fn=<SoftmaxBackward0>)