In [132]:
# Embedding creation
import torch
from transformers import AutoTokenizer, AutoModel

# DB connection
from google.cloud.sql.connector import Connector
import sqlalchemy
from sqlalchemy import create_engine, text

# Utils
from tqdm import tqdm

# Data processing
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import os

In [93]:
class DatabaseInterface:
    def __init__(self, instance_connection_name, db_user, db_pass, db_name):
        self.instance_connection_name = instance_connection_name
        self.db_user = db_user
        self.db_pass = db_pass
        self.db_name = db_name
        self.connector = Connector()
        self.pool = self.create_pool()

    def get_conn(self):
        conn = self.connector.connect(
            self.instance_connection_name,
            "pg8000",
            user=self.db_user,
            password=self.db_pass,
            db=self.db_name
        )
        return conn

    def create_pool(self):
        return create_engine(
            "postgresql+pg8000://",
            creator=self.get_conn,
        )
    
    def run_query(self, query, fetch=True):
        with self.pool.connect() as connection:
            result = connection.execute(text(query))
            connection.commit() 
            if fetch:
                return result.fetchall()
            else:
                return None
            
    def insert_data_from_dataframe(self, dataframe, table_name):
        dataframe.to_sql(
            table_name,
            self.pool,
            if_exists='append',
            index=False,
            method='multi'
        )
        
    def create_table(self, table_name, columns):
        """
        Creates a table with the given name and columns.
        :param table_name: The name of the table
        :param columns: A dictionary where keys are column names and values are SQL data types
        """
        cols = ', '.join(f'{col} {dtype}' for col, dtype in columns.items())
        create_table_query = f'CREATE TABLE {table_name} ({cols});'
        self.run_query(create_table_query, fetch=False)

    def drop_table(self, table_name):
        """
        Drops the table with the given name.
        :param table_name: The name of the table
        """
        drop_table_query = f'DROP TABLE {table_name};'
        print(self.run_query(drop_table_query, fetch=False))

In [94]:
project_id = "steam-378309"
region = "europe-west3"
instance_name = "legalm"

DB_NAME = "pubmed"
DB_USER = "postgres"
DB_PASS = "bestday13"
DB_PORT = "5432"

INSTANCE_CONNECTION_NAME = f"{project_id}:{region}:{instance_name}"
print(f"Your instance connection name is: {INSTANCE_CONNECTION_NAME}")

Your instance connection name is: steam-378309:europe-west3:legalm


In [95]:
db_interface = DatabaseInterface(INSTANCE_CONNECTION_NAME, DB_USER, DB_PASS, DB_NAME)

In [96]:
table_query = """
        SELECT tablename 
        FROM pg_catalog.pg_tables 
        WHERE schemaname != 'pg_catalog' 
        AND schemaname != 'information_schema';
        """

In [97]:
abstracts_query = """
                  SELECT * FROM pm_abstracts;
                  """

In [99]:
all_abstracts = db_interface.run_query(abstracts_query)

In [133]:
text_splitter = RecursiveCharacterTextSplitter(
    separators=[".", "\n"],
    chunk_size=500,
    chunk_overlap=0,
    length_function=len,
)

In [135]:
chunked = []

# Iterate over each tuple in the list
for abstract_tuple in tqdm(all_abstracts):
    pmid, title, abstract, _ = abstract_tuple  # Unpack the tuple
    
    splits = text_splitter.create_documents([abstract])
    for s in splits:
        r = {"pmid": pmid, "title": title, "abstract": s.page_content}
        chunked.append(r)

100%|██████████| 17076/17076 [00:01<00:00, 12339.00it/s]


In [136]:
model_name = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

In [141]:
# Extract all abstracts
all_abstracts = [item["abstract"] for item in chunked]

# Tokenize all abstracts at once
all_tokenized_inputs = tokenizer(all_abstracts, return_tensors="pt", padding=True, truncation=True)

# Initialize an empty list to store the data
data = []
batch_size = 256

In [142]:
# Loop through the tokenized inputs in batches
for i in tqdm(range(0, len(chunked), batch_size)):
    batch_abstracts_dicts = chunked[i:i+batch_size]
    
    # Extract the relevant slices from the pre-tokenized inputs
    batch_input_ids = all_tokenized_inputs['input_ids'][i:i+batch_size]
    batch_attention_mask = all_tokenized_inputs['attention_mask'][i:i+batch_size]
    
    # Generate embeddings for the encoded text
    with torch.no_grad():
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        batch_embeddings = outputs.last_hidden_state[:, 0, :].numpy()

    # Append the data to the list
    for j, embedding in enumerate(batch_embeddings):
        pmid = batch_abstracts_dicts[j]["pmid"]
        title = batch_abstracts_dicts[j]["title"]
        abstract = batch_abstracts_dicts[j]["abstract"]
        data.append([pmid, title, abstract, embedding])


  0%|          | 1/277 [02:31<11:34:50, 151.05s/it]


KeyboardInterrupt: 

In [119]:
# Create a pandas DataFrame from the list
df = pd.DataFrame(data, columns=['pmid', 'title', 'abstract', 'embedding'])

In [124]:
ls = ["e"]
type(ls)

list