In [None]:
import os
import random
from tqdm.auto import tqdm

def build_en_dict_from_MedDRA(path2lltasc: str, path2ptasc: str) -> None:
    """
    Builds a dictionary from MedDRA llt and pt files.
    """

    if not os.path.exists(path2ptasc):
        print("Error: Folder Not Found ", path2ptasc)
        return

    pt_dict = {}
    pt_to_hlt = {}

    with open(path2ptasc, "r", encoding="utf-8") as file:
        for line in file:
            fs = line.strip().split("$")
            pt = fs[0]
            text = fs[1]
            hlt = fs[2]

            if pt not in pt_dict:
                pt_dict[pt] = text
            else:
                print("0")
                
            pt_to_hlt[pt] = hlt

    if not os.path.exists(path2lltasc):
        print("Error: Folder Not Found ", path2lltasc)
        return

    llt_dict = {}
    llt_to_pt = {}

    with open(path2lltasc, "r", encoding="utf-8") as file:
        for line in file:
            fs = line.strip().split("$")
            llt = fs[0]
            text = fs[1]
            pt = fs[2]

            if llt not in llt_dict:
                llt_dict[llt] = text
            else:
                print("1")
                
            llt_to_pt[llt] = pt
            
    return llt_dict, llt_to_pt, pt_dict

def is_whitespace_string(s):
    for char in s:
        if not char.isspace():
            return False
    return True

def tuple_list_to_string_pairs(tuple_list):
    string_pairs = [f"{x} {y}" for x, y in tuple_list]
    return string_pairs

def spans_overlap(span1, span2):
    return max(span1[0], span2[0]) < min(span1[1], span2[1])

def process_files(files,
                  path_to_ann,
                  path_to_txt,
                  path_to_meddra,
                  output_folder,
                  split_type,
                  consider_discontinuous,
                  llt_dict,
                  llt_to_pt,
                  pt_dict):
    
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    continuous_counter = 0
    discontinuous_counter = 0
    overlap_counter = 0

    for file_name in tqdm(files):
        
        ADE_counter = 0
        NORM_counter = 0
        continuous_ADE_spans = []
        
        ann_file_path = os.path.join(path_to_ann, file_name + '.ann')
        txt_file_path = os.path.join(path_to_txt, file_name + '.txt')
        meddra_file_path = os.path.join(path_to_meddra, file_name + '.ann')

        with open(txt_file_path, 'r') as file:
            txt_content = file.read()

        normalization_data = {}
        with open(meddra_file_path, 'r') as meddra_file:
            for line in meddra_file:
                if line.startswith('TT'):
                    try:
                        norm_id, norm_infos, _ = line.strip().split('\t')
                        norm_infos = norm_infos.replace(" + ", "/").split()[0].split("/")
                        processed_norms = []
                        for norm in norm_infos:
                            if norm in llt_dict:
                                processed_norms.append(norm)
                        adr_id = norm_id.replace('TT', 'T')
                        normalization_data[adr_id] = processed_norms
                    except:
                        pass

        T_lines = []
        N_lines = []
        
        with open(ann_file_path, 'r') as ann_file:
            for line in ann_file:
                if line.startswith('T'):
                    line_split = line.strip().split('\t')
                    if len(line_split) >= 3:
                        ann_number, ann_info, text = line_split[:3]
                        ann_split = ann_info.split(' ')
                        ann_type = ann_split[0]
                        ann_spans = " ".join(ann_split[1:]).split(";")
                        spans = [(int(part.split()[0]), int(part.split()[1])) for part in ann_spans]
                        sorted_spans = sorted(spans, key=lambda x: x[0])

                        merged_spans = []
                        start_span = sorted_spans[0][0]
                        end_span = sorted_spans[0][1]

                        for i in range(1, len(sorted_spans)):
                            if is_whitespace_string(txt_content[end_span:sorted_spans[i][0]]):
                                end_span = sorted_spans[i][1]
                            else:
                                merged_spans.append((start_span, end_span))
                                start_span, end_span = sorted_spans[i]

                        merged_spans.append((start_span, end_span))

                        if ann_type == 'ADR':
                            if len(merged_spans) > 1:
                                if not consider_discontinuous[split_type]:
                                    discontinuous_counter += 1
                                    continue
                            continuous_ADE_spans.extend(merged_spans)
                            continuous_counter += 1 if len(merged_spans) == 1 else 0
                            
                            ADE_counter += 1
                            adr_line = f'T{ADE_counter}\tADE {";".join(tuple_list_to_string_pairs(merged_spans))}\t{" ".join([txt_content[s:e] for s, e in merged_spans])}\n'
                            T_lines.append(adr_line)

                            if ann_number in normalization_data:
                                norm_infos = normalization_data[ann_number]
                                for norm_info in norm_infos:
                                    pt_memory = []
                                    NORM_counter += 1
                                    norm_line = f'N{NORM_counter}\tReference T{ADE_counter} meddra_llt_id:{norm_info}\t{llt_dict[norm_info]}\n'
                                    N_lines.append(norm_line)
                                    if norm_info in llt_to_pt:
                                        norm_info_pt = llt_to_pt[norm_info]
                                        if norm_info_pt not in pt_memory:
                                            NORM_counter += 1
                                            norm_line = f'N{NORM_counter}\tReference T{ADE_counter} meddra_pt_id:{norm_info_pt}\t{pt_dict[norm_info_pt]}\n'
                                            N_lines.append(norm_line)
                                            pt_memory.append(norm_info_pt)

        adr_lines = T_lines + N_lines
        
        # Check for overlaps among continuous ADE spans
        if any(spans_overlap(continuous_ADE_spans[i], continuous_ADE_spans[j]) for i in range(len(continuous_ADE_spans)) for j in range(i + 1, len(continuous_ADE_spans))):
            overlap_counter += 1
            print(f"Overlap found in {file_name}.ann")

        output_ann_path = os.path.join(output_folder, file_name + '.ann')
        with open(output_ann_path, 'w') as file:
            file.writelines(adr_lines)

        output_txt_path = os.path.join(output_folder, file_name + '.txt')
        with open(output_txt_path, 'w') as file:
            file.write(txt_content)

    print(f"{split_type} - Discontinuous Counter: {discontinuous_counter} (Ignored if applicable)")
    print(f"{split_type} - Continuous Counter: {continuous_counter}")
    print(f"{split_type} - Overlap Counter: {overlap_counter}")

def split_data(path_to_ann,
               path_to_txt,
               path_to_meddra,
               output_base,
               train_split,
               val_split,
               test_split,
               seed,
               consider_discontinuous,
               llt_dict,
               llt_to_pt,
               pt_dict):

    random.seed(seed)

    ann_files = {os.path.splitext(file)[0] for file in os.listdir(path_to_ann) if file.endswith('.ann')}
    txt_files = {os.path.splitext(file)[0] for file in os.listdir(path_to_txt) if file.endswith('.txt')}

    common_files = list(ann_files.intersection(txt_files))
    random.shuffle(common_files)

    total = len(common_files)
    train_end = int(total * train_split)
    val_end = train_end + int(total * val_split)

    train_files = common_files[:train_end]
    val_files = common_files[train_end:val_end]
    test_files = common_files[val_end:]

    process_files(train_files, path_to_ann, path_to_txt, path_to_meddra, os.path.join(output_base, 'train'), 'train', consider_discontinuous, llt_dict, llt_to_pt, pt_dict)
    process_files(val_files, path_to_ann, path_to_txt, path_to_meddra, os.path.join(output_base, 'val'), 'val', consider_discontinuous, llt_dict, llt_to_pt, pt_dict)
    process_files(test_files, path_to_ann, path_to_txt, path_to_meddra, os.path.join(output_base, 'test'), 'test', consider_discontinuous, llt_dict, llt_to_pt, pt_dict)


llt_dict, llt_to_pt, pt_dict = build_en_dict_from_MedDRA(
    "../../ontology_mapper/meddra_data/CADEC_meddra_16_0_english/MedAscii/llt.asc",
    "../../ontology_mapper/meddra_data/CADEC_meddra_16_0_english/MedAscii/pt.asc"
)

split_data('./data/cadec/original',
           './data/cadec/text',
           './data/cadec/meddra',
           './data/cadec/standoff',
           train_split=0.7,
           val_split=0.15,
           test_split=0.15,
           seed=42,
           consider_discontinuous={'train': False, 'val': False, 'test': False},
           llt_dict=llt_dict,
           llt_to_pt=llt_to_pt,
           pt_dict=pt_dict)