In [1]:
from dask_jobqueue import SLURMCluster
from gensim.models import Word2Vec
import pandas as pd
import dask.dataframe as dd
from transformers import BertTokenizer, BertModel
import torch
from torch.utils.dlpack import to_dlpack, from_dlpack
from numba import jit
import numpy as np
from dask import delayed, compute
import gc

# Compose SLURM script
cluster = SLURMCluster(queue='caslake', cores=20, memory='100GB', 
                       processes=20, walltime='03:00:00', interface='ib0',
                       job_extra=['--account=macs30123']
                      )

# Request resources
cluster.scale(jobs=1)

In [3]:
from dask.distributed import Client

client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://172.25.0.66:36767/status,

0,1
Dashboard: http://172.25.0.66:36767/status,Workers: 10
Total threads: 10,Total memory: 37.30 GiB

0,1
Comm: tcp://172.25.0.66:35901,Workers: 10
Dashboard: http://172.25.0.66:36767/status,Total threads: 10
Started: Just now,Total memory: 37.30 GiB

0,1
Comm: tcp://172.25.2.18:43659,Total threads: 1
Dashboard: http://172.25.2.18:34687/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:40381,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-a3tfd1fo,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-a3tfd1fo

0,1
Comm: tcp://172.25.2.18:40319,Total threads: 1
Dashboard: http://172.25.2.18:43891/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:42207,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-3a50dlx1,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-3a50dlx1

0,1
Comm: tcp://172.25.2.18:37357,Total threads: 1
Dashboard: http://172.25.2.18:42465/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:40193,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-6jt2thno,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-6jt2thno

0,1
Comm: tcp://172.25.2.18:39121,Total threads: 1
Dashboard: http://172.25.2.18:35559/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:43101,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-nkeg2zm4,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-nkeg2zm4

0,1
Comm: tcp://172.25.2.18:41611,Total threads: 1
Dashboard: http://172.25.2.18:37709/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:46859,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-51i5o0to,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-51i5o0to

0,1
Comm: tcp://172.25.2.18:38833,Total threads: 1
Dashboard: http://172.25.2.18:45003/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:33779,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-cgr28vx7,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-cgr28vx7

0,1
Comm: tcp://172.25.2.18:33081,Total threads: 1
Dashboard: http://172.25.2.18:40595/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:45817,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-tje07nhf,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-tje07nhf

0,1
Comm: tcp://172.25.2.18:38435,Total threads: 1
Dashboard: http://172.25.2.18:42751/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:34997,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-hyp63_sr,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-hyp63_sr

0,1
Comm: tcp://172.25.2.18:40427,Total threads: 1
Dashboard: http://172.25.2.18:42559/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:40245,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-72or2mra,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-72or2mra

0,1
Comm: tcp://172.25.2.18:39561,Total threads: 1
Dashboard: http://172.25.2.18:40605/status,Memory: 3.73 GiB
Nanny: tcp://172.25.2.18:39255,
Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-alklsbdj,Local directory: /scratch/local/jobs/20534540/dask-worker-space/worker-alklsbdj


In [4]:
# Patent data scraped from Google Patents
patent_df = pd.read_csv("all_patent_info.csv")
patent_df = patent_df.iloc[250000:]

In [5]:
# Load the BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

In [None]:
# Drop rows with null values in the 'abstract' column
patent_df = patent_df.dropna(subset=['abstract'])

# Ensure all entries in the 'abstract' column are strings
patent_df['abstract'] = patent_df['abstract'].astype(str)


ddf = dd.from_pandas(patent_df, npartitions=10)
print("finish to ddf")
# Function to compute BERT embeddings
def compute_bert_embedding(texts):
    inputs = tokenizer(texts, return_tensors='pt', truncation=True, padding=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    return embeddings

# Define the metadata for the new column as a numpy array
meta = pd.Series(dtype=object)

# Apply the function directly using Dask with metadata
ddf['embeddings'] = ddf['abstract'].apply(lambda x: compute_bert_embedding(x), meta=('embeddings', 'object'))

# Compute the result
result = ddf.compute()

# Ensure embeddings are stored as numpy arrays
result['embeddings'] = result['embeddings'].apply(np.array)

In [None]:
print(result['embeddings'])

result.to_csv('embeddings_result1.csv', index=False)

In [None]:
client.close()