In [1]:
%load_ext autoreload
%autoreload 2

In [110]:
from dataclasses import dataclass
from typing import Tuple, List, Optional, Generator, Union, Callable, Set
import os

from nsc.data import utils, tokenization
from spell_checking import DATA_DIR, BASE_DIR

In [3]:
tokenizer = tokenization.BPETokenizer(
    tokenization.TokenizerConfig(file_path=os.path.join(DATA_DIR, "tokenizers/bpe/wiki_bookcorpus_10k_no_prefix_space.pkl"))
)

In [4]:
def get_tokens_and_doc(s: str, tokenizer: tokenization.Tokenizer) -> Tuple[List[List[str]], "Doc"]:
    tok_fn = tokenization.get_tokenization_fn(tokenizer, True)
    _, doc = utils.tokenize_words(s, return_doc=True, with_dep_parser=True, with_ner=True, with_pos_tags=True)
    tokens = tok_fn(doc)
    return [[tokenizer.id_to_token(token_id) for token_id in token_ids] for token_ids in tokens], doc

In [74]:
tokens, doc = get_tokens_and_doc("Corect this sentesne!", tokenizer)
tokens, doc

([['C', 'ore', 'ct'], ['Ġthis'], ['Ġsent', 'es', 'ne'], ['!']],
 Corect this sentesne!)

In [194]:
def _open_tikz(on_grid: bool, bg_rectangle: bool) -> str:
    options = ["auto", "transform shape", "anchor=center"]
    if bg_rectangle:
        options.append("show background rectangle")
    if on_grid:
        options.append("on grid")
    return f"\\begin{{tikzpicture}}[{', '.join(options)}]"

def _close_tikz() -> str:
    return r"\end{tikzpicture}"


@dataclass
class TikzElement:
    def to_tikz(self) -> str:
        raise NotImplementedError
        
    def apply(self, fn: Callable[[TikzElement], TikzElement]) -> "TikzElement":
        self = fn(self)
        return self
        

class SaveableTikzElement(TikzElement):
    def save(self, save_to: str) -> None:
        if not save_to.endswith(".tex"):
            save_to += ".tex"
        with open(save_to, "w", encoding="utf8") as of:
            of.write(self.to_tikz() + "\n")
        

class TikzPicture(SaveableTikzElement):
    def __init__(self, on_grid: bool = False, show_bg_rectangle: bool = False) -> None:
        self._elements = []
        self.on_grid = on_grid
        self.show_bg_rectangle = show_bg_rectangle
        
    def prepend(self, element: Union[TikzElement, List[TikzElement]]) -> "TikzPicture":
        if isinstance(element, TikzElement):
            self._elements.insert(0, element)
        else:
            self._elements = element + self._elements
        return self
        
    def append(self, element: Union[TikzElement, List[TikzElement]]) -> "TikzPicture":
        if isinstance(element, TikzElement):
            self._elements.append(element)
        else:
            self._elements.extend(element)
        return self
        
    def to_tikz(self) -> str:
        tikz_lines = [_open_tikz(self.on_grid, self.show_bg_rectangle)] + ["\t" + elem.to_tikz() for elem in self._elements] + [_close_tikz()]
        return "\n".join(tikz_lines)
        
    def apply(self, fn: Callable[[TikzElement], TikzElement]) -> "TikzPicture":
        self._elements = [elem.apply(fn) for elem in self._elements]
        return self
    
    @property
    def elements(self) -> List[TikzElement]:
        return self._elements
    
    
@dataclass
class Background(TikzElement):
    _elements: List[TikzElement] = None
    
    def __post_init__(self):
        self._elements = []
    
    def prepend(self, element: Union[TikzElement, List[TikzElement]]) -> "Background":
        if isinstance(element, TikzElement):
            self._elements.insert(0, element)
        else:
            self._elements = element + self._elements
        return self
        
    def append(self, element: Union[TikzElement, List[TikzElement]]) -> "Background":
        if isinstance(element, TikzElement):
            self._elements.append(element)
        else:
            self._elements.extend(element)
        return self
    
    def to_tikz(self) -> str:
        bg_lines = [r"\begin{scope}[on background layer]"] + ["\t" + elem.to_tikz() for elem in self._elements] + [r"\end{scope}"]
        return "\n".join(bg_lines)
    
    def apply(self, fn: Callable[[TikzElement], TikzElement]) -> "TikzPicture":
        self._elements = [elem.apply(fn) for elem in self._elements]
        return self
    
    @property
    def elements(self) -> List[TikzElement]:
        return self._elements
    
        
@dataclass
class String(TikzElement):
    value: str
    
    def to_tikz(self) -> str:
        return self.value
    
@dataclass
class Color(TikzElement):
    name: str
    r: int
    g: int
    b: int
    
    def __post_init__(self):
        if not 255 >= self.r >= 0:
            raise TypeError("red value must be between 0 and 255")
        if not 255 >= self.g >= 0:
            raise TypeError("green value must be between 0 and 255")
        if not 255 >= self.b >= 0:
            raise TypeError("blue value must be between 0 and 255")
    
    def to_tikz(self) -> str:
        return f"\\definecolor{{{self.name}}}{{RGB}}{{{self.r}, {self.g}, {self.b}}}"

@dataclass
class _BaseNode(TikzElement):
    name: str = None
    value: str = ""
    color: str = "black"
    fill: str = ""
    shape: str = "circle"
    style: str = "line"
    outline: bool = True
    options: Tuple[str] = ()
    
    def __post_init__(self):
        if self.name is None:
            raise TypeError("name cannot be None")
    
    @property
    def _position(self) -> str:
        raise NotImplementedError
    
    def to_tikz(self) -> str:
        options = [self.color]
        if self.fill:
            options.append(f"fill={self.fill}")
        if self.shape:
            options.append(self.shape)
        if self.style != "line":
            options.append(self.style)
        if self.outline:
            options.append("draw")
        options.extend(self.options)
        option_str = "[" + ",".join(options) + "]"
        return f"\\node{option_str} ({self.name}) {self._position} {{{self.value}}};"
    
@dataclass
class Node(_BaseNode):
    x: float = None
    y: float = None
    
    def __post_init__(self):
        super().__post_init__()
        if self.x is None:
            raise TypeError("x is None")
        if self.y is None:
            raise TypeError("y is None")
    
    @property
    def _position(self) -> str:
        return f"at ({self.x:.1f},{self.y:.1f})"
    
@dataclass
class NoPosNode(_BaseNode):
    @property
    def _position(self) -> str:
        return ""
    
@dataclass
class RelNode(_BaseNode):
    rel: Node = None
    direction: str = "right"
    distance: str = ""
    anchor: Optional[str] = None
    
    def __post_init__(self):
        super().__post_init__()
        if self.rel is None:
            raise TypeError("rel is None")
    
    @property
    def _position(self) -> str:
        if self.anchor is not None:
            anchor_str_rel = f".{self.anchor}"
            anchor_str_node = f"anchor={self.anchor}"
        else:
            anchor_str_rel = ""
            anchor_str_node = ""
        return f"[{self.direction}={self.distance} of {self.rel.name}{anchor_str_rel}, {anchor_str_node}]"
    
    
@dataclass
class BetweenNode(_BaseNode):
    first: Node = None
    second: Node = None
    ratio: float = 0.5
    anchors: Optional[Tuple[str, str]] = None
    
    def __post_init__(self):
        super().__post_init__()
        if self.first is None:
            raise TypeError("first is None")
        if self.second is None:
            raise TypeError("second is None")
    
    @property
    def _position(self) -> str:
        if self.anchors is not None:
            anchor_first_str = f".{self.anchors[0]}"
            anchor_second_str = f".{self.anchors[1]}"
        else:
            anchor_first_str = ""
            anchor_second_str = ""
        return f"at ($({self.first.name}{anchor_first_str})!{self.ratio}!({self.second.name}{anchor_second_str})$)"

    
@dataclass
class Edge(TikzElement):
    src: Union[str, _BaseNode] = None
    dst: Union[str, _BaseNode] = None
    value: str = ""
    direction: str = "undirected"
    color: str = "black"
    style: str = "line"
    thickness: str = "thin"
    controls: Tuple[Tuple[float, float]] = ()
    
    # edge is either self loop
    self_loop: str = ""
    
    # or not
    out_deg: Optional[float] = None
    in_deg: Optional[float] = None
    bend: str = ""
    
    def __post_init__(self):
        if self.src is None:
            raise TypeError("src is None")
        if self.dst is None:
            raise TypeError("dst is None")
            
    @property
    def src_name(self) -> str:
        return self.src if isinstance(self.src, str) else self.src.name
    
    @property
    def dst_name(self) -> str:
        return self.dst if isinstance(self.dst, str) else self.dst.name
    
    def to_tikz(self) -> str:
        options = [self.color, self.thickness]
        
        if self.style != "line":
            options.append(self.style)
        
        assert self.direction in {"directed", "undirected", "bidirected"}
        if self.direction == "directed":
            options.append("->")
        elif self.direction == "bidirected":
            options.append("<->")
            
        if self.controls:
            raise NotImplementedError("controls for edges not yet implemented")
        
        edge_options = []
        if self.self_loop:
            assert self.self_loop in {"above", "below", "right", "left"}
            edge_options.append(f"loop {self.self_loop}")
        else:
            if self.bend:
                assert self.bend in {"left", "right"}
                edge_options.append(f"bend {self.bend}")
            if self.out_deg:
                edge_options.append(f"out={self.out_deg}")
            if self.in_deg:
                edge_options.append(f"in={self.in_deg}")
            
        option_str = "[" + ",".join(options) + "]"
        edge_option_str = "[" + ",".join(edge_options) + "]"
        
        return f"\\draw{option_str} ({self.src_name}) edge {edge_option_str} node {{{self.value}}} ({self.dst_name});"

    
@dataclass
class Box(TikzElement):
    fit: List[_BaseNode] = None
    color: str = "black"
    style: str = "line"
    shape: str = "rectangle"
    label: Tuple[str, str] = None
    padding: Tuple[float, float, float, float] = None
    
    def __post_init__(self) -> None:
        if not self.fit:
            raise TypeError("fit is None or empty")
            
    def to_tikz(self) -> str:
        options = [self.shape, self.color]
        if self.style != "line":
            options.append(self.style)
            
        if self.label is not None:
            label, label_position = self.label
            options.append(f"label={label_position}:{{{label}}}")
            
        if self.padding is not None:
            left, right, top, bottom = self.padding
            xsep = (left + right) / 2
            options.append(f"inner xsep={xsep}cm")
            ysep = (top + bottom) / 2
            options.append(f"inner ysep={ysep}cm")
            xshift = xsep - left
            options.append(f"xshift={xshift}cm")
            yshift = ysep - bottom
            options.append(f"yshift={yshift}cm")
            
        option_str = ", ".join(options)
        return f"\\node[{option_str}, draw, fit={' '.join(f'({n.name})' for n in self.fit)}] {{}};"
    
@dataclass
class _BaseMatrix(_BaseNode):
    _rows: List[List[TikzElement]] = None
    
    def __post_init__(self):
        self._rows = []
        
    def add_row(self, row: List[TikzElement]) -> None:
        if len(self._rows):
            assert len(row) == len(self._rows[-1]), "all rows must contain the same number of elements"
        self._rows.append(row)
        
    def to_tikz(self) -> str:
        rows = []
        for row in self._rows:
            row_str = r" & ".join(col.to_tikz() for col in row) + r" \\"
            rows.append(row_str)
        self.options = (*self.options, "matrix")
        self.value = "\n".join(rows)
        return super().to_tikz()
    
@dataclass
class RelMatrix(_BaseMatrix, RelNode):
    pass

In [207]:
def generate_fully_connected_graph(tokens: List[List[str]], doc: "Doc") -> TikzPicture:
    nodes = []
    edges = []
    
    flat_tokens = [(word_idx, i, t) for word_idx, word_tokens in enumerate(tokens) for i, t in enumerate(word_tokens)]
    
    pic = TikzPicture()
    
    node_kwargs = {"fill": "uni_light_gray", "shape": "ellipse"}
    
    for i, (word_idx, in_word_idx, token) in enumerate(flat_tokens):
        if word_idx > 0 and in_word_idx == 0 and doc[word_idx - 1].whitespace_ == " ":
            token = "\#"+ token[1:]
        node_kwargs.update({"name": f"token_{i}", "value": token})
        if i == 0:
            node = Node(x=0, y=0, **node_kwargs)
        else:
            node = RelNode(rel=nodes[-1], direction="right", **node_kwargs)
        nodes.append(node)
        for j, _ in enumerate(flat_tokens):
            edges.append(
                Edge(
                    src=node_kwargs["name"], 
                    dst=f"token_{j}", 
                    direction="directed", 
                    color="black!50",
                    out_deg=30,
                    in_deg=150,
                    bend="left",
                    self_loop="left" if i == j else "",
                    style="line",
                    thickness="very thin"
                )
            )
            
    bg = Background()
    bg.append(edges)
    return TikzPicture().append(nodes).append(bg)


def generate_word_graph(
    tokens: List[List[str]], 
    doc: "Doc",
    add_dependency_edges: bool = False
) -> TikzPicture:
    word_nodes = []
    token_nodes = []
    feat_nodes = []
    dummy_nodes = []
    edges = []
    
    flat_tokens = [t for word_tokens in tokens for t in word_tokens]
    
    token_node_kwargs = {"fill": "uni_light_gray", "shape": "ellipse"}
    word_node_kwargs = {"fill": "uni_blue", "shape": "ellipse", "color": "white"}
    
    token_start_idx = 0
    
    for word_idx, (word_tokens, word) in enumerate(zip(tokens, doc)):
        for j, token in enumerate(word_tokens):
            if j == 0 and word_idx > 0 and doc[word_idx - 1].whitespace_ == " ":
                token = "\#"+ token[1:]
            token_node_kwargs.update({"name": f"token_{token_start_idx + j}", "value": token})
            if token_start_idx == 0 and j == 0:
                token_node = Node(x=0, y=0, **token_node_kwargs)
            else:
                token_node = RelNode(rel=token_nodes[-1], direction="right", **token_node_kwargs)
            token_nodes.append(token_node)
            
            for k, _ in enumerate(word_tokens):
                edges.append(
                    Edge(
                        src=token_node_kwargs["name"],
                        dst=f"token_{token_start_idx+k}",
                        direction="directed",
                        out_deg=45,
                        in_deg=135,
                        bend="left",
                        self_loop="left" if j == k else "",
                        color="uni_red" if j == k else "black",
                        style="line",
                        thickness="very thin"
                    )
                )
            edges.append(
                Edge(
                    src=token_node_kwargs["name"], 
                    dst=f"word_{word_idx}", 
                    direction="directed", 
                    color="uni_green",
                    style="line",
                    thickness="very thin"
                )
            )
        
        word_node_kwargs.update({"name": f"word_{word_idx}", "value": word.text, "distance": "3.5cm", "anchor": "center"})
        if len(word_tokens) % 2 == 1:
            # place above central token node if there are an odd number of tokens
            word_node = RelNode(rel=token_nodes[-(len(word_tokens) // 2 + 1)], direction="above", **word_node_kwargs)
        else:
            # place above invisible dummy node between the two central token nodes if there are an even number of tokens
            dummy_name = f"dummy_{word_idx}"
            dummy_node = BetweenNode(name=dummy_name, first=token_nodes[-(len(word_tokens) // 2 + 1)], 
                                     second=token_nodes[-(len(word_tokens) // 2)], shape="", outline=False)
            dummy_nodes.append(dummy_node)
            word_node = RelNode(rel=dummy_node, direction="above", **word_node_kwargs)
            
        word_nodes.append(word_node)
        feature_rect = RelNode(
            name=f"feature_{word_node.name}", rel=word_node, direction="above", shape="rectangle", 
            distance="0.25cm", fill=word_node_kwargs["fill"], options=("minimum width=1cm", "minimum height=0.25cm")
        )
        word_nodes.append(feature_rect)
        
        for to_word_idx, _ in enumerate(tokens):
            after = to_word_idx > word_idx
            edges.append(
                Edge(
                    src=word_node_kwargs["name"],
                    dst=f"word_{to_word_idx}",
                    direction="directed",
                    out_deg=30,
                    in_deg=150,
                    bend="left",
                    self_loop="left" if word_idx == to_word_idx else "",
                    color="uni_red" if word_idx == to_word_idx else "uni_blue",
                    style="line",
                    thickness="very thin"
                )
            )
            
        token_start_idx += len(word_tokens)
        
        if add_dependency_edges:
            dep = word.head.i
            edges.append(
                Edge(
                    src=word_node_kwargs["name"],
                    dst=f"word_{dep}",
                    direction="directed",
                    out_deg=20,
                    in_deg=160,
                    bend="left",
                    self_loop="below" if dep == word_idx else "",
                    color="uni_orange",
                    style="line",
                    thickness="very thin"
                )
            )
            edges.append(
                Edge(
                    dst=word_node_kwargs["name"],
                    src=f"word_{dep}",
                    direction="directed",
                    out_deg=20,
                    in_deg=160,
                    bend="left",
                    self_loop="below" if dep == word_idx else "",
                    color="uni_yellow",
                    style="line",
                    thickness="very thin"
                )
            )
    bg = Background()
    bg.append(edges)
    return TikzPicture().append(token_nodes).append(dummy_nodes).append(word_nodes).append(bg)

In [216]:
def mark_node_and_edges(
    pic: TikzPicture, 
    node_names: Set[str],
) -> TikzPicture:
    def _mark(element: TikzElement) -> TikzElement:
        feature_node_names = set(f"feature_{node_name}" for node_name in node_names)
        if isinstance(element, _BaseNode):
            if not isinstance(element, BetweenNode) and element.name not in node_names and element.name not in feature_node_names:
                element.color = "black!40"
                element.fill = "uni_dark_gray!40"
        elif isinstance(element, Edge):
            if element.dst_name in node_names:
                element.thickness = "thick"
                element.style = "line"
            else:
                element.color = "uni_dark_gray!40"
        
        return element
        
    return pic.apply(_mark)


def add_encoder_box(pic: TikzPicture, word_graph: bool = False) -> TikzPicture:
    all_nodes = [node for node in pic.elements if isinstance(node, _BaseNode)]
    if word_graph:
        padding = (0.75, 0.25, 2, 1)
    else:
        padding = (0.75, 0.25, 2.5, 2.5)
    box = Box(
        fit=all_nodes, 
        style="dashed", 
        label=("Encoder", "left"),
        padding=padding
    )
    pic.append(box)
    return pic


def add_transformer_feature_head(pic: TikzPicture, tokens: List[List[str]], doc: "Doc") -> TikzPicture:
    token_nodes = [node for node in pic.elements if isinstance(node, _BaseNode) and node.name.startswith("token")]
    
    word_node_groups = []
    running_node_idx = 0
    for word_tokens in tokens:
        word_node_groups.append(token_nodes[running_node_idx:running_node_idx+len(word_tokens)])
        running_node_idx += len(word_tokens)
    
    edges = []
    dummy_nodes = []
    agg_nodes = []
    word_nodes = []
    plus_nodes = []
    feat_nodes = []
    func_nodes = []
    for i, (word_node_group, word) in enumerate(zip(word_node_groups, doc)):
        dummy_node = BetweenNode(name=f"dummy_{i}", first=word_node_group[0], second=word_node_group[-1], shape="", outline=False)
        dummy_nodes.append(dummy_node)
        matrix_node = RelMatrix(name=f"mat_{i}", rel=dummy_node, direction="above", anchor="center", distance="4", style="densely dotted", shape="rectangle")
        agg_node = NoPosNode(name=f"agg_{i}", value="Agg.", shape="rectangle")
        feat_node = NoPosNode(name=f"feat_{i}", fill="uni_blue", shape="rectangle", options=("minimum width=1cm", "minimum height=0.25cm"))
        plus_node = NoPosNode(name=f"plus_{i}", value="$\mathbin\Vert$", color="black", shape="", outline=False)
        matrix_node.add_row([agg_node, plus_node, feat_node])
        agg_nodes.append(matrix_node)
        word_node = RelNode(name=f"word_{i}", fill="uni_blue", shape="ellipse", color="white", rel=matrix_node, direction="above", 
                            anchor="center", distance="2.5", value=word.text)
        word_nodes.append(word_node)
        func_node = BetweenNode(name=f"func_{i}", value="$f$", anchors=("north", "south"), fill="uni_light_blue!80", first=matrix_node, second=word_node)
        func_nodes.append(func_node)
        for node in word_node_group:
            edges.append(
                Edge(src=node, dst=agg_node, direction="undirected", thickness="thick")
            )
        edges.append(
            Edge(src=matrix_node, dst=word_node, direction="directed", thickness="thick")
        )

    head_box = Box(
        fit=word_nodes + agg_nodes,
        style="dashed",
        label=("Word features", "left"),
        padding=(0.5, 0.5, 0.5, 0.5)
    )
    pic.append(dummy_nodes + agg_nodes + feat_nodes + plus_nodes + word_nodes + [head_box] + edges + func_nodes)
    return pic


def add_sequence_classification_head(
    pic: TikzPicture,
    dim_edge: Callable[[TikzElement], bool],
    dim_node: Callable[[TikzElement], bool],
    word_level: bool = False,
    distance: int = 4
) -> TikzPicture:
    aggregation_nodes = [node for node in pic.elements if isinstance(node, _BaseNode) and node.name.startswith("word" if word_level else "token")]
    def _dim_edges(element: TikzElement) -> TikzElement:
        if isinstance(element, Edge) and dim_edge(element):
            element.color = "uni_dark_gray!40"
        if word_level:
            if isinstance(element, _BaseNode) and dim_node(element):
                element.color = "black!40"
                element.fill = "uni_dark_gray!40"
        return element
    
    pic.apply(_dim_edges)
    
    dummy_node = BetweenNode(name="dummy", first=aggregation_nodes[0], second=aggregation_nodes[-1], shape="", outline=False)
    aggregation_node = RelNode(name="agg", value="Aggregation", 
                               rel=dummy_node, direction="above", anchor="center", distance=str(distance), shape="rectangle")
    edges = []
    for node in aggregation_nodes:
        edges.append(
            Edge(src=node, dst=aggregation_node, direction="undirected", thickness="thick")
        )
    output_node = RelNode(
        name="out", value=r"$\hat{l}$", fill="uni_green", rel=aggregation_node, 
        direction="above", distance="2.5", anchor="center", shape="diamond"
    )
    classification_node = BetweenNode(name="clf", value="$f$", fill="uni_yellow", first=aggregation_node, 
                                      second=output_node, shape="circle", anchors=("north", "south"))
    edges.append(
        Edge(src=aggregation_node, dst=output_node, direction="directed", thickness="thick")
    )
    head_box = Box(
        fit=[aggregation_node, output_node],
        style="dashed",
        label=("Classifier", "left"),
        padding=(0.5, 0.5, 0.5, 0.5)
    )
    bg = Background()
    bg.append(edges)
    pic.append([dummy_node, aggregation_node, output_node, classification_node, head_box, bg])
    return pic


def add_word_classification_head(
    pic: TikzPicture,
    tokens: List[List[str]],
    doc: "Doc",
    dim_edge: Callable[[TikzElement], bool],
    dim_node: Callable[[TikzElement], bool],
    word_level: bool = False,
    distance: int = 4
) -> TikzPicture:
    aggregation_nodes = [node for node in pic.elements if isinstance(node, _BaseNode) and node.name.startswith("word" if word_level else "token")]
    def _dim_edges(element: TikzElement) -> TikzElement:
        if isinstance(element, Edge) and dim_edge(element):
            element.color = "uni_dark_gray!40"
        if word_level:
            if isinstance(element, _BaseNode) and dim_node(element):
                element.color = "black!40"
                element.fill = "uni_dark_gray!40"
        return element
    
    pic.apply(_dim_edges)
    
    aggregation_node_groups = [[]]
    running_node_idx = 0
    for i, (word, word_tokens) in enumerate(zip(doc, tokens)):
        offset = 1 if word_level else len(word_tokens)
        aggregation_node_groups[-1].extend(aggregation_nodes[running_node_idx:running_node_idx+offset])
        running_node_idx += offset
        if word.whitespace_ == " ":
            aggregation_node_groups.append([])
    
    edges = []
    dummy_nodes = []
    aggregated_nodes = []
    classification_nodes = []
    output_nodes = []
    for i, aggregation_node_group in enumerate(aggregation_node_groups):
        dummy_node = BetweenNode(name=f"dummy_{i}", first=aggregation_node_group[0], second=aggregation_node_group[-1], shape="", outline=False)
        dummy_nodes.append(dummy_node)
        agg_node = RelNode(name=f"agg_{i}", value="Agg.", 
                           rel=dummy_node, direction="above", anchor="center", distance=str(distance), shape="rectangle")
        aggregated_nodes.append(agg_node)
        output_node = RelNode(
            name=f"out_{i}", value=f"$\\hat{{l_{{{i+1}}}}}$", fill="uni_green", rel=agg_node, 
            direction="above", distance="2.5", anchor="center", shape="diamond"
        )
        output_nodes.append(output_node)
        clf_node = BetweenNode(name="clf", value="$f$", fill="uni_yellow", first=agg_node, 
                                second=output_node, shape="circle", anchors=("north", "south"))
        classification_nodes.append(clf_node)
        for node in aggregation_node_group:
            edges.append(
                Edge(src=node, dst=agg_node, direction="undirected", thickness="thick")
            )
        edges.append(
            Edge(src=agg_node, dst=output_node, direction="directed", thickness="thick")
        )

    head_box = Box(
        fit=aggregated_nodes + output_nodes,
        style="dashed",
        label=("Classifier", "left"),
        padding=(0.5, 0.5, 0.5, 0.5)
    )
    bg = Background()
    bg.append(edges)
    pic.append(dummy_nodes + aggregated_nodes + output_nodes + classification_nodes + [head_box, bg])
    return pic

In [217]:
colors = [
    Color("uni_red", 193, 0, 42),
    Color("uni_light_gray", 224, 225, 221),
    Color("uni_medium_gray", 178, 180, 179),
    Color("uni_dark_gray", 154, 155, 156),
    Color("uni_blue", 0, 74, 153),
    Color("uni_green", 115, 150, 0),
    Color("uni_orange", 233, 131, 0),
    Color("uni_yellow", 239, 189, 71),
    Color("uni_light_blue", 167, 193, 227)
]

### Generate latex figures

In [218]:
LATEX_DIR = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "latex"))
LATEX_FIGURE_DIR = os.path.join(LATEX_DIR, "figures")

import subprocess
import sys

def compile_pdf():
    p = subprocess.call(["pdflatex", "-synctex=1", "-interaction=nonstopmode", "thesis_main.tex"], 
                        cwd=LATEX_DIR, 
                        stdout=subprocess.DEVNULL, 
                        stderr=subprocess.DEVNULL)
    print(p)

#### Fully connected graph

In [219]:
pic = generate_fully_connected_graph(tokens, doc)
pic.prepend(colors)
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_fc_graph.tex"))

pic = mark_node_and_edges(pic, {"token_4"})
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_fc_graph_mark.tex"))

#### Word graph

In [220]:
pic = generate_word_graph(tokens, doc, add_dependency_edges=False)
pic.prepend(colors)
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_word_graph.tex"))

pic = mark_node_and_edges(pic, {"token_4", "word_0"})
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_word_graph_mark.tex"))

#### Sequence classification

In [221]:
pic = generate_fully_connected_graph(tokens, doc)
pic = add_encoder_box(pic, word_graph=False)
pic = add_sequence_classification_head(
    pic,
    dim_edge=lambda e: e.dst_name.startswith("token"), 
    dim_node=lambda e: False, 
)
pic.prepend(colors)
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_token_graph_sequence_classification.tex"))

pic = generate_fully_connected_graph(tokens, doc)
pic = add_encoder_box(pic, word_graph=False)
pic = add_transformer_feature_head(pic, tokens, doc)
pic = add_sequence_classification_head(
    pic, 
    dim_edge=lambda e: e.dst_name.startswith("token"), 
    dim_node=lambda e: False, 
    word_level=True, 
    distance=2
)
pic.prepend(colors)
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_token_graph_sequence_classification_with_features.tex"))

pic = generate_word_graph(tokens, doc, add_dependency_edges=False)
pic = add_encoder_box(pic, word_graph=True)
pic = add_sequence_classification_head(
    pic,
    dim_edge=lambda e: e.dst_name.startswith("token") or e.dst_name.startswith("word"), 
    dim_node=lambda e: e.name.startswith("token"), 
    word_level=True
)
pic.prepend(colors)
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_word_graph_sequence_classification.tex"))

#### Word classification

In [222]:
pic = generate_fully_connected_graph(tokens, doc)
pic = add_encoder_box(pic, word_graph=False)
pic = add_word_classification_head(
    pic, tokens, doc, 
    dim_edge=lambda e: e.dst_name.startswith("token"), 
    dim_node=lambda e: e.name.startswith("token"), 
    word_level=False
)
pic.prepend(colors)
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_token_graph_word_classification.tex"))

pic = generate_fully_connected_graph(tokens, doc)
pic = add_encoder_box(pic, word_graph=False)
pic = add_transformer_feature_head(pic, tokens, doc)
pic = add_word_classification_head(
    pic, tokens, doc,
    dim_edge=lambda e: e.dst_name.startswith("token"), 
    dim_node=lambda e: False, 
    word_level=True, 
    distance=2
)
pic.prepend(colors)
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_token_graph_word_classification_with_features.tex"))

pic = generate_word_graph(tokens, doc, add_dependency_edges=False)
pic = add_encoder_box(pic, word_graph=True)
pic = add_word_classification_head(
    pic, tokens, doc, 
    dim_edge=lambda e: e.dst_name.startswith("token") or e.dst_name.startswith("word"), 
    dim_node=lambda e: e.name.startswith("token"), 
    word_level=True
)
pic.prepend(colors)
pic.save(os.path.join(LATEX_FIGURE_DIR, "tikz_test_word_graph_word_classification.tex"))

#### Compile all pdf

In [223]:
compile_pdf()

0


### Tokenization repair

In [11]:
# generate tokenization repair approach figures
from nsc.utils import tokenization_repair

LATEX_DIR = "../../../../masters_project/tokenization-repair-transformer/visualizations"

sequence = "re pairthi s"
correct = "repair this"

def generate_encoder_only(s: str, c: str) -> TikzPicture:
    char_nodes = []
    repair_token_nodes = []
    edges = []
    boxes = []
    
    sample_eo_node = 7
    
    repair_tokens = tokenization_repair.get_whitespace_operations(s, c)
    assert len(s) == len(repair_tokens)
    
    char_node_kwargs = {"shape": "ellipse", "fill": "uni_light_gray"}
    repair_node_kwargs = {"shape": "diamond", "distance": "2cm"}
    
    rt_to_color = {
        0: ("uni_orange!50", "black"),
        1: ("uni_blue!50", "black"),
        2: ("uni_red!50", "black")
    }
    
    angle_step = 90 / len(s)
    
    for i, (char, rt) in enumerate(zip(s, repair_tokens)):
        char_node_kwargs.update({"name": f"eo_{i}", "value": f"'{char}'"})
        if i == 0:
            char_node = Node(x=0, y=0, **char_node_kwargs)
        else:
            char_node = RelNode(rel=char_nodes[-1], direction="right", **char_node_kwargs)
        char_nodes.append(char_node)
        
        fill_rt, color_rt = rt_to_color[rt]
        repair_token_nodes.append(
            RelNode(name=f"eo_rt_{i}", rel=char_nodes[-1], direction="below", value=str(rt), fill=fill_rt, color=color_rt, **repair_node_kwargs)
        )
        
        # if len(char_nodes) > 1:
        #     edges.append(
        #         Edge(
        #             src=char_nodes[-1],
        #             dst=char_nodes[-2],
        #             direction="bidirected",
        #             style="densely dashed"
        #         )
        #     )
        
        pos_diff = i - sample_eo_node
        edges.append(
            Edge(
                src=char_nodes[-1],
                dst=f"eo_rt_{sample_eo_node}",
                thickness="thin",
                style="densely dotted",
                direction="directed",
                in_deg=90 - pos_diff * angle_step,
                out_deg=-90 - pos_diff * angle_step
            )
        )
        
        # edges.append(
        #     Edge(
        #         src=char_nodes[-1],
        #         dst=repair_token_nodes[-1],
        #         direction="directed"
        #     )
        # )
        
        boxes.append(
            Box(fit=[char_node, repair_token_nodes[-1]], style="densely dotted", shape="rectangle")
        )
        
    char_box = Box(
        fit=char_nodes, 
        style="dashed", 
        label=("Input", "left"),
        padding=(0.25, 0.25, 0.25, 0.25)
    )
    rt_box = Box(
        fit=repair_token_nodes, 
        style="dashed", 
        label=("Output", "left"),
        padding=(0.25, 0.25, 0.25, 0.25)
    )
    eo_box = Box(
        fit=char_nodes + repair_token_nodes, 
        style="line", 
        label=("Encoder only", "above"),
        padding=(1.5, 0.5, 0.5, 0.5)
    )
    return TikzPicture(False).append(char_nodes).append(repair_token_nodes).append(boxes).append(edges).append(char_box).append(rt_box).append(eo_box)
    
    
def generate_nmt(s: str, c: str) -> TikzPicture:
    input_nodes = []
    edges = []
    
    sample_output_node = 7
    
    angle_step = 90 / len(s)
    
    char_node_kwargs = {"fill": "uni_light_gray", "shape": "ellipse"}
    
    for i, char in enumerate(s):
        char_node_kwargs.update({"name": f"nmt_input_{i}", "value": f"'{char}'"})
        if i == 0:
            char_node = Node(x=0, y=0, **char_node_kwargs)
        else:
            char_node = RelNode(rel=input_nodes[-1], direction="right", **char_node_kwargs)
        input_nodes.append(char_node)
        
        # if len(input_nodes) > 1:
        #     edges.append(
        #         Edge(
        #             src=input_nodes[-1],
        #             dst=input_nodes[-2],
        #             direction="bidirected",
        #             style="densely dashed"
        #         )
        #     )
        
        pos_diff = i - sample_output_node
        edges.append(
            Edge(
                src=input_nodes[-1],
                dst=f"nmt_output_{sample_output_node}",
                thickness="thin",
                style="densely dotted",
                direction="directed",
                in_deg=90 - pos_diff * angle_step,
                out_deg=-90 - pos_diff * angle_step
            )
        )
    
    
    char_node_kwargs.update({"name": "nmt_bos", "value": "<bos>"})
    bos_node = RelNode(rel=input_nodes[0], direction="below", distance="2cm", **char_node_kwargs)
    output_nodes = [bos_node]
    edges.append(
        Edge(
            src=bos_node,
            dst=f"nmt_output_{sample_output_node}",
            direction="directed",
            style="densely dotted",
            thickness="thin",
            out_deg=-20,
            in_deg=200
        )
    )
    
    for i, char in enumerate(c):
        node_color = "black" if i <= sample_output_node else "black!40"
        node_fill = "uni_light_gray" if i <= sample_output_node else "uni_dark_gray!40"
        char_node_kwargs.update(
            {"name": f"nmt_output_{i}", "value": f"'{char}'", "color": node_color, "fill": node_fill}
        )
        char_node = RelNode(rel=output_nodes[-1], direction="right", **char_node_kwargs)
        output_nodes.append(char_node)
        
        edges.append(
            Edge(
                src=output_nodes[-2],
                dst=output_nodes[-1],
                direction="directed",
                style="line",
                value=f"{len(output_nodes) - 1}.",
                color="black" if i <= sample_output_node else "uni_dark_gray!40"
            )
        )
        
        if i < sample_output_node:
            edges.append(
                Edge(
                    src=output_nodes[-1],
                    dst=f"nmt_output_{sample_output_node}",
                    direction="directed",
                    style="densely dotted",
                    thickness="thin",
                    out_deg=-20,
                    in_deg=200
                )
            )
        
            
    char_node_kwargs.update({"name": "nmt_eos", "value": "<eos>"})
    eos_node = RelNode(rel=output_nodes[-1], direction="right", **char_node_kwargs)
    output_nodes.append(eos_node)
    
    edges.append(
        Edge(
            src=output_nodes[-2], 
            dst=eos_node, 
            direction="directed", 
            value=f"{len(c) + 1}.",
            color="uni_dark_gray!40"
        )
    )
    
    output_nodes = output_nodes + [bos_node, eos_node]
        
    input_box = Box(
        fit=input_nodes, 
        style="dashed", 
        label=("Input", "left"),
        padding=(0.25, 0.25, 0.25, 0.25)
    )
    output_box = Box(
        fit=output_nodes, 
        style="dashed", 
        label=("Output", "left"),
        padding=(0.25, 0.25, 0.25, 0.25)
    )
    nmt_box = Box(
        fit=input_nodes + output_nodes, 
        style="line", 
        label=("Encoder Decoder", "above"),
        padding=(1.5, 0.5, 0.5, 0.5)
    )
    return TikzPicture(False).append(input_nodes).append(output_nodes).append(edges).append(input_box).append(output_box).append(nmt_box)
    
eo_pic = generate_encoder_only(sequence, correct)
eo_pic.prepend(colors)
eo_pic.save(os.path.join(LATEX_DIR, "approach_eo.tex"))

nmt_pic = generate_nmt(sequence, correct)
nmt_pic.prepend(colors)
nmt_pic.save(os.path.join(LATEX_DIR, "approach_nmt.tex"))