In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Data Loading and Initial Preparation

In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import re
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification
import torch
from nltk.corpus import stopwords
import nltk

# Download NLTK stopwords
nltk.download('stopwords')

# Data Loading
def read_csv_in_chunks(file_path, chunk_size=100000):
    chunks = pd.read_csv(file_path, chunksize=chunk_size)
    return pd.concat(chunks, ignore_index=True)

# Load the datasets
df = pd.read_csv('/content/drive/MyDrive/cleaned_data/sampled_pubmed_data.csv')
print(f"Initial shape of the dataset: {df.shape}")

# Load your fine-tuned models and tokenizer
tokenizer = AutoTokenizer.from_pretrained("/content/drive/MyDrive/tokenizer")
ner_model = AutoModelForTokenClassification.from_pretrained("/content/drive/MyDrive/trained_ner_model")
relation_model = AutoModelForSequenceClassification.from_pretrained("/content/drive/MyDrive/trained_relation_model")

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ner_model.to(device)
relation_model.to(device)

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


Initial shape of the dataset: (6918932, 5)


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [None]:
# Ensure 'Year' column exists and is in the correct format
if 'Year' not in df.columns:
    df['Year'] = pd.to_datetime(df['Date']).dt.year

# Stratified sampling
def stratified_sample(df, frac=0.1, random_state=42):
    return df.groupby('Year', group_keys=False).apply(lambda x: x.sample(frac=frac, random_state=random_state))

df_subset = stratified_sample(df, frac=0.1)
df_subset.to_csv('/content/drive/MyDrive/cleaned_data/sample_data.csv')
print(f"Shape of 10% subset: {df_subset.shape}")

# Print distribution of years in the subset
print(df_subset['Year'].value_counts().sort_index())

Shape of 10% subset: (691892, 7)
Year
2019    112128
2020    138753
2021    147977
2022    144040
2023    148994
Name: count, dtype: int64


## Text Cleaning and Preprocessing

In [None]:
import spacy
nlp = spacy.load("en_core_sci_sm")
print("Model loaded successfully!")

Model loaded successfully!


# Data Cleaning and NER

In [None]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from tqdm import tqdm
import re
import spacy
from scispacy.abbreviation import AbbreviationDetector
import os


# Load spaCy model for abbreviation expansion
nlp = spacy.load("en_core_sci_sm")
if "abbreviation_detector" not in nlp.pipe_names:
    nlp.add_pipe("abbreviation_detector")

# Set up device and model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("/content/drive/MyDrive/tokenizer")
ner_model = AutoModelForTokenClassification.from_pretrained("/content/drive/MyDrive/trained_ner_model").to(device)

# Ensure the model is in evaluation mode
ner_model.eval()

def preprocess_text(text):
    if pd.isna(text):
        return ""

    # Basic cleaning
    text = re.sub(r'[\[\]\?\.]', '', text)  # Remove specific punctuation
    text = re.sub(r'<[^>]+>', '', text)  # Remove HTML tags
    text = re.sub(r'\s+', ' ', text).strip()  # Standardize whitespace

    # Abbreviation expansion
    doc = nlp(text)
    expanded_text = []
    for token in doc:
        if token._.is_abbreviation:
            expanded_text.append(token._.long_form.text)
        else:
            expanded_text.append(token.text)

    return " ".join(expanded_text)

def clean_and_process_text(text):
    if pd.isna(text):
        return []

    # Clean and preprocess text
    text = preprocess_text(text)

    # Check if text is less than 2 words
    if len(text.split()) < 2:
        return []

    # Tokenize and get NER predictions
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = ner_model(**inputs)

    predictions = torch.argmax(outputs.logits, dim=2)[0].cpu().numpy()

    # Decode tokens and labels
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    id2label = ner_model.config.id2label

    # Reconstruct words and assign labels
    words = []
    labels = []
    current_word = ""
    current_label = "O"

    for token, pred in zip(tokens, predictions):
        if token in ['[CLS]', '[SEP]', '[PAD]']:
            if current_word:
                words.append(current_word)
                labels.append(current_label)
            current_word = ""
            current_label = "O"
            continue
        if token.startswith('##'):
            current_word += token[2:]
        else:
            if current_word:
                words.append(current_word)
                labels.append(current_label)
            current_word = token
            current_label = id2label[pred]

    if current_word:
        words.append(current_word)
        labels.append(current_label)

    return list(zip(words, labels))

def process_chunk(chunk, column_name):
    try:
        chunk[f'processed_{column_name}'] = chunk[column_name].apply(clean_and_process_text)
        # Remove rows where processed text is empty (less than 2 words or None)
        chunk = chunk[chunk[f'processed_{column_name}'].apply(len) > 0]
        return chunk
    except Exception as e:
        print(f"Error processing chunk: {str(e)}")
        return None

def process_column(df, column_name, chunk_size, temp_folder):
    os.makedirs(temp_folder, exist_ok=True)

    # Find the last processed chunk
    processed_chunks = [int(f.split('_')[-1].split('.')[0]) for f in os.listdir(temp_folder) if f.startswith('processed_chunk_')]
    start_chunk = max(processed_chunks) + chunk_size if processed_chunks else 0

    for i in tqdm(range(start_chunk, len(df), chunk_size), desc=f"Processing {column_name}"):
        chunk = df.iloc[i:i+chunk_size]
        chunk_file = f'{temp_folder}/processed_chunk_{i}.csv'

        processed_chunk = process_chunk(chunk, column_name)

        if processed_chunk is not None and not processed_chunk.empty:
            try:
                processed_chunk.to_csv(chunk_file, index=False)
                print(f"Processed and saved chunk {i} to {i+chunk_size}")
            except Exception as e:
                print(f"Error saving chunk {i} to {i+chunk_size}: {str(e)}")

def combine_chunks(temp_folder):
    processed_chunks = []
    for filename in os.listdir(temp_folder):
        if filename.startswith('processed_chunk_'):
            try:
                chunk = pd.read_csv(f'{temp_folder}/{filename}')
                processed_chunks.append(chunk)
            except Exception as e:
                print(f"Error reading {filename}: {str(e)}")

    if processed_chunks:
        final_df = pd.concat(processed_chunks, ignore_index=True)
        return final_df
    else:
        print(f"No processed chunks found in {temp_folder}")
        return None

In [None]:
# Main processing
if __name__ == '__main__':
    # Load your data
    df = pd.read_csv('/content/drive/MyDrive/cleaned_data/sample_data.csv')

    temp_folder_titles = '/content/drive/MyDrive/temp_chunks_titles'
    temp_folder_abstracts = '/content/drive/MyDrive/temp_chunks_abstracts'

    # Process titles (will skip if already completed)
    process_column(df, 'Title', 5000, temp_folder_titles)

    # Process abstracts (will resume from where it left off)
    process_column(df, 'Abstract', 3000, temp_folder_abstracts)

    # Combine processed titles
    processed_titles = combine_chunks(temp_folder_titles)

    # Combine processed abstracts
    processed_abstracts = combine_chunks(temp_folder_abstracts)

    # Combine titles and abstracts
    if processed_titles is not None and processed_abstracts is not None:
        combined_df = pd.merge(processed_titles, processed_abstracts, on=df.columns.drop(['Title', 'Abstract']).tolist())

        # Save the combined processed data
        combined_df.to_csv('/content/drive/MyDrive/cleaned_data/final_processed_data.csv', index=False)
        print("Processing completed. Combined data saved to /content/drive/MyDrive/cleaned_data/final_processed_data.csv")
    else:
        print("Error: Could not combine processed titles and abstracts.")

In [None]:
# Delete the columns that are created during data cleaning and NER
import pandas as pd

# Read the CSV file
df_final_processed_data = pd.read_csv('/content/drive/MyDrive/final_processed_data.csv')

# Drop the specified columns
columns_to_drop = ['Unnamed: 0', 'Unnamed: 0.1', 'Title_y', 'Abstract_y']
df_final_processed_data = df_final_processed_data.drop(columns=columns_to_drop, errors='ignore')

# Rename columns
df_final_processed_data = df_final_processed_data.rename(columns={'Title_x': 'Title', 'Abstract_x': 'Abstract'})

# Save the modified DataFrame back to a CSV file
df_final_processed_data.to_csv('/content/drive/MyDrive/processed_pubmed_data_with_entities.csv', index=False)

print("Duplicated and index columns have been removed, and 'Title_x' and 'Abstract_x' have been renamed.")
print("The cleaned data has been saved to 'processed_pubmed_data_with_entities.csv'.")

In [None]:
import pandas as pd
df_final_processed_data = pd.read_csv('/content/drive/MyDrive/processed_pubmed_data_with_entities.csv')
df_final_processed_data.info()
# Print distribution of years after data cleaning and preprocessing
print(df_final_processed_data['Year'].value_counts().sort_index())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 585249 entries, 0 to 585248
Data columns (total 8 columns):
 #   Column              Non-Null Count   Dtype 
---  ------              --------------   ----- 
 0   Date                585249 non-null  object
 1   Title               585249 non-null  object
 2   Abstract            585249 non-null  object
 3   MeshHeading         392552 non-null  object
 4   Keywords            434748 non-null  object
 5   Year                585249 non-null  int64 
 6   processed_Title     585249 non-null  object
 7   processed_Abstract  585249 non-null  object
dtypes: int64(1), object(7)
memory usage: 35.7+ MB
Year
2019     94028
2020    115855
2021    124417
2022    123783
2023    127166
Name: count, dtype: int64


## Further Data Preprocessing - Match Label with Entity Name

In [None]:
import re
import pandas as pd
from typing import List, Tuple, Union
import ast

def enhanced_process_ner_results(ner_results: Union[str, List[Tuple[str, str]]]) -> List[Tuple[str, str]]:
    entities = []
    current_entity = []
    current_label = None
    label_map = {
        'LABEL_1': 'CHEMICAL', 'LABEL_2': 'CHEMICAL',
        'LABEL_3': 'GENE-Y', 'LABEL_4': 'GENE-Y',
        'LABEL_5': 'GENE-N', 'LABEL_6': 'GENE-N'
    }

    def get_entity_type(label: str) -> str:
        return label_map.get(label, 'UNKNOWN')

    def clean_entity_text(text: str) -> str:
        text = re.sub(r'\s*-\s*', '-', text)
        text = re.sub(r'\s+([,.;:!?)])', r'\1', text)
        text = re.sub(r'(\()\s+', r'\1', text)
        return text.strip()

    def finalize_entity():
        if current_entity:
            entity_text = ' '.join(token for token, _ in current_entity)
            entity_text = clean_entity_text(entity_text)
            entity_label = get_entity_type(current_entity[0][1])
            entities.append((entity_text, entity_label))

    # Convert string representation of list to actual list if necessary
    if isinstance(ner_results, str):
        try:
            ner_results = ast.literal_eval(ner_results)
        except:
            print(f"Error parsing: {ner_results[:100]}...")  # Print first 100 chars for debugging
            return []

    # Ensure ner_results is a list of tuples
    if not isinstance(ner_results, list) or not all(isinstance(item, tuple) and len(item) == 2 for item in ner_results):
        print(f"Invalid format: {ner_results[:5]}...")  # Print first 5 items for debugging
        return []

    for token, label in ner_results:
        if label == 'LABEL_0':
            finalize_entity()
            current_entity = []
            current_label = None
        elif label in label_map:
            if not current_entity or get_entity_type(label) != current_label:
                finalize_entity()
                current_entity = [(token, label)]
                current_label = get_entity_type(label)
            else:
                current_entity.append((token, label))

    finalize_entity()
    return entities

def process_dataframe_column(df: pd.DataFrame, column_name: str) -> pd.Series:
    return df[f'processed_{column_name}'].apply(enhanced_process_ner_results)

def process_dataset(df: pd.DataFrame) -> pd.DataFrame:
    for column in ['title', 'abstract']:
        if f'processed_{column}' in df.columns:
            df[f'{column}_entities'] = process_dataframe_column(df, column)
    return df

In [None]:
df_final_processed_data = process_dataset(df_final_processed_data)
# Save the modified DataFrame back to a CSV file
df_final_processed_data.to_csv('/content/drive/MyDrive/cleaned_data/cleaned_final_processed_data.csv', index=False)

Processing Title...
Error parsing: [('Epigallocatechin', 'LABEL_1'), ('gallate', 'LABEL_1'), ('diminishes', 'LABEL_0'), ('cigarette', '...
Finished processing Title.
Processing Abstract...
Error parsing: [('Entamoeba', 'LABEL_0'), ('histolytica', 'LABEL_0'), ('intestinal', 'LABEL_0'), ('parasite', 'LABE...
Error parsing: [('Kidney', 'LABEL_0'), ('care', 'LABEL_0'), ('United', 'LABEL_0'), ('States', 'LABEL_0'), ('highly'...
Finished processing Abstract.


In [4]:
import pandas as pd
df_final_processed_data = pd.read_csv('/content/drive/MyDrive/cleaned_data/cleaned_final_processed_data.csv')

In [5]:
df_final_processed_data.head(5)

Unnamed: 0,article_id,Date,Title,Abstract,MeshHeading,Keywords,Year,processed_Title,processed_Abstract,title_entities,abstract_entities
0,0,2019-03-01,Cellular Proteostasis During Influenza A Virus...,"In order to efficiently replicate, viruses req...","Humans; Influenza A virus; Influenza, Human; M...",influenza A virus (IAV); protein aggregation; ...,2019,"[('Cellular', 'LABEL_0'), ('Proteostasis', 'LA...","[('In', 'LABEL_0'), ('order', 'LABEL_0'), ('to...",[],[]
1,1,2019-07-01,Next-generation sequencing with comprehensive ...,Familial adenomatous polyposis (FAP) is an aut...,Adenomatous Polyposis Coli; Adenomatous Polypo...,APC; Colorectal cancer; Familial adenomatous p...,2019,"[('Next', 'LABEL_0'), ('-', 'LABEL_0'), ('gene...","[('Familial', 'LABEL_0'), ('adenomatous', 'LAB...","[('APC', 'GENE-Y')]","[('APC', 'GENE-Y'), ('APC', 'GENE-Y')]"
2,2,2019-12-01,[THE HEALTH SAVING TECHNOLOGIES AT A PEDAGOGIC...,The article investigates health saving technol...,Biomedical Technology; Educational Technology;...,,2019,"[('THE', 'LABEL_0'), ('HEALTH', 'LABEL_0'), ('...","[('The', 'LABEL_0'), ('article', 'LABEL_0'), (...",[],[]
3,3,2019-11-01,"An overview of colistin resistance, mobilized ...","Colistin, also known as polymyxin E, is an ant...",,Enterobacteriaceae; animals; colistin alternat...,2019,"[('An', 'LABEL_0'), ('overview', 'LABEL_0'), (...","[('Colistin', 'LABEL_1'), (',', 'LABEL_0'), ('...","[('colistin', 'CHEMICAL'), ('colistin', 'CHEMI...","[('Colistin', 'CHEMICAL'), ('polymyxin E', 'CH..."
4,4,2019-10-01,Decoupling Filamentous Phage Uptake and Energy...,Filamentous phages are nonlytic viruses that s...,Bacterial Proteins; Bacteriophages; Escherichi...,Tol-Pal system; Ton system; bacteriophage; bac...,2019,"[('Decoupling', 'LABEL_0'), ('Filamentous', 'L...","[('Filamentous', 'LABEL_0'), ('phages', 'LABEL...","[('TolQRA', 'GENE-N')]","[('Tol', 'GENE-N'), ('Tol system', 'GENE-N'), ..."


# Relationship Extraction

In [None]:
import pandas as pd
import numpy as np

df_final_processed_data = pd.read_csv('/content/drive/MyDrive/cleaned_data/processed_pubmed_data_with_entities.csv')

def index_dataset(df):
    # Create a unique identifier for each row
    df['article_id'] = np.arange(len(df))

    # Set this identifier as the index
    df.set_index('article_id', inplace=True)

    return df

# Apply indexing to your dataset
df_final_processed_data = index_dataset(df_final_processed_data)

In [None]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm, tqdm as tqdm_notebook
import torch
import os
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
import ast

# Define relationship types
RELATIONSHIP_TYPES = [
    'INHIBITOR', 'DIRECT-REGULATOR', 'SUBSTRATE', 'ACTIVATOR', 'INDIRECT-UPREGULATOR',
    'INDIRECT-DOWNREGULATOR', 'ANTAGONIST', 'PRODUCT-OF', 'PART-OF', 'AGONIST',
    'AGONIST-ACTIVATOR', 'SUBSTRATE_PRODUCT-OF', 'AGONIST-INHIBITOR'
]

# Load the tokenizer and model
tokenizer_path = '/content/drive/MyDrive/tokenizer'
model_path = '/content/drive/MyDrive/trained_relation_model'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=len(RELATIONSHIP_TYPES))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

class DrugTargetDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['text']
        entity1 = item['entity1']
        entity2 = item['entity2']

        combined_text = f"{text} [SEP] {entity1} [SEP] {entity2}"
        encoding = self.tokenizer.encode_plus(
            combined_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'entity1': entity1,
            'entity2': entity2,
            'entity1_type': item['entity1_type'],
            'entity2_type': item['entity2_type'],
            'article_id': item['article_id']
        }

# Data cleaning and preprocessing
def clean_relationships(df):
    print(f"Initial shape: {df.shape}")

    # Ensure article_id is string
    df['article_id'] = df['article_id'].astype(str)
    print(f"After ensuring article_id is string shape: {df.shape}")

    # Standardize entity names
    df['entity1'] = df.apply(lambda row: standardize_entity(row['entity1'], row['entity1_type']), axis=1)
    df['entity2'] = df.apply(lambda row: standardize_entity(row['entity2'], row['entity2_type']), axis=1)
    print(f"After standardizing entities shape: {df.shape}")

    # Remove rows with blank entities
    df = df.dropna(subset=['entity1', 'entity2'])
    print(f"After removing blank entities shape: {df.shape}")

    # Create a directional pair (maintain original order)
    df['entity_pair'] = df.apply(lambda row: (row['entity1'], row['entity2']), axis=1)

    # Group by article_id and sort within each group by confidence
    df_sorted = df.sort_values(['article_id', 'confidence'], ascending=[True, False])

    # Drop duplicates within each article_id group, preserving directionality
    df_cleaned = df_sorted.drop_duplicates(subset=['article_id', 'entity_pair', 'relationship'], keep='first')
    print(f"After removing duplicates shape: {df_cleaned.shape}")

    # Drop temporary columns
    df_cleaned = df_cleaned.drop(columns=['entity_pair'])

    return df_cleaned

def standardize_entity(entity, entity_type):
    if pd.isna(entity) or entity.strip() == '':
        return None

    entity = str(entity).strip()

    if entity_type == 'CHEMICAL':
        return entity.lower()  # Changed to lowercase for CHEMICAL entities
    else:
        return entity  # Other entity types remain unchanged

def safe_eval(entity_string):
    try:
        return ast.literal_eval(entity_string)
    except (ValueError, SyntaxError):
        print(f"Error parsing entity string: {entity_string}")
        return []

def process_chunk(chunk, temp_folder, chunk_id):
    data = []
    for _, row in chunk.iterrows():
        article_id = row.name  # This gets the index, which is your article_id
        text = f"{row['Title']} {row['Abstract']}"
        entities = safe_eval(row['title_entities']) + safe_eval(row['abstract_entities'])

        for i, e1 in enumerate(entities):
            for e2 in entities[i+1:]:
                if e1[1] in ['CHEMICAL', 'GENE-Y', 'GENE-N'] and e2[1] in ['CHEMICAL', 'GENE-Y', 'GENE-N']:
                    data.append({
                        'article_id': str(article_id),  # Convert to string to ensure consistency
                        'text': text,
                        'entity1': e1[0],
                        'entity1_type': e1[1],
                        'entity2': e2[0],
                        'entity2_type': e2[1]
                    })

    print(f"Chunk {chunk_id}: Found {len(data)} valid entity pairs to process")

    dataset = DrugTargetDataset(data, tokenizer, max_length=256)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

    predictions = extract_relationships(model, dataloader, device)

    cleaned_predictions = clean_relationships(predictions)

    chunk_file = f'{temp_folder}/processed_chunk_{chunk_id}.csv'
    cleaned_predictions.to_csv(chunk_file, index=False)
    print(f"Processed and saved chunk {chunk_id} with {len(cleaned_predictions)} relationships")
    return len(cleaned_predictions)

def extract_relationships(model, dataloader, device):
    model.eval()
    all_predictions = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting relationships", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=1)
            predicted_relationships = torch.argmax(probabilities, dim=1)

            for i, pred in enumerate(predicted_relationships):
                relationship = RELATIONSHIP_TYPES[pred.item()]
                all_predictions.append({
                    'article_id': batch['article_id'][i],
                    'entity1': batch['entity1'][i],
                    'entity1_type': batch['entity1_type'][i],
                    'entity2': batch['entity2'][i],
                    'entity2_type': batch['entity2_type'][i],
                    'relationship': relationship,
                    'confidence': probabilities[i][pred].item()
                })

    return pd.DataFrame(all_predictions)

# Modify the combine_chunks function to handle potential errors
def combine_chunks(temp_folder):
    all_predictions = []
    for filename in os.listdir(temp_folder):
        if filename.startswith('processed_chunk_'):
            try:
                chunk_df = pd.read_csv(f'{temp_folder}/{filename}')
                all_predictions.append(chunk_df)
            except Exception as e:
                print(f"Error reading file {filename}: {str(e)}")
                continue

    if not all_predictions:
        raise ValueError("No valid chunks found. Please check your data and processing steps.")

    combined_df = pd.concat(all_predictions, ignore_index=True)
    combined_df.to_csv('/content/drive/MyDrive/cleaned_data/concat_data.csv')
    final_cleaned_df = clean_relationships(combined_df)

    print(f"Combined and cleaned {len(all_predictions)} chunks. Total relationships: {len(final_cleaned_df)}")
    return final_cleaned_df

def process_full_dataset(df, chunk_size, temp_folder):
    os.makedirs(temp_folder, exist_ok=True)
    progress_file = os.path.join(temp_folder, 'progress.json')

    # Find the highest numbered chunk file
    existing_chunks = [int(f.split('_')[-1].split('.')[0]) for f in os.listdir(temp_folder) if f.startswith('processed_chunk_')]

    if existing_chunks:
        start_chunk = max(existing_chunks) + 1
        print(f"Found {len(existing_chunks)} existing processed chunks.")
    else:
        start_chunk = 0

    total_relationships = 0
    if os.path.exists(progress_file):
        try:
            with open(progress_file, 'r') as f:
                file_content = f.read().strip()
                if file_content:
                    progress = json.loads(file_content)
                    total_relationships = progress.get('total_relationships', 0)
                else:
                    print("Progress file is empty. Starting with 0 total relationships.")
        except json.JSONDecodeError:
            print("Error reading progress file. Starting with 0 total relationships.")
        except Exception as e:
            print(f"Unexpected error reading progress file: {str(e)}. Starting with 0 total relationships.")

    print(f"Starting processing from chunk {start_chunk}")
    print(f"Total relationships so far: {total_relationships}")

    total_chunks = (len(df) + chunk_size - 1) // chunk_size
    remaining_chunks = max(0, total_chunks - start_chunk)

    if remaining_chunks > 0:
        pbar = tqdm_notebook(total=remaining_chunks, desc="Processing chunks")
        for chunk_id in range(start_chunk, total_chunks):
            try:
                start_row = chunk_id * chunk_size
                end_row = min((chunk_id + 1) * chunk_size, len(df))
                chunk = df.iloc[start_row:end_row]

                chunk_relationships = process_chunk(chunk, temp_folder, chunk_id)
                total_relationships += chunk_relationships

                # Update progress
                pbar.update(1)
                pbar.set_postfix({"Total Relationships": total_relationships})

                # Save progress after each chunk
                progress = {
                    'last_processed_chunk': chunk_id,
                    'total_relationships': total_relationships
                }
                with open(progress_file, 'w') as f:
                    json.dump(progress, f)

            except Exception as e:
                print(f"\nAn error occurred while processing chunk {chunk_id}: {str(e)}")
                print("Continuing with the next chunk...")

        pbar.close()
    else:
        print("All chunks have been processed. Moving to the next step.")

    print("\nFinished processing all chunks.")
    return total_relationships

# Main execution
chunk_size = 100  # Adjust based on your memory constraints
temp_folder = '/content/drive/MyDrive/relationship_extraction_temp'

print(f"Starting relationship extraction on dataset with {len(df_final_processed_data)} rows")
total_relationships = process_full_dataset(df_final_processed_data, chunk_size, temp_folder)

print("\nCombining all processed chunks...")
final_results = combine_chunks(temp_folder)

# Save the final results
output_path = '/content/drive/MyDrive/cleaned_data/final_relationship_results.csv'
final_results.to_csv(output_path, index=False)
print(f"Relationship extraction completed. Results saved to {output_path}")
print(f"Total relationships extracted: {len(final_results)}")

# Analysis of cleaning impact
print("\nAnalysis of cleaning impact:")
print(f"Number of unique articles: {final_results['article_id'].nunique()}")
print("\nTop 10 most frequent relationships:")
print(final_results['relationship'].value_counts().head(10))

## Further Data Cleaning on Relationship Extraction Results

In [None]:
import pandas as pd
import re

# Define a set of common stopwords
stopwords = set(['and', 'or', 'the', 'a', 'an', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'])

def is_valid_entity(entity):
    # Remove leading/trailing whitespace
    entity = entity.strip().lower()

    # Check if entity is empty, consists only of punctuation, or is a stopword
    if not entity or re.match(r'^[-(),/.>{}\']+$', entity) or entity in stopwords:
        print(f"Removed invalid entity: '{entity}'")
        return False

    # Check if entity consists only of numbers
    if re.match(r'^\d+$', entity):
        print(f"Removed number-only entity: '{entity}'")
        return False

    # Punctuations that shouldn't appear at the start
    invalid_start = r'^[-\]\)-,/.>\'%&]'

    # Punctuations that shouldn't appear at the end
    invalid_end = r'[-,(/.>\'{$%]$'

    # Check if entity starts with invalid punctuation
    if re.match(invalid_start, entity):
        print(f"Removed entity starting with invalid punctuation: '{entity}'")
        return False

    # Check if entity ends with invalid punctuation
    if re.search(invalid_end, entity):
        print(f"Removed entity ending with invalid punctuation: '{entity}'")
        return False

    return True

# Load the data
final_results = pd.read_csv('/content/drive/MyDrive/cleaned_data/final_relationship_results.csv')

# Filter rows based on entity validity
cleaned_final_results = final_results[final_results['entity1'].apply(is_valid_entity) & final_results['entity2'].apply(is_valid_entity)]

# Save the filtered data
cleaned_final_results.to_csv('/content/drive/MyDrive/cleaned_data/pubmed_data_with_relationships.csv', index=False)

# Print some examples of removed entities
print("\nExamples of removed entities:")
for idx, row in final_results.iterrows():
    if not is_valid_entity(row['entity1']):
        print(f"Removed entity1: '{row['entity1']}'")
    if not is_valid_entity(row['entity2']):
        print(f"Removed entity2: '{row['entity2']}'")
    if idx > 20:  # Limit to first 20 rows for brevity
        break

print(f"\nData filtering completed. {len(final_results) - len(cleaned_final_results)} rows removed.")
print(f"Filtered data saved to 'pubmed_data_with_relationships.csv' with {len(cleaned_final_results)} rows.")

# Check for any remaining entities starting with punctuation
remaining_invalid_start = cleaned_final_results[cleaned_final_results['entity1'].str.match(r'^[-\]\)-,/.>\'%&]') |
                                                cleaned_final_results['entity2'].str.match(r'^[-\]\)-,/.>\'%&]')]
if not remaining_invalid_start.empty:
    print("\nWarning: Some entities still start with invalid punctuation:")
    print(remaining_invalid_start[['entity1', 'entity2']].head())

In [None]:
cleaned_final_results.head(5)

Unnamed: 0,article_id,entity1,entity1_type,entity2,entity2_type,relationship,confidence
0,1,APC,GENE-Y,APC,GENE-Y,AGONIST,0.592381
1,10,ion channel,GENE-N,G protein-coupled receptors,GENE-N,AGONIST,0.567602
2,10,ion channel,GENE-N,GPCR,GENE-N,AGONIST,0.567602
3,10,ion channel,GENE-N,nuclear receptor,GENE-N,AGONIST,0.567602
4,10,G protein-coupled receptors,GENE-N,GPCR,GENE-N,AGONIST,0.567602


*Next Step: Novel Targets Identification and Polypharmacology Analysis*