Skip to content

Commit

Permalink
Merge pull request #68 from nicolomonti/main
Browse files Browse the repository at this point in the history
Add clustering
  • Loading branch information
huu4ontocord committed Nov 9, 2023
2 parents dc79de6 + c07f044 commit 668f1bd
Show file tree
Hide file tree
Showing 5 changed files with 601 additions and 0 deletions.
71 changes: 71 additions & 0 deletions clustering/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import requests
import logging

import trafilatura

from transformers import pipeline
from transformers import AutoTokenizer

import numpy as np

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

max_embedding_characters = 128 # This is a deliberately low value, as the current model is not intended for document embedding

feature_extractor_checkpoint = 'sentence-transformers/LaBSE'
tokenizer_checkpoint = 'gpt2'

feature_extractor = pipeline('feature-extraction', framework='pt', model=feature_extractor_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)

def fetch_and_parse(url):
try:
response = requests.get(url, timeout=10)

response.raise_for_status()
except (requests.HTTPError, requests.ConnectionError, requests.Timeout) as error:
logging.error(f'Failed to fetch {url}: {error}')

return None, None

content = response.text

markdown = trafilatura.extract(content, output_format='txt', include_formatting=True, \
include_tables=True, include_images=True, no_fallback=True, include_links=True)

return content, markdown

def embed(text):
embedding = feature_extractor(text)

return embedding

def tokenize(text):
tokens = tokenizer.encode(text)

return tokens

def process_url(url):
content, markdown = fetch_and_parse(url)

content_short = content[:max_embedding_characters]

tokens = tokenize(content)
embedding = embed(content_short)

embedding = np.array(embedding)

return content, markdown, tokens, embedding

def main():
url = 'https://huggingface.co'

content, markdown, tokens, embedding = process_url(url)

for current in [content, markdown, embedding.shape]:
print(f'{"-" * 32}\n{current}')

print('-' * 32)

if __name__ == '__main__':
main()
49 changes: 49 additions & 0 deletions clustering/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM


class FeatureExtractor:
def __init__(self, device='cpu', model_id='bigscience/bloom-560m', num_decoder_blocks=8):
self.device = device

self.num_decoder_blocks = num_decoder_blocks
self.model_id = model_id

self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

self.model = AutoModelForCausalLM.from_pretrained(self.model_id)

h = self.model.transformer.h[:num_decoder_blocks] # Note that this will change for different families of models
self.model.transformer.h = h

self.model = self.model.to(device)


def encode(self, text):
tokens = self.tokenizer(text, padding=True, return_tensors='pt').to(self.device)

output = self.model(**tokens, output_hidden_states=True).hidden_states[-1]
output = output.detach().cpu().numpy()

return output


def __call__(self, text):
output = self.encode(text)

return output


def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using {device} device')

feature_extractor = FeatureExtractor(device=device)

output = feature_extractor('Hello world!')
print(output)


if __name__ == '__main__':
main()
248 changes: 248 additions & 0 deletions clustering/hierarchical_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import math
import random

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import arange, argmax
from tqdm import tqdm
from collections import Counter

import uuid

import numpy as np
from fast_pytorch_kmeans import KMeans

from feature_extractor import FeatureExtractor
from memmap_utils import np_memmap, get_np_memmap_length


class ClusterAnalysis(nn.Module):
def __init__(
self,
mmap_file=None,
embed_dim=128,
dtype=np.float32,
):
super().__init__()

self.mmap_file = mmap_file
self.embed_dim = embed_dim

self.dtype = dtype

self.clusters = {}
self.span_to_cluster_label = {}


@staticmethod
def _cluster_one_batch(
true_k,
spans,
clusters,
span_to_cluster_label,
level,
cluster_embeddings,
min_overlap_merge_cluster,
device
):
with torch.no_grad():
embeddings = torch.from_numpy(cluster_embeddings)

km = KMeans(n_clusters=true_k, mode='cosine')
km_labels = km.fit_predict(embeddings.to(device=device, dtype=torch.float32)).tolist()

embeddings = None

if not clusters:
label_to_label = {}

for span, label in zip(spans, km_labels):
label = (label, level)

if label not in label_to_label:
label_to_label[label] = (span[0], level)

label = label_to_label[label]

clusters[label] = clusters.get(label, []) +[ span]
span_to_cluster_label[span] = label

output = list(clusters.keys())

return output

tmp_cluster = {}

for span, label in zip(spans, km_labels):
tmp_cluster[label] = tmp_cluster.get(label, [])+[span]

new_labels = []

for a_cluster in tmp_cluster.values():
for span in a_cluster:
need_labels = [span for span in a_cluster if span not in span_to_cluster_label or span_to_cluster_label[span][1] != level]
cluster_labels = [span_to_cluster_label[span] for span in a_cluster if span in span_to_cluster_label and span_to_cluster_label[span][1] == level]

if not need_labels:
continue

if not cluster_labels:

label = (span[0], level)

else:
most_common = Counter(cluster_labels).most_common(1)[0]

if most_common[1] < min_overlap_merge_cluster:
label = (span[0], level)

else:
label = most_common[0]

new_labels.append(label)

for span in need_labels:
clusters[label] = clusters.get(label, []) + [span]
span_to_cluster_label[span] = label

return new_labels


def create_hiearchical_clusters(
self,
force_recluster_idxs=None,
max_level=4,
max_cluster_size=32, # Small value for debug purposes
min_overlap_merge_cluster=2,
prefered_leaf_node_size=None,
kmeans_batch_size=250000,
use_tqdm=False,
device='cuda:0'
):
mmap_file = self.mmap_file
embed_dim = self.embed_dim
dtype = self.dtype

mmap_len = get_np_memmap_length(mmap_file, [0, embed_dim], dtype=dtype)

clusters = self.clusters
span_to_cluster_label = self.span_to_cluster_label

if force_recluster_idxs:
force_recluster_idxs = set(force_recluster_idxs)
else:
force_recluster_idxs = ()

already_clustered = set([span[0] for span in span_to_cluster_label if span[1] == 0 and span[0] not in force_recluster_idxs])

idxs = []

if force_recluster_idxs:
idxs = list(force_recluster_idxs)
force_recluster_idxs = None

idxs.extend([idx for idx in range(mmap_len) if idx not in already_clustered])

if not idxs:
return

already_clustered = list(already_clustered)

if len(already_clustered) > int(0.5 * kmeans_batch_size):
idxs.extend(random.sample(already_clustered, int(0.5 * kmeans_batch_size)))
else:
idxs.extend(already_clustered)

already_clustered = None

idxs.extend([span[0] for span in span_to_cluster_label if span[1] != 0])
idxs = list(set(idxs))
random.shuffle(idxs)

if not prefered_leaf_node_size:
prefered_leaf_node_size= int(max_cluster_size * 0.7)

for level in range(max_level):
all_spans = [(idx, level) for idx in idxs]
len_spans = len(all_spans)

step_size = int(0.7 * kmeans_batch_size)
num_times = max(3, math.ceil(len_spans / step_size))

if use_tqdm:
num_times_2 = tqdm.tqdm(range(num_times))

else:
num_times_2 = range(num_times)

for times in num_times_2:
max_rng = min(len_spans, step_size)

spans = all_spans[:max_rng]

not_already_clustered = [span for span in all_spans[:max_rng - step_size] if span not in span_to_cluster_label]

if len(not_already_clustered) > int(0.5 * kmeans_batch_size):
spans.extend(random.sample(not_already_clustered, int(0.5 * kmeans_batch_size)))
else:
spans.extend(not_already_clustered)

if len(spans) == 0: break

already_clustered = [span for span in all_spans[:max_rng - step_size] if span in span_to_cluster_label]

if len(already_clustered) > int(0.5 * kmeans_batch_size):
spans.extend(random.sample(already_clustered, int(0.5 * kmeans_batch_size)))

else:
spans.extend(already_clustered)

embedding_idxs = [span[0] for span in spans]

if level == 0:
true_k = int(len(embedding_idxs) / prefered_leaf_node_size)

else:
true_k = int(len(embedding_idxs ) / max_cluster_size)

cluster_embeddings = np_memmap(mmap_file, shape=[mmap_len, embed_dim], idxs=embedding_idxs, dtype=dtype)

new_labels = self._cluster_one_batch(true_k, spans, clusters, span_to_cluster_label, level, cluster_embeddings, min_overlap_merge_cluster, device)

if not new_labels:
break

need_more = False

assert prefered_leaf_node_size <= max_cluster_size, 'prefered_leaf_node_size Must not exceed max_cluster_size'

if times <= num_times - 2:
for label in new_labels:
if len(clusters[label]) < prefered_leaf_node_size:
del clusters[label]

need_more = True

if not need_more:
break

idxs = [val[0][0] for key, val in clusters.items() if key[1] == level]

if len(idxs) < max_cluster_size:
break


def main():
cluster_analysis = ClusterAnalysis(
mmap_file='output/embeddings.mmap',
embed_dim=1024
)

cluster_analysis.create_hiearchical_clusters()

print(list(cluster_analysis.clusters.keys()))


if __name__ == '__main__':
main()
Loading

0 comments on commit 668f1bd

Please sign in to comment.