# Create database and populate with neuron explainer data

In [1]:
import sqlite3
import msgspec
import json
import os
import numpy as np
import umap

  from .autonotebook import tqdm as notebook_tqdm


## DB connect

DB schema created with [DB Designer](https://erd.dbdesigner.net/)

![DB schema graphic](ChinaGraph2024-gpt2small-db-schema.png)

In [4]:
db_conn = sqlite3.connect('./ChinaGraph2024.db')

Before proceeding, please initialise the `neurons` and `activations` tables in the database using the `./ChinaGraph2024-gpt2small-init.session.sql` query

## importing data

### functions for inserting db rows

In [4]:
def insert_neuron(
  cursor,
  layer_index,
  neuron_index,
  explanations,
  activations
):
  neuron_id = f"L{layer_index:02d}N{neuron_index:04d}"
  cursor.execute("""
    INSERT INTO neurons (
      id,
      layer_index,
      neuron_index,
      explanation_text,
      explanation_ev_correlation_score,
      explanation_rsquared_score,
      explanation_absolute_dev_explained_score,
      activation_mean,
      activation_variance,
      activation_skewness,
      activation_kurtosis
    )
    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
    """,
    (
      neuron_id,
      layer_index,
      neuron_index,
      explanations['scored_explanations'][0]['explanation'],
      explanations['scored_explanations'][0]['scored_simulation']['ev_correlation_score'],
      explanations['scored_explanations'][0]['scored_simulation']['rsquared_score'],
      explanations['scored_explanations'][0]['scored_simulation']['absolute_dev_explained_score'],
      activations['mean'],
      activations['variance'],
      activations['skewness'],
      activations['kurtosis']
    )
  )

In [5]:
def insert_activation(
  cursor,
  layer_index,
  neuron_index,
  category,
  encoded_tokens_str,
  encoded_activation_values_str,
):
  neuron_id = f"L{layer_index:02d}N{neuron_index:04d}"
  cursor.execute("""
    INSERT INTO activations (
      neuron_id,
      category,
      tokens,
      activation_values
    )
    VALUES (?, ?, ?, ?);
    """,
    (
      neuron_id,
      category,
      encoded_tokens_str,
      encoded_activation_values_str
    )
  )

In [32]:
def compact_json_encode(obj):
  return json.dumps(obj, separators=(',', ':'))

### start import

In [16]:
json_decoder = msgspec.json.Decoder()

for layer_index in range(12):
  for neuron_index in range(3072):
    explanations_file_path = os.path.join("data", "gpt2-small", "explanations", str(layer_index), str(neuron_index) + ".jsonl")
    activations_file_path = os.path.join("data", "gpt2-small", "activations", str(layer_index), str(neuron_index) + ".json")
    
    # Load both json files
    with open(explanations_file_path, 'rb') as f:
      explanations = json_decoder.decode(f.read())

    with open(activations_file_path, 'rb') as f:
      activations = json_decoder.decode(f.read())
        
    cursor = db_conn.cursor()
    insert_neuron(cursor, layer_index, neuron_index, explanations, activations)
    
    for record in activations['random_sample']:
      insert_activation(cursor, layer_index, neuron_index, 'random', compact_json_encode(record['tokens']), compact_json_encode(record['activations']))
    
    for record in activations['most_positive_activation_records']:
      insert_activation(cursor, layer_index, neuron_index, 'top', compact_json_encode(record['tokens']), compact_json_encode(record['activations']))
    
    db_conn.commit()
    
    print(f"L{layer_index:02d}N{neuron_index:04d} inserted")

L00N0000 inserted
L00N0001 inserted
L00N0002 inserted
L00N0003 inserted
L00N0004 inserted
L00N0005 inserted
L00N0006 inserted
L00N0007 inserted
L00N0008 inserted
L00N0009 inserted
L00N0010 inserted
L00N0011 inserted
L00N0012 inserted
L00N0013 inserted
L00N0014 inserted
L00N0015 inserted
L00N0016 inserted
L00N0017 inserted
L00N0018 inserted
L00N0019 inserted
L00N0020 inserted
L00N0021 inserted
L00N0022 inserted
L00N0023 inserted
L00N0024 inserted
L00N0025 inserted
L00N0026 inserted
L00N0027 inserted
L00N0028 inserted
L00N0029 inserted
L00N0030 inserted
L00N0031 inserted
L00N0032 inserted
L00N0033 inserted
L00N0034 inserted
L00N0035 inserted
L00N0036 inserted
L00N0037 inserted
L00N0038 inserted
L00N0039 inserted
L00N0040 inserted
L00N0041 inserted
L00N0042 inserted
L00N0043 inserted
L00N0044 inserted
L00N0045 inserted
L00N0046 inserted
L00N0047 inserted
L00N0048 inserted
L00N0049 inserted
L00N0050 inserted
L00N0051 inserted
L00N0052 inserted
L00N0053 inserted
L00N0054 inserted
L00N0055 i

### get embeddings for neuron explanations 

In [99]:
def update_explanation_embedding_by_id(cursor, id, embedding_json_string):
  cursor.execute("""
    UPDATE neurons
    SET explanation_embedding = ? 
    WHERE id = ?;
    """,
    (
      embedding_json_string,
      id,
    )
  )

In [100]:
def round_floats(o, ndigits=8):
    if isinstance(o, float):
        return round(o, ndigits)
    if isinstance(o, dict):
        return {k: round_floats(v, ndigits) for k, v in o.items()}
    if isinstance(o, (list, tuple)):
        return [round_floats(x, ndigits) for x in o]
    return o

In [103]:
from sentence_transformers import SentenceTransformer

cursor = db_conn.cursor()
model = SentenceTransformer('all-mpnet-base-v2')

for layer_index in range(12):
  neurons = cursor.execute(
    'SELECT id, explanation_text FROM neurons WHERE layer_index = ?', 
    (layer_index,)
  ).fetchall()
  
  explanations = [n[1] for n in neurons]
  
  print(f"Calculating embeddings for explanations in L{layer_index}")
  embeddings = model.encode(explanations, normalize_embeddings=True)
  
  print(f"Saving embeddings for explanations to DB for L{layer_index}")
  for neuron, embedding in zip(neurons, embeddings):
    update_explanation_embedding_by_id(cursor, neuron[0], compact_json_encode(round_floats(embedding.tolist(), 8)))
  db_conn.commit()

Calculating embeddings for explanations in L0
Saving embeddings for explanations to DB for L0
Calculating embeddings for explanations in L1
Saving embeddings for explanations to DB for L1
Calculating embeddings for explanations in L2
Saving embeddings for explanations to DB for L2
Calculating embeddings for explanations in L3
Saving embeddings for explanations to DB for L3
Calculating embeddings for explanations in L4
Saving embeddings for explanations to DB for L4
Calculating embeddings for explanations in L5
Saving embeddings for explanations to DB for L5
Calculating embeddings for explanations in L6
Saving embeddings for explanations to DB for L6
Calculating embeddings for explanations in L7
Saving embeddings for explanations to DB for L7
Calculating embeddings for explanations in L8
Saving embeddings for explanations to DB for L8
Calculating embeddings for explanations in L9
Saving embeddings for explanations to DB for L9
Calculating embeddings for explanations in L10
Saving embedd