Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Message Tree Topic Modeling Pipeline (#1650)
This is a PR to add topic modeling and k-hop message passing including a much faster sparse implementation and sentence transformer embedding aggregation. Using message passing and using the k_hop adj matrix to aggregate the embedding features into cluster features like a GCN seems to result in much better topic clusters. I also added loading tools for the exported message trees, a new util requirements.txt, and refactored the cosine_similarity in similarity_functions.py to instead compute the cosine similarity kernel. cos_sim and embed_data functions were ported over from one of @kenhktsui filter/cluster notebook(Scalable Agglomerative Clustering.ipynb) BERTopic: https://maartengr.github.io/BERTopic/ still a WIP but I tested it locally and it works and wanted to get feedback. There are a couple more cleanup tasks, like typing, doc_strings, moving globals like ADJACENCY_THRESHOLD and MODEL_NAME to config, allowing for more customizability of the topic_model, etc. Please let me know if you notice any errors or have any suggestions. :)
- Loading branch information
1 parent
d048248
commit 87e02e2
Showing
4 changed files
with
376 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import json | ||
from collections import defaultdict | ||
from typing import List | ||
|
||
import pandas as pd | ||
|
||
|
||
def load_jsonl(filepaths): | ||
data = [] | ||
for filepath in filepaths: | ||
with open(filepath, "r") as f: | ||
for line in f: | ||
data.append(json.loads(line)) | ||
return data | ||
|
||
|
||
def separate_qa_helper(node, depth, msg_dict): | ||
if "text" in node: | ||
if node["role"] == "prompter": | ||
msg_dict["user_messages"].append(str(node["text"])) | ||
elif node["role"] == "assistant": | ||
msg_dict["assistant_messages"].append(str(node["text"])) | ||
depth += 1 | ||
if "replies" in node: | ||
for reply in node["replies"]: | ||
separate_qa_helper(reply, depth, msg_dict) | ||
|
||
|
||
def store_qa_data_separate(trees, data): | ||
message_list = [] | ||
for i, msg_tree in enumerate(trees): | ||
if "prompt" in msg_tree.keys(): | ||
separate_qa_helper(msg_tree["prompt"], i, data) | ||
elif "prompt" not in msg_tree.keys(): | ||
message_list.append(msg_tree) | ||
return data, message_list | ||
|
||
|
||
def group_qa_helper(node, depth, msg_pairs): | ||
if "text" in node: | ||
if node["role"] == "prompter": | ||
if "replies" in node: | ||
for reply in node["replies"]: | ||
qa_pair = {"instruct": str(node["text"]), "answer": str(reply["text"])} | ||
msg_pairs.append(qa_pair) | ||
depth += 1 | ||
if "replies" in node: | ||
for reply in node["replies"]: | ||
group_qa_helper(reply, depth, msg_pairs) | ||
|
||
|
||
def store_qa_data_paired(trees, data: List): | ||
message_list = [] | ||
for i, msg_tree in enumerate(trees): | ||
if "prompt" in msg_tree.keys(): | ||
group_qa_helper(msg_tree["prompt"], i, data) | ||
elif "prompt" not in msg_tree.keys(): | ||
message_list.append(msg_tree) | ||
return data, message_list | ||
|
||
|
||
def load_data(filepaths: List[str], paired=False): | ||
trees = load_jsonl(filepaths) | ||
if paired: | ||
data = [] | ||
data, message_list = store_qa_data_paired(trees, data) | ||
sents = [f"{qa['instruct']} {qa['answer']}" for qa in data] | ||
elif not paired: | ||
data = defaultdict(list) | ||
data, message_list = store_qa_data_separate(trees, data) | ||
sents = data["user_messages"] + data["assistant_messages"] | ||
|
||
data = [(i, sent) for i, sent in enumerate(sents)] | ||
data = pd.DataFrame(data, columns=["id", "query"]) | ||
return data, message_list |
106 changes: 106 additions & 0 deletions
106
backend/oasst_backend/utils/message_tree_topic_modeling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import argparse | ||
|
||
from bertopic import BERTopic | ||
from bertopic.representation import MaximalMarginalRelevance | ||
from bertopic.vectorizers import ClassTfidfTransformer | ||
from exported_tree_loading import load_data | ||
from sentence_transformers import SentenceTransformer | ||
from similarity_functions import compute_cos_sim_kernel, embed_data, k_hop_message_passing_sparse | ||
from sklearn.feature_extraction.text import CountVectorizer | ||
|
||
|
||
def argument_parsing(): | ||
parser = argparse.ArgumentParser(description="Process some arguments.") | ||
parser.add_argument("--model_name", type=str, default="all-MiniLM-L6-v2") | ||
parser.add_argument("--cores", type=int, default=1) | ||
parser.add_argument("--pair_qa", type=bool, default=True) | ||
parser.add_argument("--use_gpu", type=bool, default=False) | ||
parser.add_argument("--batch_size", type=int, default=128) | ||
parser.add_argument("--k", type=int, default=2) | ||
parser.add_argument("--threshold", type=float, default=0.65) | ||
parser.add_argument("--exported_tree_path", nargs="+", help="<Required> Set flag", required=True) | ||
parser.add_argument("--min_topic_size", type=int, default=10) | ||
parser.add_argument("--diversity", type=float, default=0.2) | ||
parser.add_argument("--reduce_frequent_words", type=bool, default=False) | ||
parser.add_argument("--reduce_outliers_strategy", type=str, default="c-tf-idf") | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def load_topic_model(args): | ||
vectorizer_model = CountVectorizer(stop_words="english") | ||
ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=False) | ||
model = SentenceTransformer(MODEL_NAME) | ||
representation_model = MaximalMarginalRelevance(diversity=args.diversity) | ||
topic_model = BERTopic( | ||
nr_topics="auto", | ||
min_topic_size=args.min_topic_size, | ||
representation_model=representation_model, | ||
vectorizer_model=vectorizer_model, | ||
ctfidf_model=ctfidf_model, | ||
embedding_model=model, | ||
) | ||
return topic_model | ||
|
||
|
||
def fit_topic_model(topic_model, data, embeddings, key="query"): | ||
topics, probs = topic_model.fit_transform(data[key].to_list(), embeddings) | ||
return topics, probs | ||
|
||
|
||
def get_topic_info(topic_model): | ||
return topic_model.get_topic_info() | ||
|
||
|
||
def reduce_topics(topic_model, data, nr_topics, key="query"): | ||
topic_model.reduce_topics(data[key].to_list(), nr_topics) | ||
return topic_model | ||
|
||
|
||
def get_representative_docs(topic_model): | ||
return topic_model.get_representative_docs() | ||
|
||
|
||
def reduce_outliers(topic_model, data, topics, probs, key="query", strategy="c-tf-idf"): | ||
if strategy == "c-tf-idf": | ||
new_topics = topic_model.reduce_outliers(data[key].to_list(), topics, strategy, threshold=0.1) | ||
elif strategy == "embeddings": | ||
new_topics = topic_model.reduce_outliers(data[key].to_list(), topics, strategy) | ||
elif strategy == "distributions": | ||
new_topics = topic_model.reduce_outliers(data[key].to_list(), topics, probabilities=probs, strategy=strategy) | ||
else: | ||
raise ValueError("Invalid strategy") | ||
return new_topics | ||
|
||
|
||
def compute_hierarchical_topic_tree(topic_model, data, key="query"): | ||
hierarchical_topics = topic_model.hierarchical_topics(data[key].to_list()) | ||
tree = topic_model.get_topic_tree(hierarchical_topics) | ||
return hierarchical_topics, tree | ||
|
||
|
||
if __name__ == "__main__": | ||
""" | ||
Main function to run topic modeling on a list of exported message trees. | ||
Example usage: | ||
python message_tree_topic_modeling.py --exported_tree_path 2023-02-06_oasst_prod.jsonl 2023-02-07_oasst_prod.jsonl | ||
""" | ||
args = argument_parsing() | ||
MODEL_NAME = args.model_name | ||
data, message_list = load_data(args.exported_tree_path, args.pair_qa) | ||
embs = embed_data(data, model_name=MODEL_NAME, cores=args.cores, gpu=args.use_gpu) | ||
adj_matrix = compute_cos_sim_kernel(embs, args.threshold) | ||
print(adj_matrix.shape) | ||
print(embs.shape) | ||
A_k, agg_features = k_hop_message_passing_sparse(adj_matrix, embs, args.k) | ||
print(A_k.shape) | ||
topic_model = load_topic_model(args) | ||
topics, probs = fit_topic_model(topic_model, data, agg_features) | ||
freq = get_topic_info(topic_model) | ||
rep_docs = get_representative_docs(topic_model) | ||
print(freq) | ||
for k, v in rep_docs.items(): | ||
print(k) | ||
print(v) | ||
print("\n\n\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,197 @@ | ||
from typing import List | ||
import math | ||
|
||
import numpy as np | ||
import scipy.sparse as sp | ||
import torch | ||
import torch.nn.functional as F | ||
from pandas import DataFrame | ||
from sentence_transformers import SentenceTransformer | ||
from torch import Tensor | ||
from tqdm import tqdm | ||
|
||
ADJACENCY_THRESHOLD = 0.65 | ||
|
||
def cosine_similarity(a: List[float], b: List[float]): | ||
"""Compute cosine similarity (dot product of two vectors divided by the product of their norms.)""" | ||
norm_a = np.linalg.norm(a) | ||
norm_b = np.linalg.norm(b) | ||
if norm_a == 0 or norm_b == 0: | ||
raise ZeroDivisionError("One of the vectors has a norm of zero.") | ||
return np.dot(a, b) / (norm_a * norm_b) | ||
|
||
def embed_data( | ||
data: DataFrame, | ||
key: str = "query", | ||
model_name: str = "all-MiniLM-L6-v2", | ||
cores: int = 1, | ||
gpu: bool = False, | ||
batch_size: int = 128, | ||
): | ||
""" | ||
Embed the sentences/text using the MiniLM language model (which uses mean pooling) | ||
""" | ||
print("Embedding data") | ||
model = SentenceTransformer(model_name) | ||
print("Model loaded") | ||
|
||
def euclidean_distance(a: List[float], b: List[float]): | ||
"""Compute euclidean distance (norm of the difference of two vectors.)""" | ||
return np.linalg.norm(np.subtract(a, b)) | ||
sentences = data[key].tolist() | ||
unique_sentences = data[key].unique() | ||
print("Unique sentences", len(unique_sentences)) | ||
|
||
if cores == 1: | ||
embeddings = model.encode(unique_sentences, show_progress_bar=True, batch_size=batch_size) | ||
else: | ||
devices = ["cpu"] * cores | ||
if gpu: | ||
devices = None # use all CUDA devices | ||
|
||
# Start the multi-process pool on multiple devices | ||
print("Multi-process pool starting") | ||
pool = model.start_multi_process_pool(devices) | ||
print("Multi-process pool started") | ||
|
||
chunk_size = math.ceil(len(unique_sentences) / cores) | ||
|
||
# Compute the embeddings using the multi-process pool | ||
embeddings = model.encode_multi_process(unique_sentences, pool, batch_size=batch_size, chunk_size=chunk_size) | ||
model.stop_multi_process_pool(pool) | ||
|
||
print("Embeddings computed") | ||
|
||
mapping = {sentence: embedding for sentence, embedding in zip(unique_sentences, embeddings)} | ||
embeddings = np.array([mapping[sentence] for sentence in sentences]) | ||
|
||
return embeddings | ||
|
||
|
||
def cos_sim(a: Tensor, b: Tensor): | ||
""" | ||
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. | ||
:return: Matrix with res[i][j] = cos_sim(a[i], b[j]) | ||
""" | ||
if not isinstance(a, torch.Tensor): | ||
a = torch.tensor(np.array(a)) | ||
|
||
if not isinstance(b, torch.Tensor): | ||
b = torch.tensor(np.array(b)) | ||
|
||
if len(a.shape) == 1: | ||
a = a.unsqueeze(0) | ||
|
||
if len(b.shape) == 1: | ||
b = b.unsqueeze(0) | ||
|
||
a_norm = torch.nn.functional.normalize(a, p=2, dim=1) | ||
b_norm = torch.nn.functional.normalize(b, p=2, dim=1) | ||
return torch.mm(a_norm, b_norm.transpose(0, 1)) | ||
|
||
|
||
def cos_sim_torch(embs_a: Tensor, embs_b: Tensor) -> Tensor: | ||
""" | ||
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. | ||
Using torch.nn.functional.cosine_similarity | ||
:return: Matrix with res[i][j] = cos_sim(a[i], b[j]) | ||
""" | ||
if not isinstance(embs_a, torch.Tensor): | ||
embs_a = torch.tensor(np.array(embs_a)) | ||
|
||
if not isinstance(embs_b, torch.Tensor): | ||
embs_b = torch.tensor(np.array(embs_b)) | ||
|
||
if len(embs_a.shape) == 1: | ||
embs_a = embs_a.unsqueeze(0) | ||
|
||
if len(embs_b.shape) == 1: | ||
embs_b = embs_b.unsqueeze(0) | ||
A = F.cosine_similarity(embs_a.unsqueeze(1), embs_b.unsqueeze(0), dim=2) | ||
return A | ||
|
||
|
||
def gaussian_kernel_torch(embs_a, embs_b, sigma=1.0): | ||
""" | ||
Computes the Gaussian kernel matrix between two sets of embeddings using PyTorch. | ||
:param embs_a: Tensor of shape (batch_size_a, embedding_dim) containing the first set of embeddings. | ||
:param embs_b: Tensor of shape (batch_size_b, embedding_dim) containing the second set of embeddings. | ||
:param sigma: Width of the Gaussian kernel. | ||
:return: Tensor of shape (batch_size_a, batch_size_b) containing the Gaussian kernel matrix. | ||
""" | ||
if not isinstance(embs_a, torch.Tensor): | ||
embs_a = torch.tensor(embs_a) | ||
|
||
if not isinstance(embs_b, torch.Tensor): | ||
embs_b = torch.tensor(embs_b) | ||
|
||
# Compute the pairwise distances between the embeddings | ||
dist_matrix = torch.cdist(embs_a, embs_b) | ||
|
||
# Compute the Gaussian kernel matrix | ||
kernel_matrix = torch.exp(-(dist_matrix**2) / (2 * sigma**2)) | ||
|
||
return kernel_matrix | ||
|
||
|
||
def compute_cos_sim_kernel(embs, threshold=0.65, kernel_type="cosine"): | ||
# match case to kernel type | ||
if kernel_type == "gaussian": | ||
A = gaussian_kernel_torch(embs, embs) | ||
if kernel_type == "cosine": | ||
A = cos_sim_torch(embs, embs) | ||
adj_matrix = torch.zeros_like(A) | ||
adj_matrix[A > threshold] = 1 | ||
adj_matrix[A <= threshold] = 0 | ||
adj_matrix = adj_matrix.numpy().astype(np.float32) | ||
return adj_matrix | ||
|
||
|
||
def k_hop_message_passing(A, node_features, k): | ||
""" | ||
Compute the k-hop adjacency matrix and aggregated features using message passing. | ||
Parameters: | ||
A (numpy array): The adjacency matrix of the graph. | ||
node_features (numpy array): The feature matrix of the nodes. | ||
k (int): The number of hops for message passing. | ||
Returns: | ||
A_k (numpy array): The k-hop adjacency matrix. | ||
agg_features (numpy array): The aggregated feature matrix for each node in the k-hop neighborhood. | ||
""" | ||
|
||
print("Compute the k-hop adjacency matrix") | ||
A_k = np.linalg.matrix_power(A, k) | ||
|
||
print("Aggregate the messages from the k-hop neighborhood:") | ||
agg_features = node_features.copy() | ||
|
||
for i in tqdm(range(k)): | ||
agg_features += np.matmul(np.linalg.matrix_power(A, i + 1), node_features) | ||
|
||
return A_k, agg_features | ||
|
||
|
||
def k_hop_message_passing_sparse(A, node_features, k): | ||
""" | ||
Compute the k-hop adjacency matrix and aggregated features using message passing. | ||
Parameters: | ||
A (numpy array or scipy sparse matrix): The adjacency matrix of the graph. | ||
node_features (numpy array or scipy sparse matrix): The feature matrix of the nodes. | ||
k (int): The number of hops for message passing. | ||
Returns: | ||
A_k (numpy array): The k-hop adjacency matrix. | ||
agg_features (numpy array): The aggregated feature matrix for each node in the k-hop neighborhood. | ||
""" | ||
|
||
# Convert input matrices to sparse matrices if they are not already | ||
if not sp.issparse(A): | ||
A = sp.csr_matrix(A) | ||
if not sp.issparse(node_features): | ||
node_features = sp.csr_matrix(node_features) | ||
|
||
# Compute the k-hop adjacency matrix and the aggregated features | ||
A_k = A.copy() | ||
agg_features = node_features.copy() | ||
|
||
for i in tqdm(range(k)): | ||
# Compute the message passing for the k-hop neighborhood | ||
message = A_k.dot(node_features) | ||
# Apply a GCN layer to aggregate the messages | ||
agg_features = A_k.dot(agg_features) + message | ||
# Update the k-hop adjacency matrix by adding new edges | ||
A_k += A_k.dot(A) | ||
|
||
return A_k.toarray(), agg_features.toarray() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
pandas | ||
sentence-transformers | ||
bertopic | ||
scipy |