In [1]:
!pip install flash-attn



In [6]:
import os

import h5py
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

# Set random seed for reproducibility
torch.random.manual_seed(0)

# Create a small dataset
data = [
    {"uid": "A001", "seq": "MLEVPVWIPILAFAVGLGLGLLIPHLQKPFQRF", "text": "This protein is involved in membrane transport."},
    {"uid": "A002", "seq": "MSLEQKKGADIISKILQIQNSIGKTTSPSTLKT", "text": "This enzyme catalyzes the hydrolysis of ATP."},
    {"uid": "A003", "seq": "MKMKQQGLVADLLPNIRVMKTFGHFVFNYYNDN", "text": "This transcription factor regulates gene expression."}
]

# Save the dataset as a CSV file
csv_file = 'sample_data.csv'
df = pd.DataFrame(data)
df.to_csv(csv_file, index=False)
print(f"Sample data saved to {csv_file}")

Sample data saved to sample_data.csv


In [None]:
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3.5-mini-instruct",
    device_map="cuda",
    torch_dtype="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")

In [5]:
# Function to extract embeddings
def extract_embeddings(text, sentence_level=True):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    last_hidden_state = outputs.hidden_states[-1]

    if sentence_level:
        # Average over tokens to get sentence embedding
        embeddings = last_hidden_state.mean(dim=1)
    else:
        # Keep per-token embeddings
        embeddings = last_hidden_state.squeeze(0)

    # Convert to float32 before converting to numpy
    return embeddings.cpu().float().numpy()

# Function to process CSV and save embeddings
def process_csv_to_hdf5(csv_file, hdf5_file, sentence_level=True):
    # Read the CSV file
    df = pd.read_csv(csv_file)

    # Create an HDF5 file
    with h5py.File(hdf5_file, 'w') as f:
        # Create a group for sentence-level or token-level embeddings
        group_name = 'sentence_embeddings' if sentence_level else 'token_embeddings'
        group = f.create_group(group_name)

        # Process each row in the DataFrame
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing entries"):
            uid = row['uid']
            text = row['text']

            # Extract embeddings
            embedding = extract_embeddings(text, sentence_level)

            # Save embeddings to HDF5 file
            group.create_dataset(uid, data=embedding)

    print(f"Embeddings saved to {hdf5_file}")

    # Verify the contents of the HDF5 file
    with h5py.File(hdf5_file, 'r') as f:
        print(f"\nContents of the HDF5 file ({group_name}):")
        for key in f[group_name].keys():
            print(f"UID: {key}, Shape: {f[group_name][key].shape}")

In [7]:
# Save sentence-level embeddings
process_csv_to_hdf5(csv_file, 'sentence_embeddings.h5', sentence_level=True)

# Save token-level embeddings
process_csv_to_hdf5(csv_file, 'token_embeddings.h5', sentence_level=False)

# Clean up the CSV file
os.remove(csv_file)
print(f"\nRemoved temporary CSV file: {csv_file}")

Processing entries: 100%|██████████| 3/3 [00:00<00:00,  3.90it/s]


Embeddings saved to sentence_embeddings.h5

Contents of the HDF5 file (sentence_embeddings):
UID: A001, Shape: (1, 3072)
UID: A002, Shape: (1, 3072)
UID: A003, Shape: (1, 3072)


Processing entries: 100%|██████████| 3/3 [00:00<00:00,  3.92it/s]

Embeddings saved to token_embeddings.h5

Contents of the HDF5 file (token_embeddings):
UID: A001, Shape: (9, 3072)
UID: A002, Shape: (15, 3072)
UID: A003, Shape: (9, 3072)

Removed temporary CSV file: sample_data.csv





In [None]:
# Example usage
csv_file = 'your_file.csv'  # Datafile with `uid` and `text` columns
hdf5_file = 'embeddings.h5'  # output path

# Save sentence-level embeddings
process_csv_to_hdf5(csv_file, 'sentence_embeddings.h5', sentence_level=True)

# Save token-level embeddings
process_csv_to_hdf5(csv_file, 'token_embeddings.h5', sentence_level=False)a