In [1]:
import torch
import h5py
import json
import numpy as np
import accelerate
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

import sys
import os
sys.path.append(os.getcwd() + "/../")

from src.utils.constants import CHEMBL_DATA_FILE, MODELS_DIR, LLAMA_3P3_70B_MODEL_DIR, LLAMA_3P3_70B_MODEL_NAME

In [6]:
with h5py.File(CHEMBL_DATA_FILE, 'r') as f:
    paper_ids = list(f.keys())

    abstracts = []
    canon_SMILES_lists = []
    for paper_name in tqdm(paper_ids):
        h5_dataset = f[paper_name]
        abstract = h5_dataset['abstract'][()].decode('utf-8')
        abstracts.append(abstract)
        compounds_list = json.loads(h5_dataset['compounds'][()].decode('utf-8'))
        canon_SMILES = [compound['canonical_smiles'] for compound in compounds_list]
        canon_SMILES_lists.append(canon_SMILES)

data = list(zip(canon_SMILES_lists, abstracts))

100%|██████████| 82507/82507 [01:36<00:00, 858.40it/s] 


In [7]:
data[1]

(['COc1ccc2[nH]c3ccc4cc[n+](CCO)cc4c3c2c1.[I-]',
  'COc1ccc2[nH]c3ccc4cc[n+](CCO)cc4c3c2c1.[I-]',
  'COc1ccc2[nH]c3ccc4cc[n+](CCO)cc4c3c2c1.[I-]',
  'COc1ccc2[nH]c3ccc4cc[n+](CCO)cc4c3c2c1.[I-]',
  'COc1ccc2c(c1)c1c3c[n+](CCN4CCCCC4)ccc3ccc1n2C.[Cl-]',
  'COc1ccc2c(c1)c1c3c[n+](CCN4CCCCC4)ccc3ccc1n2C.[Cl-]',
  'COc1ccc2c(c1)c1c3c[n+](CCN4CCCCC4)ccc3ccc1n2C.[Cl-]',
  'COc1ccc2c(c1)c1c3ccncc3ccc1n2C',
  'COc1ccc2c(c1)c1c3ccncc3ccc1n2C',
  'COc1ccc2c(c1)c1c3ccncc3ccc1n2C',
  'COc1ccc2c(c1)c1c3ccncc3ccc1n2C',
  'C[n+]1ccc2c(ccc3[nH]c4ccc(O)cc4c32)c1.[I-]',
  'C[n+]1ccc2c(ccc3[nH]c4ccc(O)cc4c32)c1.[I-]',
  'C[n+]1ccc2c(ccc3[nH]c4ccc(O)cc4c32)c1.[I-]',
  'Oc1ccc2[nH]c3ccc4ccncc4c3c2c1',
  'Oc1ccc2[nH]c3ccc4ccncc4c3c2c1',
  'Oc1ccc2[nH]c3ccc4ccncc4c3c2c1',
  'Oc1ccc2[nH]c3ccc4ccncc4c3c2c1',
  'COc1ccc2c(c1)c1c3c[n+](C)ccc3ccc1n2C.[I-]',
  'COc1ccc2c(c1)c1c3c[n+](C)ccc3ccc1n2C.[I-]',
  'COc1ccc2c(c1)c1c3c[n+](C)ccc3ccc1n2C.[I-]',
  'COc1ccc2c(c1)c1c3c[n+](C)ccc3ccc1n2C.[I-]',
  'COc1ccc2[nH]c3

In [3]:
def load_model(model_name, cache_dir):
    """
    Load the model and tokenizer from the Hugging Face Hub. If the model is not
    found, it will be downloaded and cached in the specified directory.

    Args:
        model_name (str): The name of the model to load.
    """

    # Check if the model is already cached
    if os.path.exists(cache_dir):
        print(f"Model {model_name} already cached in {cache_dir}.")
        model = AutoModel.from_pretrained(
            cache_dir,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            local_files_only=True
        )
        tokenizer = AutoTokenizer.from_pretrained(
            cache_dir,
            local_files_only=True
        )

    else:
        print(f"Downloading model {model_name} to {cache_dir}...")
        model = AutoModel.from_pretrained(
            model_name,
            cache_dir=cache_dir,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            local_files_only=True
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            cache_dir=cache_dir,
            local_files_only=True
        )

    return model, tokenizer


def get_embeddings(model, tokenizer, texts):
    """
    Get the embeddings for a given list of strings using the specified model and tokenizer.
    Args:
        texts (list): A list of strings to encode.
        model: The model to use for generating embeddings.
        tokenizer: The tokenizer to use for encoding the text.

    Returns:
        torch.Tensor: The embeddings for the input text.
    """
    # Tokenize the input text
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )

    # Move the inputs to the same device as the model
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Forward pass to get the embeddings
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state

    # Get the mean of the last hidden state across the sequence length (do not include padding)
    attention_mask = inputs["attention_mask"]
    mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
    sum_embeddings = torch.sum(embeddings * mask_expanded, 1)
    sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
    embeddings = sum_embeddings / sum_mask

    return embeddings

In [4]:
model_name = LLAMA_3P3_70B_MODEL_NAME
cache_dir = LLAMA_3P3_70B_MODEL_DIR

model, tokenizer = load_model(model_name, cache_dir)

Model meta-llama/Llama-3.3-70B-Instruct already cached in /data2/scratch/junhalee/extract-llama-embed/models/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b/.


Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


In [5]:
for name, param in model.named_parameters():
    print(f"Layer: {name}, Device: {param.device}")

Layer: embed_tokens.weight, Device: cuda:1
Layer: layers.0.self_attn.q_proj.weight, Device: cuda:1
Layer: layers.0.self_attn.k_proj.weight, Device: cuda:1
Layer: layers.0.self_attn.v_proj.weight, Device: cuda:1
Layer: layers.0.self_attn.o_proj.weight, Device: cuda:1
Layer: layers.0.mlp.gate_proj.weight, Device: cuda:1
Layer: layers.0.mlp.up_proj.weight, Device: cuda:1
Layer: layers.0.mlp.down_proj.weight, Device: cuda:1
Layer: layers.0.input_layernorm.weight, Device: cuda:1
Layer: layers.0.post_attention_layernorm.weight, Device: cuda:1
Layer: layers.1.self_attn.q_proj.weight, Device: cuda:1
Layer: layers.1.self_attn.k_proj.weight, Device: cuda:1
Layer: layers.1.self_attn.v_proj.weight, Device: cuda:1
Layer: layers.1.self_attn.o_proj.weight, Device: cuda:1
Layer: layers.1.mlp.gate_proj.weight, Device: cuda:1
Layer: layers.1.mlp.up_proj.weight, Device: cuda:1
Layer: layers.1.mlp.down_proj.weight, Device: cuda:1
Layer: layers.1.input_layernorm.weight, Device: cuda:1
Layer: layers.1.post_

In [None]:
model.to("cuda:0")

In [None]:
# Get embeddings for the abstracts in batches 
batch_size = 32
embeddings = []
for i in tqdm(range(0, len(abstracts), batch_size)):
    batch = abstracts[i:i + batch_size]
    batch_embeddings = get_embeddings(model, tokenizer, batch)
    embeddings.append(batch_embeddings)
embeddings = torch.cat(embeddings, dim=0)

In [None]:
# Explore the HDF5 file
"""
File has 82507 papers
- Each paper is keyed by CHEMBL ID (ranges from CHEMBL1121361 to CHEMBL5483193), keying leads to a Group
- Each paper is a group with 3 keys: 'abstract', 'compounds', and 'doi', keying leads to a Dataset (needs to be decoded to utf-8)
    - 'abstract' is a string of the abstract
    - 'doi' is a string of the doi
    - 'compounds' is a string of the compounds, which are in SMILES format

keys are the CHEMBL IDs ranging from 'CHEMBL1121361' to 'CHEMBL5483193'
Key -> Group, each group again has 
"""

with h5py.File(CHEMBL_DATA_FILE, 'r') as f:
    dataset_name = list(f.keys())[2000]
    dataset = f[dataset_name]

    abstract = dataset['abstract'][()].decode('utf-8')
    doi = dataset['doi'][()].decode('utf-8')
    compounds = json.loads(dataset['compounds'][()].decode('utf-8'))

In [None]:
with h5py.File(CHEMBL_DATA_FILE, 'r') as f:
    dataset_names = list(f.keys())

    # Get all the abstracts along with their corresponding CHEMBL IDs
    abstracts = {}
    for dataset_name in tqdm(dataset_names):
        dataset = f[dataset_name]
        abstract = dataset['abstract'][()].decode('utf-8')
        abstracts[dataset_name] = abstract

In [None]:
with h5py.File(CHEMBL_DATA_FILE, 'r') as f:
    paper_ids = list(f.keys())

    abstracts = []
    for paper_name in tqdm(paper_ids):
        h5_dataset = f[paper_name]
        abstract = h5_dataset['abstract'][()].decode('utf-8')
        abstracts.append(abstract)

In [None]:
abstracts