In [None]:
import os
import xml.etree.ElementTree as ET

def load_xml_files_from_folder(folder_path):
    """
    Load and parse all XML files in a specified folder.

    Args:
    folder_path (str): The path to the folder containing XML files.

    Returns:
    dict: A dictionary where keys are filenames and values are the parsed XML root elements.
    """
    xml_files = {}
    for filename in os.listdir(folder_path):
        if filename.endswith('.xml'):
            file_path = os.path.join(folder_path, filename)
            try:
                tree = ET.parse(file_path)
                root = tree.getroot()
                xml_files[filename] = root
            except ET.ParseError as e:
                print(str(e))
                xml_files[filename] = str(e)
    return xml_files

def create_standoff_files_with_section_text(xml_files, base_output_folder, split_type, ignore_discontinuous=False):
    """
    Create .txt and .ann files for each section in each XML file, handling multiple and discontinuous start positions,
    and filtering for specific entity types, including normalization annotations.
    """
    for file_name, root in xml_files.items():
        base_filename = os.path.splitext(file_name)[0]
        section_id_to_name = {section.get('id'): section.get('name') for section in root.findall('.//Section')}
        sections_text = {section.get('id'): section.text for section in root.findall('.//Section')}
        mentions = root.findall('.//Mention')

        # Create normalization map
        normalization_map = {}
        for reaction in root.findall('.//Reaction'):
            reaction_str = reaction.get('str').lower().strip()
            normalization = reaction.find('.//Normalization')
            if normalization is not None:
                meddra_pt_id = normalization.get('meddra_pt_id')
                meddra_pt = normalization.get('meddra_pt')
                meddra_llt_id = normalization.get('meddra_llt_id', None)
                meddra_llt = normalization.get('meddra_llt', None)
                normalization_map[reaction_str] = (meddra_pt_id, meddra_pt, meddra_llt_id, meddra_llt)

        if not mentions:
            continue

        sections = {}

        for mention in mentions:
            mention_type = mention.get('type')

            if mention_type not in ['AdverseReaction']:
                continue

            section_id = mention.get('section')
            section_name = section_id_to_name.get(section_id, "UnknownSection")
            start_positions = mention.get('start').split(',')
            lengths = mention.get('len').split(',')
            mention_text = mention.get('str')

            if len(start_positions) != len(lengths):
                print(f"Warning: Mismatch in lengths and starts in mention '{mention.get('id')}'")
                continue

            if len(start_positions) > 1 and ignore_discontinuous:
                continue

            if section_id not in sections:
                sections[section_id] = {'text': sections_text.get(section_id, ""), 'mentions': [], 'name': section_name}

            mention_ranges = []
            for start, length in zip(start_positions, lengths):
                start = int(start)
                length = int(length)
                end = start + length
                mention_ranges.append((start, end))

            formatted_ranges = ';'.join([f"{start} {end}" for start, end in mention_ranges])
            true_text = " ".join([sections_text[section_id][a:b] for a,b in mention_ranges])
            sections[section_id]['mentions'].append(("ADE", formatted_ranges, true_text, mention_text.lower().strip() in normalization_map, mention_text.lower().strip()))

        for section_id, data in sections.items():
            safe_section_name = ''.join(e for e in data['name'] if e.isalnum())
            section_output_folder = os.path.join(base_output_folder, safe_section_name, split_type)
            os.makedirs(section_output_folder, exist_ok=True)

            txt_filename = os.path.join(section_output_folder, f"{base_filename}_{safe_section_name}.txt")
            ann_filename = os.path.join(section_output_folder, f"{base_filename}_{safe_section_name}.ann")

            text_bound_annotations = []
            normalization_annotations = []

            with open(txt_filename, 'w') as txt_file, open(ann_filename, 'w') as ann_file:
                txt_file.write(data['text'])
                mention_counter = 1
                normalization_counter = 1
                data['mentions'].sort(key=lambda x: int(x[1].split(';')[0].split()[0]))
                
                for mention_type, mention_ranges, text, has_normalization, mention_text in data['mentions']:
                    if has_normalization:
                        text_bound_annotation = f"T{mention_counter}\t{mention_type} {mention_ranges}\t{text}\n"
                        text_bound_annotations.append(text_bound_annotation)

                        meddra_pt_id, meddra_pt, meddra_llt_id, meddra_llt = normalization_map[mention_text]
                        if meddra_pt_id and meddra_pt:
                            normalization_annotation = f"N{normalization_counter}\tReference T{mention_counter} meddra_pt:{meddra_pt_id}\t{meddra_pt}\n"
                            normalization_annotations.append(normalization_annotation)
                            normalization_counter += 1
                        if meddra_llt_id and meddra_llt:
                            normalization_annotation = f"N{normalization_counter}\tReference T{mention_counter} meddra_llt:{meddra_llt_id}\t{meddra_llt}\n"
                            normalization_annotations.append(normalization_annotation)
                            normalization_counter += 1
                        mention_counter += 1

                # Write all text-bound annotations first
                for t_annotation in text_bound_annotations:
                    ann_file.write(t_annotation)

                # Followed by all normalization annotations
                for n_annotation in normalization_annotations:
                    ann_file.write(n_annotation)

# Specify the folder path (for this example, we use the same folder as the provided XML files)
folder_path = './data/TAC2017/train_xml'

# Load and parse all XML files in the specified folder
loaded_xml_files = load_xml_files_from_folder(folder_path)

import random

# Set the seed for reproducibility
random.seed(42)

# Convert the dictionary items into a list of (filename, root) tuples
xml_items = list(loaded_xml_files.items())

# Shuffle the list randomly
random.shuffle(xml_items)

# Calculate the split point
split_index = int(0.9 * len(xml_items))

# Split the list into train and validation sets
train_xml_files = dict(xml_items[:split_index])
val_xml_files = dict(xml_items[split_index:])

# Print the number of files in each set
print(f"Number of files in train set: {len(train_xml_files)}")
print(f"Number of files in validation set: {len(val_xml_files)}")

# Specify the folder path (for this example, we use the same folder as the provided XML files)
folder_path = './data/TAC2017/gold_xml'

# Load and parse all XML files in the specified folder
test_xml_files = load_xml_files_from_folder(folder_path)
print(f"Number of files in test set: {len(test_xml_files)}")

base_output_folder = './data/TAC2017/standoff/ADEs_only'

create_standoff_files_with_section_text(train_xml_files, base_output_folder, 'train', ignore_discontinuous=True)
create_standoff_files_with_section_text(val_xml_files, base_output_folder, 'val', ignore_discontinuous=False)
create_standoff_files_with_section_text(test_xml_files, base_output_folder, 'test', ignore_discontinuous=False)

import os
import shutil

def merge_folders(src1, src2, src3, dest):
    """Merges the 'train', 'val', and 'test' subdirectories of three source directories into a single directory at the specified destination."""

    subdirs = ['train', 'val', 'test']

    for src in (src1, src2, src3):
        if not os.path.isdir(src):
            raise ValueError(f"Source directory '{src}' does not exist.")

        for subdir in subdirs:
            target = os.path.join(dest, subdir)

            if not os.path.exists(target):
                os.makedirs(target)

            src_subdir = os.path.join(src, subdir)
            if os.path.isdir(src_subdir):
                for dirpath, _, filenames in os.walk(src_subdir):
                    for filename in filenames:
                        if not filename.startswith("."):
                            src_file = os.path.join(dirpath, filename)
                            dst_file = os.path.join(target, filename)
                            if not os.path.exists(dst_file):
                                shutil.copy2(src_file, dst_file)
                            else:
                                print(f"Warning: file {filename} already exists at destination; skipping...")

merge_folders('./data/TAC2017/standoff/ADEs_only/adversereactions', './data/TAC2017/standoff/ADEs_only/boxedwarnings', './data/TAC2017/standoff/ADEs_only/warningsandprecautions', './data/TAC2017/standoff/ADEs_only/tac_ad_bo_wa')