In [1]:
%load_ext autoreload
%autoreload 2

### Sample usage

In [20]:
from kogito.models.bart.comet_bart import COMETBART
from kogito.inference import CommonsenseInference
from kogito.models.base import KnowledgeModel
from kogito.core.knowledge import KnowledgeGraph

model: KnowledgeModel = COMETBART.from_pretrained("/Users/mismayil/Desktop/EPFL/nlplab/comet-atomic-2020/comet-atomic_2020_BART")
csi = CommonsenseInference()
text = "Student gets a library card"
kgraph: KnowledgeGraph = csi.infer(text, model)
kgraph.to_jsonl("kgraph.json")

INFO:kogito.core.utils:using task specific params for summarization: {'early_stopping': True, 'length_penalty': 2.0, 'max_length': 24, 'min_length': 1, 'no_repeat_ngram_size': 3, 'num_beams': 4}


Extracting heads...
Matching relations...
Generating commonsense graph...


In [15]:
for kg in kgraph:
    print(kg)
    break

Knowledge(head="Student gets a library card", relation="Causes", tails= student gets library card, base=KnowledgeBase.CONCEPTNET)


### Customizing processors

In [4]:
csi.processors

{'head': ['sentence_extractor', 'phrase_extractor', 'noun_extractor'],
 'relation': ['simple_matcher']}

In [6]:
csi.remove_processor("noun_extractor")

In [7]:
csi.processors

{'head': ['sentence_extractor', 'phrase_extractor'],
 'relation': ['simple_matcher']}

In [12]:
from kogito.core.head import KnowledgeHeadExtractor, KnowledgeHead, KnowledgeHeadType
from typing import Optional, List
from spacy.tokens import Doc

import spacy

class NounHeadExtractor(KnowledgeHeadExtractor):
    def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]:
        if not doc:
            doc = self.lang(text)

        heads = []

        for token in doc:
            if token.pos_ == "NOUN":
                heads.append(KnowledgeHead(text=token.text, type=KnowledgeHeadType.NOUN, entity=token))
        
        return heads

noun_extractor = NounHeadExtractor("noun_extractor", spacy.load("en_core_web_sm"))
csi.add_processor(noun_extractor)

In [13]:
csi.processors

{'head': ['sentence_extractor', 'phrase_extractor', 'noun_extractor'],
 'relation': ['simple_matcher']}

### Low-level API

In [None]:
model.train(input_graph)
output_graph = model.generate(input_graph)
# model.evaluate(...)