In [None]:
import sys
sys.path.append("..")
sys.path.append("../src")

from dataset import load_dataset

In [None]:
import os
import sys

#Import config file. Update config.py according to your environment
from config import path_to_data

import pandas as pd
import numpy as np
import h5py

import requests, json

from transformers import AutoTokenizer, AutoModel
from adapters import AutoAdapterModel
import torch
import torch.nn.functional as F

from tqdm import tqdm
import time

In [2]:
dataset_source = 'semanticsscholar' #'arxiv'

In [None]:
def load_dataset(dataset_source='semanticsscholar', years=None, data_types=None):
    """Fetch dataset metadata and text data for given years and data types."""
    years = (
        [str(year) for year in years] if isinstance(years, list) else
        [str(years)] if years else
        [year for year in os.listdir(os.path.join(path_to_data, 'metadata')) if year.isdigit()]
    )
    data_types = data_types or ['metadata', 'text']

    load_metadata = 'metadata' in data_types
    load_textdata = 'text' in data_types

    metadata, textdata = [], []
    metadata_years, textdata_years = [], []

    for year in years:
        if load_metadata:
            metadata_path = os.path.join(path_to_data, 'metadata', year, f'{dataset_source}_metadata_{year}.parquet')
            if os.path.isfile(metadata_path):
                metadata.append(pd.read_parquet(metadata_path, engine="pyarrow"))
                metadata_years.append(year)
        
        if load_textdata:
            textdata_path = os.path.join(path_to_data, 'text', year, f'{dataset_source}_text_{year}.parquet')
            if os.path.isfile(textdata_path):
                textdata.append(pd.read_parquet(textdata_path, engine="pyarrow"))
                textdata_years.append(year)
    
    metadata = pd.concat(metadata, axis=0) if metadata else []
    textdata = pd.concat(textdata, axis=0) if textdata else []
    
    msg_parts = []
    if load_metadata:
        msg_parts.append(f'metadata loaded for years: {metadata_years}')
    if load_textdata:
        msg_parts.append(f'text data loaded for years: {textdata_years}')
    
    if msg_parts:
        print("; ".join(msg_parts))
    
    if load_metadata and load_textdata and len(metadata) != len(textdata):
        raise ValueError("Metadata and text data don't have the same length.")
    
    output = (data for data in [metadata, textdata] if len(data) > 0)
    return output

## Load dataset

In [None]:
year = 2005

metadata, textdata = fetch_dataset(dataset_source, year)

## Load embedding model

In [None]:
model_name = 'allenai/specter2'#'sentence-transformers/all-mpnet-base-v1' #
batch_size = 128
max_length = 512
Nsamples = len(data)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
if 'sentence-transformer' in model_name:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
elif 'specter2' in model_name:
    tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')
    model = AutoAdapterModel.from_pretrained('allenai/specter2_base')
    model.load_adapter(model_name, source="hf", set_active=True)

model.to(device);
for param in model.parameters():
    param.requires_grad = False

## Compute embeddings

In [None]:
# preprocess the input
embeddings = []

model.eval()
with torch.no_grad():
    with tqdm(range(Nsamples//batch_size + 1), desc=f'total of {Nsamples} papers', unit='batch') as pbar:
        for i in pbar:
            batch_text = data.iloc[i*batch_size:(i + 1)*batch_size]['title'] + tokenizer.sep_token + data.iloc[i*batch_size:(i + 1)*batch_size]['abstract']
            batch_text = batch_text.to_list()
            if batch_text:
                batch_tokens = tokenizer(batch_text, padding=True, truncation=True, return_tensors="pt", return_token_type_ids=False, max_length=max_length)
                for key in batch_tokens.keys():
                    batch_tokens[key] = batch_tokens[key].to(device)
                
                output = model(**batch_tokens)
                # first token in the batch as the embedding
                if 'pooler_output' in output.keys():
                    embeddings_batch = output.pooler_output.cpu().numpy().astype(np.float32)
                else:
                    embeddings_batch = output.last_hidden_state[:, 0, :].cpu().numpy().astype(np.float32)
                    
                embeddings.append(embeddings_batch)

## Save embeddings

In [None]:
model_nickname = model_name.split('/')[-1]
embeddings_filepath = os.path.join(config.path_to_data, 'embeddings', f'embeddings_{dataset_source}_{model_nickname}.h5')

In [None]:
embeddings = np.concat(embeddings, axis=0)

In [None]:
#we should save as hf5 instead of HDFStore
B, D = embeddings.shape
with pd.HDFStore(embeddings_filepath, mode="w", complib="blosc", complevel=9) as store:
    store.put("metadata", data, format="table", data_columns=True)
    store.put("embeddings", pd.DataFrame(embeddings).set_index(data.index), format="table")