In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cd "/content/drive/MyDrive/Knowledge-grounded Task-oriented Dialogue system"

In [3]:
!pip install torch==1.13.1
!pip install tqdm==4.62.3
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch==1.13.1
  Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl (887.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m887.5/887.5 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.7.99
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.3/849.3 kB[0m [31m69.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cublas-cu11==11.10.3.66
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.1/317.1 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-nvrtc-cu11==11.7.99
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━

In [4]:
import os
import json
import torch
import random
import copy

import pandas as pd

from tqdm import tqdm
from collections import defaultdict

In [5]:
#scripts/dataset_walker
class DatasetWalker(object):
    def __init__(self, dataset, dataroot, labels=False, labels_file=None, incl_knowledge=False):
        path = os.path.join(os.path.abspath(dataroot))
            
        if dataset not in ['train', 'val']:
            raise ValueError('Wrong dataset name: %s' % (dataset))

        logs_file = os.path.join(path, dataset, 'logs.json')
        with open(logs_file, 'r') as f:
            self.logs = json.load(f)

        self.labels = None

        if labels is True:
            if labels_file is None:
                labels_file = os.path.join(path, dataset, 'labels.json')

            with open(labels_file, 'r') as f:
                self.labels = json.load(f)

        self._incl_knowledge = incl_knowledge
        if self._incl_knowledge is True:
            # knowledge_reader 수정
            #self._knowledge = knowledge_reader(dataroot)
            self._knowledge = KnowledgeReader(dataroot)

    def __iter__(self):
        if self.labels is not None:
            for log, label in zip(self.logs, self.labels):
                if self._incl_knowledge is True and label['target'] is True:
                    for idx, snippet in enumerate(label['knowledge']):
                        domain = snippet['domain']
                        entity_id = snippet['entity_id']
                        doc_type = snippet['doc_type']
                        doc_id = snippet['doc_id']

                        if doc_type == 'review':
                            sent_id = snippet['sent_id']                            
                            sent = self._knowledge.get_review_sent(domain, entity_id, doc_id, sent_id)
                            label['knowledge'][idx]['sent'] = sent
                            
                        elif doc_type == 'faq':
                            doc = self._knowledge.get_faq_doc(domain, entity_id, doc_id)
                            question = doc['question']
                            answer = doc['answer']

                            label['knowledge'][idx]['question'] = question
                            label['knowledge'][idx]['answer'] = answer
                
                yield(log, label)
        else:
            for log in self.logs:
                yield(log, None)

    def __len__(self, ):
        return len(self.logs)

In [6]:
#scripts/knowledge_reader
class KnowledgeReader(object):
    def __init__(self, dataroot, knowledge_file='knowledge.json'):
        path = os.path.join(os.path.abspath(dataroot))

        with open(os.path.join(path, knowledge_file), 'r') as f:
            self.knowledge = json.load(f)

    def get_domain_list(self):
        return list(self.knowledge.keys())

    def get_entity_list(self, domain):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name")

        entity_ids = []
        for entity_id in self.knowledge[domain].keys():
            entity_ids.append(int(entity_id))

        result = []
        for entity_id in sorted(entity_ids):
            entity_name = self.knowledge[domain][str(entity_id)]['name']
            result.append({'id': entity_id, 'name': entity_name})

        return result

    def get_entity_name(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        result = self.knowledge[domain][str(entity_id)]['name'] or None

        return result

    def get_faq_doc_ids(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)
        
        result = []

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_obj = self.knowledge[domain][str(entity_id)]
        for doc_id, doc_obj in entity_obj['faqs'].items():
            result.append(doc_id)

        return result

    def get_faq_doc(self, domain, entity_id, doc_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_name = self.get_entity_name(domain, entity_id)

        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['faqs']:
            raise ValueError("invalid doc id: %s" % str(doc_id))

        doc_obj = self.knowledge[domain][str(entity_id)]['faqs'][str(doc_id)]
        result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'question': doc_obj['question'], 'answer': doc_obj['answer']}

        return result

    def get_review_doc_ids(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        result = []
        
        entity_obj = self.knowledge[domain][str(entity_id)]
        for doc_id, doc_obj in entity_obj['reviews'].items():
            result.append(doc_id)

        return result

    def get_review_doc(self, domain, entity_id, doc_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_name = self.get_entity_name(domain, entity_id)

        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']:
            raise ValueError("invalid doc id: %s" % str(doc_id))
        
        doc_obj = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]
        
        result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'sentences': doc_obj['sentences']}
        if 'traveler_type' in doc_obj:
            result['traveler_type'] = doc_obj['traveler_type']
        
        if 'dishes' in doc_obj:
            result['dishes'] = doc_obj['dishes']

        if 'drinks' in doc_obj:
            result['drinks'] = doc_obj['drinks']

        return result
    
    def get_review_sent(self, domain, entity_id, doc_id, sent_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))
        
        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']:
            raise ValueError("invalid doc id: %s" % str(doc_id))

        if str(sent_id) not in self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences']:
            raise ValueError("invalid sentence id: %s" % str(sent_id))

        result = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences'][str(sent_id)]

        return result

In [7]:
task = "selection"
dataroot = '/content/drive/MyDrive/Knowledge-grounded Task-oriented Dialogue system/Knowledge-grounded-ToD/data'
negative_sample_method = 'all'
knowledge_file = 'knowledge.json'
eval_only = False

class BaseDataset(torch.utils.data.Dataset):
    global dstc11
    global dstc11_val 

    df = pd.DataFrame({"history" : 'aaa', "knowledge_keys": 'bbb', "knowledge" : 'ccc', "candidates_keys": 'ddd', 'candidates' : 'eee'}, index = [0])
    dstc11 = pd.DataFrame(df, columns = ["history", "knowledge_keys", "knowledge", "candidates_keys", "candidates"])
    dstc11_val = pd.DataFrame(df, columns = ["history", "knowledge_keys", "knowledge", "candidates_keys", "candidates"])

    #df = pd.DataFrame({"knowledge" : 'aaa'}, index = [0])
    #knowledge = pd.DataFrame(df, columns = ['knowledge'])

    def __init__(self, split_type, labels=True, labels_file=None):
        self.dataroot = dataroot
        self.split_type = split_type
        self.task = task
        self.negative_sample_method = negative_sample_method

        self.dataset_walker = DatasetWalker(split_type, labels=labels, dataroot=self.dataroot, labels_file=labels_file)
        self.dialogs = self._prepare_conversations()
        self.knowledge_reader = KnowledgeReader(self.dataroot, knowledge_file)
        self.snippets = self._prepare_knowledge()
        self._create_examples()

    def _prepare_conversations(self):
        """ Tokenize and encode the dialog data """
        dialogs = []
        for i, (log, label) in enumerate(tqdm(self.dataset_walker, disable=False)):
            dialog = {}
            dialog["id"] = i
            dialog["log"] = log
            dialog["label"] = label
            dialogs.append(dialog)
        return dialogs

    def _prepare_knowledge(self):
        """ Tokenize and encode the knowledge snippets """
        self.knowledge_docs = self._get_snippet_list()

        snippets = defaultdict(dict)
        for snippet_id, snippet in enumerate(self.knowledge_docs):
            key = "{}__{}__{}".format(snippet["domain"], str(snippet["entity_id"]) or "", snippet["doc_id"])
            knowledge = self._knowledge_to_string(snippet["doc"], name=snippet["entity_name"] or "")
            snippets[key] = knowledge

        return snippets

    def _get_snippet_list(self):
        """ Get all knowledge snippets in the dataset """
        result = []
        i = 0
        for domain in self.knowledge_reader.get_domain_list():
            for entity_id in self.knowledge_reader.knowledge[domain].keys():
                for review_doc_id in self.knowledge_reader.get_review_doc_ids(domain, entity_id):
                    review_doc = self.knowledge_reader.get_review_doc(domain, entity_id, review_doc_id)
                    for review_sent_id, review_sent in review_doc['sentences'].items():
                        #i += 1
                        #knowledge.loc[i] = [used_knowledge]
                        result.append(
                            {'domain': domain, 'entity_id': entity_id, 'entity_name': review_doc['entity_name'],
                             'doc_id': f"{review_doc_id}-{review_sent_id}",
                             'doc': {'body': review_sent}})
                for faq_doc_id in self.knowledge_reader.get_faq_doc_ids(domain, entity_id):
                    #i += 1
                    faq_doc = self.knowledge_reader.get_faq_doc(domain, entity_id, faq_doc_id)
                    result.append({'domain': domain, 'entity_id': entity_id, 'entity_name': faq_doc['entity_name'],
                                   'doc_id': faq_doc_id,
                                   'doc': {'body': f"{faq_doc['question']} {faq_doc['answer']}"}})

        return result

    def _knowledge_to_string(self, doc, name=""):
        """ Convert a knowledge snippet to a string """
        doc_body = f"{name.title()}: {doc['body']}"
        return doc_body

    def _create_examples(self):
        """ Creating examples for model training and evaluation """
        self.examples = []
        idx = 0
        for dialog in tqdm(self.dialogs, disable=False, desc='creating examples'):
            dialog_id = dialog["id"]
            label = dialog["label"]
            dialog = dialog["log"]

            if label is None:
                # So we create dummy target here
                label = {"target": False}

            target = label["target"]

            if not target and self.task != "detection":
                # we only care about non-knowledge-seeking turns in turn detection task
                continue

            history = [turn["text"] for turn in dialog[-3:]]
            
            gt_resp = label.get("response", "")

            if target:
                knowledge_keys = []
                knowledge_candidates = defaultdict(lambda: 0)
                used_knowledge = []
                candidates_text = []
                knowledge_prefix_visited = set()

                if "knowledge" not in label:
                    raise ValueError("Please run entity matching before running knowledge selection")

                label_knowledge = label["knowledge"]

                for knowledge in label_knowledge:
                    if not (self.task == 'selection' and eval_only):
                        if knowledge['doc_type'] == 'review':
                            knowledge_key = f"{knowledge['domain']}__{knowledge['entity_id']}__{knowledge['doc_id']}-{knowledge['sent_id']}"
                        else:
                            knowledge_key = f"{knowledge['domain']}__{knowledge['entity_id']}__{knowledge['doc_id']}"

                    # find snippets with same entity as candidates
                    prefix = "{}__{}".format(knowledge["domain"], knowledge["entity_id"])
                    if prefix not in knowledge_prefix_visited:
                        knowledge_prefix_visited.add(prefix)
                        _knowledge_candidates = [
                            cand
                            for cand in self.snippets.keys()
                            if "__".join(cand.split("__")[:-1]) == prefix
                        ]

                        for _knowledge_cand_idx, _knowledge_cand in enumerate(_knowledge_candidates):
                            knowledge_candidates[_knowledge_cand] = 1

                    if self.split_type == "train" and self.negative_sample_method == "oracle":
                        # if there's not enough candidates during training, we just skip this example
                        if len(knowledge_candidates) < 2 or len(knowledge_candidates) <= len(label["knowledge"]): #n_candidates : 2
                            continue

                    if not (self.task == 'selection' and eval_only):
                        used_knowledge.append(self.snippets[knowledge_key])
                        knowledge_keys.append(knowledge_key)
                knowledge_candidates = [k for k, v in knowledge_candidates.items()]
                for k in knowledge_candidates : 
                  candidates_text.append(self.snippets[k])

            else:
                knowledge_candidates = None
                used_knowledge = []
                knowledge_keys = []

            self.examples.append({
                "history": history,
                "knowledge": used_knowledge,
                "knowledge_keys": knowledge_keys,
                "candidates": knowledge_candidates,
                "candidates_text" : candidates_text,
                "response_text": gt_resp,
                "label": label,
                "knowledge_seeking": target,
                "dialog_id": dialog_id
            })

            dstc11.loc[idx] = [history, knowledge_keys, used_knowledge, knowledge_candidates, candidates_text]
            if self.split_type == 'val' : 
              dstc11_val.loc[idx] = [history, knowledge_keys, used_knowledge, knowledge_candidates, candidates_text]
            idx += 1

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        return len(self.examples)

In [8]:
class DSTC11_Dataset(BaseDataset):
    def __init__(self, split_type, labels=True, labels_file=None):
        super(DSTC11_Dataset, self).__init__(split_type, labels, labels_file)

    def __getitem__(self, index):
        example = self.examples[index]
        instance, _ = self.build_input_from_segments(
            example["knowledge"],
            example["history"],
        )
        print(instance)
        return instance

class DSTC11_EvalDataset(BaseDataset):
    def __init__(self, split_type, labels=True, labels_file=None):
        super(DSTC11_EvalDataset, self).__init__(split_type, labels, labels_file)

    def __getitem__(self, index):
        example = self.examples[index]
        return example

    def collate_fn(self, batch):
        return batch

In [9]:
train_dataset = DSTC11_Dataset(split_type = 'train')
valid_dataset = DSTC11_Dataset(split_type = 'val')

100%|██████████| 28431/28431 [00:00<00:00, 1040506.23it/s]
creating examples: 100%|██████████| 28431/28431 [01:17<00:00, 364.56it/s]
100%|██████████| 4173/4173 [00:00<00:00, 974273.90it/s]
creating examples: 100%|██████████| 4173/4173 [00:10<00:00, 383.25it/s]


In [10]:
neg_idx = [[] for _ in range(len(dstc11))]
neg_text = [[] for _ in range(len(dstc11))]
for idx, cand_keys in enumerate(dstc11['candidates_keys']) :
  for i, key in enumerate(cand_keys):
    if key not in dstc11['knowledge_keys'][idx] :
      neg_idx[idx].append(key)
      neg_text[idx].append(dstc11['candidates'][idx][i])
dstc11['neg_keys'] = neg_idx
dstc11['neg_samples'] = neg_text
dstc11

Unnamed: 0,history,knowledge_keys,knowledge,candidates_keys,candidates,neg_keys,neg_samples
0,"[Do either of them have a 3 star rating?, Yes,...","[hotel__20__9-4, hotel__20__6-4, hotel__20__4-2]",[Hobsons House: I also saw some hairs in the b...,"[hotel__20__0-0, hotel__20__0-1, hotel__20__0-...",[Hobsons House: I was very please with my rece...,"[hotel__20__0-0, hotel__20__0-1, hotel__20__0-...",[Hobsons House: I was very please with my rece...
1,[I'm also looking for a restaurant by the name...,"[restaurant__19250__0-3, restaurant__19250__0-4]",[Maharajah Tandoori Restaurant: First thing is...,"[restaurant__19250__0-0, restaurant__19250__0-...",[Maharajah Tandoori Restaurant: My husband and...,"[restaurant__19250__0-0, restaurant__19250__0-...",[Maharajah Tandoori Restaurant: My husband and...
2,[I want it moderately priced and I don't care ...,"[hotel__7__6-2, hotel__7__6-3, hotel__7__3-0, ...",[Ashley Hotel: This place definitely delivered...,"[hotel__7__0-0, hotel__7__0-1, hotel__7__0-2, ...",[Ashley Hotel: I enjoyed my breakfast choices ...,"[hotel__7__0-0, hotel__7__0-1, hotel__7__0-2, ...",[Ashley Hotel: I enjoyed my breakfast choices ...
3,[I'm looking for information on the cambridge ...,"[hotel__28__3-4, hotel__28__6-4]",[The Cambridge Belfry: One of the best things ...,"[hotel__28__0-0, hotel__28__0-1, hotel__28__1-...",[The Cambridge Belfry: My girlfriend and I enj...,"[hotel__28__0-0, hotel__28__0-1, hotel__28__1-...",[The Cambridge Belfry: My girlfriend and I enj...
4,"[How about a gastropub restaurant?, I have two...","[restaurant__19188__0-3, restaurant__19188__1-...",[Backstreet Bistro: It's a nice location and t...,"[restaurant__19188__0-0, restaurant__19188__0-...",[Backstreet Bistro: My friends and I went into...,"[restaurant__19188__0-0, restaurant__19188__0-...",[Backstreet Bistro: My friends and I went into...
...,...,...,...,...,...,...,...
14763,"[In the mid-range would be fine., Rajmahal is ...","[restaurant__19274__3-1, restaurant__19274__3-...",[Rajmahal: She was very nice and accommodating...,"[restaurant__19274__0-0, restaurant__19274__0-...",[Rajmahal: I visited Rajmahal recently by myse...,"[restaurant__19274__0-0, restaurant__19274__0-...",[Rajmahal: I visited Rajmahal recently by myse...
14764,[It should be in the west and have a star rati...,"[hotel__17__0-2, hotel__17__6-1]",[Finches Bed And Breakfast: But where the room...,"[hotel__17__0-0, hotel__17__0-1, hotel__17__0-...",[Finches Bed And Breakfast: I'm conflicted abo...,"[hotel__17__0-0, hotel__17__0-1, hotel__17__0-...",[Finches Bed And Breakfast: I'm conflicted abo...
14765,[I need the reservation for 2 people for 14:30...,"[restaurant__19182__0-2, restaurant__19182__2-...",[The Golden Curry: The wait staff is courteous...,"[restaurant__19182__0-0, restaurant__19182__0-...",[The Golden Curry: If you're looking for a pla...,"[restaurant__19182__0-0, restaurant__19182__0-...",[The Golden Curry: If you're looking for a pla...
14766,"[On Tuesday please., I've booked your room for...","[hotel__19__7-8, hotel__19__2-2, hotel__19__2-3]",[Hamilton Lodge: The bed really needed a new m...,"[hotel__19__0-0, hotel__19__0-1, hotel__19__0-...","[Hamilton Lodge: I was here on business., Hami...","[hotel__19__0-0, hotel__19__0-1, hotel__19__0-...","[Hamilton Lodge: I was here on business., Hami..."


In [11]:
idx = 0
for i in dstc11['history']:
  dstc11['history'].loc[idx] = '[SEP]'.join(s for s in i)
  idx += 1

print(dstc11.head())

                                             history  \
0  Do either of them have a 3 star rating?[SEP]Ye...   
1  I'm also looking for a restaurant by the name ...   
2  I want it moderately priced and I don't care w...   
3  I'm looking for information on the cambridge b...   
4  How about a gastropub restaurant?[SEP]I have t...   

                                      knowledge_keys  \
0   [hotel__20__9-4, hotel__20__6-4, hotel__20__4-2]   
1   [restaurant__19250__0-3, restaurant__19250__0-4]   
2  [hotel__7__6-2, hotel__7__6-3, hotel__7__3-0, ...   
3                   [hotel__28__3-4, hotel__28__6-4]   
4  [restaurant__19188__0-3, restaurant__19188__1-...   

                                           knowledge  \
0  [Hobsons House: I also saw some hairs in the b...   
1  [Maharajah Tandoori Restaurant: First thing is...   
2  [Ashley Hotel: This place definitely delivered...   
3  [The Cambridge Belfry: One of the best things ...   
4  [Backstreet Bistro: It's a nice location an

In [12]:
idx = 0
for i in dstc11_val['history']:
  dstc11_val['history'].loc[idx] = '[SEP]'.join(s for s in i)
  idx += 1

In [13]:
history = list(dstc11['history'])
history_val = list(dstc11_val['history'])
len(history_val)

2129

In [None]:
cand_keys = []
querys = []
target = [] # knowledge면 1, 아니면 0
for idx, know_key in enumerate(dstc11['knowledge_keys']) :
  for i, k in enumerate(know_key) :
    cand_keys.append(dstc11['knowledge'][idx][i])
    querys.append(idx)
    target.append(1)
  
  for _ in range(len(know_key)) :
    i = random.randint(0, len(dstc11['neg_keys'][idx])-1)
    cand_keys.append(dstc11['neg_samples'][idx][i])
    querys.append(idx)
    target.append(0)
dpr_dataset = pd.DataFrame({'candidates' : cand_keys, 'target' : target, 'query_id' : querys })

In [None]:
dstc11_val

In [16]:
tmp = 0
for i in range(len(knowledge_val)) :
  tmp += len(knowledge_val[i])
print(tmp)

202137


In [15]:
knowledge = dpr_dataset['candidates'].tolist()
knowledge_val = dstc11_val['candidates'].tolist()
len(knowledge_val)

2129

In [None]:
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
from torch.utils.data import TensorDataset

Qmodel_name = "facebook/dpr-question_encoder-single-nq-base"
Cmodel_name = "facebook/dpr-ctx_encoder-single-nq-base"

q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(Qmodel_name)
c_tokenizer = DPRContextEncoderTokenizer.from_pretrained(Cmodel_name)

q_encoder = DPRQuestionEncoder.from_pretrained(Qmodel_name)
c_encoder = DPRContextEncoder.from_pretrained(Cmodel_name)

q_encoder.train()
c_encoder.train()

In [None]:
train_query_encodings = q_tokenizer(history, padding = True, truncation = True, return_tensors = 'pt')
train_context_encodings = c_tokenizer(knowledge, padding = True, truncation = True, return_tensors = 'pt')

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DPR(nn.Module):


  def __init__(self, query_model, passage_model, query_tokenizer, passage_tokenizer, 
              dense_size, freeze_params = 0.0, batch_size = 2):
    
    '''
    :query_model : The model that encodes queries to dense representation
    :passage_model : The model that encodes passages to dense representation
    :query_tokenizer : tokenizer for queries
    :passage_tokenizer : tokenizer for passages
    :passage_dict : dictionary of passages with their unique id
    :questions : A list of tuples with question and their correct passage id
    :dense_size : the dimension to which the DPR has to encode
    :freeze_params : the percentage of the parameters to be frozen
    :batch_size : the batch size for training
    :sample_size: the sample size for negative sampling
    '''
    super(DPR, self).__init__()
    self.query_model = query_model
    self.query_tokenizer = query_tokenizer
    self.passage_model = passage_model
    self.passage_tokenizer = passage_tokenizer
    self.freeze_params = freeze_params
    #self.sample_size = sample_size
    self.batch_size = batch_size

    self.passage_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
                                          nn.ReLU(),
                                          nn.Linear(dense_size * 2, dense_size),
                                          nn.GELU())
    
    self.query_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
                                          nn.ReLU(),
                                          nn.Linear(dense_size * 2, dense_size),
                                          nn.GELU())
    self.log_softmax = nn.LogSoftmax(dim = 1)
    self.freeze_layers()


  # Freeze the first self.freeze_params % layers
  def freeze_layers(self):
    num_query_layers = sum(1 for _ in self.query_model.parameters())
    num_passage_layers = sum(1 for _ in self.passage_model.parameters())

    for parameters in list(self.query_model.parameters())[:int(self.freeze_params * num_query_layers)]:
      parameters.requires_grad = False

    for parameters in list(self.query_model.parameters())[int(self.freeze_params * num_query_layers):]:
      parameters.requires_grad = True

    for parameters in list(self.passage_model.parameters())[:int(self.freeze_params * num_passage_layers)]:
      parameters.requires_grad = False

    for parameters in list(self.passage_model.parameters())[int(self.freeze_params * num_passage_layers):]:
      parameters.requires_grad = True

  def get_passage_vectors(self, passage):
    p_vector = self.passage_model(input_ids = passage.input_ids, 
                                  attention_mask = passage.attention_mask)
    p_vector = self.query_to_dense(p_vector.pooler_output)
    return p_vector

  def get_query_vector(self, query):
    q_vector = self.query_model(input_ids = query.input_ids, 
                                attention_mask = query.attention_mask)
    q_vector = self.query_to_dense(q_vector.pooler_output)
    return q_vector

  def dot_product(self, q_vector, p_vector):
    q_vector = q_vector.unsqueeze(1)
    sim = torch.matmul(q_vector, torch.transpose(p_vector, -2, -1))
    return sim

  def forward(self, context_input_ids, context_attention_mask, query_input_ids, query_attention_mask):
    dense_passage = self.passage_model(input_ids = context_input_ids, attention_mask = context_attention_mask)
    dense_query = self.query_model(input_ids = query_input_ids, attention_mask = query_attention_mask)
    dense_passage = dense_passage['pooler_output']
    dense_query = dense_query['pooler_output']
    dense_passage = self.passage_to_dense(dense_passage)
    dense_query = self.query_to_dense(dense_query)
    similarity_score = self.dot_product(dense_query, dense_passage)
    similarity_score = similarity_score.squeeze(1)
    logits = self.log_softmax(similarity_score)
    return logits

In [None]:
dpr_model = DPR(query_model = q_encoder, 
                passage_model = c_encoder, 
                query_tokenizer = q_tokenizer, 
                passage_tokenizer = c_tokenizer, 
                dense_size = 64,
                freeze_params = 0.3,
                batch_size = 4)

dpr_model.train()

sum(p.numel() for p in dpr_model.parameters()), sum(p.numel() for p in dpr_model.parameters() if p.requires_grad == True)

(217996672, 124251520)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [None]:
criterion = nn.NLLLoss()
optimizer = torch.optim.AdamW(dpr_model.parameters(), lr = 5e-5)

In [None]:
batch_size = 32
num_questions = len(train_context_encodings.input_ids)
num_questions

113488

In [None]:
def get_batch_train(num):
  true = []
  context_input_ids_tensor = []
  context_attention_mask_tensor = []
  optimizer.zero_grad()

  for i in range(batch_size) : 
    idx = num + i
    query_idx = dpr_dataset['query_id'][idx]
    context_input_ids_tensor.append(train_context_encodings.input_ids[idx])
    context_attention_mask_tensor.append(train_context_encodings.attention_mask[idx])
    query_input_ids = train_query_encodings.input_ids[query_idx]
    query_attention_mask = train_query_encodings.attention_mask[query_idx]

    target = dpr_dataset['target'][idx]
    if target == 1 :
      true.append(1)
    else :
      true.append(0)
  
  context_input_ids_tensor = torch.stack(context_input_ids_tensor)
  context_attention_mask_tensor = torch.stack(context_attention_mask_tensor)
  query_input_ids = query_input_ids.unsqueeze(0)
  query_attention_mask = query_attention_mask.unsqueeze(0)

  return context_input_ids_tensor, context_attention_mask_tensor, query_input_ids, query_attention_mask, true

In [None]:
batch_loss = 0
num_train_epochs = 5

q_encoder.zero_grad()
c_encoder.zero_grad()
torch.cuda.empty_cache()

#train_iterator = trange(int(num_train_epochs), desc = "Epoch")
for idx in range(num_train_epochs):
  print(f"Epoch : {idx}")
  print(f"len : {num_questions // batch_size}")
  for i, batch in enumerate(range(num_questions // batch_size)) :
    q_encoder.train()
    c_encoder.train()
    if torch.cuda.is_available() :
      context_input_ids_tensor, context_attention_mask_tensor, query_input_ids, query_attention_mask, true = get_batch_train(batch * batch_size)
      pred = dpr_model(context_input_ids_tensor, context_attention_mask_tensor, query_input_ids, query_attention_mask)
      
      true = torch.tensor([0])
      loss = criterion(pred, true)
      loss.backward()
      batch_loss += loss.item()
      optimizer.step()

      q_encoder.zero_grad()
      c_encoder.zero_grad()
      torch.cuda.empty_cache()

      if i%1000 == 0:
        print(f"Batch : {int(i/1000)}  Loss : {batch_loss/1000}")
        batch_loss = 0

  print("Evaluation")     
  with torch.no_grad():
    c_encoder.eval()
    q_encoder.eval()

    # top ~ acc 구하기
    top_5 = 0
    top_20 = 0
    top_100 = 0

    val_num = len(history_val)
    for i in range(len(history_val)) :
        #print(history_val[i], dstc11['knowledge'][i])
        val_query_encodings = q_tokenizer(history_val[i], padding = True, truncation = True, return_tensors = 'pt')
        q_emb = q_encoder(input_ids = val_query_encodings.input_ids, 
                          attention_mask = val_query_encodings.attention_mask).pooler_output.to(device)

        val_context_encodings = c_tokenizer(dstc11['candidates'][i], padding = True, truncation = True, return_tensors = 'pt')
        c_emb = c_encoder(input_ids = val_context_encodings.input_ids,
                          attention_mask = val_context_encodings.attention_mask).pooler_output.to(device)
        #c_emb = torch.Tensor(c_emb).squeeze()

        sim = torch.matmul(q_emb, torch.transpose(c_emb, 0, 1))
        rank = torch.argsort(sim, dim=1, descending=True).squeeze()

        if i in rank[0:5]: 
            top_5 += 1
        if i in rank[0:20]: 
            top_20 += 1
        if i in rank[0:100]: 
            top_100 += 1

    print('top-5 acc: ', top_5/val_num * 100)
    print('top-20 acc: ', top_20/val_num * 100)
    print('top-100 acc: ', top_100/val_num * 100)

Epoch : 0
len : 3546
Batch : 0  Loss : 0.0033298861980438233
Batch : 1  Loss : 3.4686846764087678
Batch : 2  Loss : 3.470310158729553
Batch : 3  Loss : 3.467512248516083
Evaluation
top-5 acc:  0.18788163457022078
top-20 acc:  0.7515265382808831
top-100 acc:  4.3212775951150775
Epoch : 1
len : 3546
Batch : 0  Loss : 1.892682545185089
Batch : 1  Loss : 3.466170468568802
Batch : 2  Loss : 3.46631063580513
Batch : 3  Loss : 3.4662934935092924
Evaluation
top-5 acc:  0.3287928604978863
top-20 acc:  1.0803193987787694
top-100 acc:  4.274307186472522
Epoch : 2
len : 3546
Batch : 0  Loss : 1.8943127226829528
Batch : 1  Loss : 3.4669907686710357
Batch : 2  Loss : 3.465662769794464
Batch : 3  Loss : 3.4661018614768984
Evaluation
top-5 acc:  0.18788163457022078
top-20 acc:  0.9394081728511038
top-100 acc:  4.274307186472522
Epoch : 3
len : 3546
Batch : 0  Loss : 1.892113411426544
Batch : 1  Loss : 3.466515021085739
Batch : 2  Loss : 3.4658370172977446
Batch : 3  Loss : 3.4662423129081725
Evaluatio

___________________________________________

In [None]:
with torch.no_grad():
    c_encoder.eval()
    q_encoder.eval()

    # top ~ acc 구하기
    top_5 = 0
    top_20 = 0
    top_100 = 0

    val_num = len(history_val)
    for i in range(len(history_val)) :
        #print(history_val[i], dstc11['knowledge'][i])
        val_query_encodings = q_tokenizer(history_val[i], padding = True, truncation = True, return_tensors = 'pt')
        q_emb = q_encoder(input_ids = val_query_encodings.input_ids, 
                          attention_mask = val_query_encodings.attention_mask).pooler_output.to(device)

        val_context_encodings = c_tokenizer(dstc11['candidates'][i], padding = True, truncation = True, return_tensors = 'pt')
        c_emb = c_encoder(input_ids = val_context_encodings.input_ids,
                          attention_mask = val_context_encodings.attention_mask).pooler_output.to(device)
        c_emb = torch.Tensor(c_emb).squeeze()

        sim = torch.matmul(q_emb, torch.transpose(c_emb, 0, 1))
        rank = torch.argsort(sim, dim=1, descending=True).squeeze()
        print(rank)

        if i in rank[0:5]: 
            top_5 += 1
        if i in rank[0:20]: 
            top_20 += 1
        if i in rank[0:100]: 
            top_100 += 1

    print('top-5 acc: ', top_5/val_num * 100)
    print('top-20 acc: ', top_20/val_num * 100)
    print('top-100 acc: ', top_100/val_num * 100)

tensor([84, 78, 56, 70, 16, 89, 65, 69, 90, 45, 39, 88, 25, 67, 55, 59, 58, 12,
        66, 71, 28, 53, 75, 80, 29, 64, 63, 61, 60, 23, 11, 68, 72, 62,  1,  7,
        10, 81, 57, 24, 74, 77, 83, 82, 54, 44, 38, 37, 76,  2,  9, 33, 85, 47,
        18, 34,  8, 46, 43, 32, 51, 31,  0, 19, 86, 73, 49, 13,  3, 17, 15, 22,
        20, 41, 21, 87,  4, 42, 79, 27, 14,  5, 36, 50, 40, 52, 26,  6, 30, 48,
        35], device='cuda:0')
tensor([28, 22, 15, 30,  5,  9, 12, 18, 11, 61, 48, 26, 20,  4, 55, 39, 27, 40,
        60, 44,  7, 58, 57, 34,  0, 25, 37,  1, 53, 50,  8, 13, 16,  6, 67, 17,
        41, 56, 49, 43, 29, 10,  3, 36, 31, 63, 23, 72, 14, 64, 46, 75,  2, 33,
        62, 47, 24, 42, 45, 54, 52, 51, 32, 19, 59, 68, 21, 77, 66, 35, 65, 73,
        76, 38, 69, 71, 74, 70], device='cuda:0')
tensor([86, 58, 84, 54, 87, 43, 29, 85, 66, 62,  7, 56, 49, 90, 61, 33, 68,  9,
        35, 44, 60,  3, 72, 36,  1, 45, 10, 26, 28, 83, 48, 34, 55, 42, 19, 52,
        70, 21, 14, 63, 67, 31, 16, 25, 

KeyboardInterrupt: ignored