  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from core.dataloaders.focus.focus_dataloader import FoCusTestDatasetV1
from core.inference.inference_scripts import (
    FocusPersonaExtractorV1,
    FocusKnowledgeKandidateExtractorV1,
    ResponseGeneratorV1,
    BartFocusTestDatasetV1,
)

test_dataset = FoCusTestDatasetV1(
    input_dataset_path="./datasets/FoCus/test_focus_public.json",
)

persona_model_name = "./results/microsoft/deberta-v3-small/checkpoint-87000/"
persona_extractor = FocusPersonaExtractorV1(
    model_name=persona_model_name,
)

knowledge_extractor = FocusKnowledgeKandidateExtractorV1()
response_extractor = ResponseGeneratorV1(
    model_name="./bart_base_2cx77pua",
)

bart_focus_test_dataset = BartFocusTestDatasetV1(
    initial_dataset=test_dataset,
    knowledge_kandidate_extractor=knowledge_extractor,
    focus_persona_extractor=persona_extractor,
    response_generator=response_extractor,
)

## create persona and knowledge prediction dataset

In [3]:
import json

with open("./predicts/bart_predict_full_dataset.json", "r") as f:
    predicts = json.load(f)

In [6]:
from core.inference.inference_scripts import make_submission 

# make_submission(
#     knowledge_persona_save_path=
#     predicts=
#     response_save_path=
# )

### score persona and knowledge prediction on validation dataset

In [1]:
from core.dataloaders.focus.focus_dataloader import FoCusTestDatasetV1
from core.inference.inference_scripts import (
    FocusPersonaExtractorV1,
    FocusKnowledgeKandidateExtractorV1,
    ResponseGeneratorV1,
    BartFocusTestDatasetV1,
    BartFocusTestDatasetDictV1,
    BartRensponseTestDatasetDictV1, 
    
)
from core.utils import TextEvaluator
from core.dataloaders.focus.focus_dataloader import FoCusDatasetV1
from typing import List, TypedDict, Optional
import numpy as np
import json

class BartFocusValidDatasetDictV1(TypedDict):
    """
    knowledge: List[str] все знания об объекте из википедии что у нас есть
    knowledge_candidates: List[str] 1 истиный пример, который использовался и 9 ложных
    query: str последний вопрос от пользователя
    predicted_persona_grouding: List[int] предсказанная персона. массив из 5 элементов,
        где 1 - персона использована, 0 - не использована
    predicted_persona: List[str] предсказанная персона(только использованные)
    predicted_knowledge_index: int предсказанное знание
    predicted_knowledge_candidate: str предсказанное знание
    response: str ответ на вопрос
    blue_score: float blue скор ответа
    rougeL_score: float rougeL скор ответа
    charF_score: float charF скор ответа
    persona_grouding: List[int] истинная персона. массив из 5 элементов,
        где 1 - персона использована, 0 - не использована
    persona: List[str] истинная персона
    knowledge_index: int истинный индекс знания
    gold_response: str истинный ответ на вопрос
    """

    knowledge: List[str]
    knowledge_candidates: List[str]
    query: str
    predicted_persona_grouding: Optional[List[int]]
    predicted_persona: Optional[List[str]]
    predicted_knowledge_index: int
    predicted_knowledge_candidate: str
    response: str
    blue_score: float
    rougeL_score: float
    charF_score: float
    persona_grouding: List[int]
    knowledge_index: int
    persona: List[str]
    gold_response: str



class BartFocusValidDatasetV1:
    def __init__(
        self,
        initial_dataset: FoCusDatasetV1,
        knowledge_kandidate_extractor: FocusKnowledgeKandidateExtractorV1,
        focus_persona_extractor: FocusPersonaExtractorV1,
        response_generator: ResponseGeneratorV1,
        text_evaluator: TextEvaluator,
    ) -> None:
        self.initial_dataset = initial_dataset
        self.knowledge_extractor = knowledge_kandidate_extractor
        self.persona_extractor = focus_persona_extractor
        self.response_generator = response_generator
        self.dataset: List[BartFocusValidDatasetDictV1] = []
        self.text_evaluator = text_evaluator
        self.dataset = []

        self.text_metrics = {
            "blue_score": [],
            "rougeL_score": [],
            "chrf_score": [],
        }

    def evaluate(self,
        gold_persona: bool = False,
        gold_knowledge: bool = False,
    ) -> None:
        dataset = []

        print("Start evaluating")
        print(f"Use gold persona: {gold_persona}")
        print(f"Use gold knowledge: {gold_knowledge}")

        for i, sample in enumerate(self.initial_dataset):  # type: ignore
            knowledge = sample["knowledge"]
            persona_sentences = sample["persona"]
            persona_grounding = sample["persona_grounding"]
            used_persona = [sent for sent, used in zip(persona_sentences, persona_grounding) if used == 1]
            knowledge_candidates = sample["knowledge_candidates"]
            used_knowledge = knowledge_candidates[sample['knowledge_answer_index']]
            query = sample["dialog"][-2]
            user_response = sample["dialog"][-1] 

            predicted_persona_grounding = None
            predicted_knowledge_index = None
            predicted_persona = None
            predicted_knowledge = None
                
            response_sample = None
            # gold_persona and gold_knowledge
            if gold_persona and gold_knowledge:
                response_sample = BartRensponseTestDatasetDictV1(
                    persona=used_persona,
                    knowledge_candidate=used_knowledge,
                    query=query,
                )
            # gold_persona and not gold_knowledge
            elif gold_persona and not gold_knowledge:
                knowledge_prediction = self.knowledge_extractor.extract(
                    persona=persona_sentences,
                    query=query,
                    knowledge_candidates=knowledge_candidates,
                )
                predicted_knowledge = knowledge_prediction["predicted_knowledge"]
                
                response_sample = BartRensponseTestDatasetDictV1(
                    persona=used_persona,
                    knowledge_candidate=predicted_knowledge,
                    query=query,
                )
            # not gold_persona and gold_knowledge
            elif not gold_persona and gold_knowledge:
                knowledge_prediction = self.knowledge_extractor.extract(
                    persona=persona_sentences,
                    query=query,
                    knowledge_candidates=knowledge_candidates,
                )
                predicted_knowledge = knowledge_prediction["predicted_knowledge"]

                persona_prediction = self.persona_extractor.extract(
                    persona_sentences=persona_sentences,
                    used_knowledge=predicted_knowledge,
                    query=query,
                )
                predicted_persona = persona_prediction["predicted_persona"]

                response_sample = BartRensponseTestDatasetDictV1(
                    persona=predicted_persona,
                    knowledge_candidate=used_knowledge,
                    query=query,
                )
            # not gold_persona and not gold_knowledge
            elif not gold_persona and not gold_knowledge:
                knowledge_prediction = self.knowledge_extractor.extract(
                    persona=persona_sentences,
                    query=query,
                    knowledge_candidates=knowledge_candidates,
                )
                predicted_knowledge = knowledge_prediction["predicted_knowledge"]

                persona_prediction = self.persona_extractor.extract(
                    persona_sentences=persona_sentences,
                    used_knowledge=predicted_knowledge,
                    query=query,
                )
                predicted_persona = persona_prediction["predicted_persona"]

                response_sample = BartRensponseTestDatasetDictV1(
                    persona=predicted_persona,
                    knowledge_candidate=predicted_knowledge,
                    query=query,
                )

            bot_response = self.response_generator.generate_response(
                res_sample=response_sample,
            )

            text_metrics = self.text_evaluator.evaluate(
                generated_texts=[bot_response],
                original_texts=[user_response],
            )

            for key in text_metrics.keys():
                self.text_metrics[key].append(text_metrics[key])

            dataset_sample = BartFocusValidDatasetDictV1(
                knowledge=knowledge,
                knowledge_candidates=knowledge_candidates,
                query=query,
                predicted_persona_grouding=predicted_persona_grounding,
                predicted_persona=predicted_persona,
                predicted_knowledge_index=predicted_knowledge_index,
                predicted_knowledge_candidate=predicted_knowledge,
                response=bot_response,
                gold_response=user_response,
                blue_score=text_metrics["blue_score"],
                rougeL_score=text_metrics["rougeL_score"],
                charF_score=text_metrics["chrf_score"],
                persona_grouding=persona_grounding,
                knowledge_index=sample['knowledge_answer_index'],
                persona=used_persona,
            )
            dataset.append(dataset_sample)
            # print(dataset_sample)
            print(f"Progress {i}/{len(self.initial_dataset)}")

        self.dataset = dataset

        print("Finish evaluating")
        print(f"Use gold persona: {gold_persona}")
        print(f"Use gold knowledge: {gold_knowledge}")
        print(f"Averaged blue score: {np.mean(self.text_metrics['blue_score'])}")
        print(f"Averaged rougeL score: {np.mean(self.text_metrics['rougeL_score'])}")
        print(f"Averaged charF score: {np.mean(self.text_metrics['chrf_score'])}")

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> BartFocusTestDatasetDictV1:
        return self.dataset[index]

    def save_dataset_to_json(self, path: str) -> None:
        with open(path, "w") as f:
            json.dump(self.dataset, f, indent=2)


from core.dataloaders.focus.focus_dataloader import FoCusTestDatasetV1
from core.inference.inference_scripts import (
    FocusPersonaExtractorV1,
    FocusKnowledgeKandidateExtractorV1,
    ResponseGeneratorV1,
)

from core.utils import TextEvaluator

persona_model_name = "./results/microsoft/deberta-v3-small/checkpoint-87000/"
persona_extractor = FocusPersonaExtractorV1(
    model_name=persona_model_name,
)

knowledge_extractor = FocusKnowledgeKandidateExtractorV1()
response_extractor = ResponseGeneratorV1(
    model_name="./bart_base_2cx77pua",
)

valid_dataset = FoCusDatasetV1(
    input_dataset_path="./datasets/FoCus/valid_focus.json",
)
text_evaluator = TextEvaluator()

bart_focus_valid_dataset = BartFocusValidDatasetV1(
    initial_dataset=valid_dataset,
    knowledge_kandidate_extractor=knowledge_extractor,
    focus_persona_extractor=persona_extractor,
    response_generator=response_extractor,
    text_evaluator=text_evaluator,
)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
bart_focus_valid_dataset.evaluate(
    gold_persona=True,
    gold_knowledge=True,
)
bart_focus_valid_dataset.save_dataset_to_json("./evaluation/gold_persona_gold_knowledge_bart_focus_valid_dataset.json")
# Use gold persona: True
# Use gold knowledge: True
# Averaged blue score: 29.40638251464509
# Averaged rougeL score: 0.5165242114426458
# Averaged charF score: 0.5096153992286768

In [2]:
bart_focus_valid_dataset.evaluate(
    gold_persona=False,
    gold_knowledge=True,
)
bart_focus_valid_dataset.save_dataset_to_json("./evaluation/extracted_persona_gold_knowledge_bart_focus_valid_dataset.json")
# Use gold persona: False
# Use gold knowledge: True
# Averaged blue score: 26.57767439164382
# Averaged rougeL score: 0.4867563729015985
# Averaged charF score: 0.4815876512941998

Start evaluating
Use gold persona: False
Use gold knowledge: True


  total_n_grams[n] = tensor(sum(n_grams_counts[n].values()))
  matching_n_grams[n] = tensor(


Progress 0/5639
Progress 1/5639
Progress 2/5639
Progress 3/5639
Progress 4/5639
Progress 5/5639
Progress 6/5639
Progress 7/5639
Progress 8/5639
Progress 9/5639
Progress 10/5639
Progress 11/5639
Progress 12/5639
Progress 13/5639
Progress 14/5639
Progress 15/5639
Progress 16/5639
Progress 17/5639
Progress 18/5639
Progress 19/5639
Progress 20/5639
Progress 21/5639
Progress 22/5639
Progress 23/5639
Progress 24/5639
Progress 25/5639
Progress 26/5639
Progress 27/5639
Progress 28/5639
Progress 29/5639
Progress 30/5639
Progress 31/5639
Progress 32/5639
Progress 33/5639
Progress 34/5639
Progress 35/5639
Progress 36/5639
Progress 37/5639
Progress 38/5639
Progress 39/5639
Progress 40/5639
Progress 41/5639
Progress 42/5639
Progress 43/5639
Progress 44/5639
Progress 45/5639
Progress 46/5639
Progress 47/5639
Progress 48/5639
Progress 49/5639
Progress 50/5639
Progress 51/5639
Progress 52/5639
Progress 53/5639
Progress 54/5639
Progress 55/5639
Progress 56/5639
Progress 57/5639
Progress 58/5639
Progres

In [3]:
bart_focus_valid_dataset.evaluate(
    gold_persona=True,
    gold_knowledge=False,
)
bart_focus_valid_dataset.save_dataset_to_json("./evaluation/gold_persona_extracted_knowledge_bart_focus_valid_dataset.json")
# Use gold persona: True
# Use gold knowledge: False
# Averaged blue score: 27.32761553556316
# Averaged rougeL score: 0.49319772668931866
# Averaged charF score: 0.48781404757945224

Start evaluating
Use gold persona: True
Use gold knowledge: False
Progress 0/5639
Progress 1/5639
Progress 2/5639
Progress 3/5639
Progress 4/5639
Progress 5/5639
Progress 6/5639
Progress 7/5639
Progress 8/5639
Progress 9/5639
Progress 10/5639
Progress 11/5639
Progress 12/5639
Progress 13/5639
Progress 14/5639
Progress 15/5639
Progress 16/5639
Progress 17/5639
Progress 18/5639
Progress 19/5639
Progress 20/5639
Progress 21/5639
Progress 22/5639
Progress 23/5639
Progress 24/5639
Progress 25/5639
Progress 26/5639
Progress 27/5639
Progress 28/5639
Progress 29/5639
Progress 30/5639
Progress 31/5639
Progress 32/5639
Progress 33/5639
Progress 34/5639
Progress 35/5639
Progress 36/5639
Progress 37/5639
Progress 38/5639
Progress 39/5639
Progress 40/5639
Progress 41/5639
Progress 42/5639
Progress 43/5639
Progress 44/5639
Progress 45/5639
Progress 46/5639
Progress 47/5639
Progress 48/5639
Progress 49/5639
Progress 50/5639
Progress 51/5639
Progress 52/5639
Progress 53/5639
Progress 54/5639
Progress 

In [4]:
bart_focus_valid_dataset.evaluate(
    gold_persona=False,
    gold_knowledge=False,
)
bart_focus_valid_dataset.save_dataset_to_json("./evaluation/extracted_persona_extracted_knowledge_bart_focus_valid_dataset.json")

Start evaluating
Use gold persona: False
Use gold knowledge: False
Progress 0/5639
Progress 1/5639
Progress 2/5639
Progress 3/5639
Progress 4/5639
Progress 5/5639
Progress 6/5639
Progress 7/5639
Progress 8/5639
Progress 9/5639
Progress 10/5639
Progress 11/5639
Progress 12/5639
Progress 13/5639
Progress 14/5639
Progress 15/5639
Progress 16/5639
Progress 17/5639
Progress 18/5639
Progress 19/5639
Progress 20/5639
Progress 21/5639
Progress 22/5639
Progress 23/5639
Progress 24/5639
Progress 25/5639
Progress 26/5639
Progress 27/5639
Progress 28/5639
Progress 29/5639
Progress 30/5639
Progress 31/5639
Progress 32/5639
Progress 33/5639
Progress 34/5639
Progress 35/5639
Progress 36/5639
Progress 37/5639
Progress 38/5639
Progress 39/5639
Progress 40/5639
Progress 41/5639
Progress 42/5639
Progress 43/5639
Progress 44/5639
Progress 45/5639
Progress 46/5639
Progress 47/5639
Progress 48/5639
Progress 49/5639
Progress 50/5639
Progress 51/5639
Progress 52/5639
Progress 53/5639
Progress 54/5639
Progress