# Load n2c2 2018 track 2 data and convert to custom JSON format

In [13]:
# Load libraries
from datasets import load_dataset
import json
from collections import Counter
import ast
import tqdm
from collections import OrderedDict
import dspy
import pickle
import os

In [7]:
# Load n2c2 data
# Download the n2c2 2018 track 2 data and save it to a local folder. Specify data_dir variable.

data_dir = "/prj/doctoral_letters/notebooks/MIEQA/i2b22018/n2c2_2018_track2/"
dataset = load_dataset("bigbio/n2c2_2018_track2", name="n2c2_2018_track2_source", data_dir=data_dir)

## Process the data

Input: n2c2 letter

Output: all letters merged into a single dataset. Each letter split by newline "\n". Each line containing drugs and their related information in an OrderedDict format. Merge lines, if related information is in neighbouring lines.

In [8]:
# Helper functions

# Main converter function
def process_sample(sample):
    counter = 0
    # Extract data from the input sample
    relations = sample['relations']
    tags = sample['tags']
    text = sample['text']
    relation_names = dataset['train'].features['relations'][0]['relation'].names
    tag_names = dataset['train'].features['tags'][0]['tag'].names

    # Create mappings from tag IDs to tag data
    tag_id_to_tag = {tag['id']: tag for tag in tags}
    relation_id_to_relation = {relation['id']: relation for relation in relations}

    # Get positions of '\n's in the input text
    newline_positions = [pos for pos, char in enumerate(text) if char == '\n']

    # Build a list of tag ranges
    tag_ranges = [(tag['start'], tag['end']) for tag in tags]

    # Identify newline positions (note: only positions of '\n's not within any tag)
    safe_newline_positions = []
    for newline_pos in newline_positions:
        in_tag = False
        for start, end in tag_ranges:
            if start <= newline_pos < end:
                in_tag = True
                break
        if not in_tag:
            safe_newline_positions.append(newline_pos)

    # Build all paragraph positions
    paragraph_starts = [0] + [pos + 1 for pos in safe_newline_positions]
    paragraph_ends = safe_newline_positions + [len(text)]

    paragraph_positions = list(zip(paragraph_starts, paragraph_ends))

    # Extract all paragraphs
    paragraphs = [' ... ' + text[start:end] + ' ... ' for start, end in paragraph_positions]

    # Map the tags to the paragraphs they belong to
    tag_id_to_paragraphs = {}  # tag_id -> list of paragraph indices
    for tag in tags:
        tag_start = tag['start']
        tag_end = tag['end']
        tag_id = tag['id']
        paragraphs_containing_tag = []
        for idx, (para_start, para_end) in enumerate(paragraph_positions):
            if tag_end > para_start and tag_start < para_end:
                paragraphs_containing_tag.append(idx)
        tag_id_to_paragraphs[tag_id] = paragraphs_containing_tag

    # Map the paragraphs to the tags they contain
    paragraph_idx_to_tag_ids = {idx: [] for idx in range(len(paragraphs))}
    for tag_id, paragraphs_indices in tag_id_to_paragraphs.items():
        for idx in paragraphs_indices:
            paragraph_idx_to_tag_ids[idx].append(tag_id)

    # Initialize all relevant tracking sets
    processed_tags = set()
    processed_relations = set()
    processed_paragraphs = set()
    output = []

    # Process each paragraph
    for para_idx, paragraph in enumerate(paragraphs):
        if paragraph == " ...  ... ":
            continue
        if para_idx in processed_paragraphs:
            continue

        # Find 'Drug' tags in the paragraph. This is the root tag for all relations.
        tag_ids_in_paragraph = paragraph_idx_to_tag_ids[para_idx]
        drug_tags_in_paragraph = [
            tag_id for tag_id in tag_ids_in_paragraph if tag_names[tag_id_to_tag[tag_id]['tag']] == 'Drug'
        ]

        # Initialize data structures for this paragraph
        paragraph_dict = {'text': paragraph}
        drugs_list = []

        # Ignore empty paragraphs
        if not paragraph.strip():  
            continue

        if not drug_tags_in_paragraph:
            # If no drugs are found, add a default "no drugs" entry
            drugs_dict = OrderedDict()
            drugs_dict['no drugs'] = {
                'ade': "",
                'dosage': "",
                'duration': "",
                'form': "",
                'frequency': "",
                'reason': "",
                'route': "",
                'strength': ""
            }
            paragraph_dict['drugs'] = drugs_dict
            output.append(paragraph_dict)

            # Skip further processing for this paragraph
            continue

        # Initialize relevant tracking sets per paragraph
        included_tags = set()
        included_relations = set()
        paragraphs_included = set([para_idx])
        tags_to_process = set(drug_tags_in_paragraph)

        # Now, process tags and relations recursively
        while tags_to_process:
            current_tag_id = tags_to_process.pop()
            if current_tag_id in included_tags:
                continue  # Use included_tags instead of processed_tags

            current_tag = tag_id_to_tag[current_tag_id]
            current_tag_text = current_tag['text'].strip()
            current_tag_type = tag_names[current_tag['tag']]

            if current_tag_type == 'Drug':
                if current_tag_id not in [drug[1] for drug in drugs_list]:
                    drugs_list.append((current_tag['start'], current_tag_id, {
                        'text': current_tag_text,
                        'relations': {rel_name.replace('-Drug', '').lower(): [] for rel_name in relation_names}
                    }))

            included_tags.add(current_tag_id)

            # Process all relations where the current tag is arg2
            for relation in relations:
                if relation['arg2_id'] == current_tag_id and relation['id'] not in included_relations:
                    relation_id = relation['id']
                    rel_type_idx = relation['relation']
                    rel_type_name = relation_names[rel_type_idx]
                    rel_type_key = rel_type_name.replace('-Drug', '').lower()
                    arg1_id = relation['arg1_id']
                    arg1_tag = tag_id_to_tag[arg1_id]
                    arg1_text = arg1_tag['text'].strip()

                    # Include the paragraph containing arg1 if not already included
                    arg1_paragraph_indices = tag_id_to_paragraphs[arg1_id]
                    for idx in arg1_paragraph_indices:
                        if idx not in paragraphs_included:
                            if idx < min(paragraphs_included):
                                paragraph_dict['text'] = paragraphs[idx] + '\n' + paragraph_dict['text']
                            else:
                                paragraph_dict['text'] += '\n' + paragraphs[idx]
                            paragraphs_included.add(idx)
                            processed_paragraphs.add(idx)
                            tags_to_process.update(paragraph_idx_to_tag_ids[idx])

                    # Update the drugs information
                    for drug in drugs_list:
                        if drug[1] == current_tag_id:
                            drug[2]['relations'][rel_type_key].append((relation_id, arg1_id, arg1_text))

                    included_relations.add(relation_id)
                    if tag_names[arg1_tag['tag']] == 'Drug':
                        tags_to_process.add(arg1_id)

            # Process relations where current tag is arg1
            for relation in relations:
                if relation['arg1_id'] == current_tag_id and relation['id'] not in included_relations:
                    relation_id = relation['id']
                    rel_type_idx = relation['relation']
                    rel_type_name = relation_names[rel_type_idx]
                    rel_type_key = rel_type_name.replace('-Drug', '').lower()
                    arg2_id = relation['arg2_id']
                    arg2_tag = tag_id_to_tag[arg2_id]
                    arg2_text = arg2_tag['text'].strip()

                    # Include all paragraph(s) containing arg2 if not already included
                    arg2_paragraph_indices = tag_id_to_paragraphs[arg2_id]
                    for idx in arg2_paragraph_indices:
                        if idx not in paragraphs_included:
                            if idx < min(paragraphs_included):
                                paragraph_dict['text'] = paragraphs[idx] + '\n' + paragraph_dict['text']
                            else:
                                paragraph_dict['text'] += '\n' + paragraphs[idx]
                            paragraphs_included.add(idx)
                            processed_paragraphs.add(idx)
                            tags_to_process.update(paragraph_idx_to_tag_ids[idx])

                    # Update the drugs information
                    if arg2_id not in [drug[1] for drug in drugs_list]:
                        drugs_list.append((arg2_tag['start'], arg2_id, {
                            'text': arg2_text,
                            'relations': {rel_name.replace('-Drug', '').lower(): [] for rel_name in relation_names}
                        }))
                    for drug in drugs_list:
                        if drug[1] == arg2_id:
                            drug[2]['relations'][rel_type_key].append((relation_id, current_tag_id, current_tag_text))

                    included_relations.add(relation_id)
                    if tag_names[arg2_tag['tag']] == 'Drug':
                        tags_to_process.add(arg2_id)

        # After processing all tags and relations, update global sets
        processed_tags.update(included_tags)
        processed_relations.update(included_relations)

        # Sort the drugs by their start positions
        drugs_list.sort(key=lambda x: x[0])

        # Handle duplicate drug names by appending a counter: (n)
        drug_name_counts = Counter(drug[2]['text'] for drug in drugs_list)
        drug_seen_counts = Counter()
        for drug in drugs_list:
            drug_name = drug[2]['text']
            if drug_name_counts[drug_name] > 1:
                drug_seen_counts[drug_name] += 1
                drug[2]['text'] = f"{drug_name} ({drug_seen_counts[drug_name]})"

        drugs_dict = OrderedDict((drug_id, drug_data) for _, drug_id, drug_data in drugs_list)
        paragraph_dict['drugs'] = drugs_dict
        output.append(paragraph_dict)
        processed_paragraphs.update(paragraphs_included)

    # Return the final list of all drugs including their relation information
    return output

# PRepare final format
def simplify_output_structure(output):
    simplified_output = {'text': output['text']}
    drugs = output.get('drugs', {})
    
    drug_name_counter = {}
    
    if 'no drugs' in drugs:
        simplified_output['medications'] = []
    else:
        simplified_output['medications'] = []
        for drug_id, drug_info in drugs.items():
            drug_name = drug_info['text'] # Extract the drug name
            
            simplified_relations = {}
            simplified_relations['medication'] = drug_name
            for relation_type, relation_values in drug_info['relations'].items():
                if relation_values:
                    extracted_values = [value[2] for value in relation_values]
                    if len(extracted_values) == 1:
                        simplified_relations[relation_type] = extracted_values[0]
                    else:
                        simplified_relations[relation_type] = extracted_values
                else:
                    simplified_relations[relation_type] = ""
            
            simplified_output['medications'].append(simplified_relations)
    
    return simplified_output


In [42]:
# Process training set

if not os.path.exists("./data"):
    os.makedirs("./data", exist_ok=True)

output = []
for sample_idx, sample in enumerate(tqdm.tqdm(dataset['train'])):
    output.extend(process_sample(sample))

output_simple = []
for sample in output[:]:
    output_simple.append(simplify_output_structure(sample))

# Convert to dspy format
dataset_train = []
for sample in output_simple[:]:
    par = sample['text']
    del sample['text']
    answer = str(sample)
    dataset_train.append(dspy.Example(paragraph=par, answer=answer).with_inputs("paragraph"))

# Save data to folder
# Saving the dataset to a file
with open('./data/n2c2_train.pkl', 'wb') as f:
    pickle.dump(dataset_train, f)


100%|██████████| 303/303 [00:04<00:00, 73.31it/s]


In [43]:
# Process test set
output = []
for sample_idx, sample in enumerate(tqdm.tqdm(dataset['test'])):
    output.extend(process_sample(sample))

output_simple = []
for sample in output[:]:
    output_simple.append(simplify_output_structure(sample))

# Convert to dspy format
dataset_test = []
for sample in output_simple[:]:
    par = sample['text']
    del sample['text']
    answer = str(sample)
    dataset_test.append(dspy.Example(paragraph=par, answer=answer).with_inputs("paragraph"))

# Save data to folder
# Saving the dataset to a file
with open('./data/n2c2_test.pkl', 'wb') as f:
    pickle.dump(dataset_test, f)


100%|██████████| 202/202 [00:02<00:00, 87.81it/s]
