In [54]:
import torch
import h5py
import json
import numpy as np
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

In [None]:
MODELS_DIR = "/data2/scratch/junhalee/extract-llama-embed/models"


def load_model(model_name):
    """
    Load the model and tokenizer from the Hugging Face Hub. If the model is not
    found, it will be downloaded and cached in the specified directory.

    Args:
        model_name (str): The name of the model to load.
    """
    print(f"Loading model {model_name}...")
    model = AutoModel.from_pretrained(
        model_name,
        cache_dir=cache_dir,
        local_files_only=True
    )
    
    # Load tokenizer with same cache directory
    tokenizer = AutoModel.from_pretrained(
        model_name,
        cache_dir=cache_dir,
        local_files_only=True
    )

    return model, tokenizer


def get_embeddings(model, tokenizer, text):
    """
    Get the embeddings for a given text using the specified model and tokenizer.
    Args:
        text (str): The input text to encode.
        tokenizer: The tokenizer to use for encoding the text.
        model: The model to use for generating embeddings.

    Returns:
        torch.Tensor: The embeddings for the input text.
    """
    # Tokenize with padding/truncation
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    
    # Get hidden states (batch_size x seq_len x hidden_dim)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state
    return embeddings

In [None]:
DATA_DIR = "/data2/scratch/junhalee/extract-llama-embed/data/chembl_35/chembl_35_sqlite/chembl_clip_abstracts.h5"

# Explore the HDF5 file
"""
File has 82507 papers
- Each paper is keyed by CHEMBL ID (ranges from CHEMBL1121361 to CHEMBL5483193), keying leads to a Group
- Each paper is a group with 3 keys: 'abstract', 'compounds', and 'doi', keying leads to a Dataset (needs to be decoded to utf-8)
    - 'abstract' is a string of the abstract
    - 'doi' is a string of the doi
    - 'compounds' is a string of the compounds, which are in SMILES format

keys are the CHEMBL IDs ranging from 'CHEMBL1121361' to 'CHEMBL5483193'
Key -> Group, each group again has 
"""


with h5py.File(DATA_DIR, 'r') as f:
    dataset_name = list(f.keys())[2000]
    dataset = f[dataset_name]

    abstract = dataset['abstract'][()].decode('utf-8')
    doi = dataset['doi'][()].decode('utf-8')
    compounds = json.loads(dataset['compounds'][()].decode('utf-8'))

CHEMBL ID: CHEMBL1123522
Abstract: The synthesis of 1-[(2-mercaptocyclopentyl)carbonyl]-L-prolines, 1-[(2-mercaptocyclobutyl)carbonyl]-L-prolines and related benzoyl derivatives as pure isomers is described. The abilities of all the compounds to inhibit angiotensin converting enzyme (ACE) in vitro and in vivo and to lower the systolic blood pressure in renal hypertensive dogs were determined. Three of them, namely 1-[[2-(benzoylthio)cyclopentyl]carbonyl]-L-proline (10f(R,S], 1-[(2-mercaptocyclopentyl)carbonyl]-L-proline (10g(R,S], and 1-[[2-(benzoylthio)cyclobutyl]carbonyl]-L-proline (16f(R,S], were found to be as potent as captopril in reducing blood pressure. The influence of chirality and ring size on the ACE inhibition is described.
DOI: 10.1021/jm00153a017
Compounds: [{'molfile': '\n     RDKit          2D\n\n 16 17  0  0  1  0  0  0  0  0999 V2000\n    5.4167   -2.5750    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0\n    6.1292   -2.1625    0.0000 N   0  0  0  0  0  0  0  0  0  0

In [55]:
with h5py.File(DATA_DIR, 'r') as f:
    dataset_names = list(f.keys())

    # Get all the abstracts along with their corresponding CHEMBL IDs
    abstracts = {}
    for dataset_name in tqdm(dataset_names):
        dataset = f[dataset_name]
        abstract = dataset['abstract'][()].decode('utf-8')
        abstracts[dataset_name] = abstract

100%|██████████| 82507/82507 [00:10<00:00, 7612.57it/s]
