In [1]:
import os
import urllib

# Download the Arxiv papers from Kaggle
# https://www.kaggle.com/datasets/Cornell-University/arxiv/
if not os.path.exists("arxiv-metadata-oai-snapshot.json.zip"):
    print("Downloading Arxiv papers from Kaggle...")
    urllib.urlretrieve(
        "https://storage.googleapis.com/kaggle-data-sets/612177/7219250/compressed/arxiv-metadata-oai-snapshot.json.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gcp-kaggle-com%40kaggle-161607.iam.gserviceaccount.com%2F20231220%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20231220T030514Z&X-Goog-Expires=259200&X-Goog-SignedHeaders=host&X-Goog-Signature=aeac57aa1a74cb7973076f84aad1d609bd2b82ed4c8974cbbb29d311533fdf9c8d98791eec48c4eb63dbb18d63c30adfa3c6d4749305bd788fc28ea9119d85132706310a541f4b72315fb74903656c1fcc5a9b1934a293255dd329a3ee5e5550938f987604ebaeb7fac101ad509c76af532275b20b0be775182cca8dd548d477570609fb7a3fcce716bcc1f0c2982d041fab4d91e750feed38225a69caf52804259c3b48609b356b2a850e346b0f079097d5af0efd167a8c3e3e778091167dcb4e90aedabc7d575bf3834d81041bc0a51279b3315c377fd6d55b790f3a851403bcbd2ab9d47e121a80df941cfb816fae8eb63a177227c6fc7bd852d38d58fa4c",
        "arxiv-metadata-oai-snapshot.json.zip",
    )

In [6]:
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
import zipfile
from tqdm import tqdm

papers_file = "arxiv-metadata-oai-snapshot.parquet"
if not os.path.exists(papers_file):
    parquet_writer = None
    selected_fields = ['id', 'categories', 'title', 'abstract', 'update_date']
    dtypes = {key: 'str' for key in selected_fields}
    with zipfile.ZipFile("arxiv-metadata-oai-snapshot.json.zip", 'r') as z:
        with z.open("arxiv-metadata-oai-snapshot.json") as f:
            for chunk in tqdm(pd.read_json(f, lines=True, chunksize=100000, dtype=dtypes)):
                table = pa.Table.from_pandas(chunk[selected_fields])
                if parquet_writer is None:
                    parquet_writer = pq.ParquetWriter(papers_file, table.schema, compression='snappy')
                parquet_writer.write_table(table)
    if parquet_writer:
        parquet_writer.close()

24it [03:01,  7.57s/it]


In [24]:
# Read the data
data = pd.read_parquet(papers_file)

In [25]:
# Extract all cs.CL (computation & language) papers with LLM in the title
llm = data[data['title'].str.contains('LLM') & data['categories'].str.contains('cs.CL')].copy()

In [26]:
# Set up langchain embeddings
from langchain.embeddings.cache import CacheBackedEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.storage.file_system import LocalFileStore
from gramex.config import variables

file_store = LocalFileStore(os.path.join(variables['GRAMEXDATA'], 'langchain-embeddings'))
base = OpenAIEmbeddings()
cached_embeddings = CacheBackedEmbeddings.from_bytes_store(base, file_store, namespace=base.model)

In [27]:
# Define classify() and cluster() functions
import numpy as np
from typing import List
from sklearn.cluster import BisectingKMeans
from sklearn.metrics import silhouette_score, silhouette_samples


def classify(docs: List[str], topics: List[str], **kwargs):
    doc_embed = np.array(cached_embeddings.embed_documents(docs))
    topic_embed = np.array(cached_embeddings.embed_documents(topics))
    return np.dot(doc_embed, topic_embed.T)


def cluster(docs: List[str], n: int = 20, **kwargs):
    # Cluster the documents
    cluster_model = BisectingKMeans(init='k-means++', n_clusters=n, n_init=10, max_iter=1000)
    doc_embed = np.array(cached_embeddings.embed_documents(docs))
    cluster_model.fit(doc_embed)
    # Calculate the distance from each point to each centroid
    distances = np.linalg.norm(doc_embed[:, np.newaxis] - cluster_model.cluster_centers_, axis=2)
    return {
        "label": cluster_model.labels_,
        "score": silhouette_score(doc_embed, cluster_model.labels_),
        "scores": silhouette_samples(doc_embed, cluster_model.labels_),
        "centroid": np.argmin(distances, axis=0),
    }

In [77]:
# Create topics by clustering
result = cluster(llm['title'].tolist(), n=25)
llm['cluster'] = result['label']
llm['score'] = result['scores']
clusters = (
    llm.groupby('cluster')
    .apply(lambda group: group.nlargest(3, 'score')['title'].tolist())
    .tolist()
)



In [128]:
import json
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

chat_model = ChatOpenAI(model='gpt-4-1106-preview', temperature=0)
messages = [
    HumanMessage(
        content=f'''Here are clusters of papers related to LLMs.
Suggest 2-4 word topic names for each cluster.
Do NOT use "LLM" in the title.
Return a JSON string array of length {len(clusters)}.

{json.dumps(clusters, indent=2)}'''
    )
]

In [129]:
# Get the ChatGPT response
subtopic_response = chat_model.invoke(messages)

In [130]:
# Extract everything inside ```json ... ```
import re

match = re.search(r'```json(.*?)```', subtopic_response.content, re.DOTALL)
subtopics = json.loads(match.group(1) if match else subtopic_response.content)
subtopics

['Chatbot Emotional Support',
 'Agent Behavior Analysis',
 'Prompt Injection Security',
 'Programming with LLMs',
 'LLM Application Safety',
 'Hallucination Detection',
 'Enhancing Reasoning Skills',
 'Misinformation Detection',
 'Ethical LLM Frameworks',
 'LLM in Psychological Assessment',
 'Personalized Learning Models',
 'LLM Moral Judgements',
 'Text Clustering and Compression',
 'LLM Dialogue Orchestration',
 'LLM Sociability Benchmarks',
 'Cross-Lingual LLM Enhancement',
 'LLM Weight Quantization',
 'LLM Benchmark Generation',
 'Adaptive LLM Reasoning',
 'Multimodal LLM Integration',
 'LLM Prompt Efficiency',
 'LLM Knowledge Base QA',
 'Data Selection for LLMs',
 'LLM Robustness Improvement',
 'LLM Text Generation Evaluation']

In [131]:
# Create higher-level topic groups
size = int(len(subtopics) ** 0.5)
messages = [
    HumanMessage(
        content=f'''Cluster these topics into {size} groups.
Return a JSON object with keys as a 2-4 word group name and values as arrays of topics.

{json.dumps(subtopics, indent=2)}'''
    )
]

In [132]:
# Get the ChatGPT response
topic_response = chat_model.invoke(messages)

In [133]:
match = re.search(r'```json(.*?)```', topic_response.content, re.DOTALL)
topics = json.loads(match.group(1) if match else topic_response.content)
topics

{'LLM Development & Safety': ['Prompt Injection Security',
  'LLM Application Safety',
  'Ethical LLM Frameworks',
  'LLM Robustness Improvement',
  'LLM Weight Quantization',
  'Data Selection for LLMs'],
 'LLM Performance & Evaluation': ['Hallucination Detection',
  'LLM Benchmark Generation',
  'LLM Sociability Benchmarks',
  'LLM Text Generation Evaluation',
  'LLM Knowledge Base QA',
  'LLM Prompt Efficiency'],
 'LLM Use Cases & Applications': ['Chatbot Emotional Support',
  'LLM in Psychological Assessment',
  'LLM Moral Judgements',
  'LLM Dialogue Orchestration',
  'Programming with LLMs',
  'Multimodal LLM Integration'],
 'LLM Learning & Adaptation': ['Enhancing Reasoning Skills',
  'Adaptive LLM Reasoning',
  'Personalized Learning Models',
  'Cross-Lingual LLM Enhancement'],
 'Content Analysis & Management': ['Agent Behavior Analysis',
  'Misinformation Detection',
  'Text Clustering and Compression']}

In [141]:
result = {}
result['topics'] = [
    {'topic': topic, 'subtopic': subtopic}
    for topic, subtopics in topics.items()
    for subtopic in subtopics
]
result['docs'] = [
    {'chapter': 'arxiv', 'section': 'LLM', 'para': row['title'] + '⬛' + row['abstract']}
    for _, row in llm.iterrows()
]

In [143]:
# Loop through each row and column in maches and create a {doc, topic, similarity} list
min_similarity = 0.75
matches = result['matches'] = []
similarity = classify(
    [row['para'] for row in result['docs']],
    [row['subtopic'] for row in result['topics']]
)
for row in range(len(similarity)):
    for col in range(len(similarity[row])):
        if similarity[row][col] > min_similarity:
            matches.append({'doc': row, 'topic': col, 'similarity': similarity[row][col]})

In [144]:
# Save as data.json
with open("data.json", "w") as handle:
    handle.write(json.dumps(result, indent=2))