In [1]:
import json

with open("data/persona/persona_atomic_final_123.json") as f:
    persona_atomic = json.load(f)

In [2]:
persona_atomic_centered = {}

for id_, data in persona_atomic.items():
    text = data["text"]
    facts = data["facts"]

    persona_atomic_centered[id_] = []

    for i, sentence in enumerate(text):
        fact_set = facts[str(i)]

        rpa = []
        rpp = []
        rpf = []
        irr = []

        for head, fact in fact_set.items():
            for triple in fact["triples"]:
                kg = {"head": head, "relation": triple["relation"], "tail": triple["tail"]}

                if triple["relationship"] == "rpa":
                    rpa.append(kg)
                elif triple["relationship"] == "rpp":
                    rpp.append(kg)
                elif triple["relationship"] == "rpf":
                    rpf.append(kg)
                elif triple["relationship"] == "irr":
                    irr.append(kg)

        centered_data = {
            "past": text[max(i-2, 0):i],
            "present": sentence,
            "future": text[i+1:min(i+3, len(text))],
            "rpa": rpa,
            "rpp": rpp,
            "rpf": rpf,
            "irr": irr
        }
        persona_atomic_centered[id_].append(centered_data)

In [None]:
persona_atomic_centered["1273"]

In [5]:
with open("data/persona/persona_atomic_centered.json", "w") as f:
    json.dump(persona_atomic_centered, f, indent=4)

In [2]:
import json

with open("data/persona/persona_atomic_centered.json") as f:
    persona_atomic_centered = json.load(f)

with open("data/persona/persona_atomic_did_train_90.json") as f:
    persona_atomic_train = json.load(f)

with open("data/persona/persona_atomic_did_val_15.json") as f:
    persona_atomic_val = json.load(f)

with open("data/persona/persona_atomic_did_test_18.json") as f:
    persona_atomic_test = json.load(f)

In [8]:
train_data = []
val_data = []
test_data = []

for id_, data in persona_atomic_centered.items():
    for sample in data:
        context = f"{'.'.join(sample['past'])}.{sample['present']}"
        for kg in sample["rpf"]:
            new_sample = {"context": context, **kg} 
            if int(id_) in persona_atomic_val:
                val_data.append(new_sample)
            elif int(id_) in persona_atomic_test:
                test_data.append(new_sample)
            else:
                train_data.append(new_sample)


In [10]:
len(train_data), len(val_data), len(test_data)

(747, 121, 212)

In [11]:
def write_jsonl(data, filepath):
    with open(filepath, "w") as f:
        f.writelines("\n".join([json.dumps(d) for d in data]))
    
write_jsonl(train_data, "data/persona/persona_atomic_train.jsonl")
write_jsonl(val_data, "data/persona/persona_atomic_val.jsonl")
write_jsonl(test_data, "data/persona/persona_atomic_test.jsonl")


In [13]:
from kogito.inference import CommonsenseInference
from kogito.core.processors.relation import BERTRelationMatcher

csi = CommonsenseInference()
csi.remove_processor("simple_relation_matcher")
csi.add_processor(BERTRelationMatcher("bert_matcher"))

Downloading: 100%|██████████| 228/228 [00:00<00:00, 40.1kB/s]
Downloading: 100%|██████████| 438M/438M [00:29<00:00, 14.9MB/s] 
Downloading: 100%|██████████| 570/570 [00:00<00:00, 196kB/s]
Downloading: 100%|██████████| 440M/440M [00:29<00:00, 14.8MB/s] 
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical

In [14]:
csi.processors

{'head': ['sentence_extractor',
  'noun_phrase_extractor',
  'verb_phrase_extractor'],
 'relation': ['graph_relation_matcher', 'bert_matcher']}

In [17]:
for id_, data in persona_atomic_centered.items():
    for data_point in data:
        kgraph = csi.infer(text=data_point["present"])
        data_point["kgraph"] = []

        for kg in kgraph:
            data_point["kgraph"].append({"head": str(kg.head), "relation": str(kg.relation), "tails": []})
        
        break
    break

Extracting heads...
Matching relations...


Downloading: 100%|██████████| 232k/232k [00:00<00:00, 477kB/s]  
Downloading: 100%|██████████| 28.0/28.0 [00:00<00:00, 10.1kB/s]
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /Users/mismayil/Desktop/EPFL/F2022/OP/lightning_logs
  rank_zero_warn(


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]


In [18]:
persona_atomic_centered["1273"][0]

{'past': [],
 'present': 'hey , i am in a lady motorcycle club and i love to drive fast',
 'future': ['i am married to a wife beater and have two kids',
  'well do you want me to come beat him ? i have never lost a fight'],
 'rpa': [{'head': 'PersonX drives ___ fast',
   'relation': 'xIntent',
   'tail': 'to get a thrill'},
  {'head': "PersonX loves PersonX's motorcycle",
   'relation': 'xWant',
   'tail': 'to go for a ride'},
  {'head': 'motorcycle',
   'relation': 'HasProperty',
   'tail': 'two wheels and can go fast'}],
 'rpp': [],
 'rpf': [{'head': 'motorcycle',
   'relation': 'ObjectUse',
   'tail': 'hit the road'},
  {'head': 'motorcycle', 'relation': 'HasProperty', 'tail': 'two wheels'},
  {'head': 'motorcycle',
   'relation': 'HasProperty',
   'tail': 'two wheeled vehicle'},
  {'head': 'motorcyle', 'relation': 'ObjectUse', 'tail': 'drive there'}],
 'irr': [{'head': 'PersonX drives ___ fast',
   'relation': 'isFilledBy',
   'tail': 'motorcycles'},
  {'head': 'PersonX drives ___ 