In [1]:
import pandas as pd
import numpy as np
import sqlite3
from tqdm import tqdm
import pickle
from pandarallel import pandarallel
from time import time
from tokenizers import Tokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from scipy.stats import linregress
from vocabulary_functions import get_mutated, get_parents, set_difference, set_intersection, load_tokenizers, calc_agreement, calc_dice_idx_only

In [2]:
pandarallel.initialize(nb_workers=20, progress_bar=True)

INFO: Pandarallel will run on 20 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [31]:
# 'dataset': {'uniref50', 'uniref90'}
# 'is_pretokenizer': {True, False}
# 'subs_matrix': {'blosum45', 'blosum62', 'pam70', 'pam250'}
# 'mutation_cutoff': {0.7, 0.8, 0.9}
# 'min_mutation_freq': {0, 0.05,. 0.005}
# 'min_mutation_len': {3}
# 'max_mutation_len': {12}
# 'vocab_size': list=[800, 1600, 3200, 6400, 12800, 25600, 51200]

vocab_sizes = [800, 1600, 3200, 6400, 12800, 25600]
uniref_id = "50"

tokenizer_opts_list = [
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': False,
        'subs_matrix': 'blosum62',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': False,
        'subs_matrix': 'pam70',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': True,
        'subs_matrix': 'blosum62',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
    {
        'is_mut': True,
        'dataset': f'uniref{uniref_id}',
        'is_pretokenizer': True,
        'subs_matrix': 'pam70',
        'mutation_cutoff': 0.7,
        'min_mutation_freq': 0.05,
        'min_mutation_len': 3,
        'max_mutation_len': 12,
        'vocab_size': vocab_sizes
    },
]

In [32]:
tokenizer_list = load_tokenizers(tokenizer_opts_list, 'hf')
inner_vocab_list = load_tokenizers(tokenizer_opts_list, 'vocab')

vocab_list = {}
for name, tokenizer in tokenizer_list.items():
    vocab_list[name] = list(set([token for token, idx in tokenizer.get_vocab().items()]))

In [130]:
methods = [method_name[:-len(str(vocab_sizes[0]))-1] for method_name in list(tokenizer_list.keys())[::len(vocab_sizes)]]
methods2names = {mn:mn.replace('mut', 'evo').replace('std', '').replace('blosum', 'BLOSUM').replace('pam', 'PAM').replace('pre', 'Pre') for mn in methods}
methods2names = {k: ' '.join(v.split()[:-2]) if 'evoBPE' in v else v for k, v in methods2names.items()}
methods2names

{'mutBPE blosum62 0.7 0.05': 'evoBPE BLOSUM62',
 'mutBPE pam70 0.7 0.05': 'evoBPE PAM70',
 'mutBPE pre blosum62 0.7 0.05': 'evoBPE Pre BLOSUM62',
 'mutBPE pre pam70 0.7 0.05': 'evoBPE Pre PAM70'}

In [117]:
vocab_lineage_list = {}
for k, v in inner_vocab_list.items():
    template_dict = {
        
    }
    vocab_lineage_list[k] = {token:{
                                'frequency': -1,
                                'order': -1,
                                'parent_pair': [],
                                'parent_mutation': "",
                                'parent_mutation_similarity': -1,
                                'partner_pair_self': False,
                                'partner_pair_left': [],
                                'partner_pair_right': [],
                                'child_pair': [],
                                'child_mutation': []
                            } for token in v.keys()}

for method_name, vocab in tqdm(inner_vocab_list.items()):
    for token, inner_vocab_elements in vocab.items():
        vocab_lineage_list[method_name][token]['frequency'] = inner_vocab_elements['frequency']
        vocab_lineage_list[method_name][token]['order'] = inner_vocab_elements['order']
        vocab_lineage_list[method_name][token]['parent_pair'] = inner_vocab_elements['pair'] if 'pair' in inner_vocab_elements else []
        vocab_lineage_list[method_name][token]['parent_mutation'] = inner_vocab_elements['parent'] if 'parent' in inner_vocab_elements else ""
        vocab_lineage_list[method_name][token]['parent_mutation_similarity'] = inner_vocab_elements['similarity'] if 'similarity' in inner_vocab_elements else -1

        if 'pair' in inner_vocab_elements:
            if inner_vocab_elements['pair'][0] == inner_vocab_elements['pair'][1]:
                vocab_lineage_list[method_name][inner_vocab_elements['pair'][0]]['partner_pair_self'] = True
                vocab_lineage_list[method_name][inner_vocab_elements['pair'][0]]['child_pair'].append(token)
            else:
                vocab_lineage_list[method_name][inner_vocab_elements['pair'][0]]['partner_pair_right'].append(inner_vocab_elements['pair'][1])
                vocab_lineage_list[method_name][inner_vocab_elements['pair'][1]]['partner_pair_left'].append(inner_vocab_elements['pair'][0])
                vocab_lineage_list[method_name][inner_vocab_elements['pair'][0]]['child_pair'].append(token)
                vocab_lineage_list[method_name][inner_vocab_elements['pair'][1]]['child_pair'].append(token)
        if 'parent' in inner_vocab_elements:
                vocab_lineage_list[method_name][inner_vocab_elements['parent']]['child_mutation'].append(token)

100%|██████████| 24/24 [00:00<00:00, 84.75it/s] 


In [None]:
method_name = 'mutBPE blosum62 0.7 0.05 3200'

In [None]:
import networkx as nx
from pyvis.network import Network
import random

def create_vocabulary_graph(vocabulary):
    # Create a directed graph
    G = nx.DiGraph()
    
    # Add all tokens as nodes
    for token, info in vocabulary.items():
        if len(token) > 2:
            # Add node with attributes
            G.add_node(token, 
                    frequency=info['frequency'],
                    order=info['order'],
                    title=f"Token: {token}\nFreq: {info['frequency']}\nOrder: {info['order']}")
    
    # Add edges for parent pairs
    for token, info in vocabulary.items():
        if info['parent_pair']:
            # Add edges from both parents to the token
            if info['parent_pair'][0] in G.nodes and token in G.nodes:  # Check if parent exists in vocabulary
                G.add_edge(info['parent_pair'][0], token, 
                             relationship="parent_left",
                             title="Parent Left",
                             color='blue')
            if info['parent_pair'][1] in G.nodes and token in G.nodes:  # Check if parent exists in vocabulary
                G.add_edge(info['parent_pair'][1], token, 
                            relationship="parent_right",
                            title="Parent Right",
                            color='blue')
    
    # Add edges for parent mutations
    for token, info in vocabulary.items():
        if info['parent_mutation']:
            if info['parent_mutation'] in G.nodes and token in G.nodes:  # Check if parent mutation exists
                G.add_edge(info['parent_mutation'], token,
                          relationship="parent_mutation",
                          title=f"Parent Mutation\nSimilarity: {info['parent_mutation_similarity']:.2f}",
                          color='green')
    
    # # Add edges for partner pairs
    # for token, info in vocabulary.items():
    #     # Left partners
    #     for partner in info['partner_pair_left']:
    #         if partner in vocabulary:
    #             child_token = next((child for child in info['child_pair'] 
    #                               if all(p in child for p in [token, partner])), None)
    #             if child_token:
    #                 G.add_edge(token, child_token,
    #                          relationship="partner_pair_left",
    #                          title="Partner Pair (Left)",
    #                          color='red')
    #                 G.add_edge(partner, child_token,
    #                          relationship="partner_pair_right",
    #                          title="Partner Pair (Right)",
    #                          color='red')
        
    #     # Right partners
    #     for partner in info['partner_pair_right']:
    #         if partner in vocabulary:
    #             child_token = next((child for child in info['child_pair'] 
    #                               if all(p in child for p in [token, partner])), None)
    #             if child_token:
    #                 G.add_edge(partner, child_token,
    #                          relationship="partner_pair_left",
    #                          title="Partner Pair (Left)",
    #                          color='red')
    #                 G.add_edge(token, child_token,
    #                          relationship="partner_pair_right",
    #                          title="Partner Pair (Right)",
    #                          color='red')
    
    # # Add edges for child mutations
    # for token, info in vocabulary.items():
    #     for child in info['child_mutation']:
    #         if child in vocabulary:
    #             G.add_edge(token, child,
    #                       relationship="child_mutation",
    #                       title="Child Mutation",
    #                       color='purple')
    
    return G

def visualize_graph(G, output_file='vocabulary_graph.html'):
    # Create a Pyvis network
    net = Network(height='750px', width='100%', bgcolor='#ffffff', 
                 font_color='#000000')
    
    # Configure physics
    net.force_atlas_2based()
    net.show_buttons(filter_=['physics'])
    
    # Add nodes with size based on frequency and color based on order
    for node in G.nodes(data=True):
        token = node[0]
        attrs = node[1]
        
        # Calculate node size based on frequency (log scale)
        size = 10 + (attrs.get('frequency', 1) ** 0.5) / 10
        
        # Calculate color based on order (gradient from green to red)
        order = attrs.get('order', 0)
        max_order = max(attrs.get('order', 0) for _, attrs in G.nodes(data=True))
        color_value = int(255 * (order / max_order))
        color = f'rgb({color_value},100,{255-color_value})'
        
        net.add_node(token, 
                    title=attrs.get('title', ''),
                    size=size,
                    color=color)
    
    # Add edges with their properties
    for edge in G.edges(data=True):
        net.add_edge(edge[0], edge[1],
                    title=edge[2].get('title', ''),
                    color=edge[2].get('color', '#888888'))
    
    # Save the graph
    net.save_graph(output_file)
    return net

In [212]:
# Example usage:
vocabulary = {
    'TLL': {
        'frequency': 5437,
        'order': 189,
        'parent_pair': ['T', 'LL'],
        'parent_mutation': 'SLL',
        'parent_mutation_similarity': 0.75,
        'partner_pair_self': False,
        'partner_pair_left': [],
        'partner_pair_right': [],
        'child_pair': [],
        'child_mutation': []
    },
    'EKPY': {
        'frequency': 3063,
        'order': 954,
        'parent_pair': ['EK', 'PY'],
        'parent_mutation': '',
        'parent_mutation_similarity': -1,
        'partner_pair_self': False,
        'partner_pair_left': ['HTG', 'HSG'],
        'partner_pair_right': [],
        'child_pair': ['HTGEKPY', 'HSGEKPY'],
        'child_mutation': ['ERPY', 'EKPF', 'ERPF']
    }
}

method_name = 'mutBPE blosum62 0.7 0.05 1600'
vocabulary = vocab_lineage_list[method_name]

# Create and visualize the graph
G = create_vocabulary_graph(vocabulary)

In [202]:
net = visualize_graph(G, method_name+'.html')

In [218]:
'P' in G.nodes

False

In [213]:
G.edges[('P', 'PR')]

KeyError: "The edge ('P', 'PR') is not in the graph."

In [195]:
G.edges[('LLL', 'LLV')]

{'relationship': 'parent_mutation',
 'title': 'Parent Mutation\nSimilarity: 0.75',
 'color': 'green'}

In [196]:
G.edges[('LLV', 'LLL')]

KeyError: "The edge ('LLV', 'LLL') is not in the graph."