## Read Data

In [None]:
from torchtext.data import get_tokenizer
import os
import numpy as np
from transformers import AutoTokenizer
from datasets import load_dataset,load_from_disk
from loguru import logger
import tqdm

class Options:
    def __init__(self, name) -> None:
        self.name= name
    def name(self):
        return self.name

# project gloal parameter
options = Options("Model")
options.base_path="/home/yang/github/fuzzys2s/"
options.SOS = 0 # start of sentence
options.EOS = 1 # End of sentence
options.PAD = 2 # padding token
options.UNK = 3 # unknown token, word frequency low

class Vocab:
    def __init__(self, name):
        self.name = name
        self.word2index = {"<sos>":options.SOS, "<eos>":options.EOS, "<pad>":options.PAD,"<unk>":options.UNK}
        self.word2count = {"<sos>":1, "<eos>":1, "<pad>":1,"<unk>":1}
        self.index2word = {options.SOS: "<sos>", options.EOS: "<eos>", options.PAD:"<pad>",options.UNK: "<unk>"}
        self.n_words = 4  # Count PAD , SOS and EOS
        self.feature_max = [] # max value of feature
        self.feature_min = [] # min value of feature
        self.line_max = 0 # max length of sentence

    def addTokens(self, tokens):
        for word in tokens:
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

class Vocab:
    def __init__(self, name):
        self.name = name
        self.word2index = {"<sos>":options.SOS, "<eos>":options.EOS, "<pad>":options.PAD,"<unk>":options.UNK}
        self.word2count = {"<sos>":1, "<eos>":1, "<pad>":1,"<unk>":1}
        self.index2word = {options.SOS: "<sos>", options.EOS: "<eos>", options.PAD:"<pad>",options.UNK: "<unk>"}
        self.n_words = 4  # Count PAD , SOS and EOS
        self.feature_max = [] # max value of feature
        self.feature_min = [] # min value of feature
        self.line_max = 0 # max length of sentence

    def addTokens(self, tokens):
        for word in tokens:
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

def get_base_tokenizer(name):
    tokenizer_path = options.base_path+"output/"+name+"/"
    if os.path.exists(tokenizer_path):
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    else:
        tokenizer = AutoTokenizer.from_pretrained(name)
        tokenizer.save_pretrained(tokenizer_path)
    return tokenizer.tokenize

def read_line_pair(src_path, tgt_path):
    src_fd = open(src_path, "r")
    tgt_fd = open(tgt_path, "r")
    src_lines = src_fd.readlines()
    tgt_lines = tgt_fd.readlines()
    lines =[]
    for i  in range(len(src_lines)):
        src = src_lines[i]
        tgt = tgt_lines[i]
        lines.append([src, tgt])
    src_fd.close()
    tgt_fd.close()
    return lines

def read_dataset(name, subpath):
    dataset_path = options.base_path+"output/"+name+"/"+subpath+"/"
    if os.path.exists(dataset_path):
        dataset = load_from_disk(dataset_path)
    else :
        if subpath == '':
            dataset = load_dataset(name)
        else:
            dataset = load_dataset(name,subpath)
        dataset.save_to_disk(dataset_path)
    print(name,'-',subpath,'done')

def read_hearthstone_data(tokenizer, vocab):
    logger.info("read raw data")
    train_lines = read_line_pair(options.base_path+'doc/hearthstone/train_hs.in', options.base_path+'doc/hearthstone/train_hs.out')
    valid_lines = read_line_pair(options.base_path+'doc/hearthstone/dev_hs.in', options.base_path+'doc/hearthstone/dev_hs.out')
    test_lines = read_line_pair(options.base_path+'doc/hearthstone/test_hs.in', options.base_path+'doc/hearthstone/test_hs.out')
    lines = train_lines + valid_lines + test_lines
    data = []
    for src, tgt in lines:
        src_tokens = tokenizer(src)
        tgt_tokens = tokenizer(tgt)
        vocab.addTokens(src_tokens)
        vocab.addTokens(tgt_tokens)
        data.append([src_tokens, tgt_tokens])
    return data

def read_euconst_data(tokenizer, vocab):
    logger.info("read opus_euconst data")
    dataset = read_dataset('opus_euconst', 'en-fr')
    logger.info("read raw tokens")
    src_lang = 'en'
    tgt_lang = 'fr'
    train_len = dataset['train'].num_rows
    logger.info("dataset:opus_euconst, total: %d"  %(train_len))
    train_raw_data = dataset['train']
    train_iter = iter(train_raw_data)
    train_data = np.empty([train_len], dtype=int).tolist()
    for i in range(train_len):
        data = next(train_iter)
        src = data['translation'][src_lang]
        tgt = data['translation'][tgt_lang]
        src = tokenizer(src)
        tgt = tokenizer(tgt)
        vocab.addTokens(src)
        vocab.addTokens(tgt)
        train_data[i] = [src, tgt]
    return train_data

def read_xlsum_data(tokenizer, vocab):
    dataset = read_dataset('GEM/xlsum', 'french')
    logger.info("read raw tokens")
    train_len = dataset['train'].num_rows
    valid_len = dataset['validation'].num_rows
    test_len = dataset['test'].num_rows
    train_data = np.empty([train_len], dtype=int).tolist()
    train_iter = iter(dataset['train'])
    for i in tqdm(range(train_len), 'read train data'):
        data = next(train_iter)
        src = data['text']
        tgt = data['target']
        src = tokenizer(src)
        tgt = tokenizer(tgt)
        train_data[i] = [src, tgt]
    valid_data = np.empty([valid_len], dtype=int).tolist()
    valid_iter = iter(dataset['validation'])
    for i in tqdm(range(valid_len), 'read valid data'):
        data = next(valid_iter)
        src = data['text']
        tgt = data['target']
        src = tokenizer(src)
        tgt = tokenizer(tgt)
        valid_data[i] = [src, tgt]
    test_data = np.empty([test_len], dtype=int).tolist()
    test_iter = iter(dataset['test'])
    for i in tqdm(range(test_len), 'read test data'):
        data = next(test_iter)
        src = data['text']
        tgt = data['target']
        src = tokenizer(src)
        tgt = tokenizer(tgt)
        test_data[i] = [src, tgt]
    data = train_data+ valid_data+ test_data
    return data
hs_vocab = Vocab('hs')
euconst_vocab = Vocab('euconst')
xlsum_vocab = Vocab('xlsum')

hs_data = read_hearthstone_data(tokenizer, hs_vocab)






## Plot Graph