# Setup

In [None]:
!pip list

## Installations

In [None]:
%pip install -r "requirements.txt"

In [None]:
%pip uninstall -y triton

## Imports

In [None]:
import pandas as pd
import gffpandas.gffpandas as gffpd
from py2neo import Graph, Node, Relationship, cypher
from bisect import bisect_left as bisect

In [None]:
from typing import Sequence, Callable
from random import Random, randrange
import Levenshtein as lev
from multiprocessing.pool import ThreadPool as Pool

In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModel

In [None]:
import transformers
transformers.__version__

In [None]:
def thread_starmap(arg_list, *, threads, progress_step=255):
  def progress(f):
    def inner(i, args):
      if not i&progress_step: print('Invocation', i)
      return f(*args)

  def inner(fn):
    with Pool(threads) as p:
      print('Running map...')
      p.starmap(progress(fn), enumerate(arg_list))
      print('Map complete. Destroying pool')
    print('Pool complete.')
    return fn
  return inner

## Connection

In [None]:
import os
from dotenv import load_dotenv
load_dotenv()

graph = Graph("neo4j+s://"+os.getenv("neo4j_connection_url"), auth=(os.getenv("neo4j_user"), os.getenv("neo4j_password")))

# Check connection
graph.run("RETURN 1")

# Main Graph

In [None]:
def add_triple_to_graph(i, triple):

  if not i&255: print('Triple number', i)
      
  h, r, t = triple
  # Create or match the nodes, 'name' is a property
  head_node = Node("Entity", name=h)
  tail_node = Node("Entity", name=t)

  # Merge ensures nodes are only created if they don't already exist
  graph.merge(head_node, "Entity", "name")
  graph.merge(tail_node, "Entity", "name")

  # Create the relationship (head_node)-[relation]->(tail_node)    
  relationship = Relationship(head_node, r, tail_node)

  # Merge the relationship into the graph
  graph.merge(relationship)

## Create Sequence and Identity nodes + relations

In [None]:
import csv
exon_data = []
path = "./data/exons/exon_sequences.csv"
hgnc = set()
labels = ["gene_id","transcript_id","hgnc_id","gene_type","gene_name","exon_id","protein_id","sequence"]
count = 0
with open(path, 'r') as f:
    reader = csv.reader(f)
    for row in reader:
        # hgnc.add(tuple(row))
        exon_data.append(dict(zip(labels, row)))

# print(f"number of unique hgnc ids {len(hgnc)}")
exon_data = exon_data[1:]
print(f"pulled {len(exon_data)} sequences!")
exon_data = [kv for kv in exon_data if len(kv)==len(labels)]
print(f"new length {len(exon_data)}")

In [None]:
# maybe make a decorator for the thread pool thingi?

# TODO: ADD GO ref ids to Identity nodes later
seq_node_id_node_query = '''
MERGE (s:Sequence {sequence: $sequence, exon_id: $exon_id, protein_id: $protein_id})-[:MAPS_TO]->(i:Identity {gene_id: $gene_id, hgnc_id: $hgnc_id, transcript_id: $transcript_id});
'''
def make_seq_node_with_id_node(i, d):
    # arr is of type - [ensembl_id, gene_type, gene_name, hgnc_id, seq]
    if not i&127: print(f"Processed {i} nodes!")
    graph.run(seq_node_id_node_query, **d)
    

In [None]:
sequence_identity_node_query = '''
MERGE (s:Sequence {sequence: $sequence, exon_id: $exon_id, protein_id: $protein_id})
MERGE (t:Transcript {transcript_id: $transcript_id})
MERGE (g:Gene {gene_id: $gene_id, hgnc_id: $hgnc_id})
MERGE (s)-[:MAPS_TO]->(t)
MERGE (t)-[:PART_OF]->(g);
'''
# identity_node_query = '''
# MERGE (i:Identity {gene_id: $gene_id, hgnc_id: $hgnc_id, transcript_id: $transcript_id})
# MATCH (e:Sequence {exon_id: $exon_id}), (i:Identity {gene_id: $gene_id})
# MERGE (e)-[:MAPS_TO]->(i);
# '''
# sequence_identity_relation_query = '''
# MATCH (e:Sequence {exon_id: $exon_id}), (i:Identity {gene_id: $gene_id})
# MERGE (e)-[:MAPS_TO]->(i);
# '''
count = 0
total = len(exon_data)
def make_seq_node_with_id_node(i, d):
    global count, total
    # arr is of type - [ensembl_id, gene_type, gene_name, hgnc_id, seq]
    if not count&127: print(f"Processed {count}/{total} nodes!")
    graph.run(sequence_identity_node_query, **d)
    count+=1

In [None]:
print('Creating pool')
with Pool(100) as p:
  print('Running map...')
  p.starmap(make_seq_node_with_id_node, enumerate(exon_data))
  print('Queries complete. Destroying pool')
print('Pool complete.')

## Map Ensembl ID to Go

In [None]:
import csv

labels = ["gene_id", "go"]
path = "./data/go/ensembl_go_protein_coding.csv"
go_data = []
go = set()
with open(path, "r") as f:
    reader = csv.reader(f)
    for row in reader:
        go_data.append(dict(zip(labels, row)))
        # go.add(row[1])

go_data = go_data[1:]
print(f"pulled {len(go_data)} records!")
# print(f"unique go ids: {len(go)}")

In [None]:
gene_go_mapper_query = '''
MATCH (s: Gene {gene_id: $gene_id})
MERGE (go: Go {go_id: $go})
MERGE (s)-[:HAS_GO]->(go);
'''
count = 0
total = len(go_data)
def map_gene_to_go(i, d):
    global count, total
    if not count&127: print(f"Processed {count}/{total} nodes!")
    graph.run(gene_go_mapper_query, **d)
    count+=1

In [None]:
print('Creating pool')
with Pool(100) as p:
  print('Running map...')
  p.starmap(map_gene_to_go, enumerate(go_data))
  print('Queries complete. Destroying pool')
print('Pool complete.')

In [None]:
import csv

go_triples = []
path = "./data/go/go_triples_protein_coding.csv"
labels = ["go_id", "relation_type", "property"]

with open(path, "r") as f:
    reader = csv.reader(f)
    for row in reader:
        go_triples.append(dict(zip(labels, row)))

go_triples = go_triples[1:]
print(go_triples[0])
print(go_triples[3])
print(f"pulled {len(go_triples)} go triples!")

In [None]:
count = 0
total = len(go_triples)
add_go_triple_query = '''
MATCH (go:Go {go_id: $go_id})
MERGE (gop:GoProperty {property: $property})
'''
relation = "MERGE (go)-[:{relation_type}]->(gop)"

def add_go_triple_to_graph(i, triple):
    global count, total
    if not count&127 and count!=0: print(f"processed {count}/{total} triples!")
    r = relation.format(relation_type=triple["relation_type"])
    q = add_go_triple_query+r
    graph.run(q, **triple)
    count+=1

In [None]:
print('Creating pool')
with Pool(100) as p:
  print('Running map...')
  p.starmap(add_go_triple_to_graph, enumerate(go_triples))
  print('Queries complete. Destroying pool')
print('Pool complete.')

# Search Tree

## Retrieving the Genes and their Node ids

In [None]:
id_map = graph.run('MATCH (g:Sequence) RETURN {g: g, id: elementId(g)}')

In [None]:
id_map = id_map.data()

In [None]:
len(id_map)
id_map[:10]

In [None]:
id_pairs: list[tuple[str, str]] = [(v['id'], v['g']['sequence']) for r in id_map for v in [r.values().__iter__().__next__()]]

print(len(id_pairs), 'pairs')

id_pairs[:5]

In [None]:
import os

os.environ["WANDB_DISABLED"] = "true"
model_name = "zhihan1996/DNABERT-2-117M"
path = "./data/dnabert2_model"
tokenizer, model = None, None
tokenizer_path, model_path = path+"/tokenizer/"+model_name, path+"/model/"+model_name

if os.path.isdir(tokenizer_path) and os.path.isdir(model_path):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.save_pretrained(tokenizer_path)
    model.save_pretrained(model_path)

# cuda_model = model.to("cuda")

## Distance function using DNABERT-2

In [None]:
def dnabert_embedding_mean(dna):
    # inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"].to('cuda')
    inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
    hidden_states = model(inputs)[0] # [1, sequence_length, 768]
    # print(inputs.shape)
    # print(hidden_states.shape)

    # embedding with mean pooling
    embedding_mean = torch.mean(hidden_states[0], dim=0)
    # print(embedding_mean.shape) # expect to be 768
    # print(embedding_mean)
    
    # # embedding with max pooling
    # embedding_max = torch.max(hidden_states[0], dim=0)[0]
    # print(embedding_max.shape) # expect to be 768
    
    return embedding_mean.detach()

In [None]:
def dnabert_embedding_distance(a, b):

    # print(f"got -> \n{a}\n{b}")
    try:
        vd1 = memo[a]
    except KeyError:
        print("KeyError: a:", a)
        
    try:
        vd2 = memo[b]
    except KeyError:
        print("KeyError: b:", b)
    
    return float(torch.linalg.vector_norm(vd1-vd2))

## Calculate and memoize embeddings. Store for future sessions.

In [None]:
import numpy as np

def embedding_maker_worker(i, tup):
    id, seq = tup
    with torch.no_grad():
        if not i&127: print(f"{i}: {len(seq)}")
        arr.append(len(seq))
        memo[seq] = dnabert_embedding_mean(seq)

# print(f"first seven - {arr[:8]}")
# print(f"avg seq len {np.mean(arr)}")

In [None]:
import os
import ast
import csv

path = "./data/working/dnabert2_embeddings.csv"
memo = {}
arr = []

if os.path.exists(path):
    print("Found existing dump of embeddings, loading them...")
    with open(path, "r") as f:
        reader = csv.reader(f)
        for row in reader:
            seq, embedding = row
            embedding = torch.tensor(ast.literal_eval(embedding))
            memo[seq] = embedding
else:
    print("Embeddings don't exist, creating...")
    with Pool(200) as p:
      print('Running map...')
      p.starmap(embedding_maker_worker, enumerate(id_pairs[::-1]))
      print('Queries complete. Destroying pool')

    print("Writing Rows...")
    with open(path, "w") as f:
        writer = csv.writer(f)
        for seq, embd in memo.items():
            writer.writerow([seq, embd.tolist()])

print("Done!")

In [None]:
# !rm "./data/working/dnabert2_embeddings.csv"

In [None]:
memo["CCCAGATCTCTTCAG"]
print(f"size of id_pairs: {len(id_pairs)}")
print(f"size of memo: {len(memo)}")

In [None]:
print(dnabert_embedding_distance(
  "CCCAGATCTCTTCAG",
  "TTTTTATGCCTCATTCTGTGAAAATTGCTGTAGTCTCTTCCAGTTATGAAGAAG",
))

print(len(memo))

## Construction

In [None]:
query = '''

  MATCH (r:Sequence WHERE elementId(r) = $root_id)

  MATCH (n:Sequence WHERE elementId(n) = $n_id)

  SET r.thresh = $thresh

  SET r.limit = $limit

  SET n.level = $level

  CREATE (r)-[:SEARCH_BRANCH {branch_type: $branch_type}]->(n)

'''

In [None]:
def create_tree_instructions(graph, id_pairs: list[tuple[str, str]], distance: Callable[[Sequence, Sequence], float], r: Random, level = 0) -> tuple[str, str, list[dict]] | None:

  if not id_pairs: return None

  out = []



  l = len(id_pairs)

  root_idx = r.randrange(l)

  root_id, root = id_pairs.pop(root_idx)



  s = sorted(id_pairs, key=lambda g: distance(root, g[1]))

  split_pos = l//2

  if not s: limit = thresh = 0.0

  elif l//2 >= len(s): limit = thresh = distance(root, s[-1][1])

  else:
    print(s)

    distances = [distance(root, g[1]) for g in s]

    thresh = distances[l//2]

    limit = distances[-1]

    # if limit != thresh: print('EXPECTED', thresh, limit)

    split_pos = bisect(distances, thresh)

    # out.append({'id': root_id, 'limit': limit, 'thresh': thresh, 'far_subtree': (distances[split_pos], distances[-1])})



  near_result = create_tree_instructions(graph, s[:split_pos], distance, r, level=level+1)

  if near_result is not None:

    out.extend(near_result[2])

    out.append({

        "branch_type": "NEAR",

        "root_id": root_id,

        "n_id": near_result[0],

        "thresh": thresh,

        "level": level+1,

        "limit": limit,

    })



  far_result = create_tree_instructions(graph, s[split_pos:], distance, r, level=level+1)

  if far_result is not None:

    out.extend(far_result[2])

    out.append({

        "branch_type": "FAR",

        "root_id": root_id,

        "n_id": far_result[0],

        "thresh": thresh,

        "level": level+1,

        "limit": limit,

    })



  print(level, root_id, root[:64])



  return root_id, root, out

In [None]:
def edit_distance(a: Sequence, b: Sequence) -> float:  # levenshtein distance

  ai = 0

  bi = 0

  d = 0.0

  while ai < len(a) and bi < len(b):

    if a[ai] == b[bi]:  # no edit

      ai += 1

      bi += 1

    elif ai > 0 and a[ai-1] == b[bi]:  # add to a / delete from b

      d += 1

      bi += 1

    elif bi > 0 and a[ai] == b[bi-1]:  # add to b / delete from a

      d += 1

      ai += 1

    else:  # replacement

      d += 1

      ai += 1

      bi += 1



  d += len(a)-ai

  d += len(b)-bi



  return d

In [None]:
# seed = randrange(1<<32)

# seed = 0x509c2232

seed = 0xf67832cb

print(f'seed = 0x{seed:08x}')

r = Random(seed)
     
# root_id, root, kwarg_list = create_tree_instructions(graph, id_pairs.copy(), edit_distance, r)
root_id, root, kwarg_list = create_tree_instructions(graph, id_pairs.copy(), dnabert_embedding_distance, r)

root_id, root

In [None]:
# Takes 1-2 seconds. *Blazingly Fast*

print('Creating pool')

with Pool(100) as p:

  print('Running map...')

  p.map(

    lambda kwargs: graph.run(query, **kwargs),

    kwarg_list,

  )

  print('Queries complete. Destroying pool')

print('Pool complete.')

## Verification

In [None]:
def tree_testing(graph, id_pairs: list[tuple[str, str]], distance: Callable[[Sequence, Sequence], float], r: Random, level = 0) -> tuple[str | None, str | None, list[dict]]:

  if not id_pairs: return None, None, []

  out = []



  l = len(id_pairs)

  root_idx = r.randrange(l)

  root_id, root = id_pairs.pop(root_idx)



  s = sorted(id_pairs, key=lambda g: distance(root, g[1]))

  split_pos = l//2

  if not s:

    thresh = 0.0

    limit = 0.0

  elif l//2 >= len(s): limit = thresh = distance(root, s[-1])

  else:

    distances = [distance(root, g[1]) for g in s]

    thresh = distances[l//2]

    limit = distances[-1]

    if limit != thresh: print('EXPECTED', thresh, limit)

    split_pos = bisect(distances, thresh)

    out.append({'id': root_id, 'limit': limit, 'thresh': thresh, 'far_subtree': (distances[split_pos], distances[-1])})



  near_result = tree_testing(graph, s[:split_pos], distance, r, level=level+1)

  out.extend(near_result[2])

  far_result = tree_testing(graph, s[split_pos:], distance, r, level=level+1)

  out.extend(far_result[2])



  if level < 7: print(level, root_id, root[:64])



  return root_id, root, out



seed = 0xf67832cb

print(f'seed = 0x{seed:08x}')

r = Random(seed)

# root_id, root, lev_properties = tree_testing(graph, id_pairs.copy(), lev.distance, r)

root_id, root, lev_properties = tree_testing(graph, id_pairs.copy(), edit_distance, r)

root_id, root

In [None]:
[prop for prop in lev_properties if len({*prop['far_subtree']}) == 2]

In [None]:
len(lev_properties)

In [None]:
root_id

In [None]:
len(kwarg_list)

In [None]:
# Check if update was successful

l = [(d['r']['thresh'], d['r']['limit']) for d in graph.run('MATCH  (r:Sequence)-[:SEARCH_BRANCH]->(n) RETURN r, n').data() if d['r']['thresh'] != d['r']['limit']]

print(len(l))

l

In [None]:
unique_kwargs = {(kwarg['n_id'], kwarg['root_id']) for kwarg in kwarg_list}

len(unique_kwargs)

In [None]:
for i, kwargs in enumerate(kwarg_list[:15]):

  print('''

    MATCH (r:Gene WHERE elementId(r) = {root_id!r})

    MATCH (n:Gene WHERE elementId(n) = {n_id!r})

    SET r.thresh = {thresh!r}

    CREATE (r)-[:SEARCH_BRANCH {{branch_type: {branch_type!r}}}]->(n)

    RETURN 1

  '''.format(**kwargs))

## Search

In [None]:
def search(root_id, root_name, tree_thresh, tree_limit, q, qthresh, paths=1, distance=dnabert_embedding_distance) -> tuple[list[tuple[float, str, str]], int]:

  # print("dist->")
  # print(q)
  # print(root_name)
  if not tree_thresh and not tree_limit:
      return ([root_name], paths)
      
  d = distance(q, root_name)  

  if tree_limit <= qthresh-d:

    print('SUBTREE BATCH')

    results = graph.run(

      '''

      MATCH (r WHERE elementId(r) = $root_id)-[:SEARCH_BRANCH*1..]->(c)

      RETURN {child_id: elementId(c), child: c.sequence}

      ''',

      root_id=root_id,

    ).data()



    out = []

    for result in results:

      [values] = [*result.values()]

      out.append((distance(q, values['child']), values['child_id'], values['child']))

    return out, paths



  elif d <= qthresh:

    out = [(d, root_id, root_name)]

  else:

    out = []



  if tree_thresh is None: return out, paths



  result = graph.run(

    '''

    MATCH (r WHERE elementId(r) = $root_id)-[sb:SEARCH_BRANCH]->(c)

    RETURN {child_id: elementId(c), child: c.sequence, thresh: c.thresh, limit: c.limit, branch: sb.branch_type}

    ''',

    root_id=root_id,

  ).data()



  children: dict[str, tuple[str, str, str, str]] = {}

  for child in result:

    [values] = [*child.values()]

    children[values['branch']] = values['child_id'], values['child'], values['thresh'], values["limit"]



  if abs(d-tree_thresh) <= qthresh:

    paths += 1

    print(f'Paths: {paths}, Diff: {abs(d-tree_thresh)}')


  # print("near:", children["NEAR"])
  # print("far:", children["FAR"])
    
  if d-qthresh < tree_thresh and 'NEAR' in children:
    # print("NEAR: calling but not passing q:", q)
    terms, paths = search(*children['NEAR'], q, qthresh, paths)

    out.extend(terms)

  elif d+qthresh >= tree_thresh and 'FAR' in children:
    # print("FAR: calling but not passing q:", q)
    terms, paths = search(*children['FAR'], q, qthresh, paths)

    out.extend(terms)



  return (out, paths)

In [None]:
q_result, paths = search(root_id, root, root_thresh, root_limit, root, 30000)

print(len(q_result), 'similar genes')

sorted(q_result)

## Searching on exon graph

In [None]:
root_thresh = 1.7634133100509644  # queried from the neo4j console for ({level: 0})
root_limit = 6.264886856079102  # queried from the neo4j console for ({level: 0})
root_id, root = ('4:c81c1c99-e897-44b9-b8cc-e4b6e264a420:1956', 'GCCCCACGTGTGTGCTGAGCAGGAGCTGACCCTGGTGGGCCGCCGCCAGCCGTGCGTGCAGGCCTTAAGCCACACGGTGCCGGTGTGGAAGGCCGGCTGTGGGTGGCAGGCGTGGTGCGTGGGTCATGAGCGGAG')

In [None]:
q_result, paths = search(root_id, root, root_thresh, root_limit, root, 1)
print(len(q_result), 'similar genes')
print("results!\n-----------")
print(q_result)
print(paths)

In [None]:
def similarity_search(q):
    memo[q] = dnabert_embedding_mean(q)
    q_result, paths = search(root_id, root, root_thresh, root_limit, q, 1)
    return q_result

In [None]:
q = "CCCAGATCTCTTCAG"
# print(memo[q])
similarity_search("CCCAGATCTCTTCAG")

In [None]:
similarity_search("GTCTGTGATGAGGAATGGCACCACTACGTCCTCAATGTAGAATTCCCGAGTGTGACTCTCTATGTGGATGGCACGTCCCACGAGCCCTTCTCTGTGACTGAGGATTACCCGCTCCATCCATCCCTCAGCTCGTGGTGGGGGCTTGCTGGCAAG")

In [None]:
print(similarity_search(root))

seqs = ["GGGCCCCCTCATAAATGTGCCTTAATTTTCGCAGATAACAGTGGAATAGACATCATTTTGGGAGTCTTCCCCTTTGTCAGGGAGCTACTCCTTAGAGGGACAGAG",
        "AGCCTGAAGCCGCTGCTGGAGAAGCGCCGGCGCGCGCGCATCAACCAGAGCCTGAGCCAGCTTAAGGGGCTCATCCTGCCGCTGCTGGGCCGGGAG",
        "GTCTGTGATGAGGAATGGCACCACTACGTCCTCAATGTAGAATTCCCGAGTGTGACTCTCTATGTGGATGGCACGTCCCACGAGCCCTTCTCTGTGACTGAGGATTACCCGCTCCATCCATCCAAGATAGAAACTCAGCTCGTGGTGGGGGCTTGCTGGCAAG",
        "GCCCAGTTCAATCTGCTGAGCAGCACCATGGACCAGATGAGCAGCCGCGCGGCCTCGGCCAGCCCCTACACCCCAGAGCACGCCGCCAGCGTGCCCACCCACTCGCCCTACGCACAACCCAGCTCCACCTTCGACACCATGTCGCCGGCGCCTGTCATCCCCTCCAACACCGACTACCCCGGACCCCACCACTTTGAGGTCACTTTCCAGCAGTCCAGCACGGCCAAGTCAGCCACCTGGACG",
        "GCTCGCGGGACCCCTGCTCCAACGTGACCTGCAGCTTCGGCAGCACCTGTGCGCGCTCGGCCGACGGGCTGACGGCCTCGTGCCTGTGCCCCGCGACCTGCCGTGGCGCCCCCGAGGGGACCGTCTGCGGCAGCGACGGCGCCGACTACCCCGGCGAGTGCCAGCTCCTGCGCCGCGCCTGCGCCCGCCAGGAGAATGTCTTCAAGAAGTTCGACGGCCCTTGTG"]

out = []
for x in seqs:
    out.append(similarity_search(x))

print(out)

# **Flask Server**