In [1]:
import json
import glob

paths = glob.glob("../../www/data/policies/2021/policies.json")

policies = []
for json_path in paths:
  with open(json_path) as f:
    policies += json.load(f)


In [2]:
import pprint

def create_entry(prompt, policy):
  return '{prompt} TOPIC: {topic}, POLITICAL PARTY: {party}, POLICY DESCRIPTION: {policy_description}'.format(
    prompt=prompt,
    topic=policy["topic"],
    party=policy["party"],
    policy_description=policy["title"]["EN"]
  )

prompt = "clustering:"
corpus = [create_entry(prompt, policy) for policy in policies]
pprint.pp(corpus)

['clustering: TOPIC: economy, POLITICAL PARTY: Conservative, POLICY '
 'DESCRIPTION: Implement tax on digital services',
 'clustering: TOPIC: economy, POLITICAL PARTY: Conservative, POLICY '
 'DESCRIPTION: Allow tech companies to issue flow-through shares',
 'clustering: TOPIC: science, POLITICAL PARTY: Conservative, POLICY '
 'DESCRIPTION: Create the Canadian Advanced Research Agency',
 'clustering: TOPIC: housing, POLITICAL PARTY: Conservative, POLICY '
 'DESCRIPTION: Build 145,000 additional homes in 3&nbsp;years',
 'clustering: TOPIC: housing, POLITICAL PARTY: Liberal, POLICY DESCRIPTION: '
 'Build 120,000 additional homes, repair 130,000 homes in 4 years',
 'clustering: TOPIC: housing, POLITICAL PARTY: Liberal, POLICY DESCRIPTION: '
 'Home Buyers’ Bill of Rights',
 'clustering: TOPIC: housing, POLITICAL PARTY: Liberal, POLICY DESCRIPTION: '
 'Ban foreign investors not living in or moving to Canada from buying homes '
 'for 2 years',
 'clustering: TOPIC: housing, POLITICAL PARTY: C

In [3]:
import llama_cpp
import pprint
import numpy as np

bert = llama_cpp.Llama(
  model_path="/Users/jahfer/src/models/nomic-embed-text-v1.5.f32.gguf",
  embedding=True,
  n_gpu_layers=0,
  n_ctx=8192,
  n_batch=8192,
  rope_scaling_type=llama_cpp.LLAMA_ROPE_SCALING_TYPE_YARN,
  rope_freq_scale=0.75,
)

embed_chunks = np.array_split(corpus, len(corpus) // 5)

embeddings = []
for chunk in embed_chunks:
  embeddings += [e['embedding'] for e in bert.create_embedding(chunk.tolist())['data']]

llama_model_loader: loaded meta data with 22 key-value pairs and 112 tensors from /Users/jahfer/src/models/nomic-embed-text-v1.5.f32.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = nomic-bert
llama_model_loader: - kv   1:                               general.name str              = nomic-embed-text-v1.5
llama_model_loader: - kv   2:                     nomic-bert.block_count u32              = 12
llama_model_loader: - kv   3:                  nomic-bert.context_length u32              = 2048
llama_model_loader: - kv   4:                nomic-bert.embedding_length u32              = 768
llama_model_loader: - kv   5:             nomic-bert.feed_forward_length u32              = 3072
llama_model_loader: - kv   6:            nomic-bert.attention.head_count u32              = 12
llama_model_loader: - kv   7:    nomic-b

In [12]:
from sklearn.metrics.pairwise import cosine_similarity

pairs = list(zip(policies, embeddings))

top_k = []
for [policy, vec] in pairs:
  all_cos_for_policy = []
  for [embedded_policy, other_vec] in pairs:
    cos = cosine_similarity([vec], [other_vec])[0][0]
    if (cos > 0.999) or (cos < 0.9): continue
    all_cos_for_policy.append([embedded_policy, cos])
  sorted_cos = sorted(all_cos_for_policy, reverse=True, key=lambda cos: cos[1])
  top_k.append([policy["id"], [x[0]["id"] for x in sorted_cos][:5]])

In [15]:
pprint.pp(top_k)

[[1, [60]],
 [2, []],
 [3, [242, 63, 64]],
 [4, [5, 8, 118]],
 [5, [4, 118, 87]],
 [6, [226, 67, 68]],
 [7, [22, 69, 120]],
 [8, [228, 4, 6]],
 [9, [12, 11, 13]],
 [10, [13, 70, 9]],
 [11, [12, 9, 13]],
 [12, [9, 11, 13]],
 [13, [11, 10, 9]],
 [14, [31]],
 [15, [16, 212]],
 [16, [15, 17, 212]],
 [17, [16, 241]],
 [18, [93, 92]],
 [19, []],
 [20, [239, 166]],
 [21, [259]],
 [22, [7, 120, 69]],
 [23, [255, 227]],
 [24, []],
 [25, [76]],
 [26, [72, 128, 139]],
 [27, [78, 80, 221]],
 [28, [74, 130]],
 [29, [43]],
 [30, [31, 135, 36]],
 [31, [57, 30, 251]],
 [32, [236, 56, 103]],
 [33, [189, 190, 96]],
 [34, [191]],
 [35, [232, 91]],
 [36, [251, 141, 138]],
 [37, [83, 143, 104]],
 [38, [37, 104, 83]],
 [39, [45, 49, 209]],
 [40, [46, 145, 155]],
 [41, [55, 86, 39]],
 [42, [54, 39, 86]],
 [43, [29]],
 [44, []],
 [45, [39, 252, 49]],
 [46, [145, 40, 53]],
 [47, [146, 48, 210]],
 [48, [47, 147, 210]],
 [49, [45, 149, 39]],
 [50, [252, 45, 54]],
 [51, [209, 45, 149]],
 [52, [150, 50, 51]],
 [53