In [18]:
str([[123, 123], [1, 2]])

'[[123, 123], [1, 2]]'

In [19]:
from torch.utils.data import Dataset
from typing import List, Dict, TypedDict
import json
from transformers import AutoTokenizer, BartTokenizer

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from itertools import chain

class TF_IDF:
	def __init__(self, 
			corpus: List[List[int]],
		) -> None:
		"""
		Example usage:
			corpus = [
				[1, 2, 3, 4],
				[1, 2, 3],
				[1, 2],
			]

			tf_idf = TF_IDF(corpus)
			
			similar_sentences = tf_idf.get_similar([1, 2, 3], n=3)
			>>> similar_sentences
			[
				[1, 2, 3],
				[1, 2, 3, 4],
				[1, 2]
			]

		Args:
			corpus (List[List[int]]): токенизированный корпус
		"""
		self.vectorizer = TfidfVectorizer(
			# token_pattern is number
			token_pattern=r"(?u)\b\d+\b", 
		)
		new_corpus = self.__encode_sentences(corpus)

		self.X = self.vectorizer.fit_transform(new_corpus)
		self.corpus = corpus
	
	def __encode_sentence(self, sentence: List[int]) -> str:
		return " ".join(list(map(str, sentence)))

	def __encode_sentences(self, sentences: List[List[int]]) -> List[str]:
		return list(map(self.__encode_sentence, sentences))
	
	def top_similar(self, 
			query: List[List[int]] = None,
			top_k: int = 1,
		) -> List[List[int]]:
		query = self.__encode_sentences(query)
		query = self.vectorizer.transform(query)
		
		similarity = cosine_similarity(self.X, query)
		similarity = similarity.flatten()
		similarity = np.argsort(similarity)[::-1][:top_k]
		similarity = similarity.tolist()

		similar_samples = [self.corpus[i] for i in similarity]
		return similar_samples

class FoCusTF_IDF(TF_IDF):
	def __init__(self,
		**kwargs,
	) -> None:
		super().__init__(**kwargs)

		self.cached_similar = {}
	
	def top_similar(self, 
			query: List[List[int]] = None, 
			top_k: int = 1
		) -> List[List[int]]:
		query_str = str(query)

		if query_str in self.cached_similar:
			return self.cached_similar[query_str]
		
		similar_samples = super().top_similar(
			query=query,
			top_k=top_k,
		)
		self.cached_similar[query_str] = similar_samples

		return similar_samples
class FoCusDatasetSampleDictV1(TypedDict):
	persona: List[str]
	knowledge_candidates: List[str]
	persona_grounding: List[int]
	dialog: List[int]
	knowledge_answer_index: int
	knowledge: List[str]

class FoCusDatasetSampleV1:
	__slots__ = (
		'persona', 
		'knowledge_candidates',  
		'persona_grounding', 
		'dialog', 
		'knowledge_answer_index',
		"knowledge"
	)

	def __init__(self, 
			persona: List[str],
			knowledge_candidates: List[str],
			persona_grounding: List[int],
			dialog: List[str],
			knowledge: List[str],
			knowledge_answer_index: int,
		) -> None:
		self.persona = persona
		self.knowledge_candidates = knowledge_candidates
		self.persona_grounding = persona_grounding
		self.knowledge_answer_index = knowledge_answer_index
		self.dialog = dialog
		self.knowledge = knowledge
	
	def get_dict(self) -> FoCusDatasetSampleDictV1:
		return {
			'persona': self.persona,
			'knowledge_candidates': self.knowledge_candidates,
			'persona_grounding': self.persona_grounding,
			'dialog': self.dialog,
			'knowledge_answer_index': self.knowledge_answer_index,
			'knowledge': self.knowledge,
		}

class BartFoCusDatasetSampleHyperparametersV1:
	def __init__(self,
			dialog_history_length: int = 1,
			context_length: int = 1,
			knowledge_length: int = 1,
		) -> None:
		"""
		Args:
			dialog_history_length (int): количество пар диалогов(назад), которые будут 
				использоваться для генерации ответа	
			context_length (int): количество предложений из диалога, относительно которых 
				будут выбираться похожие из поля knowledge
			knowledge_length (int): количество предложений из knowledge, которые будут
				подаваться на вход модели 
		"""
		self.dialog_history_length = 1
		self.context_length = 1
		self.knowledge_length = 1
		
		self.max_persona_tokens = 200
		self.max_dialog_history_tokens = 200
		self.max_knowledge_tokens = 200
		self.max_bot_response_tokens = 150

		self.dialog_bos_token = '<dialog>'
		self.dialog_eos_token = '</dialog>'

class BartFoCusTokenizerV1(BartTokenizer):
	def __init__(self,
			*args,
			**kwargs 
		) -> None:
		super().__init__(**kwargs)

	@classmethod
	def from_pretrained(cls, 
			*args, 
			hyperparameters: BartFoCusDatasetSampleHyperparametersV1 = None, 
			**kwargs
		):
		
		tokenizer: BartTokenizer = BartTokenizer.from_pretrained(*args, **kwargs)
		
		if hyperparameters is not None:
			tokens = [
				hyperparameters.dialog_bos_token,
				hyperparameters.dialog_eos_token,
			]

			tokenizer.add_special_tokens({'additional_special_tokens': tokens})

		return tokenizer

class BartFoCusDatasetSampleDictV1(TypedDict):
	input_ids: List[int]
	attention_mask: List[int]

class BartFoCusDatasetSampleV1:
	"""
	[BOS][persona][SEP][knowledge][SEP][dialog][:-1][SEP]<dialog>[dialog][-1]</dialog> 
	- [dialog] - набор диалоговых пар
	- persona - все предложения персоны
	- knowledge - топ наиболее похожих предложений из knowledge к контексту диалога
	- [dialog][:-1] - все диалоговые пары, кроме ответа бота
	- <dialog>[dialog][-1]</dialog> - ответ бота 

	"""
	def __init__(self, 
			focus_dataset_sample: FoCusDatasetSampleDictV1,
			tokenizer: BartFoCusTokenizerV1,
			h_params: BartFoCusDatasetSampleHyperparametersV1,
		) -> None:
		self.focus_dataset_sample = focus_dataset_sample
		self.tokenizer = tokenizer
		self.h_params = h_params

		self.bos_token_id = self.tokenizer.bos_token_id
		self.pad_token_id = self.tokenizer.pad_token_id
		self.unk_token_id = self.tokenizer.unk_token_id
		self.sep_token_id = self.tokenizer.sep_token_id
		self.cls_token_id = self.tokenizer.cls_token_id

		self.dialog_bos = self.__get_token_id(h_params.dialog_bos_token)
		self.dialog_eos = self.__get_token_id(h_params.dialog_eos_token)
	
	def __get_token_id(self, token: str) -> int:
		return self.tokenizer.convert_tokens_to_ids(token)
	
	def __flat_list(self, list_of_lists: List[List]) -> List:
		return list(chain.from_iterable(list_of_lists))

	def get_dict(self) -> BartFoCusDatasetSampleDictV1:
		dialog_history_length = self.h_params.dialog_history_length
		context_length = self.h_params.context_length
		knowledge_length = self.h_params.knowledge_length

		persona = self.focus_dataset_sample['persona']
		dialog = self.focus_dataset_sample['dialog']
		knowledge = self.focus_dataset_sample['knowledge']

		encoded_persona = self.tokenizer.batch_encode_plus(
			persona, 
			add_special_tokens=False
		)

		dialog_history = dialog[-2*dialog_history_length:]
		dialog_history_feature = self.tokenizer.batch_encode_plus(
			dialog_history[:-1], 
			add_special_tokens=False
		)
		dialog_history_target = self.tokenizer.batch_encode_plus(
			dialog_history[-1:], 
			add_special_tokens=False
		)

		# контекст на основе которого подбирается knowledge
		query_context = dialog_history_feature['input_ids'][-context_length:]
		encoded_knowledge = self.tokenizer.batch_encode_plus(
			knowledge, 
			add_special_tokens=False
		)
		
		tf_idf = FoCusTF_IDF(corpus=encoded_knowledge['input_ids'])
		most_similar_knowledge = tf_idf.top_similar(
			query=query_context,
		)
		
		# [BOS][persona][SEP][knowledge][SEP][dialog][:-1][SEP]<dialog>[dialog][-1]</dialog>
		flat_persona = self.__flat_list(encoded_persona['input_ids'])
		flat_knowledge = self.__flat_list(most_similar_knowledge)
		flat_dialog_history = self.__flat_list(dialog_history_feature['input_ids'])
		flat_bot_response = self.__flat_list(dialog_history_target['input_ids'])

		flat_persona = flat_persona[:self.h_params.max_persona_tokens]
		flat_knowledge = flat_knowledge[:self.h_params.max_knowledge_tokens]
		flat_dialog_history = flat_dialog_history[:self.h_params.max_dialog_history_tokens]
		flat_bot_response = flat_bot_response[:self.h_params.max_bot_response_tokens]

		input_sequence = [
			self.bos_token_id,
			*flat_persona,
			self.sep_token_id,
			*flat_knowledge,
			self.sep_token_id,
			*flat_dialog_history,
			self.sep_token_id,
			self.dialog_bos,
			*flat_bot_response,
			self.dialog_eos
		]

		attention_mask = [1] * len(input_sequence)

		return {
			'input_ids': input_sequence,
			'attention_mask': attention_mask,
		}
		

class FoCusDatasetV1:
	def __init__(self,
		input_dataset_path: str = None,
		) -> None:
		assert input_dataset_path is not None, 'input_dataset_path is None'

		self.input_dataset_path: str = input_dataset_path
		self.dataset: List[FoCusDatasetSampleDictV1] = []

		self.__build_dataset()
	
	def __build_dataset(self) -> None:
		initial_dataset = self.__read_dataset(self.input_dataset_path)
		self.dataset = self.__create_initial_dataset(initial_dataset=initial_dataset)
	
	def __create_initial_dataset(self, initial_dataset: Dict = None) -> List[FoCusDatasetSampleDictV1]:
		dataset = []
		initial_dataset_data = initial_dataset['data']
		
		for i, dialog_set in enumerate(initial_dataset_data):
			persona = dialog_set['persona']
			utterances = dialog_set['utterance']
			knowledge = dialog_set['knowledge']
			
			for j, utterance in enumerate(utterances):
				persona_grounding = list(map(int, utterance['persona_grounding']))
				knowledge_candidates = utterance['knowledge_candidates']
				knowledge_answer_index = utterance['knowledge_answer_index']
				dialog_index_key = [item for item in utterance.keys() if 'dialog' in item][0]
				dialog = utterance[dialog_index_key]
				
				data_sample = FoCusDatasetSampleV1(
					persona=persona,
					knowledge_candidates=knowledge_candidates,
					persona_grounding=persona_grounding,
					dialog=dialog,
					knowledge_answer_index=knowledge_answer_index,
					knowledge=knowledge,
				)
				data_sample = data_sample.get_dict()
				dataset.append(data_sample)
		
		return dataset
	
	def __read_dataset(self, input_path: str) -> list:
		with open(input_path, 'r') as f:
			dataset = json.load(f)
		return dataset

	def __len__(self) -> int:
		return len(self.dataset)
	
	def __getitem__(self, index: int) -> FoCusDatasetSampleDictV1:
		return self.dataset[index]

class PytorchFoCusDatasetV1(Dataset):
	def __init__(self, 
			dataset: FoCusDatasetV1,
		) -> None:
		self.dataset = dataset
		self.bart_hyperparameters = BartFoCusDatasetSampleHyperparametersV1()
		self.bart_tokenizer = BartFoCusTokenizerV1.from_pretrained(
			'facebook/bart-large',
			hyperparameters=self.bart_hyperparameters
		)
	
	def __len__(self) -> int:
		return len(self.dataset)
	
	def __getitem__(self, index: int) -> FoCusDatasetSampleDictV1:
		dataset_sample = self.dataset[index]
		train_sample = BartFoCusDatasetSampleV1(
			focus_dataset_sample=dataset_sample,
			tokenizer=self.bart_tokenizer,
			h_params=self.bart_hyperparameters,
		)
		return train_sample
