In [5]:
%load_ext autoreload
%autoreload 2

### Sample usage

In [None]:
from kogito.models.bart.comet_bart import COMETBART
from kogito.inference import CommonsenseInference
from kogito.models.base import KnowledgeModel
from kogito.core.knowledge import KnowledgeGraph

model: KnowledgeModel = COMETBART.from_pretrained("/Users/mismayil/Desktop/EPFL/nlplab/comet-atomic-2020/comet-atomic_2020_BART")
csi = CommonsenseInference()
text = "Gabby always brought cookies to school. But at lunch, everyone wanted them. And she had a hard time saying no. Gabby began to hate the other students. And at lunch, she ate far away from everyone."
kgraph: KnowledgeGraph = csi.infer(text, model, model_args={"num_generate": 3, "batch_size": 128})
kgraph.to_jsonl("kgraph2.json")

### Customizing processors

In [None]:
csi.processors

In [None]:
csi.remove_processor("noun_phrase_extractor")

In [None]:
csi.processors

In [None]:
from kogito.core.head import KnowledgeHeadExtractor, KnowledgeHead, KnowledgeHeadType
from typing import Optional, List
from spacy.tokens import Doc

import spacy

class NounHeadExtractor(KnowledgeHeadExtractor):
    def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]:
        if not doc:
            doc = self.lang(text)

        heads = []

        for token in doc:
            if token.pos_ == "NOUN":
                heads.append(KnowledgeHead(text=token.text, type=KnowledgeHeadType.NOUN_PHRASE, entity=token))
        
        return heads

noun_extractor = NounHeadExtractor("noun_extractor", spacy.load("en_core_web_sm"))
csi.add_processor(noun_extractor)

In [None]:
csi.processors

In [1]:
from kogito.inference import CommonsenseInference
from kogito.core.knowledge import KnowledgeGraph
from kogito.models.bart.comet_bart import COMETBART
from kogito.models.base import KnowledgeModel

csi = CommonsenseInference()
model: KnowledgeModel = COMETBART.from_pretrained("/Users/mismayil/Desktop/EPFL/nlplab/comet-atomic-2020/comet-atomic_2020_BART")


INFO:kogito.core.utils:using task specific params for summarization: {'early_stopping': True, 'length_penalty': 2.0, 'max_length': 24, 'min_length': 1, 'no_repeat_ngram_size': 3, 'num_beams': 4}


### Dry-run

In [11]:
text = "Gabby always brought cookies to school."
kgraph: KnowledgeGraph = csi.infer(text, model, model_args={"num_generate": 3, "batch_size": 128}, dry_run=True)
kgraph.to_jsonl("kgraph_dry_run.json")

Extracting heads...
Matching relations...


### Relation subset

In [12]:
text = "Gabby always brought cookies to school."
kgraph: KnowledgeGraph = csi.infer(text, model, model_args={"num_generate": 3, "batch_size": 128}, dry_run=True, relations=["ObjectUse", "Causes"])
kgraph.to_jsonl("kgraph_rel_subset.json")

Extracting heads...
Matching relations...


### No Head extraction

In [13]:
text = "Gabby always brought cookies to school."
kgraph: KnowledgeGraph = csi.infer(text, model, model_args={"num_generate": 3, "batch_size": 128}, dry_run=True, extract_heads=False)
kgraph.to_jsonl("kgraph_no_head_extract.json")

Matching relations...


### No Relation matching and no subset of relations

In [14]:
text = "Gabby always brought cookies to school."
kgraph: KnowledgeGraph = csi.infer(text, model, model_args={"num_generate": 3, "batch_size": 128}, dry_run=True, match_relations=False)

Extracting heads...


ValueError: No relation found to match

### No Relation matching with subset of relations

In [15]:
text = "Gabby always brought cookies to school."
kgraph: KnowledgeGraph = csi.infer(text, model, model_args={"num_generate": 3, "batch_size": 128}, dry_run=True, match_relations=False, relations=["Causes", "Desires"])
kgraph.to_jsonl("kgraph_no_match_subset.json")

Extracting heads...


### No Head extraction, no Relation matching with subset of relations (hence, ultimate manual specification)

In [16]:
text = "Gabby always brought cookies to school."
kgraph: KnowledgeGraph = csi.infer(text, model, model_args={"num_generate": 3, "batch_size": 128}, dry_run=True, extract_heads=False, match_relations=False, relations=["Causes", "Desires"])
kgraph.to_jsonl("kgraph_manual.json")

### Low-level API

In [None]:
model.train(input_graph)
output_graph = model.generate(input_graph)
# model.evaluate(...)