In [None]:
import nltk
from nltk.corpus import wordnet as wn

In [None]:
import pickle
import spacy
from pathlib import Path

In [None]:
import shutil
def make_to_path(path):
    return Path(path) if(type(path) == str) else path
def remove_if_exists(path):
    path = make_to_path(path)
    if(path.exists() and path.is_dir()):
        shutil.rmtree(path)
    elif(path.exists()):
        path.unlink()
def create_dir(path):
    path = make_to_path(path)
    if(not path.exists()):  path.mkdir(parents=True, exist_ok=True)
    return path

In [None]:
remove_if_exists('./data')
create_dir('./data')

In [None]:
# create bs batches, tot_len//bs seqlen <- in each file
class TextGenerator():
    def __init__(self, path_list=None, text_list=None, dest_path=Path('./data/text_gen_data'), nlp=None, encoding="utf8"):
        assert type(path_list) == list or type(text_list) == list
        dest_path = make_to_path(dest_path)
        dest_path = create_dir(dest_path)
        if(path_list is None):
            self.path_list = self.create_data_paths(dest_path, text_list)
        else:
            self.path_list = path_list
        self.nlp = spacy.load("en_core_web_sm") if(nlp is None) else nlp
        self.encoding = encoding
    
    def create_data_paths(self, base_path, text_list):
        i = 0
        paths = []
        for text in text_list:
            dest_path = base_path / f"{i}.txt"
            i += 1
            paths.append(dest_path)
            with open(dest_path, "w") as f:
                f.write(text)
        return paths
    
    def get_next_text(self):
        for p in self.path_list:
            with open(p, 'r',encoding=self.encoding) as f:
                s = f.read()
                yield p, s

In [None]:
sentences = [
    'Hariom is playing',
    'I got milk from the store',
    'from the store, I got milk',
    'hey hey hey',
    'this sentence can be structured to a tree in a systematic manner',
    'what is he doing',
    'sentence illegal is',
    'sentence is legal',
    'This sentence is going to be long because we need to check what is going on',
    'The poop is kept on the beautiful table',
    'bread and butter is good for health',
    'wine is great with bread and butter is good for health'
]

In [None]:
nlp = spacy.load("en_core_web_sm")

In [None]:
class Vocab():
    def __init__(self, tokenizer, max_len=20000, nlp=None):
        self.tokenizer, self.max_len = tokenizer, max_len
        self.nlp = spacy.load("en_core_web_sm") if(nlp is None) else nlp
        self.itos, self.stoi, self.stoc = [], {}, {}
        self.special_tokens = ["xxbos", "xxeos", "xxmaj", "xxcap", "xxunk"]
        
    def trim_word_index_mappings(self):
        sorted_stoc = {k: v for k, v in reversed(sorted(self.stoc.items(), key=lambda item: item[1]))}
        for k in self.special_tokens:
            if k in sorted_stoc: del sorted_stoc[k]
        sorted_list = list(sorted_stoc.items())
        
        new_stoi, new_itos, new_stoc = {}, [], {}
        for spcl_tok in self.special_tokens:
            cur_ind = len(new_itos)
            new_stoi[spcl_tok] = cur_ind
            new_itos.append(spcl_tok)
            
        for i in range(min(self.max_len, len(sorted_list))):
            k, v = sorted_list[i]
            cur_ind = len(new_itos)
            new_stoc[k], new_stoi[k] = v, cur_ind
            new_itos.append(k)
        self.itos, self.stoi, self.stoc = new_itos, new_stoi, new_stoc
    
    def update_token_list(self, token_list, word_to_leaf_type):
        new_list = []
        for token in token_list:
            tok, ltype = token
            if tok not in self.stoi: new_list.append(("xxunk", "SPCL_APPEND"))
            else: new_list.append((tok, ltype))
        return new_list
    
    def push_tok_to_word_index_mappings(self, tok):
        if tok not in self.stoi:
            ind = len(self.itos)
            self.stoi[tok], self.stoc[tok] = ind, 1
            self.itos.append(tok)
        else:
            self.stoc[tok] += 1
        
    def get_trimmed_tokens(self, token_list, word_to_leaf_type):
        return self.update_token_list(token_list, word_to_leaf_type)    

In [None]:
class Tokenizer():
    def __init__(self, path_list=None, text_list=None, dest_path=Path('./data/tokenizer_data'), 
                 apply_defaults=True, max_len=2000, callbacks=[], 
                 nlp=None):
        dest_path = make_to_path(dest_path)
        dest_path = create_dir(dest_path)
        self.leaf_types = {
            "ROOT_LEAF": 0,
            "ONE_CHILD_LEAF": 1,
            "TWO_CHILD_LEAF": 2,
            "BRIDGE_LEAF": 3,
            "INVARIANT": 4,
        }       
        self.word_to_leaf_type = {
            "PUNCT": self.leaf_types["BRIDGE_LEAF"],
            "SYM": self.leaf_types["ONE_CHILD_LEAF"],
            "ADJ": self.leaf_types["ONE_CHILD_LEAF"],
            "ADP": self.leaf_types["ONE_CHILD_LEAF"],
            "CCONJ": self.leaf_types["BRIDGE_LEAF"],
            "NUM": self.leaf_types["ONE_CHILD_LEAF"],
            "DET": self.leaf_types["ONE_CHILD_LEAF"],
            "ADV": self.leaf_types["TWO_CHILD_LEAF"],
            
            "X": self.leaf_types["ONE_CHILD_LEAF"],
            "VERB": self.leaf_types["TWO_CHILD_LEAF"],
            "NOUN": self.leaf_types["ONE_CHILD_LEAF"],
            "PROPN": self.leaf_types["ONE_CHILD_LEAF"],
            "PART": self.leaf_types["ONE_CHILD_LEAF"],
            "INTJ": self.leaf_types["BRIDGE_LEAF"],
            "PRON": self.leaf_types["ONE_CHILD_LEAF"],
            "SPCL_APPEND": self.leaf_types["INVARIANT"],
            "SPCL_BRIDGE": self.leaf_types["BRIDGE_LEAF"],
            "ROOT": self.leaf_types["ROOT_LEAF"]
        }
        
        self.nlp = spacy.load("en_core_web_sm") if(nlp is None) else nlp
        self.text_generator = TextGenerator(path_list, text_list, nlp=self.nlp)
        self.tokens_paths = []
        self.callbacks = [self.put_bos_eos, self.deal_with_caps] if(apply_defaults) else []
        self.callbacks.extend(callbacks)
        self.tokens_len = self.prepare_initial_tokens(path_list, dest_path)
        self.max_len = max_len
        self.vocab = Vocab(self, max_len, nlp=self.nlp)
        self.prepare_vocab()
        self.update_tokens_lists()
        
    def tokenify(self, vector):
        lv = [i.item() for i in vector]
        ret = []
        for item in lv:
            ret.append(tk.vocab.itos[item])
        return ret
    
    def put_bos_eos(self, tokens_list):
        return [("xxbos", "SPCL_BRIDGE"), *tokens_list, ("xxeos", "SPCL_BRIDGE")]

    def deal_with_caps(self, token_list):
        nlist = []
        for token in token_list:
            nlist.append((token[0].lower(), token[1]))
            if(token[0].isupper()):
                nlist.append(("xxcap", "SPCL_APPEND"))
            elif(token[0][0].isupper()):
                nlist.append(("xxmaj", "SPCL_APPEND"))
        return nlist

    def apply_callbacks(self, token_list):
        for callback in self.callbacks:
            token_list = callback(token_list)
        return token_list
    
    def reset_vocab(self):
        self.vocab = Vocab(self, self.max_len)

    def prepare_vocab(self):
        for _,token_list in self.get_next_token_list():
            for tok,_ in token_list:
                self.vocab.push_tok_to_word_index_mappings(tok)
        self.vocab.trim_word_index_mappings()
        
    def prepare_initial_tokens(self, src_paths, dest_root):
        tot_len = 0
        for path, content in self.text_generator.get_next_text():
            fname = path.stem
            dest_path = dest_root / f"{fname}.pkl"
            self.tokens_paths.append(dest_path)
            token_list = self.simplified_tagged_list(content)
            token_list = self.apply_callbacks(token_list)
            tot_len += len(token_list)
            with open(dest_path, 'wb') as f:
                pickle.dump(token_list, f)
        return tot_len
    
    def get_next_token_list(self):
        for p in self.tokens_paths:
            with open(p, "rb") as f:
                tk_list = pickle.load(f)
                for i in range(0, len(tk_list), self.max_len):
                    yield p, tk_list[i : min(i+self.max_len, len(tk_list))]

    def update_tokens_lists(self):
        for p in self.tokens_paths:
            with open(p, "rb") as f:
                tk_list = pickle.load(f)
                new_tk_list = self.vocab.get_trimmed_tokens(tk_list, self.word_to_leaf_type)
            with open(p, "wb") as f:
                pickle.dump(new_tk_list, f)
    
    def simplified_tagged_list(self, sentence):
        doc = self.nlp(sentence)
        tags = [(token.text, token.pos_) for token in doc]
        return tags

    def get_indices_from_tokens(self, tokens):
        return [self.vocab.stoi[token[0]] for token in tokens]
    
    def get_word_tensors(self, tokens):
        indices = self.get_indices_from_tokens(tokens)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        return torch.tensor(indices, dtype=torch.long, device=device).view(-1, 1)
    def get_ltype(self, token):
        return self.word_to_leaf_type[token]
    def get_ltypes_dict(self):
        return self.leaf_types


In [None]:
def get_brown_text():
    from nltk.corpus import brown
    return brown.words(categories='news')

In [None]:
btext_list = get_brown_text()
btext = ""
for word in btext_list:
    btext += f" {word}"
btext_list = [btext]


In [None]:
import gc
gc.collect()

In [None]:
from nltk.tag import pos_tag, map_tag
from nltk.tokenize import word_tokenize

def unsimplified_tagged_list(sentence):
    tokens = word_tokenize(sentence)
    return pos_tag(tokens)
    
def tagged_list(sentence):
    tokens = word_tokenize(sentence)
    posTagged = pos_tag(tokens)
    tags = [(word, map_tag('en-ptb', 'universal', tag)) for word, tag in posTagged]
    return tags

def download_tag_packs():
    import nltk
    nltk.download('punkt')
    nltk.download(["tagsets", "universal_tagset"])
    nltk.download("maxent_treebank_pos_tagger")
    nltk.download("maxent_ne_chunker")
    nltk.download("punkt")
    nltk.download('wordnet')
    nltk.download('averaged_perceptron_tagger')
    nltk.download('brown')

def get_brown_text():
    from nltk.corpus import brown
    return brown.words(categories='news')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
class InvalidDirectionError(Exception):
    def __init__(self, *args):
        if args:
            self.message = args[0]
        else:
            self.message = None

    def __str__(self):
        if self.message:
            return f'InvalidDirectionError: {self.message}'
        else:
            return 'InvalidDirectionError: '

In [None]:
class BaseLeaf():
    def __init__(self, token, wtype, tokenizer, embedding_size=50):
        self.tokenizer = tokenizer
        self.tokens, self.leaf_type = [(token, wtype)], tokenizer.get_ltype(wtype)
        self.leaf_types_dict = self.tokenizer.get_ltypes_dict()
        self.example_tokens, self.target_tokens, self.example_ltypes, self.target_ltypes = None, None, None, None

    def label_tokens(self, example_tokens, target_tokens):
        self.example_tokens, self.target_tokens = ((self.tokenizer.get_word_tensors(example_tokens), self.tokenizer.get_word_tensors(target_tokens))
                                                  if self.get_self_leaf_type() is not self.leaf_types_dict["ROOT_LEAF"]
                                                  else (None, None))
    
    def label_ltypes(self, example_ltypes, target_ltypes):
        self.example_ltypes, self.target_ltypes = ((torch.tensor(example_ltypes).view(-1,1), torch.tensor(target_ltypes).view(-1,1))
                                                  if self.get_self_leaf_type() is not self.leaf_types_dict["ROOT_LEAF"]
                                                  else (None, None))
    
    def get_labelled_set(self):
        return self.example_tokens, self.target_tokens, self.example_ltypes, self.target_ltypes

    def append_token(self, token, wtype):
        self.tokens.append((token, wtype))
        
    def get_self_leaf_type(self):
        return self.leaf_type
    def generate_text(self):
        before, after = self.get_ordered_children()
        s = ""
        for c in before:
            s += c.generate_text() if(c is not None) else ""
        for token, wtype in self.tokens:
            s += f" {token}"
        for c in after:
            s += c.generate_text() if(c is not None) else ""
        return s
    
    def get_next_leaf(self):
        before, after = self.get_ordered_children()
        for child in before:
            yield from child.get_next_leaf()
        yield self
        for child in after:
            yield from child.get_next_leaf()
            
    def get_leaf_type_of(self, wtype):
        return self.tokenizer.get_ltype(wtype)
    def __str__(self, spacing):
        raise NotImplementedError()
    def push_token(self, token, wtype):
        raise NotImplementedError()
    def set_parent(self, leaf):
        raise NotImplementedError()
    def replace_child(self, old, new):
        raise NotImplementedError()
    def get_ordered_children(self):
        raise NotImplementedError()
    def get_ordered_children_for_model(self):
        raise NotImplementedError()

In [None]:
class TwoChildLeaf(BaseLeaf):
    def __init__(self, token, wtype, tokenizer, creator=None):
        super().__init__(token, wtype, tokenizer)
        assert (self.get_self_leaf_type() == self.leaf_types_dict["TWO_CHILD_LEAF"] or
               self.get_self_leaf_type() == self.leaf_types_dict["INVARIANT"])
        self.l_child, self.r_child = None, None
        self.parent = creator
        self.bridge = None
    
    def get_ordered_children(self):
        before, after = [], []
        if(self.l_child is not None): before.append(self.l_child)
        if(self.r_child is not None): after.append(self.r_child)
        if(self.bridge is not None): after.append(self.bridge)
        return before, after
    
    def get_ordered_children_for_model(self):
        before, after = [], []
        if(self.l_child is not None): before.append(self.l_child)
        if(self.r_child is not None): before.append(self.r_child)
        if(self.bridge is not None): after.append(self.bridge)
        return before, after
        
    def __str__(self, spacing):
        new_spacing = f"{spacing}\t"
        base_str = f"{spacing}type:{self.get_self_leaf_type()} tokens:{self.tokens}\n"
        if(self.l_child is not None):
            base_str += f"{self.l_child.__str__(new_spacing)}\n"
        else:
            base_str += f"{new_spacing}None\n"
        if(self.r_child is not None):
            base_str += f"{self.r_child.__str__(new_spacing)}\n"
        else:
            base_str += f"{new_spacing}None\n"
        if(self.bridge is not None):
            base_str += f"{self.bridge.__str__(new_spacing)}\n"
        else:
            base_str += f"{new_spacing}None\n"
        return base_str

    def set_parent(self, parent):
        self.parent = parent
    
    def replace_child(self, old, new):
        if(self.l_child == old):
            self.l_child = new
        elif(self.r_child == old):
            self.r_child = new
        else:
            raise Exception('does not exist as child')
            
    def push_token(self, token, wtype):
        ltype = self.get_leaf_type_of(wtype)
        if(ltype == self.leaf_types_dict["BRIDGE_LEAF"]):
            self.bridge = BridgeLeaf(token, wtype, self.tokenizer, self)
            return self.bridge
        elif(ltype == self.leaf_types_dict["TWO_CHILD_LEAF"] or ltype == self.leaf_types_dict["INVARIANT"]):
            self.append_token(token, wtype)
            return self
        elif(ltype == self.leaf_types_dict["ONE_CHILD_LEAF"]):
            self.r_child = OneChildLeaf(token, wtype, self.tokenizer, self)
            return self.r_child
        else:
            raise Exception('wrong leaf type for child')
        
class OneChildLeaf(BaseLeaf):
    def __init__(self, token, wtype, tokenizer, creator=None):
        super().__init__(token, wtype, tokenizer)
        assert (self.get_self_leaf_type() == self.leaf_types_dict["ONE_CHILD_LEAF"] or
               self.get_self_leaf_type() == self.leaf_types_dict["INVARIANT"])
        self.bridge, self.child = None, None
        self.parent = creator
    
    def get_ordered_children(self):
        before = []
        after = []
        if(self.child is not None): after.append(self.child)
        if(self.bridge is not None): after.append(self.bridge)
        return before, after
        
    def get_ordered_children_for_model(self):
        before, after = [], []
        if(self.child is not None): before.append(self.child)
        if(self.bridge is not None): after.append(self.bridge)
        return before, after    
            
    def __str__(self, spacing):
        new_spacing = f"{spacing}\t"
        base_str = f"{spacing}type:{self.get_self_leaf_type()} words:{self.tokens}\n"
        if(self.child is not None):
            base_str += f"{self.child.__str__(new_spacing)}\n"
        else:
            base_str += f"{new_spacing}None\n"
        
        if(self.bridge is not None):
            base_str += f"{self.bridge.__str__(new_spacing)}\n"
        else:
            base_str += f"{new_spacing}None\n"
        return base_str

    def set_parent(self, parent):
        self.parent = parent
    
    def replace_child(self, old, new):
        if(self.child == old):
            self.child = new
        else:
            raise Exception('does not exist as child')
    
    def push_token(self, token, wtype):
        ltype = self.get_leaf_type_of(wtype)
        if(ltype == self.leaf_types_dict["BRIDGE_LEAF"]):
            self.bridge = BridgeLeaf(token, wtype, self.tokenizer, self)
            return self.bridge
        elif(ltype == self.leaf_types_dict["ONE_CHILD_LEAF"] or ltype == self.leaf_types_dict["INVARIANT"]):
            self.append_token(token, wtype)
            return self
        elif(ltype == self.leaf_types_dict["TWO_CHILD_LEAF"]):
            if(self.parent.get_self_leaf_type() == self.leaf_types_dict["ONE_CHILD_LEAF"]):
                self.parent.push_token(token, wtype)
            else:
                # replace new_leaf as this node
                new_leaf = TwoChildLeaf(token, wtype, self.tokenizer, None)
                new_leaf.l_child = self
                if(self.parent is not None):
                    self.parent.replace_child(self, new_leaf)   # now new_leaf is the child of the parent
                    new_leaf.set_parent(self.parent)
                    self.set_parent(new_leaf)
                return new_leaf
        else:
            raise Exception('wrong leaf type for child')
            
class BridgeLeaf(BaseLeaf):
    def __init__(self, token, wtype, tokenizer, creator=None):
        super().__init__(token, wtype, tokenizer)
        assert (self.get_self_leaf_type() == self.leaf_types_dict["BRIDGE_LEAF"] or
               self.get_self_leaf_type() == self.leaf_types_dict["INVARIANT"])
        self.parent = creator
        self.child = None
    
    def get_ordered_children(self):
        before, after = [], []
        if(self.child is not None): after.append(self.child)
        return before, after
    
    def get_ordered_children_for_model(self):
        before, after = [], []
        if(self.child is not None): after.append(self.child)
        return before, after
    
    def __str__(self, spacing):
        new_spacing = f"{spacing}\t"
        base_str = f"{spacing}type:{self.get_self_leaf_type()} words:{self.tokens}\n"
        if(self.child is not None):
            base_str += f"{self.child.__str__(new_spacing)}\n"
        else:
            base_str += f"{new_spacing}None\n"
        
        return base_str
    
    def set_parent(self, parent):
        self.parent = parent
        
    def replace_child(self, old, new):
        if(self.child == old):
            self.child = new
        else:
            raise Exception('does not exist as child')
    
    def push_token(self, token, wtype):
        ltype = self.get_leaf_type_of(wtype)
        if(ltype == self.leaf_types_dict["BRIDGE_LEAF"] or ltype == self.leaf_types_dict["INVARIANT"]):
            self.append_token(token, wtype)
            return self
        else:
            if(ltype == self.leaf_types_dict["TWO_CHILD_LEAF"]):
                new_leaf = TwoChildLeaf(token, wtype, self.tokenizer, self)
            elif(ltype == self.leaf_types_dict["ONE_CHILD_LEAF"]):
                new_leaf = OneChildLeaf(token, wtype, self.tokenizer, self)
            else:
                raise Exception('wrong leaf type for child')
            self.child = new_leaf
            return self.child


In [None]:
class RootLeaf(BaseLeaf):
    def __init__(self, tokenizer):
        super().__init__(None, "ROOT", tokenizer)
        assert self.get_self_leaf_type() == self.leaf_types_dict["ROOT_LEAF"]
        self.child = None
        self.example_tokens, self.target_tokens, self.example_ltypes, self.target_ltypes = None, None, None, None    
        self.tokenizer = tokenizer
        self.lbl_st_len = 0
        
    def __str__(self):
        if(self.child is not None):
            return self.child.__str__(spacing="")
        else:
            return 'empty'
    
    def get_labelled_set_len(self):
        return self.lbl_st_len
    
    def get_labelled_set(self):
        assert self.child is not None
        if(self.example_tokens is not None and self.example_ltypes is not None and self.target_tokens is not None and self.target_ltypes is not None):
            return self.example_tokens, self.target_tokens, self.example_ltypes,  self.target_ltypes
        example_tokens_list, target_tokens_list, example_ltypes_list, target_ltypes_list = [], [], [], []
        for child in self.child.get_next_leaf():
            example_tokens, target_tokens, example_types, target_ltypes = child.get_labelled_set()
            if(example_tokens is not None):
                example_tokens_list.append(example_tokens)
                target_tokens_list.append(target_tokens)
                example_ltypes_list.append(example_types)
                target_ltypes_list.append(target_ltypes)
        self.example_tokens = torch.cat(example_tokens_list, dim=0)
        self.example_ltypes = torch.cat(example_ltypes_list, dim=0)
        self.target_tokens = torch.cat(target_tokens_list, dim=0)
        self.target_ltypes = torch.cat(target_ltypes_list, dim=0)
        self.lbl_st_len = self.example_tokens.shape[0]
        return self.example_tokens, self.target_tokens, self.example_ltypes,  self.target_ltypes
    
    def get_node_labelled_set(self, cur, nxt):
        example_tokens, target_tokens, example_ltypes, target_ltypes = [], [], [], []
        for i in range(len(cur.tokens)-1):
            example_tokens.append(cur.tokens[i])
            example_ltypes.append(cur.get_self_leaf_type())
            target_tokens.append(cur.tokens[i+1])
            target_ltypes.append(cur.get_self_leaf_type())
        if(nxt is not None):
            example_tokens.append(cur.tokens[len(cur.tokens)-1])
            example_ltypes.append(cur.get_self_leaf_type())
            target_tokens.append(nxt.tokens[0])
            target_ltypes.append(nxt.get_self_leaf_type())
        if(target_tokens == []):
            return None, None, None, None
        return example_tokens, target_tokens, example_ltypes, target_ltypes
    
    def label_all_nodes(self):
        it = iter(self.get_next_leaf())
        cur = next(it)
        try:
            while(True):
                nxt = next(it)
                example_tokens, target_tokens, example_ltypes, target_ltypes = self.get_node_labelled_set(cur, nxt)
                cur.label_tokens(example_tokens, target_tokens)
                cur.label_ltypes(example_ltypes, target_ltypes)
                cur = nxt
        except StopIteration:
            example_tokens, target_tokens, example_ltypes, target_ltypes = self.get_node_labelled_set(cur, None)
            if(example_tokens is not None and target_tokens is not None):
                cur.label_tokens(example_tokens, target_tokens)
                cur.label_ltypes(example_ltypes, target_ltypes)
    
    def generate_text(self):
        return self.child.generate_text() if(self.child is not None) else ""
        
    def set_parent(self, parent):
        raise Exception('cannot set parent for RootLeaf')
        
    def replace_child(self, old, new):
        if(self.child == old):
            self.child = new
        else:
            raise Exception('does not exist as child')
            
    def push_token(self, token, wtype):
        ltype = self.get_leaf_type_of(wtype)
        if(ltype == self.leaf_types_dict["ONE_CHILD_LEAF"]):
            child = OneChildLeaf(token, wtype, self.tokenizer, self) 
        elif(ltype == self.leaf_types_dict["TWO_CHILD_LEAF"]):
            child = TwoChildLeaf(token, wtype, self.tokenizer, self)
        elif(ltype == self.leaf_types_dict["BRIDGE_LEAF"] or ltype == self.leaf_types_dict["INVARIANT"]):
            child = BridgeLeaf(token, wtype, self.tokenizer, self)
        else:
            raise Exception(f'invalid leaf type passed for token:{token} ltype:{ltype} and wtype:{wtype}')
        self.child = child
        return child

    def get_ordered_children(self):
        before, after = [], []
        if(self.child is not None): after.append(self.child)
        return before, after
        
    def get_ordered_children_for_model(self):
        before, after = [], []
        if(self.child is not None): after.append(self.child)
        return before, after

In [None]:
class SentenceTree():
    def __init__(self, tokenizer, tokens_list):
        self.root = RootLeaf(tokenizer)
        self.current = self.root
        self.model_current = self.root
        self.tokenizer, self.tokens_list = tokenizer, tokens_list
        self.push_tokens(tokens_list)
        self.root.label_all_nodes()
        
    def __str__(self):
        return str(self.root)

    def push_token(self, word, wtype):
        self.current = self.current.push_token(word, wtype)

    def get_labelled_set(self):
        return self.root.get_labelled_set()
    
    def push_tokens(self, token_list):
        for token, wtype in token_list:
            self.push_token(token, wtype)

    def get_beginning(self):
        return self.root

    def get_current(self):
        return sef.current

    def generate_text(self):
        return self.root.generate_text() 
    
    def get_labelled_set_len(self):
        return self.root.get_labelled_set_len()


In [None]:
import gc
class LangTree():
    def __init__(self, tokenizer, root_dest_path=Path('./data/lang_tree_data'), bs=32, bptt=70):
        root_dest_path = create_dir(make_to_path(root_dest_path))
        self.tokenizer,  self.root_dest_path = tokenizer, root_dest_path
        self.length = 0
        self.batch_dest = root_dest_path / 'batches'
        self.batch_dest = create_dir(self.batch_dest)
        self.dest_paths = self.prepare_tree_sequenced_tokens()
        self.batch_paths = self.batchify(self.dest_paths, self.batch_dest, self.length, bs, bptt)
        self.bs, self.bptt = bs, bptt
        
    def __len__(self):
        return self.length
    
    def get_one_batch_len(self):
        return self.one_batch_num_blocks
    
    def get_num_ltypes(self):
        return len(self.tokenizer.get_ltypes_dict())
    
    def __str__(self):
        s = ""
        for r in self.roots:
            s += str(r)
            s += "\n\n"
        return s
    
    def reset_len(self):
        self.length = 0
        
    def get_batch_of_bptt(self, batch_num, seq_num):
        with open(self.batch_paths[f"{batch_num}_{seq_num}"], "rb") as f:
            return pickle.load(f)
        
    def prepare_tree_sequenced_tokens(self):
        dest_paths = []
        ind = 0
        for path, token_list in self.tokenizer.get_next_token_list():
            stree = SentenceTree(self.tokenizer, token_list)
            tree_tok_sequence = stree.get_labelled_set()
            self.length += stree.get_labelled_set_len()
            fname = path.stem
            dest_path = self.root_dest_path / f"{fname}_{ind}.pkl"
            dest_paths.append(dest_path)
            with open(dest_path, "wb") as f:
                pickle.dump(tree_tok_sequence, f)
            gc.collect()
            ind += 1
        return dest_paths

    def batchify(self, src_paths, dest_root, length, bs, bptt):
        def get_next_token_from_paths(src_paths):
            for path in src_paths:
                with open(path, "rb") as f:
                    e_toks, t_toks, e_types, t_types = pickle.load(f)
                    for i in range(len(e_toks)):
                        yield (e_toks[i], t_toks[i], e_types[i], t_types[i])
                    
        one_batch_len = length // bs
        self.one_batch_num_blocks = one_batch_len // bptt
        print(length, bs, one_batch_len, bptt)
        batch_num = 0
        seq_num = 0
        batch_paths = {}
        next_tok_iter = iter(get_next_token_from_paths(src_paths))
        for batch_num in range(bs):
            for seq_num in range(one_batch_len // bptt):
                dest_path = dest_root / f"{batch_num}_{seq_num}.pkl"
                batch_paths[f"{batch_num}_{seq_num}"] = dest_path
                e_tok_list, t_tok_list, e_type_list, t_type_list = [], [], [], []
                for s in range(bptt):
                    tup = next(next_tok_iter)
                    e_tok, t_tok, e_types, t_types = tup
                    e_tok_list.append(e_tok)
                    t_tok_list.append(t_tok)
                    e_type_list.append(e_types)
                    t_type_list.append(t_types)
                e_toks_t = torch.cat(e_tok_list, dim=0)
                t_toks_t = torch.cat(t_tok_list, dim=0)
                e_types_t = torch.cat(e_type_list, dim=0)
                t_types_t = torch.cat(t_type_list, dim=0)
                with open(dest_path, "wb") as f:
                    pickle.dump((e_toks_t[:,None], t_toks_t[:,None], e_types_t[:,None], t_types_t[:,None]), f)
                    print(f"batch: {batch_num} seq: {seq_num} wrote to " + str(dest_path))
            
        return batch_paths
                        
    def get_next_labelled_set(self, num_tokens):
        for path in self.dest_paths:
            with open(path, "rb") as f:
                example_tokens, target_tokens, example_ltypes, target_ltypes = pickle.load(f)
            for i in range(0, len(example_tokens), num_tokens):
                j = min(i + num_tokens, len(example_tokens))
                yield example_tokens[i:j], target_tokens[i:j], example_ltypes[i:j], target_ltypes[i:j]

In [None]:
tk = Tokenizer(text_list=btext_list)

In [None]:
it = iter(tk.get_next_token_list())
x = next(it)
x

In [None]:
120254 // 32, 3757 // 70

In [None]:
tree = LangTree(tk)

In [None]:
# test
def tree_test(tree):
    generated_list = [tree.generate_text()]
    for (gen, sen) in zip(generated_list, tree.vocab.get_batched_token_sequences()[0:1]):
        s1 = gen.split(' ')[1:]
        s2 = [token[0] for token in sen]
        assert s1==s2

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
def get_train_valid_indices(tree, val_split=0.1):
    tot_size = tree.get_one_batch_len()
    val_size = int(0.1*tot_size)
    train_size = tot_size - val_size
    return [0, train_size], [train_size, tot_size]

In [None]:
get_train_valid_indices(tree)

In [None]:
class TreeDataset(Dataset):
    def __init__(self, tree, idx_lower, idx_upper):
        self.tree = tree
        self.idx_lower, self.idx_upper = idx_lower, idx_upper
        self.bs, self.bptt = self.tree.bs, self.tree.bptt
    def __len__(self):
        return self.idx_upper - self.idx_lower
    def __getitem__(self, idx):
        e_tok_l, t_tok_l, e_type_l, t_type_l = [], [], [], []
        for i in range(tree.bs):
            e_toks, t_toks, e_types, t_types = tree.get_batch_of_bptt(i, self.idx_lower + idx)
            e_tok_l.append(e_toks)
            t_tok_l.append(t_toks)
            e_type_l.append(e_types)
            t_type_l.append(t_types)
        e_tok_t, t_tok_t, e_type_t, t_type_t = torch.stack(e_tok_l), torch.stack(t_tok_l), torch.stack(e_type_l), torch.stack(t_type_l)
        return e_tok_t, t_tok_t, e_type_t, t_type_t

In [None]:
class TreeDataLoader():
    def __init__(self, tree_ds):
        self.ds = tree_ds
        self.length = len(tree_ds)
    def __iter__(self):
        for i in range(self.length):
            yield self.ds[i]
    def __len__(self):
        return self.ds.bptt*self.length*self.ds.bs

In [None]:
# print("train:")
# for a,b,c,d in train_dl:
#     print(a.shape)
# print("valid:")
# for a,b,c,d in valid_dl:
#     print(a.shape)

In [None]:
# train_ind, valid_ind = get_train_valid_indices(tree, 0.1)
# train_ds = TreeDataset(tree, *train_ind)
# valid_ds = TreeDataset(tree, *valid_ind)
# train_dl, valid_dl = TreeDataLoader(train_ds), TreeDataLoader(valid_ds)

# it = iter(train_dl)
# rnn = nn.RNN(50,100, batch_first=True)
# emb = nn.Embedding(len(tree.tokenizer.vocab.itos), 50)
# e = emb(a).squeeze()
# print(e.shape)
# r = rnn(e)[0]
# print(r.shape)

In [None]:
class SerialTreeModel(nn.Module):
    def __init__(self, tree, bs):
        super().__init__()
        self.emb_tok = nn.Embedding(len(tree.tokenizer.vocab.itos), 50)
        self.emb_ltype = nn.Embedding(tree.get_num_ltypes(), 50)
        self.rnn = nn.RNN(100,100, batch_first=True)
        self.lin_tok = nn.Linear(50, len(tree.tokenizer.vocab.itos))
        self.lin_ltype = nn.Linear(50, tree.get_num_ltypes())
        self.h = torch.zeros(1, bs, 100)
        self.h = self.h.cuda() if torch.cuda.is_available() else self.h
        
    def forward(self, toks, ltypes):
        a1 = self.emb_tok(toks).squeeze()
        a2 = self.emb_ltype(ltypes).squeeze()
        ip_rnn = torch.cat([a1, a2], dim=-1)
        
        res, h = self.rnn(ip_rnn, self.h)
        self.h = h.detach()
        a1, a2 = torch.split(res, 50, dim=-1)
        tok_out = self.lin_tok(a1)
        ltype_out = self.lin_ltype(a2)
        
        return tok_out, ltype_out

In [None]:
tk = Tokenizer(text_list=btext_list)

In [None]:
tree = LangTree(tk)

In [None]:
train_ind, valid_ind = get_train_valid_indices(tree, 0.1)
train_ds = TreeDataset(tree, *train_ind)
valid_ds = TreeDataset(tree, *valid_ind)
train_dl, valid_dl = TreeDataLoader(train_ds), TreeDataLoader(valid_ds)

In [None]:
model = SerialTreeModel(tree, 32)

In [None]:
x_tok, y_tok, x_type, y_type = next(iter(train_dl))

In [None]:
x_tok.shape
for i in range(x_tok.shape[0]):
    print(tree.tokenizer.tokenify(x_tok[i]))
    print(tree.tokenizer.tokenify(y_tok[i]), "\n")

In [None]:
print(len(valid_dl), len(train_dl))
print(len(tree.tokenizer.vocab.itos))
print(model(x_tok, x_type)[0].shape, y_tok.shape)

In [None]:
loss_func1 = nn.CrossEntropyLoss()
y_hat = model(x_tok, x_type)[0]
loss =loss_func1(y_hat.squeeze().view(y_hat.shape[0], y_hat.shape[2], y_hat.shape[1]), y_tok.squeeze())
loss

In [None]:
epochs = 10

In [None]:
def fit(epochs, model, opt, train_dl, valid_dl):
    loss_func1 = nn.CrossEntropyLoss()
    loss_func2 = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        for x_tok, y_tok, x_ltype, y_ltype in train_dl:
            tok_hat, ltype_hat = model(x_tok, x_ltype)
            
            tok_hat = tok_hat.squeeze().view(tok_hat.shape[0], tok_hat.shape[2], tok_hat.shape[1])
            y_tok = y_tok.squeeze()
            ltype_hat = ltype_hat.squeeze().view(ltype_hat.shape[0], ltype_hat.shape[2], ltype_hat.shape[1])
            y_ltype = y_ltype.squeeze()
            
            loss1 = loss_func1(tok_hat, y_tok)
            loss2 = loss_func2(ltype_hat, y_ltype)
            loss = loss1 + loss2
            loss.backward()
            
            opt.step()
            opt.zero_grad()

        model.eval()
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for x_tok_v, y_tok_v, x_ltype_v, y_ltype_v in valid_dl:
                tok_hat_v, ltype_hat_v = model(x_tok_v, x_ltype_v)
                
                tok_hat_v = tok_hat_v.squeeze().view(tok_hat_v.shape[0], tok_hat_v.shape[2], tok_hat_v.shape[1])
                y_tok_v = y_tok_v.squeeze()
                ltype_hat_v = ltype_hat_v.squeeze().view(ltype_hat_v.shape[0], ltype_hat_v.shape[2], ltype_hat_v.shape[1])
                y_ltype_v = y_ltype_v.squeeze()
            
                loss1_v = loss_func1(tok_hat_v, y_tok_v)
                loss2_v = loss_func2(ltype_hat_v, y_ltype_v)
                loss_v = loss1_v + loss2_v
                tot_loss += loss_v
#                 tot_acc  += (accuracy(tok_hat_v, y_tok_v) + accuracy(ltype_hat_v, y_ltype_v))/2
        nv = len(valid_dl)
        print(epoch, tot_loss/nv)

#         print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv #, tot_acc/nv

In [None]:
fit(10, model, torch.optim.Adam(model.parameters(), 0.01), train_dl, valid_dl)