In [1]:
!pip install nltk



In [2]:
from IPython.display import display
import pandas as pd
import os
import ast
import csv
from nltk.tokenize import sent_tokenize
import nltk
nltk.download("punkt_tab")

[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/jcanodeb/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [3]:
files = ['train','dev','test']

for file in files:
    labels = pd.read_csv(f'Datasets/chemprot/Original/{file}.csv', sep=',')
    subset = labels[['id', 'passages', 'entities', 'relations']]
    display(subset.head())

Unnamed: 0,id,passages,entities,relations
0,0,"[{'id': '1', 'type': 'title and abstract', 'te...","[{'id': '2', 'type': 'CHEMICAL', 'text': ['met...",[]
1,7,"[{'id': '8', 'type': 'title and abstract', 'te...","[{'id': '9', 'type': 'CHEMICAL', 'text': ['gef...","[{'id': '38', 'type': 'Downregulator', 'arg1_i..."
2,55,"[{'id': '56', 'type': 'title and abstract', 't...","[{'id': '57', 'type': 'CHEMICAL', 'text': ['ch...",[]
3,71,"[{'id': '72', 'type': 'title and abstract', 't...","[{'id': '73', 'type': 'CHEMICAL', 'text': ['ir...",[]
4,106,"[{'id': '107', 'type': 'title and abstract', '...","[{'id': '108', 'type': 'CHEMICAL', 'text': ['B...","[{'id': '127', 'type': 'Downregulator', 'arg1_..."


Unnamed: 0,id,passages,entities,relations
0,0,"[{'id': '1', 'type': 'title and abstract', 'te...","[{'id': '2', 'type': 'CHEMICAL', 'text': ['DF'...","[{'id': '58', 'type': 'Regulator', 'arg1_id': ..."
1,68,"[{'id': '69', 'type': 'title and abstract', 't...","[{'id': '70', 'type': 'CHEMICAL', 'text': ['lo...","[{'id': '104', 'type': 'Downregulator', 'arg1_..."
2,109,"[{'id': '110', 'type': 'title and abstract', '...","[{'id': '111', 'type': 'CHEMICAL', 'text': ['n...","[{'id': '146', 'type': 'Regulator', 'arg1_id':..."
3,175,"[{'id': '176', 'type': 'title and abstract', '...","[{'id': '177', 'type': 'CHEMICAL', 'text': ['s...","[{'id': '207', 'type': 'Regulator', 'arg1_id':..."
4,219,"[{'id': '220', 'type': 'title and abstract', '...","[{'id': '221', 'type': 'CHEMICAL', 'text': ['k...","[{'id': '252', 'type': 'Not', 'arg1_id': '246'..."


Unnamed: 0,id,passages,entities,relations
0,0,"[{'id': '1', 'type': 'title and abstract', 'te...","[{'id': '2', 'type': 'CHEMICAL', 'text': ['alp...","[{'id': '65', 'type': 'Regulator', 'arg1_id': ..."
1,86,"[{'id': '87', 'type': 'title and abstract', 't...","[{'id': '88', 'type': 'CHEMICAL', 'text': ['ed...","[{'id': '94', 'type': 'Downregulator', 'arg1_i..."
2,95,"[{'id': '96', 'type': 'title and abstract', 't...","[{'id': '97', 'type': 'GENE-Y', 'text': ['thro...",[]
3,106,"[{'id': '107', 'type': 'title and abstract', '...","[{'id': '108', 'type': 'CHEMICAL', 'text': ['p...","[{'id': '126', 'type': 'Downregulator', 'arg1_..."
4,132,"[{'id': '133', 'type': 'title and abstract', '...","[{'id': '134', 'type': 'CHEMICAL', 'text': ['L...","[{'id': '170', 'type': 'Regulator', 'arg1_id':..."


In [4]:
def mark_entities_in_sentence(sentence, ent1_text, ent2_text):
    if ent1_text == ent2_text:
        first = sentence.find(ent1_text)
        second = sentence.find(ent2_text, first + 1)
        if first != -1 and second != -1:
            return (
                sentence[:first] + "<e1>" + ent1_text + "</e1>" +
                sentence[first + len(ent1_text):second] + "<e2>" + ent2_text + "</e2>" +
                sentence[second + len(ent2_text):]
            )
    else:
        sentence = sentence.replace(ent1_text, f"<e1>{ent1_text}</e1>", 1)
        sentence = sentence.replace(ent2_text, f"<e2>{ent2_text}</e2>", 1)
    return sentence

In [5]:
def extract_relation_samples_from_row(row):
    results = []
    try:
        passages = ast.literal_eval(row["passages"])
        entities = {e["id"]: {"text": " ".join(e["text"]), "type": e["type"]} for e in ast.literal_eval(row["entities"])}
        relations = ast.literal_eval(row["relations"])
    except Exception as e:
        print(f"Error parsing row {row.get('id', '[unknown]')}: {e}")
        return results

    for rel in relations:
        ent1_id = rel["arg1_id"]
        ent2_id = rel["arg2_id"]
        relation = rel["type"]

        if ent1_id in entities and ent2_id in entities:
            ent1_text = entities[ent1_id]["text"]
            ent2_text = entities[ent2_id]["text"]

            for passage in passages:
                text_content = passage.get("text", "")
                if isinstance(text_content, list):
                    passage_text = " ".join(text_content)
                else:
                    passage_text = text_content
                for sentence in sent_tokenize(passage_text):
                    if ent1_text in sentence and ent2_text in sentence:
                        marked = mark_entities_in_sentence(sentence, ent1_text, ent2_text)
                        results.append([marked, relation])
                        break
    return results

In [6]:
def convert_csv_to_relation_format(input_path, output_path):
    df = pd.read_csv(input_path)
    all_rows = []

    for _, row in df.iterrows():
        samples = extract_relation_samples_from_row(row)
        all_rows.extend(samples)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f, quoting=csv.QUOTE_ALL)
        writer.writerow(["text", "relation"])
        writer.writerows(all_rows)

    print(f"File: {output_path} - {len(all_rows)} rows")

In [7]:
if __name__ == "__main__":
    base_input = "Datasets/chemprot/Original"
    base_output = "Datasets/chemprot/Processed"

    files = {
        "train.csv": "chemprot_train.csv",
        "dev.csv": "chemprot_dev.csv",
        "test.csv": "chemprot_test.csv"
    }

    for in_file, out_file in files.items():
        input_path = os.path.join(base_input, in_file)
        output_path = os.path.join(base_output, out_file)
        convert_csv_to_relation_format(input_path, output_path)

File: Datasets/chemprot/Processed/chemprot_train.csv - 6417 rows
File: Datasets/chemprot/Processed/chemprot_dev.csv - 3550 rows
File: Datasets/chemprot/Processed/chemprot_test.csv - 5721 rows
