In [1]:
import torch
import h5py
import json
import numpy as np
from tqdm import tqdm

import sys
import os
sys.path.append(os.getcwd() + "/../")
from src.utils.constants import CHEMBL_DATA_FILE

In [2]:
batch_size = 32
embeddings_file_path = "/data2/scratch/junhalee/extract-llama-embed/data/chembl_35/tests/sample_large.h5"

In [3]:
def generate_sample_embedding():
    # generate a batch_size x 8192 torch tensor with random values between -1 and 1
    tensor = torch.randn(batch_size, 8192) * 2 - 1
    return tensor

embeddings = []
for i in range(0, 82507, batch_size):
    batch_embeddings = generate_sample_embedding()
    if i + batch_size > 82507:
        batch_embeddings = batch_embeddings[:82507 - i]
    embeddings.append(batch_embeddings)

embeddings = torch.cat(embeddings, dim=0)

In [4]:
def get_chemBL_data(chemBL_file_path):
    """
    Get the abstracts from the ChemBL data file.
    """
    with h5py.File(chemBL_file_path, 'r') as f:
        paper_ids = list(f.keys())

        abstracts = []
        canon_SMILES_lists = []
        for paper_name in tqdm(paper_ids):
            h5_dataset = f[paper_name]
            abstract = h5_dataset['abstract'][()].decode('utf-8')
            abstracts.append(abstract)
            compounds_list = json.loads(h5_dataset['compounds'][()].decode('utf-8'))
            canon_SMILES = [compound['canonical_smiles'] for compound in compounds_list]
            canon_SMILES_lists.append(canon_SMILES)

    return abstracts, canon_SMILES_lists

abstracts, canon_SMILES_lists = get_chemBL_data(CHEMBL_DATA_FILE)

100%|██████████| 82507/82507 [01:29<00:00, 918.29it/s] 


In [5]:
canon_SMILES_lists = canon_SMILES_lists[:embeddings.shape[0]]

print(len(canon_SMILES_lists))
print(embeddings.shape)

json_canon_SMILES = json.dumps(canon_SMILES_lists)

82507
torch.Size([82507, 8192])


In [6]:
with h5py.File(embeddings_file_path, 'w') as f:
    f.create_dataset('embeddings', data=embeddings.cpu().numpy())
    f.create_dataset('canon_SMILES_json', data=json_canon_SMILES.encode('utf-8'))

In [7]:
# Try decoding the h5 file
with h5py.File(embeddings_file_path, 'r') as f:
    read_embeddings = f['embeddings'][()]
    read_canon_SMILES_lists = json.loads(f['canon_SMILES_json'][()].decode('utf-8'))

print(len(read_canon_SMILES_lists))
print(embeddings.shape)

82507
torch.Size([82507, 8192])
