In [1]:
import spacy
import csv
import pyprind
from torchtext import data
from torchtext.vocab import GloVe
import numpy as np
import torch
from torchtext.data import Iterator, BucketIterator
from spacy.symbols import ORTH

class dataProcesser():
    def __init__(self,src,des,n):
        self.src = src
        self.des = des
        self.n = n
        self.nlp =  spacy.load('en_core_web_sm', disable=['ner', 'parser', 'tagger'])  
        self.dics = {}
        self.sepToken = ' <sep> '
        CONTEXT = data.Field()
        ANSWER  = data.Field()
        QUESTION = data.Field()
       
        # define col: {[source data col name]:[your data col name],Field}
        fields = {'context':('Context',CONTEXT),'question':('Question',QUESTION),'supporting_facts':('Answer',ANSWER)}
        dataset = data.TabularDataset(path = src,format='json',fields=fields)
        dataset = dataset.examples[0]
        
        for i in range (0,len(dataset.Context)):
            for title,sentence in dataset.Context[i]:
                self.dics[title] = sentence
        self.go(dataset)
        
    def getAnswer(self,ans):
        res = ''
        for title, sent_id in ans:
            if title in self.dics:
                if sent_id < len(self.dics[title]):
                    res += self.dics[title][sent_id] + self.sepToken
        return res
    
    def getContext(self,text2DimList):
        res = ''
        for paragragh in text2DimList:
            res += self.sepToken.join(paragragh[1])
        return  res  

    def go(self,dataset):
        pbar = pyprind.ProgBar(self.n)
        with open(self.des,'w',encoding="utf-8",newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['passage','question', 'answer'])
            for i in range (0,self.n):
                c =  self.getContext(dataset.Context[i]).lower()
                q =  dataset.Question[i].lower()
                a =  self.getAnswer(dataset.Answer[i]).lower()
                writer.writerow([c,q,a])
                pbar.update()
        print('write down')  
        
        
class getHotpotData():
    def __init__(self,args,trainPath,devPath,):
        self.nlp = spacy.load('en_core_web_sm') 
        # Add special case rule
        special_case = [{ORTH: "<sep>"}]
        self.nlp.tokenizer.add_special_case("<sep>", special_case)
        
        
        self.trainpath= trainPath
        self.devpath= devPath
        
        # spacy -> spacy , others-> your token
        self.ANSWER  = data.Field(tokenize = self.tokenizer(mode= 'spacy'))
        self.QUESTION = data.Field(tokenize = self.tokenizer(mode= 'spacy'))
        self.PASSAGE = data.Field(tokenize = self.tokenizer(mode= 'w'))
        
        fields = {'passage':('Passage', self.PASSAGE),'question':('Question', self.QUESTION),'answer':('Answer', self.ANSWER)}
        
        self.train = data.TabularDataset(path = self.trainpath,format='csv',fields=fields)
       
        #self.train.examples.PASSAGE = [self.spilter(i for i in  self.train.examples.PASSAGE)]
        
        self.dev = data.TabularDataset(path = self.devpath,format='csv',fields=fields)
        
    
        
        self.PASSAGE.build_vocab(self.train,self.dev, vectors=GloVe(name='6B', dim=300))  
        self.QUESTION.build_vocab(self.train) 
        self.ANSWER.build_vocab(self.train)
        
        self.train_iter = data.BucketIterator(dataset=self.train, batch_size=args.batch_size, shuffle=True, sort_within_batch=False, repeat=False,device=args.gpu)
        self.dev_iter = data.BucketIterator(dataset=self.dev, batch_size=args.batch_size, shuffle=True, sort_within_batch=False, repeat=False,device=args.gpu)
        self.calculate_block_size(args.batch_size)
        print('load hotpot data done')
        
    def tokenizer(self,text,mode):
        if mode is 'spacy':
            return [str(token) for token in self.nlp(text)]
        else:
            speciallToken = ['(',')',',','.']
            for t in speciallToken:
                text = text.replace(i,' '+i+' ')
                
            return text.split(' ')
            
    def spilter(self,x,tk):
        res = []
        s = 0
        for i,t in enumerate(x): 
            if t == tk:
                res.append(x[s:i])
                s = i+1   
        res.append(x[s:-1])
        return res
    
    def calculate_block_size(self, B):
        data_lengths = []
        for e in self.train.examples:
            data_lengths.append(len(e.Passage))

        mean = np.mean(data_lengths)
        std = np.std(data_lengths)

        self.block_size = int((2 * (std * ((2 * np.log(B)) ** (1/2)) + mean)) ** (1/3))
    