In [4]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

PyTorch version: 2.7.1+cu118
CUDA available: True


In [5]:
from transformers import AutoTokenizer, AutoModel
import torch

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load the PubMedBERT-base model and tokenizer
# tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings")
# model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings")

tokenizer = AutoTokenizer.from_pretrained("pritamdeka/S-PubMedBert-MS-MARCO")
model = AutoModel.from_pretrained("pritamdeka/S-PubMedBert-MS-MARCO")

# Move model to GPU if available
model = model.to(device)

def encode_text(title, abstract):
    """Encode text using PubMedBERT with GPU support."""
    margin = 12
    max_length = 512 - margin # Maximum length for PubMedBERT
    text = f"{title} {abstract}"
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_length)
    
    # Move inputs to the same device as model
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Move embeddings back to CPU for numpy conversion
    embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return embeddings

embedding_dim = model.config.hidden_size
print(f"Embedding dimension: {embedding_dim}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Embedding dimension: 768


In [6]:
from pymilvus import MilvusClient
REMOTE_URI = "http://127.0.0.1:19530"
milvus_client = MilvusClient(uri=REMOTE_URI)

collection_name = "pmc_trec_2016"

In [32]:
milvus_client.list_collections()

[]

In [None]:
# milvus_client.drop_collection(collection_name="pmc_trec_2016")

In [33]:
index_params = milvus_client.prepare_index_params()

index_params.add_index(
    field_name="vector", # Name of the vector field to be indexed
    index_type="IVF_FLAT", # Type of the index to create
    index_name="vector_index", # Name of the index to create
    metric_type="COSINE", # Metric type used to measure similarity
    params={
        "nlist": 64, # Number of clusters for the index
    } # Index building params
)

In [34]:
from pymilvus import FieldSchema, CollectionSchema, MilvusClient, DataType

fields = [
    FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=20),
    FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
    FieldSchema(name="doc", dtype=DataType.JSON)
]
schema = CollectionSchema(fields)

milvus_client.create_collection(
    collection_name=collection_name,
    schema=schema,
    metric_type="COSINE",
    consistency_level="Bounded",
    index_params=index_params
)

milvus_client.create_index(
    collection_name=collection_name,
    index_params=index_params
)


In [35]:
import os
import json
from tqdm import tqdm

data_folder = 'pmc_shards'
data_subfiles = [f for f in os.listdir(data_folder) if os.path.isfile(os.path.join(data_folder, f))]
batch_size = 10000
print(data_subfiles)

for file_name in data_subfiles:
    print(f"Found file: {file_name}")
    if file_name.endswith('.jsonl'):
        file_path = os.path.join(data_folder, file_name)
        try:
            print(f"Processing file: {file_path}")
            
            # Read JSONL file line by line
            documents = []
            with open(file_path, 'r', encoding='utf-8') as file:
                for line_num, line in tqdm(enumerate(file), desc=f"Processing {file_name}"):
                    line = line.strip()
                    # print(f"Processing line {line_num}")
                    if line:  # Skip empty lines
                        try:
                            documents.append(json.loads(line))
                        except json.JSONDecodeError as e:
                            print(f"Error parsing line {line_num}: {e}")
                            continue
                    if line_num > 0 and line_num % batch_size == 0:
                        milvus_client.insert(collection_name=collection_name, data=documents)
                        documents = []  # Clear the list after insertion
                
                # Process the last batch if it has any documents
                if documents:
                    milvus_client.insert(collection_name=collection_name, data=documents)
            
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
            continue

['shard-00001.jsonl', 'shard-00002.jsonl', 'shard-00003.jsonl', 'shard-00004.jsonl', 'shard-00005.jsonl', 'shard-00006.jsonl', 'shard-00007.jsonl', 'shard-00008.jsonl', 'shard-00009.jsonl', 'shard-00010.jsonl', 'shard-00011.jsonl']
Found file: shard-00001.jsonl
Processing file: pmc_shards\shard-00001.jsonl


Processing shard-00001.jsonl: 100000it [01:43, 970.65it/s]


Found file: shard-00002.jsonl
Processing file: pmc_shards\shard-00002.jsonl


Processing shard-00002.jsonl: 100000it [01:23, 1195.76it/s]


Found file: shard-00003.jsonl
Processing file: pmc_shards\shard-00003.jsonl


Processing shard-00003.jsonl: 100000it [01:37, 1020.51it/s]


Found file: shard-00004.jsonl
Processing file: pmc_shards\shard-00004.jsonl


Processing shard-00004.jsonl: 100000it [01:35, 1046.15it/s]


Found file: shard-00005.jsonl
Processing file: pmc_shards\shard-00005.jsonl


Processing shard-00005.jsonl: 100000it [01:25, 1173.54it/s]


Found file: shard-00006.jsonl
Processing file: pmc_shards\shard-00006.jsonl


Processing shard-00006.jsonl: 100000it [01:36, 1032.30it/s]


Found file: shard-00007.jsonl
Processing file: pmc_shards\shard-00007.jsonl


Processing shard-00007.jsonl: 100000it [01:46, 941.36it/s]


Found file: shard-00008.jsonl
Processing file: pmc_shards\shard-00008.jsonl


Processing shard-00008.jsonl: 100000it [01:47, 929.53it/s]


Found file: shard-00009.jsonl
Processing file: pmc_shards\shard-00009.jsonl


Processing shard-00009.jsonl: 100000it [01:42, 979.30it/s]


Found file: shard-00010.jsonl
Processing file: pmc_shards\shard-00010.jsonl


Processing shard-00010.jsonl: 100000it [01:52, 886.24it/s]


Found file: shard-00011.jsonl
Processing file: pmc_shards\shard-00011.jsonl


Processing shard-00011.jsonl: 99505it [01:50, 899.67it/s] 


In [37]:
# Get collection statistics
collection_info = milvus_client.describe_collection(collection_name=collection_name)
print(f"Collection info: {collection_info}")

# Get number of entities (documents) in the collection
num_entities = milvus_client.get_collection_stats(collection_name=collection_name)
print(f"Collection stats: {num_entities}")

Collection info: {'collection_name': 'pmc_trec_2016', 'auto_id': False, 'num_shards': 1, 'description': '', 'fields': [{'field_id': 100, 'name': 'id', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 20}, 'is_primary': True}, {'field_id': 101, 'name': 'vector', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 768}}, {'field_id': 102, 'name': 'doc', 'description': '', 'type': <DataType.JSON: 23>, 'params': {}}], 'functions': [], 'aliases': [], 'collection_id': 461238685143259453, 'consistency_level': 2, 'properties': {}, 'num_partitions': 1, 'enable_dynamic_field': False, 'created_timestamp': 461273312432488451, 'update_timestamp': 461273312432488451}
Collection stats: {'row_count': 1099505}
