In [None]:
import json
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import random
import sys
import torch

In [None]:
def get_info(config):
    fonts = set()
    letters = set()
    cases = set()
    for i, file_name in enumerate(os.listdir(config['class_representation_dir'])):
        if file_name.endswith('.pt'):
            details = file_name.split('.')[0].split('_')
            letter = details[0] 
            font = '_'.join(details[1:-1])
            case = details[-1]
            letters.add(letter)
            fonts.add(font)
            cases.add(case)
    return fonts, letters, cases


def visualize(graph):
    plt.close()
    plt.figure(figsize=(10, 10))
    labeldict = {}
    for node in graph.nodes():
        labeldict[node] = graph.nodes[node]['letter']
    nx.draw(graph, labels=labeldict, with_labels=True)
    plt.show()


def get_representation_path(config, letter, font, case):
    return os.path.join(config['class_representation_dir'], '%s_%s_%s.pt' % (letter, font, case))


def does_combination_exist(config, letter, font, case):
    if os.path.exists(get_representation_path(config, letter, font, case)):
        return True
    else:
        return False


def add_representations(config, fonts, letters, cases, graph):
    for node in graph.nodes():
        letter = graph.nodes[node]['letter']
        graph.nodes[node]['representations'] = []
        for j in range(config['number_of_representations']):
            random_font = random.choice(list(fonts))
            random_case = random.choice(list(cases))
            while not does_combination_exist(config, letter, random_font, random_case):
                random_font = random.choice(list(fonts))
                random_case = random.choice(list(cases))
            graph.nodes[node]['representations'].append(
                (
                    torch.load(
                        get_representation_path(config, letter, random_font, random_case)
                    ).tolist(),  # pytorch tensor is not serializable),
                    random_font,
                    random_case
                )
            )
    return graph


def create_pool(config, letters):
    tree_pool = []
    for i in range(config['trees_in_pool']):
        random_number_of_nodes = random.randint(*config['nodes_per_tree'])
        random_tree = nx.random_powerlaw_tree(random_number_of_nodes, tries=10000)
        rename_mapping = {}
        for node_name in random_tree.nodes:
            rename_mapping[node_name] = '%d_%s' % (i, node_name)
        random_tree = nx.relabel_nodes(random_tree, rename_mapping, copy=False)
        for node in random_tree.nodes():
            letter = random.choice(list(letters))
            random_tree.nodes[node]['letter'] = letter
        tree_pool.append(random_tree)
    return tree_pool


def get_random_combination(config, tree_pool, fonts, letters, cases):
    random_number_of_trees = random.randint(*config['trees_per_template'])
    random_trees = random.sample(tree_pool, random_number_of_trees)
    graph = random_trees[0]
    for i in range(1, len(random_trees)):
        random_node_of_graph = random.choice(list(graph.nodes()))
        random_node_of_tree = random.choice(list(random_trees[i].nodes()))
        graph = nx.union(graph, random_trees[i])
        graph.add_edge(random_node_of_graph, random_node_of_tree)
    graph = add_representations(config, fonts, letters, cases, graph)
    return graph


def create_dataset(config):
    fonts, letters, cases = get_info(config)
    random.seed(config['random_seed'])
    np.random.seed(config['random_seed'])
    if not os.path.exists(config['output_dir']):
        os.makedirs(config['output_dir'])
    tree_pool = create_pool(config, letters)
    graphs = []
    for i in range(config['number_of_templates']):
        sys.stdout.write('\r%d' % (i + 1))
        sys.stdout.flush()
        graph = get_random_combination(config, tree_pool, fonts, letters, cases)
        if config['visualize']:
            visualize(graph)
        with open(os.path.join(config['output_dir'], '%d.json' % i), 'w') as output_file:
            output_file.write(json.dumps(nx.node_link_data(graph)))

In [None]:
config = {
    'trees_in_pool': 10,
    'nodes_per_tree': (8, 12),
    'trees_per_template': (3, 7),
    'number_of_templates': 200,
    'number_of_representations': 10,
    'output_dir': 'templates',
    'class_representation_dir': 'representations',
    'random_seed': 42,
    'visualize': False,
}
create_dataset(config)

In [None]:
!zip -r templates.zip templates