In [None]:
import os
import sys
import csv
from typing import Dict, List, Tuple
import regex
from tqdm.auto import tqdm
import random


def build_en_dict_from_MedDRA(path2lltasc: str, path2ptasc: str) -> None:
    """
    Builds a dictionary from MedDRA llt and pt files.
    """

    if not os.path.exists(path2ptasc):
        print("Error: Folder Not Found ", path2ptasc)
        return

    pt_dict = {}
    pt_to_hlt = {}

    with open(path2ptasc, "r", encoding="utf-8") as file:
        for line in file:
            fs = line.strip().split("$")
            pt = fs[0]
            text = fs[1]
            hlt = fs[2]

            if pt not in pt_dict:
                pt_dict[pt] = text
            else:
                print("0")
                
            pt_to_hlt[pt] = hlt

    if not os.path.exists(path2lltasc):
        print("Error: Folder Not Found ", path2lltasc)
        return

    llt_dict = {}
    llt_to_pt = {}

    with open(path2lltasc, "r", encoding="utf-8") as file:
        for line in file:
            fs = line.strip().split("$")
            llt = fs[0]
            text = fs[1]
            pt = fs[2]

            if llt not in llt_dict:
                llt_dict[llt] = text
            else:
                print("1")
                
            llt_to_pt[llt] = pt
            
    return llt_dict, llt_to_pt, pt_dict

def read_tweet_tsv_file(path: str) -> Dict[str, List[str]]:
    """
    Read a TSV file and return its data in a dictionary format.
    """
    data_dict = {}
    with open(path, 'r') as tsv_file:
        tsv_reader = csv.reader(tsv_file, delimiter='\t')
        for row in tsv_reader:
            key = row[0]
            value = row[1]
            data_dict[key] = value
    return data_dict

def read_tweet_tsv_file(path: str) -> Dict[str, str]:
    """
    Read a TSV (Tab-Separated Values) file and return its data in a dictionary format.
    Manually processes the file as plain text.
    """
    data_dict = {}
    with open(path, 'r', encoding='utf-8') as file:  # Ensure proper encoding is used
        text = file.read().strip()  # Read the entire file as a single string, remove trailing newline
        rows = text.split('\n')  # Split the text into rows on newline characters
        for row in rows:
            parts = row.split('\t')  # Split each row into parts on tab characters
            if len(parts) >= 2:  # Ensure there are at least two parts (key and value)
                key, value = parts[0], parts[1]
                data_dict[key] = value  # Add to dictionary
    return data_dict

def read_span_tsv_file(path: str) -> Dict[str, List[Tuple[int, int, str, str]]]:
    """
    Read a TSV file and return its data in a dictionary format.
    """
    data_dict = {}
    with open(path, 'r') as tsv_file:
        tsv_reader = csv.reader(tsv_file, delimiter='\t')
        for row in tsv_reader:
            key = row[0]
            span_data = (int(row[2]), int(row[3]), row[4], row[5])  # Start, end, text, meddra_llt
            if key not in data_dict:
                data_dict[key] = []
            data_dict[key].append(span_data)
    return data_dict

def merge_tweets_spans(tweets: Dict[str, str], spans: Dict[str, List[Tuple[int, int, str, str]]]) -> Dict[str, Dict[str, object]]:
    """
    Merge tweet text and spans into a single dictionary.
    """
    merged_dict = {}
    for tweet_id, tweet_text in tweets.items():
        merged_dict[tweet_id] = {
            'text': tweet_text,
            'spans': spans.get(tweet_id, [])
        }
    return merged_dict

def is_within_word(text, start, end):
    """
    Check if the span starts or ends within a word.
    - `start` is the start index of the span.
    - `end` is the end index of the span.
    Returns True if the span starts or ends within a word, False otherwise.
    """
    # Check if the character before the start index is alphanumeric (part of a word)
    if start > 0 and text[start-1].isalnum():
        return True
    # Check if the character after the end index is alphanumeric (part of a word)
    if end < len(text) and text[end].isalnum():
        return True
    return False

def adjust_span_to_word_boundary(text, start, end):
    """
    Adjusts the span to exclude leading and trailing special characters or spaces,
    then aligns it with word boundaries, ensuring indices remain within bounds.
    """
    # Ensure start and end are within bounds
    start = max(0, min(start, len(text)))
    end = max(0, min(end, len(text)))

    # Strip away leading special characters or spaces
    while start < end and not text[start].isalnum():
        start += 1

    # Ensure start is not out of bounds after adjustment
    start = min(start, len(text))

    # Strip away trailing special characters or spaces
    while end > start and not text[end - 1].isalnum():
        end -= 1

    # Adjust start backwards to the beginning of the word if not already
    while start > 0 and text[start - 1].isalnum():
        start -= 1

    # Adjust end forwards to the end of the word if not already
    while end < len(text) and text[end].isalnum():
        end += 1

    # Ensure end is not out of bounds after adjustment
    end = min(end, len(text))

    return start, end

def generate_standoff_data(data_dict: Dict[str, Dict[str, object]], llt_id_to_txt: Dict[str, str], llt_to_pt: Dict[str, str], pt_dict: Dict[str, str], output_dir: str) -> None:
    """
    Generate standoff data files from a dictionary containing tweet text and spans.
    """
    os.makedirs(output_dir, exist_ok=True)
    total_t_entities = 0  # Initialize counter for total 'T' entities
    for tweet_id, tweet_data in tqdm(data_dict.items()):
        text_file_path = os.path.join(output_dir, tweet_id + '.txt')
        ann_file_path = os.path.join(output_dir, tweet_id + '.ann')

        with open(text_file_path, 'w', encoding='utf-8') as text_file:
            cleaned_content = regex.sub('[^\p{L}\p{N}\p{P}]', ' ', tweet_data['text'])
            cleaned_content = regex.sub('\p{Z}', ' ', cleaned_content)
            text_file.write(cleaned_content)

        t_annotations = []
        n_annotations = []
        n_counter = 1

        for i, span in enumerate(tweet_data['spans']):
            # Extract span indices
            START, END = span[0], span[1]
            if is_within_word(cleaned_content, START, END):
                START, END = adjust_span_to_word_boundary(cleaned_content, START, END)
            t_line = f"T{i+1}\tADE {START} {END}\t{cleaned_content[START:END]}"
            t_annotations.append(t_line)
            total_t_entities += 1  # Increment total 'T' entities counter

            if span[3].isdigit() and span[3] in llt_id_to_txt:
                meddra_text = llt_id_to_txt[span[3]]
                n_line_llt = f"N{n_counter}\tReference T{i+1} meddra_llt_id:{span[3]}\t{meddra_text}"
                n_annotations.append(n_line_llt)
                n_counter += 1

                if span[3] in llt_to_pt and llt_to_pt[span[3]] in pt_dict:
                    pt_id = llt_to_pt[span[3]]
                    pt_text = pt_dict[pt_id]
                    n_line_pt = f"N{n_counter}\tReference T{i+1} meddra_pt_id:{pt_id}\t{pt_text}"
                    n_annotations.append(n_line_pt)
                    n_counter += 1

        with open(ann_file_path, 'w', encoding='utf-8') as ann_file:
            for t_annotation in t_annotations:
                ann_file.write(t_annotation + '\n')
            for n_annotation in n_annotations:
                ann_file.write(n_annotation + '\n')

def generate_test_data(tweets: Dict[str, str], output_dir: str) -> None:
    """
    Generate text data files for the test set.
    """
    os.makedirs(output_dir, exist_ok=True)
    for tweet_id, tweet_text in tqdm(tweets.items()):
        text_file_path = os.path.join(output_dir, tweet_id + '.txt')

        with open(text_file_path, 'w', encoding='utf-8') as text_file:
            cleaned_content = regex.sub('[^\p{L}\p{N}\p{P}]', ' ', tweet_text)
            cleaned_content = regex.sub('\p{Z}', ' ', cleaned_content)
            text_file.write(cleaned_content)

src = "./data/smm4h23/Task5_train_validation"
train_tweets = read_tweet_tsv_file(os.path.join(src, "Train/train_tweets.tsv"))

def main() -> None:
    """
    Creating standoff data with new splitting strategy.
    """
    # Set a seed for reproducibility
    seed = 42
    random.seed(seed)

    # Load and prepare data
    llt_dict, llt_to_pt, pt_dict = build_en_dict_from_MedDRA(
        "./data/smm4h23/Task5_train_validation/MedDRA/llt.asc",
        "../../ontology_mapper/meddra_data/SMM4H23_meddra_24.0_english/MedAscii/pt.asc"
    )

    src = "./data/smm4h23/Task5_train_validation"
    dst = "../data/english/smm4h23"

    # Read train and development datasets
    train_tweets = read_tweet_tsv_file(os.path.join(src, "Train/train_tweets.tsv"))
    train_spans = read_span_tsv_file(os.path.join(src, "Train/train_spans_norm.tsv"))
    dev_tweets = read_tweet_tsv_file(os.path.join(src, "Dev/tweets.tsv"))
    dev_spans = read_span_tsv_file(os.path.join(src, "Dev/spans_norm.tsv"))

    # Merge train dataset and shuffle
    merged_train_data = merge_tweets_spans(train_tweets, train_spans)
    combined_train_data = list(merged_train_data.items())
    random.shuffle(combined_train_data)

    # Split into new train and val datasets
    split_index = int(0.9 * len(combined_train_data))
    new_train_data = dict(combined_train_data[:split_index])
    new_val_data = dict(combined_train_data[split_index:])

    # Generate standoff data for the new train and val datasets
    generate_standoff_data(new_train_data, llt_dict, llt_to_pt, pt_dict, os.path.join(dst, "train"))
    generate_standoff_data(new_val_data, llt_dict, llt_to_pt, pt_dict, os.path.join(dst, "val"))

    # Process Dev dataset as test
    merged_dev_data = merge_tweets_spans(dev_tweets, dev_spans)
    generate_standoff_data(merged_dev_data, llt_dict, llt_to_pt, pt_dict, os.path.join(dst, "test"))

    # Process test data for inference
    test_src = "./data/smm4h23/Task5_test"
    test_tweets = read_tweet_tsv_file(os.path.join(test_src, "tweets.tsv"))
    generate_test_data(test_tweets, os.path.join(dst, "infer"))

if __name__ == "__main__":
    main()