In [1]:
from torchtext import data
from torchtext.data import Dataset, Iterator, Field
import numpy as np
from collections import defaultdict, Counter
import pdb

In [2]:
tok_fun = lambda s: s.split()

src_field = data.Field(init_token=None, eos_token="<EOS>",
                           pad_token="<PAD>", tokenize=tok_fun,
                           batch_first=True, lower=True,
                           unk_token="<UNK>",
                           include_lengths=True)

trg_field = data.Field(init_token="<BOS>", eos_token="<EOS>",
                           pad_token="<PAD>", tokenize=tok_fun,
                           unk_token="<UNK>",
                           batch_first=True, lower=True,
                           include_lengths=True)
edge_org_field = data.Field(init_token=None, eos_token=None,
                           pad_token=None ,tokenize=tok_fun,
                           unk_token=None,
                           batch_first=True, lower=True,
                           include_lengths=True)
edge_trg_field = data.Field(init_token=None, eos_token=None,
                           pad_token=None, tokenize=tok_fun,
                           unk_token=None,
                           batch_first=True, lower=True,
                           include_lengths=True)
positional_en_field = data.Field(init_token=None, eos_token=None,
                           pad_token=None, tokenize=tok_fun,
                           unk_token=None,
                           batch_first=True, lower=True,
                           include_lengths=True)



In [3]:
print(src_field)

<torchtext.data.field.Field object at 0x7f178eacdf10>


In [12]:
class GraphTranslationDataset(data.Dataset):
    """Defines a dataset for machine translation with a graph reprsentation on the input and levi graph transformations."""

    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.src), len(ex.trg))

    def __init__(self, src_file, trg_file, fields, **kwargs):
        """Create a TranslationDataset given paths and fields.
        Arguments:
            path: Common prefix of paths to the data files for both languages.
            exts: A tuple containing the extension to path for each language.
            fields: A tuple containing the fields that will be used for data
                in each language.
            Remaining keyword arguments: Passed to the constructor of
                data.Dataset.
        """
        if not isinstance(fields[0], (tuple, list)):
            fields = [('src', fields[0]), ('trg', fields[1]),\
                     ('edge_org', fields[2]), ('edge_trg', fields[3]), ('positional_en',fields[4])]

        examples = []
        source_words,origins,targets=self.read_conllu(src_file)
        pes=self.gen_pes(source_words,origins,targets)
        
        target_words=self.read_text_file(trg_file)
        assert len(source_words)==len(target_words),"Mismatch of source and tagret sentences"
        print(pes)
        print(targets)
        print(origins)
        for i in range(len(source_words)):
                src_line, trg_line = " ".join(source_words[i]),target_words[i]
                src_line, trg_line = src_line.strip(), trg_line.strip()
                
                if src_line != '' and trg_line != '':
                    examples.append(data.Example.fromlist(
                        [src_line, trg_line," ".join(origins[i])," ".join(targets[i]),\
                        " ".join(pes[i])],\
                        fields))
        super(GraphTranslationDataset, self).__init__(examples, fields, **kwargs)

    @classmethod
    def splits(cls, exts, fields, path=None, root='.data',
               train='train', validation='val', test='test', **kwargs):
        """Create dataset objects for splits of a TranslationDataset.
        Arguments:
            exts: A tuple containing the extension to path for each language.
            fields: A tuple containing the fields that will be used for data
                in each language.
            path (str): Common prefix of the splits' file paths, or None to use
                the result of cls.download(root).
            root: Root dataset storage directory. Default is '.data'.
            train: The prefix of the train data. Default: 'train'.
            validation: The prefix of the validation data. Default: 'val'.
            test: The prefix of the test data. Default: 'test'.
            Remaining keyword arguments: Passed to the splits method of
                Dataset.
        """
        if path is None:
            path = cls.download(root)

        train_data = None if train is None else cls(
            os.path.join(path, train), exts, fields, **kwargs)
        val_data = None if validation is None else cls(
            os.path.join(path, validation), exts, fields, **kwargs)
        test_data = None if test is None else cls(
            os.path.join(path, test), exts, fields, **kwargs)
        return tuple(d for d in (train_data, val_data, test_data)
                     if d is not None)
    def read_conllu(self,path):
        """
        creates three lists: one with the sentences, and two that represent the edges fot he graph
        Argmunets:
            path: path to a file with sentences in the ConLL-U standard
        """
        f=open(path,'r')
        lines=f.readlines()
        f.close()
        words=[]
        origins=[]
        targets=[]
        edges=[]
        temp_words=[]
        temp_origins=[]
        temp_targets=[]
        temp_edges=[]
        for line in lines:
            if line=='\n'or line=='':
                words.append(temp_words)
                origins.append(np.array(temp_origins))
                targets.append(np.array(temp_targets))
                edges.append(temp_edges)
        
                temp_words=[]
                temp_origins=[]
                temp_targets=[]
                temp_edges=[]
    
            else:
                splits=line.split('\t')
                temp_words.append(splits[1])
                temp_origins.append(int(splits[0]))
                temp_targets.append(int(splits[6]))
                temp_edges.append("<"+splits[7]+">")
        for i in range(len(words)):
            new_origins=origins[i]-1
            edges_positions=np.arange(len(words[i]),2*len(words[i]))
            new_targets=edges_positions.copy()
            
            edge_targets=targets[i]-1
            root_pos=np.argmin(edge_targets)
            edge_targets = np.delete(edge_targets, [root_pos])
            edge_origins = np.delete(edges_positions,[root_pos])
            origins[i] = [str(num) for num in list(np.concatenate((new_origins,edge_origins)))]
            targets[i] = [str(num) for num in list(np.concatenate((new_targets,edge_targets)))]
            assert len(targets[i])==len(origins[i])
            words[i]=words[i]+edges[i]
            
        return words,origins,targets
    
    def read_text_file(self,path):
        """Read a text file 
        Argmunets:
            path: path to a normal txt file
        """
        f=open(path,'r')
        lines=f.readlines()
        f.close()
        return lines
    def gen_pe(self,words,org,trg,root_kw="<root>"):
        """Calculates the min distance to the root to each node using BFS
        Argmunets:
            words: all the words of the sentence
            org: a list with the origin of each edge
            trg: a list with the target of each edge
            root_kw: the keyword of the roo tag in the sentence
        """
        start=None
        for ind,word in enumerate(words):
            if word==root_kw:
                start=ind
                continue
        assert start!=None,"sentence does not have a <root> tag"
        visited=[start]
        distance_queue=[1]
        distances=['0']*len(words)
        while len(visited)!=0:
            for index,node in enumerate(trg):
                if str(node)==str(visited[0]):
                    distances[int(org[index])]=str(distance_queue[0])
                    visited.append(org[index])
                    distance_queue.append(distance_queue[0]+1)
            visited.pop(0)
            distance_queue.pop(0)
        return distances
    def gen_pes(self,source_words,orgs,trgs,root_kw="<root>"):
        """
        Generates the positional embeddings for all the sentences in the dataset
        argmunets:
            source_words: a list of sentences 
            orgs: a list of lists of ede origins
            trgs: a list of lists of the edge targets
            root_kw:the keyword of the root tag in the senteces
        """
        pes=[]
        for i in range(len(source_words)):
            pes.append(self.gen_pe(source_words[i],orgs[i],trgs[i],root_kw))
        return pes
    
        

In [13]:
DEFAULT_UNK_ID = lambda: 0
UNK_TOKEN = "<UNK>"
PAD_TOKEN = "<PAD>"
BOS_TOKEN = "<BOS>"
EOS_TOKEN = "<EOS>"
class Vocabulary:
    """ Vocabulary represents mapping between tokens and indices. """

    def __init__(self, tokens = None, file: str = None) -> None:
        """
        Create vocabulary from list of tokens or file.

        Special tokens are added if not already in file or list.
        File format: token with index i is in line i.

        :param tokens: list of tokens
        :param file: file to load vocabulary from
        """
        # don't rename stoi and itos since needed for torchtext
        # warning: stoi grows with unknown tokens, don't use for saving or size

        # special symbols
        self.specials = ["<UNK>", "<PAD>", "<BOS>", "<EOS>"]

        self.stoi = defaultdict(DEFAULT_UNK_ID)
        self.itos = []
        if tokens is not None:
            self._from_list(tokens)
        elif file is not None:
            self._from_file(file)

    def _from_list(self, tokens = None) -> None:
        """
        Make vocabulary from list of tokens.
        Tokens are assumed to be unique and pre-selected.
        Special symbols are added if not in list.

        :param tokens: list of tokens
        """
        self.add_tokens(tokens=self.specials+tokens)
        assert len(self.stoi) == len(self.itos)

    def _from_file(self, file: str) -> None:
        """
        Make vocabulary from contents of file.
        File format: token with index i is in line i.

        :param file: path to file where the vocabulary is loaded from
        """
        tokens = []
        with open(file, "r") as open_file:
            for line in open_file:
                tokens.append(line.strip("\n"))
        self._from_list(tokens)

    def __str__(self) -> str:
        return self.stoi.__str__()

    def to_file(self, file: str) -> None:
        """
        Save the vocabulary to a file, by writing token with index i in line i.

        :param file: path to file where the vocabulary is written
        """
        with open(file, "w") as open_file:
            for t in self.itos:
                open_file.write("{}\n".format(t))

    def add_tokens(self, tokens) -> None:
        """
        Add list of tokens to vocabulary

        :param tokens: list of tokens to add to the vocabulary
        """
        for t in tokens:
            new_index = len(self.itos)
            # add to vocab if not already there
            if t not in self.itos:
                self.itos.append(t)
                self.stoi[t] = new_index

    def is_unk(self, token: str) -> bool:
        """
        Check whether a token is covered by the vocabulary

        :param token:
        :return: True if covered, False otherwise
        """
        return self.stoi[token] == DEFAULT_UNK_ID()

    def __len__(self) -> int:
        return len(self.itos)

    def array_to_sentence(self, array: np.array, cut_at_eos=True,
                          skip_pad=True):
        """
        Converts an array of IDs to a sentence, optionally cutting the result
        off at the end-of-sequence token.

        :param array: 1D array containing indices
        :param cut_at_eos: cut the decoded sentences at the first <eos>
        :param skip_pad: skip generated <pad> tokens
        :return: list of strings (tokens)
        """
        sentence = []
        for i in array:
            s = self.itos[i]
            if cut_at_eos and s == EOS_TOKEN:
                break
            if skip_pad and s == PAD_TOKEN:
                continue
            sentence.append(s)
        return sentence

    def arrays_to_sentences(self, arrays: np.array, cut_at_eos=True,
                            skip_pad=True):
        """
        Convert multiple arrays containing sequences of token IDs to their
        sentences, optionally cutting them off at the end-of-sequence token.

        :param arrays: 2D array containing indices
        :param cut_at_eos: cut the decoded sentences at the first <eos>
        :param skip_pad: skip generated <pad> tokens
        :return: list of list of strings (tokens)
        """
        sentences = []
        for array in arrays:
            sentences.append(
                self.array_to_sentence(array=array, cut_at_eos=cut_at_eos,
                                       skip_pad=skip_pad))
        return sentences


def build_vocab(field: str, max_size: int, min_freq: int, dataset: Dataset,
                vocab_file: str = None) -> Vocabulary:
    """
    Builds vocabulary for a torchtext `field` from given`dataset` or
    `vocab_file`.

    :param field: attribute e.g. "src"
    :param max_size: maximum size of vocabulary
    :param min_freq: minimum frequency for an item to be included
    :param dataset: dataset to load data for field from
    :param vocab_file: file to store the vocabulary,
        if not None, load vocabulary from here
    :return: Vocabulary created from either `dataset` or `vocab_file`
    """

    if vocab_file is not None:
        # load it from file
        vocab = Vocabulary(file=vocab_file)
    else:
        # create newly
        def filter_min(counter: Counter, min_freq: int):
            """ Filter counter by min frequency """
            filtered_counter = Counter({t: c for t, c in counter.items()
                                        if c >= min_freq})
            return filtered_counter

        def sort_and_cut(counter: Counter, limit: int):
            """ Cut counter to most frequent,
            sorted numerically and alphabetically"""
            # sort by frequency, then alphabetically
            tokens_and_frequencies = sorted(counter.items(),
                                            key=lambda tup: tup[0])
            tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
            vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]]
            return vocab_tokens

        tokens = []
        for i in dataset.examples:
            if field == "src":
                tokens.extend(i.src)
            elif field == "trg":
                tokens.extend(i.trg)
                
            ### add edge_tokens
            elif field == "edge_org":
                tokens.extend(i.edge_org)
            elif field == "edge_trg":
                tokens.extend(i.edge_trg)
            elif field =="positional_en":
                tokens.extend(i.positional_en)
                

        counter = Counter(tokens)
        if min_freq > -1:
            counter = filter_min(counter, min_freq)
        vocab_tokens = sort_and_cut(counter, max_size)
        assert len(vocab_tokens) <= max_size

        vocab = Vocabulary(tokens=vocab_tokens)
        assert len(vocab) <= max_size + len(vocab.specials)
        assert vocab.itos[DEFAULT_UNK_ID()] == UNK_TOKEN

    # check for all except for UNK token whether they are OOVs
    for s in vocab.specials[1:]:
        assert not vocab.is_unk(s)

    return vocab

In [14]:
english_sentences=["I am happy to be here","This is the test sentence number 2","the last test sentence is  this one"]
spanish_sentences=["estoy feliz de estar aquí","ésta es la oración de prueba número 2","La última ouración de prueba es ésta"]
edge_indexes=[[[0,1,2,3,4,5],[0,1,2,3,4,5]],[[0,1,2,3,4],[0,1,2,3,4]],[[0,1,2,3],[0,1,2,3]]]
dataset=GraphTranslationDataset(\
                           "/home/hec44/Documents/joeynmt/test/data/toy/dev.conll",\
                           "/home/hec44/Documents/joeynmt/test/data/toy/dev.de",\
                           fields=(src_field, trg_field,edge_org_field,edge_trg_field,positional_en_field))

[['3', '3', '1', '5', '5', '3', '3', '2', '2', '0', '4', '4', '2', '2'], ['1', '3', '3', '3', '0', '2', '2', '2'], ['3', '3', '3', '3', '3', '3', '1', '3', '2', '2', '2', '2', '2', '2', '0', '2'], ['3', '3', '1', '5', '3', '7', '7', '5', '5', '5', '5', '7', '11', '11', '11', '9', '3', '2', '2', '0', '4', '2', '6', '6', '4', '4', '4', '4', '6', '10', '10', '10', '8', '2'], ['3', '1', '5', '3', '5', '5', '5', '3', '5', '7', '5', '9', '9', '9', '7', '7', '3', '5', '3', '5', '7', '7', '7', '7', '5', '9', '7', '13', '11', '9', '9', '9', '7', '11', '9', '11', '13', '13', '11', '15', '15', '13', '3', '2', '0', '4', '2', '4', '4', '4', '2', '4', '6', '4', '8', '8', '8', '6', '6', '2', '4', '2', '4', '6', '6', '6', '6', '4', '8', '6', '12', '10', '8', '8', '8', '6', '10', '8', '10', '12', '12', '10', '14', '14', '12', '2'], ['3', '3', '3', '3', '3', '1', '3', '2', '2', '2', '2', '2', '0', '2'], ['3', '3', '3', '1', '5', '5', '5', '3', '7', '5', '7', '11', '11', '9', '3', '5', '5', '3', '3', '2'

In [15]:
"""
        source_words,origins,targets=self.read_conllu(src_file)
        pes=self.gen_pes(source_words,origins,targets)
        
        target_words=self.read_text_file(trg_file)
"""

src_vocab = build_vocab(field="src", min_freq=0,
                            max_size=99,
                            dataset=dataset, vocab_file=None)
trg_vocab = build_vocab(field="trg", min_freq=0,
                            max_size=99,
                            dataset=dataset, vocab_file=None)
edge_org_vocab = build_vocab(field="edge_org", min_freq=0,
                            max_size=99,
                            dataset=dataset, vocab_file=None)
edge_trg_vocab = build_vocab(field="edge_trg", min_freq=0,
                            max_size=99,
                            dataset=dataset, vocab_file=None)
positional_en_vocab = build_vocab(field="positional_en", min_freq=0,
                            max_size=99,
                            dataset=dataset, vocab_file=None)
src_field.vocab = src_vocab
trg_field.vocab = trg_vocab
edge_org_field.vocab = edge_org_vocab
edge_trg_field.vocab = edge_trg_vocab
positional_en_field.vocab = positional_en_vocab

In [16]:
next(iter(dataset.positional_en))

['3', '3', '1', '5', '5', '3', '3', '2', '2', '0', '4', '4', '2', '2']

In [17]:
data_iter = data.BucketIterator(
            repeat=False, sort=False, dataset=dataset,
            batch_size=3, batch_size_fn=None,
            train=True, sort_within_batch=True,
            sort_key=lambda x: len(x.src), shuffle=True)



In [18]:
batch=next(iter(data_iter))



In [24]:
batch.positional_en

(tensor([[ 5,  5,  5,  5,  5, 15,  5,  4,  4,  4,  4,  4, 14,  4],
         [ 5, 15,  7,  5,  5,  4, 14,  6,  4,  4,  0,  0,  0,  0],
         [15,  5,  5,  5, 14,  4,  4,  4,  0,  0,  0,  0,  0,  0]]),
 tensor([14, 10,  8]))

In [25]:
hasattr(batch, "positional_en")

True

In [None]:

fields = [('edge_org', data.Field(sequential=False)), ('edge_trg', data.Field(sequential=False))]
data.Example.fromlist([[0,1,2,3,4,5],\
                       [1,2,3,4,5,0]],\
                      fields)


In [None]:
ex=data.Field(sequential=True)

In [None]:
ex.vocab

In [None]:
f=open('/home/hec44/Documents/joeynmt/test/data/toy/dev.conll','r')
lines=f.readlines()
f.close()
words=[]
origins=[]
targets=[]
edges=[]
temp_words=[]
temp_origins=[]
temp_targets=[]
temp_edges=[]
for line in lines:
    if line=='\n'or line=='':
        words.append(temp_words)
        origins.append(np.array(temp_origins))
        targets.append(np.array(temp_targets))
        edges.append(temp_edges)
        
        temp_words=[]
        temp_origins=[]
        temp_targets=[]
        temp_edges=[]
    
    else:
        splits=line.split('\t')
        temp_words.append(splits[1])
        temp_origins.append(int(splits[0]))
        temp_targets.append(int(splits[6]))
        temp_edges.append("<"+splits[7]+">")
        
    
      

In [None]:
for i in range(len(words)):
    new_origins=origins[i]-1
    edges_positions=np.arange(len(words[i]),2*len(words[i]))
    new_targets=edges_positions.copy()
    
    edge_targets=targets[i]-1
    root_pos=np.argmin(edge_targets)
    edge_targets = np.delete(edge_targets, [root_pos])
    edge_origins = np.delete(edges_positions,[root_pos])
    origins[i]=list(np.concatenate((new_origins,edge_origins)))
    targets[i]=list(np.concatenate((new_targets,edge_targets)))
    words[i]=words[i]+edges[i]
    


In [None]:
words[0]

In [None]:
origins[0]  

In [None]:
targets[0]

In [None]:
f=open('/home/hec44/Documents/joeynmt/test/data/toy/dev.conll','r')
lines=f.readlines()
f.close()
lines

In [None]:
def gen_pe(words,org,trg,root_kw="<root>"):
    start=None
    for ind,word in enumerate(words):
        if word==root_kw:
            start=ind
            continue
    assert start!=None,"sentence does not have a <root> tag"
    visited=[start]
    distance_queue=[1]
    distances=[0]*len(words)
    while len(visited)!=0:
        for index,node in enumerate(trg):
            if node==visited[0]:
                distances[org[index]]=distance_queue[0]
                visited.append(org[index])
                distance_queue.append(distance_queue[0]+1)
        visited.pop(0)
        distance_queue.pop(0)
    return distances
gen_pe(words[0],origins[0],targets[0])

In [None]:
test={}
test['a']={'b':0}
test['c']={'d':1}

In [None]:
test['a']['b']['d'].get('dfcdf',None)