In [10]:
import pandas as pd
import json
import plotly.express as px
from umap import UMAP
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
%matplotlib inline

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2").to(device)

def get_embedding(inp_text):
    # Tokenize sentences
    encoded_input = tokenizer([inp_text], padding=True, truncation=True, return_tensors="pt").to(device)

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    return sentence_embeddings[0].detach().cpu().numpy()



embeddings = []

js = json.load(open("neurips2023data.json"))

df = pd.DataFrame(js.values())

df

Unnamed: 0,title,presentation_time,authors,abstract,oral,spotlight,oral_presentation_time
0,Language Models Meet World Models: Embodied Ex...,"We, Dec 13, 15:00 -- Poster Session 4","[Jiannan Xiang, Tianhua Tao, Yi Gu, Tianmin Sh...",While large language models (LMs) have shown r...,False,False,
1,SAMRS: Scaling-up Remote Sensing Segmentation ...,"Th, Dec 14, 15:00 -- Poster Session 6","[Di Wang, Jing Zhang, Bo Du, Minqiang Xu, Lin ...",The success of the Segment Anything Model (SAM...,False,False,
2,On Learning Necessary and Sufficient Causal Gr...,"Th, Dec 14, 08:45 -- Poster Session 5","[Hengrui Cai, Yixin Wang, Michael Jordan, Rui ...",The causal revolution has stimulated interest ...,False,True,
3,Dual Self-Awareness Value Decomposition Framew...,"Tu, Dec 12, 15:15 -- Poster Session 2","[Zhiwei Xu, Bin Zhang, dapeng li, Guangchong Z...",Value decomposition methods have gained popula...,False,False,
4,Two-Stage Predict+Optimize for MILPs with Unkn...,"Tu, Dec 12, 15:15 -- Poster Session 2","[Xinyi Hu, Jasper Lee, Jimmy Lee]",Consider the setting of constrained optimizati...,False,False,
...,...,...,...,...,...,...,...
3579,Efficient Bayesian Learning Curve Extrapolatio...,"Tu, Dec 12, 08:45 -- Poster Session 1","[Steven Adriaensen, Herilalaina Rakotoarison, ...",Learning curve extrapolation aims to predict m...,False,False,
3580,DICES Dataset: Diversity in Conversational AI ...,"Th, Dec 14, 15:00 -- Poster Session 6","[Lora Aroyo, Alex Taylor, Mark Díaz, Christoph...",Machine learning approaches often require trai...,False,False,
3581,Phase diagram of early training dynamics in de...,"Tu, Dec 12, 15:15 -- Poster Session 2","[Dayal Singh Kalra, Maissam Barkeshli]",We systematically analyze optimization dynamic...,False,False,
3582,Mitigating the Popularity Bias of Graph Colla...,"We, Dec 13, 08:45 -- Poster Session 3","[Yifei Zhang, Hao Zhu, yankai Chen, Zixing Son...",Graph-based Collaborative Filtering (GCF) is w...,False,True,


In [9]:
df.columns

Index(['title', 'presentation_time', 'authors', 'abstract', 'oral',
       'spotlight', 'oral_presentation_time'],
      dtype='object')

In [None]:
for title,row in tqdm(js.items()):
    input_text = row['abstract']
    embedding = get_embedding(input_text)
    embeddings.append(embedding)
    
manifold = UMAP(n_components=2, init="random", random_state=0)
projections = manifold.fit_transform(embeddings)

In [20]:
fig = px.scatter(
    df, x=projections[:, 0], y=projections[:, 1], hover_name="title", hover_data=["authors"], 
    width=800, height=800, #title="Neurips 2023 papers"
)
# fig.show()

In [21]:
fig.write_html("file.html")