In [1]:
from sentence_transformers import SentenceTransformer, util
from typing import List, TypedDict

from core.dataloaders.focus.focus_dataloader import FoCusTestDatasetV1
import torch
import json

test_dataset = FoCusTestDatasetV1(
    input_dataset_path="./datasets/FoCus/test_focus_public.json"
)

class FocusKnowledgeKandidateExtractorDictV1(TypedDict):
    predicted_index: int
    predicted_knowledge: str

class FocusKnowledgeKandidateExtractorV1:
    def __init__(self,
        model_name: str = 'all-mpnet-base-v2'
    ) -> None:
        self.model_name = model_name
        self.model: SentenceTransformer = SentenceTransformer(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
    
    def extract(self, 
        persona: List[str],
        query: str,
        knowledge_candidates: List[str],
    ) -> FocusKnowledgeKandidateExtractorDictV1:
        persona = " ".join(persona)
        query = query + " " + persona

        query_emb = self.model.encode([query], convert_to_tensor=True)
        corpus_emb = self.model.encode(knowledge_candidates, convert_to_tensor=True)

        cosine_scores = util.cos_sim(corpus_emb, query_emb)
        top_indices = cosine_scores.topk(1, dim=0).indices.flatten().tolist()
        top_sentences = [knowledge_candidates[i] for i in top_indices]
        return FocusKnowledgeKandidateExtractorDictV1(
            predicted_index=top_indices[0],
            predicted_knowledge=top_sentences[0]
        )


from core.base_models.debertav3_models import DebertaV3PersonaClassificationV3
from transformers import DebertaV2Config
from core.hyperparameters.debertav3_hyperparameters import DebertaV3HyperparametersV1
from core.dataloaders.focus.models.debertav3_dataloaders import DebertaV3FoCusPersonaTestDatasetSampleV1, DebertaV3FoCusPersonaTestDatasetSampleDictV2
from transformers import AutoTokenizer

class FocusPersonaExtractorDictV1(TypedDict):
    predicted_persona: List[str]
    predicted_persona_grounding: List[int]

class FocusPersonaExtractorV1:
    def __init__(self,
        model_name: str = 'microsoft/deberta-base',
        sample_class = DebertaV3FoCusPersonaTestDatasetSampleDictV2,
        model_sample_class = DebertaV3FoCusPersonaTestDatasetSampleV1
    ) -> None:
        self.model_name = model_name
        self.model = DebertaV3PersonaClassificationV3.from_pretrained(
            model_name,
            config=DebertaV2Config.from_pretrained(
                model_name,
            ),
        )
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        self.hyperparameters = DebertaV3HyperparametersV1(
            train_batch_size=16,
            valid_batch_size=16,
            max_dialog_history_tokens=70,
            max_knowledge_candidates_tokens=220,
            max_persona_tokens=20,
            model_name=model_name,
            project_name="focus_persona_classification",
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.sample_class = sample_class
        self.model_sample_class = model_sample_class

    
    def extract(self,
        persona_sentences: List[str],
        used_knowledge: str,
        query: str,
    ) -> List[str]:
        model_persona_samples = []
        for persona_sentence in persona_sentences:
            sample = self.sample_class(
                persona_sentence=persona_sentence,
                used_knowledge=used_knowledge,
                query=query,
            )
            model_persona_sample = self.model_sample_class(
                dataset_sample=sample,
                tokenizer=self.tokenizer,
                h_params=self.hyperparameters,
            )
            model_persona_samples.append(model_persona_sample.get_dict())
        
        predictions = []
        persona_preds: List[str] = []

        for i, model_persona_sample in enumerate(model_persona_samples):
            for key in model_persona_sample.keys():
                model_persona_sample[key] = torch.tensor(model_persona_sample[key])
                model_persona_sample[key] = model_persona_sample[key].unsqueeze(0)

            outputs = self.model(
                **model_persona_sample,
            )
            logits = outputs.logits
            pred = logits.argmax(dim=1).item()
            if pred == 1:
                persona = persona_sentences[i]
                persona_preds.append(persona)
            predictions.append(pred)

        return FocusPersonaExtractorDictV1(
            predicted_persona=persona_preds,
            predicted_persona_grounding=predictions
        )


class BartFocusTestDatasetDictV1(TypedDict):
    """
    knowledge: List[str] все знания об объекте из википедии что у нас есть
    query: str последний вопрос от пользователя
    dialog_id: str идентификатор диалога
    predicted_persona_grouding: List[int] предсказанная персона. массив из 5 элементов, 
        где 1 - персона использована, 0 - не использована
    predicted_persona: List[str] предсказанная персона(только использованные)
    predicted_knowledge_index: int предсказанное знание
    predicted_knowledge: str предсказанное знание
    position: int позиция в диалоге
    """
    knowledge: List[str]
    query: str
    dialog_id: str
    predicted_persona_grouding: List[int]
    predicted_persona: List[str]
    predicted_knowledge_index: int
    predicted_knowledge: str
    position: int


from core.base_models.bart_models import BartLMV7
from core.dataloaders.focus.models.bart_dataloaders import BartRensponseTestDatasetDictV1
from core.dataloaders.focus.models.bart_dataloaders import BartFoCusTestDatasetSampleV1
from core.hyperparameters.bart_hyperparameters import BartHyperparametersV3
from core.tokenizers.bart_tokenizers import BartFoCusTokenizerV2
from transformers import BartConfig
import torch

class ResponseGeneratorV1:
    def __init__(self,
        model_name: str = "./bart_base_2cx77pua",
    ) -> None:

        self.hyperparameters = BartHyperparametersV3(
            model_name=model_name,
        )

        self.tokenizer = BartFoCusTokenizerV2.from_pretrained(
            self.hyperparameters.model_name,
            hyperparameters=self.hyperparameters,
        )
        self.model = BartLMV7.from_pretrained(
            model_name, 
            config=BartConfig.from_pretrained(
            self.hyperparameters.model_name,
            ),  # type: ignore
            hyperparameters=self.hyperparameters,
            tokenizer=self.tokenizer, 
        )
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
    
    def generate_response(self, 
        sample: BartRensponseTestDatasetDictV1
    ) -> str:
        sample = BartFoCusTestDatasetSampleV1(
            focus_dataset_sample=sample,
            tokenizer=self.tokenizer,
            h_params=self.hyperparameters,
        )
        sample = sample.get_dict()
        for key in sample:
            sample[key] = torch.tensor(sample[key])
            sample[key] = sample[key].unsqueeze(0)
            sample[key] = sample[key].to(self.device)
        
        generated_responses = self.model.generate(
            # **sample,
            input_ids=sample["input_ids"],
            attention_mask=sample["attention_mask"],
            max_length=100,
        )
        generated_responses = self.tokenizer.batch_decode(
            generated_responses,
            skip_special_tokens=True,
        )
        generated_response = generated_responses[0]
        return generated_response

class BartFocusTestDatasetDictV1(TypedDict):
    """
    knowledge: List[str] все знания об объекте из википедии что у нас есть
    knowledge_candidates: List[str] 1 истиный пример, который использовался и 9 ложных
    query: str последний вопрос от пользователя
    dialog_id: str идентификатор диалога
    predicted_persona_grouding: List[int] предсказанная персона. массив из 5 элементов, 
        где 1 - персона использована, 0 - не использована
    predicted_persona: List[str] предсказанная персона(только использованные)
    predicted_knowledge_index: int предсказанное знание
    predicted_knowledge: str предсказанное знание
    position: int позиция в диалоге
    response: str ответ на вопрос
    """
    knowledge: List[str]
    knowledge_candidates: List[str]
    query: str
    dialog_id: str
    predicted_persona_grouding: List[int]
    predicted_persona: List[str]
    predicted_knowledge_index: int
    predicted_knowledge: str
    position: int
    response: str

class BartFocusTestDatasetV1:
    def __init__(self,
        initial_dataset: FoCusTestDatasetV1,
        knowledge_kandidate_extractor: FocusKnowledgeKandidateExtractorV1,
        focus_persona_extractor: FocusPersonaExtractorV1,
        response_generator: ResponseGeneratorV1,
    ) -> None:
        self.initial_dataset = initial_dataset
        self.knowledge_extractor = knowledge_kandidate_extractor
        self.persona_extractor = focus_persona_extractor
        self.response_generator = response_generator
        self.dataset: List[BartFocusTestDatasetDictV1] = []

        self.dataset = self.__build_dataset()
    
    def __build_dataset(self) -> List[BartFocusTestDatasetDictV1]:
        dataset = []

        for i, sample in enumerate(self.initial_dataset):
            persona = sample["persona"]
            query = sample["query"]
            knowledge = sample["knowledge"]
            position = sample["position"]
            knowledge_candidates = sample["knowledge_candidates"]
            knowledge_prediction = self.knowledge_extractor.extract(
                persona=persona,
                query=query,
                knowledge_candidates=knowledge_candidates,
            )
            predicted_knowledge_index = knowledge_prediction["predicted_index"]
            predicted_knowledge = knowledge_prediction["predicted_knowledge"]

            persona_prediction = self.persona_extractor.extract(
                persona_sentences=persona,
                used_knowledge=predicted_knowledge,
                query=query,
            )

            predicted_persona_grounding = persona_prediction["predicted_persona_grounding"]
            predicted_persona = persona_prediction["predicted_persona"]

            response_sample = BartRensponseTestDatasetDictV1(
                persona=predicted_persona,
                query=query,
                knowledge_candidate=predicted_knowledge,
            )

            bot_response = self.response_generator.generate_response(
                sample=response_sample
            )

            dataset_sample = BartFocusTestDatasetDictV1(
                knowledge=knowledge,
                query=query,
                dialog_id=sample["dialog_id"],
                predicted_persona_grouding=predicted_persona_grounding,
                predicted_persona=predicted_persona,
                predicted_knowledge_index=predicted_knowledge_index,
                predicted_knowledge=predicted_knowledge,
                position=position,
                response=bot_response,
                knowledge_candidates=knowledge_candidates,
            )
            dataset.append(dataset_sample)
            print(f"Progress {i}/{len(self.initial_dataset)}")

        return dataset
            
    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: int) -> BartFocusTestDatasetDictV1:
        return self.dataset[index]
    
    def save_dataset_to_json(self, path: str) -> None:
        with open(path, "w") as f:
            json.dump(self.dataset, f, indent=2)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
persona_model_name = "./results/microsoft/deberta-v3-small/checkpoint-87000/"
persona_extractor = FocusPersonaExtractorV1(
    model_name=persona_model_name,
)

knowledge_extractor = FocusKnowledgeKandidateExtractorV1()
response_extractor = ResponseGeneratorV1(
    model_name="./bart_base_2cx77pua",
)

bart_focus_test_dataset = BartFocusTestDatasetV1(
    initial_dataset=test_dataset,
    knowledge_kandidate_extractor=knowledge_extractor,
    focus_persona_extractor=persona_extractor,
    response_generator=response_extractor,
)
bart_focus_test_dataset.save_dataset_to_json("./predicts/bart_predict_full_dataset.json")


Progress 0/5636
Progress 1/5636
Progress 2/5636
Progress 3/5636
Progress 4/5636
Progress 5/5636
Progress 6/5636
Progress 7/5636
Progress 8/5636
Progress 9/5636
Progress 10/5636
Progress 11/5636
Progress 12/5636
Progress 13/5636
Progress 14/5636
Progress 15/5636
Progress 16/5636
Progress 17/5636
Progress 18/5636
Progress 19/5636
Progress 20/5636
Progress 21/5636
Progress 22/5636
Progress 23/5636
Progress 24/5636
Progress 25/5636
Progress 26/5636
Progress 27/5636
Progress 28/5636
Progress 29/5636
Progress 30/5636
Progress 31/5636
Progress 32/5636
Progress 33/5636
Progress 34/5636
Progress 35/5636
Progress 36/5636
Progress 37/5636
Progress 38/5636
Progress 39/5636
Progress 40/5636
Progress 41/5636
Progress 42/5636
Progress 43/5636
Progress 44/5636
Progress 45/5636
Progress 46/5636
Progress 47/5636
Progress 48/5636
Progress 49/5636
Progress 50/5636
Progress 51/5636
Progress 52/5636
Progress 53/5636
Progress 54/5636
Progress 55/5636
Progress 56/5636
Progress 57/5636
Progress 58/5636
Progres

## create persona and knowledge prediction dataset

In [3]:
import json

with open("./predicts/bart_predict_full_dataset.json", "r") as f:
    predicts = json.load(f)

In [6]:
# 'query': "I know this place, but I don't remember the name of this place.",
#  'dialog_id': '3G5ZGPU4CUD8',
#  'predicted_persona_grouding': [1, 0, 0, 0, 0],
#  'predicted_persona': ['I would like to visit France.'],
#  'predicted_knowledge_index': 2,
#  'predicted_knowledge': 'The Château de Verteuil is a historic building in Charente, France.',
#  'position': 0,
#  'response': 'This is the Château de Verteuil located in France, which you want to visit.'}

In [5]:
def convert_to_list_of_dicts(_dict):
    return [{key: _dict[key]} for key in _dict.keys()]

In [45]:
knowledge_persona = {}
for i, sample in enumerate(predicts):
    dialog_id = sample["dialog_id"]
    position = sample["position"]
    predicted_knowledge_index = sample["predicted_knowledge_index"]
    predicted_persona_grouding = sample["predicted_persona_grouding"]
    
    if dialog_id not in knowledge_persona:
        knowledge_persona[dialog_id] = []
    
    knowledge_persona[dialog_id].append({
        "pg": predicted_persona_grouding,
        "kg": predicted_knowledge_index,
        "position": position,
    })
    
    # if i > 20:
    #     break

for dialog_id in knowledge_persona:
    knowledge_persona[dialog_id].sort(key=lambda x: x["position"])
    for sample in knowledge_persona[dialog_id]:
        sample.pop("position", None)

knowledge_persona = convert_to_list_of_dicts(knowledge_persona)

In [46]:
len(knowledge_persona)

1000

In [47]:
# save to json
with open("./predicts/bart_predict_full_dataset_persona_knowledge.json", "w") as f:
    json.dump(knowledge_persona, f, indent=2)

### generate response

In [48]:
predicted_response = {}
for i, sample in enumerate(predicts):
    dialog_id = sample["dialog_id"]
    position = sample["position"]
    response = sample["response"]
    
    if dialog_id not in predicted_response:
        predicted_response[dialog_id] = []
    
    predicted_response[dialog_id].append({
        "generation": response,
        "position": position,
    })
    
    # if i > 20:
    #     break

for dialog_id in predicted_response:
    predicted_response[dialog_id].sort(key=lambda x: x["position"])
    for sample in predicted_response[dialog_id]:
        sample.pop("position", None)

predicted_response = convert_to_list_of_dicts(predicted_response)

In [49]:
len(predicted_response)

1000

In [33]:
# save to json
with open("./predicts/bart_predict_full_dataset_response.json", "w") as f:
    json.dump(predicted_response, f, indent=2)

## predict for the workshop

In [2]:
workshop_dataset = FoCusTestDatasetV1(
    input_dataset_path="./datasets/FoCus/focus_workshop_public.json"
)

In [3]:
persona_model_name = "./results/microsoft/deberta-v3-small/checkpoint-87000/"
persona_extractor = FocusPersonaExtractorV1(
    model_name=persona_model_name,
)

knowledge_extractor = FocusKnowledgeKandidateExtractorV1()
response_extractor = ResponseGeneratorV1(
    model_name="./bart_base_2cx77pua",
)

bart_focus_workshop_dataset = BartFocusTestDatasetV1(
    initial_dataset=workshop_dataset,
    knowledge_kandidate_extractor=knowledge_extractor,
    focus_persona_extractor=persona_extractor,
    response_generator=response_extractor,
)
bart_focus_workshop_dataset.save_dataset_to_json("./predicts/bart_predict_workshop_dataset.json")


Progress 0/5622
Progress 1/5622
Progress 2/5622
Progress 3/5622
Progress 4/5622
Progress 5/5622
Progress 6/5622
Progress 7/5622
Progress 8/5622
Progress 9/5622
Progress 10/5622
Progress 11/5622
Progress 12/5622
Progress 13/5622
Progress 14/5622
Progress 15/5622
Progress 16/5622
Progress 17/5622
Progress 18/5622
Progress 19/5622
Progress 20/5622
Progress 21/5622
Progress 22/5622
Progress 23/5622
Progress 24/5622
Progress 25/5622
Progress 26/5622
Progress 27/5622
Progress 28/5622
Progress 29/5622
Progress 30/5622
Progress 31/5622
Progress 32/5622
Progress 33/5622
Progress 34/5622
Progress 35/5622
Progress 36/5622
Progress 37/5622
Progress 38/5622
Progress 39/5622
Progress 40/5622
Progress 41/5622
Progress 42/5622
Progress 43/5622
Progress 44/5622
Progress 45/5622
Progress 46/5622
Progress 47/5622
Progress 48/5622
Progress 49/5622
Progress 50/5622
Progress 51/5622
Progress 52/5622
Progress 53/5622
Progress 54/5622
Progress 55/5622
Progress 56/5622
Progress 57/5622
Progress 58/5622
Progres

In [4]:
import json

with open("./predicts/bart_predict_workshop_dataset.json", "r") as f:
    predicts = json.load(f)

In [6]:
knowledge_persona = {}
for i, sample in enumerate(predicts):
    dialog_id = sample["dialog_id"]
    position = sample["position"]
    predicted_knowledge_index = sample["predicted_knowledge_index"]
    predicted_persona_grouding = sample["predicted_persona_grouding"]
    
    if dialog_id not in knowledge_persona:
        knowledge_persona[dialog_id] = []
    
    knowledge_persona[dialog_id].append({
        "pg": predicted_persona_grouding,
        "kg": predicted_knowledge_index,
        "position": position,
    })
    
    # if i > 20:
    #     break

for dialog_id in knowledge_persona:
    knowledge_persona[dialog_id].sort(key=lambda x: x["position"])
    for sample in knowledge_persona[dialog_id]:
        sample.pop("position", None)

knowledge_persona = convert_to_list_of_dicts(knowledge_persona)

In [7]:
# save to json
with open("./predicts/bart_predict_full_dataset_persona_knowledge_workshop.json", "w") as f:
    json.dump(knowledge_persona, f, indent=2)

### response


In [9]:
predicted_response = {}
for i, sample in enumerate(predicts):
    dialog_id = sample["dialog_id"]
    position = sample["position"]
    response = sample["response"]
    
    if dialog_id not in predicted_response:
        predicted_response[dialog_id] = []
    
    predicted_response[dialog_id].append({
        "generation": response,
        "position": position,
    })
    
    # if i > 20:
    #     break

for dialog_id in predicted_response:
    predicted_response[dialog_id].sort(key=lambda x: x["position"])
    for sample in predicted_response[dialog_id]:
        sample.pop("position", None)

predicted_response = convert_to_list_of_dicts(predicted_response)

In [10]:
# save to json
with open("./predicts/bart_predict_full_dataset_response_workshop.json", "w") as f:
    json.dump(predicted_response, f, indent=2)