In [96]:
from transformers import BertTokenizer, BertModel
from milvus import default_server, debug_server
from pymilvus import connections, utility, DataType, FieldSchema, CollectionSchema, Collection
import torch
from tqdm import tqdm
import pandas as pd
import numpy as np
import nltk
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
import re
import ast
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/mirandadrummond/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [97]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [98]:
def preprocess_text(text):
    # Lowercase
    text = text.lower()
    # Remove punctuation
    text = re.sub(r'[^\w\s]', '', text)
    # Tokenize
    tokens = word_tokenize(text)
    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    filtered_tokens = [word for word in tokens if word not in stop_words]
    # Lemmatize
    lemmatizer = WordNetLemmatizer()
    lemmatized_tokens = [lemmatizer.lemmatize(word) for word in filtered_tokens]
    # Join tokens back into string
    preprocessed_text = ' '.join(lemmatized_tokens)
    return preprocessed_text

In [99]:
def generate_embeddings(text):
    """Generate embeddings for a given piece of text using BERT."""
    preprocessed_text = preprocess_text(text)
    encoded_input = tokenizer(preprocessed_text, return_tensors='pt', truncation=True, max_length=512)
    with torch.no_grad():
        output = model(**encoded_input)
    return output.last_hidden_state[:,0,:].numpy()  # Use the CLS token's embedding

In [102]:
# Prepare data
arxiv = pd.read_excel('../data/arxiv100.xlsx')
arxiv.head()

Unnamed: 0,title,abstract,label
0,The Pre-He White Dwarfs in Eclipsing Binaries....,We report the first $BV$ light curves and high...,astro-ph
1,A Possible Origin of kHZ QPOs in Low-Mass X-ra...,A possible origin of kHz QPOs in low-mass X-ra...,astro-ph
2,The effects of driving time scales on heating ...,Context. The relative importance of AC and D...,astro-ph
3,A new hard X-ray selected sample of extreme hi...,Extreme high-energy peaked BL Lac objects (EHB...,astro-ph
4,The baryon cycle of Seven Dwarfs with superbub...,"We present results from a high-resolution, cos...",astro-ph


In [25]:
tqdm.pandas()
arxiv['embeddings'] = arxiv['abstract'].progress_apply(generate_embeddings)

100%|██████████| 100000/100000 [1:26:55<00:00, 19.17it/s]


In [129]:
# convert from shape (1, 768) to  (768,)
arxiv['embeddings'] = arxiv['embeddings'].apply(lambda x: x.flatten())

TypeError: list indices must be integers or slices, not str

[2024/04/21 21:41:34.870 +02:00] [INFO] [gc/gc_tuner.go:90] ["GC Tune done"] ["previous GOGC"=200] ["heapuse "=76] ["total memory"=13359] ["next GC"=76] ["new GOGC"=200] [gc-pause=59.416µs] [gc-pause-end=1713728494870447000]
[2024/04/21 21:41:40.093 +02:00] [INFO] [datacoord/index_service.go:513] ["receive DescribeIndex request"] [traceID=b26107a4bb8019be1fd160939478ad3e] [collectionID=449243396322099465] [indexName=] [timestamp=0]
[2024/04/21 21:41:40.093 +02:00] [INFO] [datacoord/index_service.go:449] ["completeIndexInfo success"] [collectionID=449243396322099465] [indexID=449243396322099562] [totalRows=12] [indexRows=12] [pendingIndexRows=0] [state=Finished] [failReason=]
[2024/04/21 21:41:40.093 +02:00] [INFO] [datacoord/index_service.go:560] ["DescribeIndex success"] [traceID=b26107a4bb8019be1fd160939478ad3e] [collectionID=449243396322099465] [indexName=]
[2024/04/21 21:41:40.093 +02:00] [INFO] [querynodev2/services.go:1370] ["sync action"] [traceID=c7869fa1e45ef160fde0a46b1f2acdf

In [None]:
# save the df as a csv file
arxiv.to_csv('arxiv100_embedded.csv', index=False)