In [49]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [50]:
from gnn_lib.data import utils, tokenization
from dataclasses import dataclass
from typing import Tuple, List, Optional

In [51]:
tokenizer = tokenization.BPETokenizer(
    tokenization.TokenizerConfig(file_path="../../data/tokenizers/bpe/wiki_bookcorpus_10k_no_prefix_space.pkl")
)

In [73]:
def get_tokens_and_doc(s: str, tokenizer: tokenization.Tokenizer) -> Tuple[List[List[str]], "Doc"]:
    tok_fn = tokenization.get_tokenization_fn(tokenizer, True)
    _, doc = utils.tokenize_words(s, return_doc=True, with_dep_parser=True, with_ner=True, with_pos_tags=True)
    tokens = tok_fn(doc)
    return [[tokenizer.id_to_token(token_id) for token_id in token_ids] for token_ids in tokens], doc

In [107]:
tokens, doc = get_tokens_and_doc("this is a test sentesne!", tokenizer)
tokens, doc

([['th', 'is'], ['Ġis'], ['Ġa'], ['Ġtest'], ['Ġsent', 'es', 'ne'], ['!']],
 this is a test sentesne!)

In [121]:
def _open_tikz(scale: float) -> str:
    return f"\\begin{{tikzpicture}}[auto, transform shape, show background rectangle]"

def _close_tikz() -> str:
    return r"\end{tikzpicture}"

@dataclass
class TikzElement:
    def to_tikz(self) -> str:
        raise NotImplementedError
        
@dataclass
class String(TikzElement):
    value: str
    
    def to_tikz(self) -> str:
        return self.value
    
@dataclass
class Color(TikzElement):
    name: str
    r: int
    g: int
    b: int
    
    def __post_init__(self):
        if not 255 >= self.r >= 0:
            raise TypeError("red value must be between 0 and 255")
        if not 255 >= self.g >= 0:
            raise TypeError("green value must be between 0 and 255")
        if not 255 >= self.b >= 0:
            raise TypeError("blue value must be between 0 and 255")
    
    def to_tikz(self) -> str:
        return f"\\definecolor{{{self.name}}}{{RGB}}{{{self.r}, {self.g}, {self.b}}}"

@dataclass
class _BaseNode(TikzElement):
    name: str = None
    value: str = ""
    color: str = "black"
    fill: str = ""
    shape: str = "circle"
    outline: bool = True
    
    def __post_init__(self):
        if self.name is None:
            raise TypeError("name cannot be None")
    
    @property
    def _position(self) -> str:
        raise NotImplementedError
    
    def to_tikz(self) -> str:
        options = [self.color]
        if self.fill:
            options.append(f"fill={self.fill}")
        if self.shape:
            assert self.shape in {"circle", "rectangle", "ellipse"}
            options.append(self.shape)
        if self.outline:
            options.append("draw")
        option_str = "[" + ",".join(options) + "]"
        return f"\\node{option_str} ({self.name}) {self._position} {{{self.value}}};"
    
@dataclass
class Node(_BaseNode):
    x: float = None
    y: float = None
    
    def __post_init__(self):
        if self.x is None:
            raise TypeError("x is None")
        if self.y is None:
            raise TypeError("y is None")
    
    @property
    def _position(self) -> str:
        return f"at ({self.x:.1f},{self.y:.1f})"
    
@dataclass
class RelNode(_BaseNode):
    rel: Node = None
    direction: str = "right"
    distance: str = ""
    
    def __post_init__(self):
        if self.rel is None:
            raise TypeError("rel is None")
    
    @property
    def _position(self) -> str:
        return f"[{self.direction}={self.distance} of {self.rel.name}]"
    
    
@dataclass
class BetweenNode(_BaseNode):
    first: Node = None
    second: Node = None
    
    def __post_init__(self):
        if self.first is None:
            raise TypeError("first is None")
        if self.second is None:
            raise TypeError("second is None")
    
    @property
    def _position(self) -> str:
        return f"at ($({self.first.name})!0.5!({self.second.name})$)"

    
@dataclass
class Edge(TikzElement):
    src: str = None
    dst: str = None
    value: str = ""
    directed: bool = True
    color: str = "black"
    style: str = "line"
    thickness: str = "thin"
    controls: Tuple[Tuple[float, float]] = ()
    
    # edge is either self loop
    self_loop: str = ""
    
    # or not
    out_deg: Optional[float] = None
    in_deg: Optional[float] = None
    bend: str = ""
    
    def __post_init__(self):
        if self.src is None:
            raise TypeError("src is None")
        if self.dst is None:
            raise TypeError("dst is None")
    
    def to_tikz(self) -> str:
        options = [self.color, self.thickness]
        
        if self.directed:
            options.append("->")
        if self.style == "dotted":
            options.append("densely dotted")
        elif self.style == "dashed":
            options.append("dashed")
        if self.controls:
            pass
        
        edge_options = []
        if self.self_loop:
            assert self.self_loop in {"above", "below", "right", "left"}
            edge_options.append(f"loop {self.self_loop}")
        else:
            if self.bend:
                assert self.bend in {"left", "right"}
                edge_options.append(f"bend {self.bend}")
            if self.out_deg:
                edge_options.append(f"out={self.out_deg}")
            if self.in_deg:
                edge_options.append(f"in={self.in_deg}")
            
        option_str = "[" + ",".join(options) + "]"
        edge_option_str = "[" + ",".join(edge_options) + "]"
        return f"\\draw{option_str} ({self.src}) edge {edge_option_str} node {{{self.value}}} ({self.dst});"


def generate_tikz_graph(elements: List[TikzElement], save_to: Optional[str] = None, scale: float = 1.0) -> str:
    tikz_lines = [_open_tikz(scale)] + [elem.to_tikz() for elem in elements] + [_close_tikz()]
    tikz_string = "\n".join(tikz_lines)
    if save_to is None:
        return tikz_string
    else:
        if not save_to.endswith(".tex"):
            save_to += ".tex"
        with open(save_to, "w", encoding="utf8") as of:
            of.write(tikz_string + "\n")
            

            
def generate_fully_connected_graph_elements(tokens: List[List[str]], doc: "Doc") -> List[TikzElement]:
    nodes = []
    edges = []
    
    flat_tokens = [(i, t) for word_tokens in tokens for i, t in enumerate(word_tokens)]
    
    
    node_kwargs = {"fill": "gray!10", "shape": "ellipse"}
    
    for i, (word_idx, token) in enumerate(flat_tokens):
        node_kwargs.update({"name": str(i), "value": token})
        if i == 0:
            node = Node(x=0, y=0, **node_kwargs)
        else:
            node = RelNode(rel=nodes[-1], direction="right", **node_kwargs)
        nodes.append(node)
        for j, _ in enumerate(flat_tokens):
            edges.append(
                Edge(
                    src=str(i), 
                    dst=str(j), 
                    directed=True, 
                    color="black",
                    out_deg=45,
                    in_deg=135,
                    bend="left",
                    self_loop="left" if i == j else ""
                )
            )
         
    return nodes + edges


def generate_word_graph_elements(
    tokens: List[List[str]], 
    doc: "Doc",
    add_dependency_edges: bool = False
) -> List[TikzElement]:
    word_nodes = []
    token_nodes = []
    dummy_nodes = []
    edges = []
    
    flat_tokens = [t for word_tokens in tokens for t in word_tokens]
    
    token_node_kwargs = {"fill": "gray!10", "shape": "ellipse"}
    word_node_kwargs = {"fill": "blue!10", "shape": "ellipse"}
    
    token_start_idx = 0
    
    for word_idx, (word_tokens, word) in enumerate(zip(tokens, doc)):
        for j, token in enumerate(word_tokens):
            token_node_kwargs.update({"name": f"token_{token_start_idx + j}", "value": token})
            if token_start_idx == 0 and j == 0:
                token_node = Node(x=0, y=0, **token_node_kwargs)
            else:
                token_node = RelNode(rel=token_nodes[-1], direction="right", **token_node_kwargs)
            token_nodes.append(token_node)
            
            for k, _ in enumerate(word_tokens):
                edges.append(
                    Edge(
                        src=token_node_kwargs["name"],
                        dst=f"token_{token_start_idx+k}",
                        directed=True,
                        out_deg=45,
                        in_deg=135,
                        bend="left",
                        self_loop="left" if j == k else "",
                        color="black" if j == k else "red",
                        value="s" if j == k else "t"
                    )
                )
            edges.append(
                Edge(
                    src=token_node_kwargs["name"], 
                    dst=f"word_{word_idx}", 
                    directed=True, 
                    color="orange",
                    value="in"
                )
            )
        
        word_node_kwargs.update({"name": f"word_{word_idx}", "value": word.text, "distance": "2cm"})
        if len(word_tokens) % 2 == 1:
            # place above central token node if there are an odd number of tokens
            word_node = RelNode(rel=token_nodes[-(len(word_tokens) // 2 + 1)], direction="above", **word_node_kwargs)
        else:
            # place above invisible dummy node between the two central token nodes if there are an even number of tokens
            dummy_name = f"dummy_{word_idx}"
            dummy_node = BetweenNode(name=dummy_name, first=token_nodes[-(len(word_tokens) // 2 + 1)], second=token_nodes[-(len(word_tokens) // 2)], shape="", outline=False)
            dummy_nodes.append(dummy_node)
            word_node = RelNode(rel=dummy_node, direction="above", **word_node_kwargs)
        word_nodes.append(word_node)
        
        for to_word_idx, _ in enumerate(tokens):
            after = to_word_idx > word_idx
            edges.append(
                Edge(
                    src=word_node_kwargs["name"],
                    dst=f"word_{to_word_idx}",
                    directed=True,
                    out_deg=45 * after - 45 * (not after),
                    in_deg=135 * after - 135 * (not after),
                    bend="left",
                    self_loop="left" if word_idx == to_word_idx else "",
                    color="black" if word_idx == to_word_idx else "blue",
                    value="s" if word_idx == to_word_idx else "w"
                )
            )
            
        token_start_idx += len(word_tokens)
        
        if add_dependency_edges:
            dep = word.head.i
            edges.append(
                Edge(
                    src=word_node_kwargs["name"],
                    dst=f"word_{dep}",
                    directed=True,
                    out_deg=20,
                    in_deg=160,
                    bend="left",
                    self_loop="above" if dep == word_idx else "",
                    color="green",
                    value="d"
                )
            )
            edges.append(
                Edge(
                    dst=word_node_kwargs["name"],
                    src=f"word_{dep}",
                    directed=True,
                    out_deg=20,
                    in_deg=160,
                    bend="left",
                    self_loop="above" if dep == word_idx else "",
                    color="yellow",
                    value="h"
                )
            )
         
    return token_nodes + dummy_nodes + word_nodes + edges


def mark_node_and_edges(
    elements: List[TikzElement], 
    node_name: str, 
    edge_value: Optional[str] = None
) -> List[TikzElement]:
    def _mark(element: TikzElement) -> TikzElement:
        if isinstance(element, _BaseNode):
            if element.name == node_name:
                return element
            else:
                element.color = "black!40"
                element.fill = "gray!40"
                return element
        elif isinstance(element, Edge):
            if element.dst == node_name and (edge_value or element.value) == element.value:
                return element
            else:
                element.color = "gray!40"
                return element
        else:
            return element
        
    return list(map(_mark, elements))
            
colors = [
    Color(name="uni_red", r=230, g=190, b=83),
    Color(name="uni_light_grey", r=230, g=190, b=83),
    Color(name="uni_blue", r=230, g=190, b=83),
]
    
elements = generate_fully_connected_graph_elements(tokens, doc)
generate_tikz_graph(colors + elements, save_to="../../../latex/figures/tikz_test_fc_graph.tex")

elements = generate_word_graph_elements(tokens, doc, add_dependency_edges=True)
generate_tikz_graph(colors + elements, save_to="../../../latex/figures/tikz_test_word_graph.tex")

elements = mark_node_and_edges(elements, "word_4")
generate_tikz_graph(colors + elements, save_to="../../../latex/figures/tikz_test_word_graph_mark.tex")