In [None]:
import os
import re
import json
import pandas as pd

from tqdm import tqdm
from glob import iglob
from ast import literal_eval

from os.path import join as pjoin

In [2]:
ROOT_DIR = "/home/kimsubin/mm_counselor/mirror"
ANNOT_DIR = pjoin(ROOT_DIR, "annot_data")

POSTPROC_RESULT_DIR = pjoin(ANNOT_DIR, "utterfeat/postprocess", "results")


In [3]:
def split_list_into_chunks(lst, chunk_size=2):
    return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]

def write_jsonl(save_path, json_obj):
    with open(save_path, 'w', encoding='utf-8') as f:
        for entry in json_obj:
            json.dump(entry, f, ensure_ascii=False)
            f.write('\n')
    return

def write_line(path, entry):
    with open(path, 'a+') as f:
        json.dump(entry, f, ensure_ascii=False)
        f.write('\n')

In [9]:
mirror = pd.read_csv(pjoin(ANNOT_DIR, "mirror.csv"), converters={
    'proc_dialogue': literal_eval
})

---

### **Basic Filtering**

In [68]:
import re

def filter_turn_count(conversation: str) -> bool:
    """
    Remove conversations with less than 4 or more than 20 turns.
    """
    turns = conversation.splitlines()
    turns = list(filter(lambda x: len(x.strip()) > 0, turns))
    return 4 <= len(turns) / 2 <= 20

def filter_speaker_count(conversation: str) -> bool:
    """
    Remove conversations with more than two distinct speakers.
    """
    speakers = set()
    turns = conversation.splitlines()
    turns = list(filter(lambda x: len(x.strip()) > 0, turns))
    for line in turns:
        match = re.match(r"^(\\w+):", line)
        if match:
            speakers.add(match.group(1))
    return len(speakers) <= 2


def filter_repetition(conversation: str) -> bool:
    """
    Detect excessive repetition within individual utterances.
    """
    turns = conversation.splitlines()
    turns = list(filter(lambda x: len(x.strip()) > 0, turns))
    for line in turns:
        speaker, _, utterance = line.partition(":")
        if utterance:
            # Normalize text and split by spaces for simple word repetition detection
            words = utterance.lower().split()
            unique_words = set(words)
            if len(unique_words) / len(words) < 0.5:  # More than 50% repetition
                return False
    return True

def filter_response_failure(conversation: str) -> bool:
    """
    Detect response failures where the context is mismatched or missing.
    Simple lexical analysis can identify potential issues.
    """
    turns = conversation.splitlines()
    turns = list(filter(lambda x: len(x.strip()) > 0, turns))
    if len(turns) < 2:
        return True  # Not enough context to analyze
    
    for i in range(1, len(turns)):
        _, _, prev_utterance = turns[i - 1].partition(":")
        speaker, _, curr_utterance = turns[i].partition(":")
        if prev_utterance.strip() and curr_utterance.strip():
            # Simple heuristic: check if response is "I don't know", "uhh", or empty-like
            if speaker.lower().strip().startswith('therapist') \
                and re.match(r"^(i don\\'t know|uh+|hmm+|)$", curr_utterance.strip(), re.IGNORECASE):
                return False
    return True


In [69]:
drop = []
for i, row in tqdm(mirror.iterrows(), total=len(mirror), desc="Basic Filtering"):
    conversation = row['dialogue'].strip()
    turn_result = filter_turn_count(conversation)
    repetition_result = filter_repetition(conversation)
    erroneous_result = filter_response_failure(conversation)
    speaker_result = filter_speaker_count(conversation)
    if not turn_result:
        drop += [{'idx': row['idx'], 'reason' : 'Turn Count'}]
    if not repetition_result:
        drop += [{'idx': row['idx'], 'reason' : 'Repetition Result'}]
    if not erroneous_result:
        drop += [{'idx': row['idx'], 'reason' : 'Response Failure'}]
    if not speaker_result:
        drop += [{'idx': row['idx'], 'reason' : 'Speaker Count'}]


Basic Filtering:   0%|          | 0/41223 [00:00<?, ?it/s]

Basic Filtering: 100%|██████████| 41223/41223 [00:05<00:00, 7389.96it/s]


In [70]:
len(set(map(lambda x: x['idx'], drop)))

0

In [71]:
len(drop)

0

---

### **Consequenc POS Filtering**

In [75]:
import nltk
from nltk import word_tokenize, pos_tag

nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('averaged_perceptron_tagger_eng')

[nltk_data] Downloading package punkt to /home/kimsubin/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/kimsubin/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /home/kimsubin/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!


True

In [None]:
def filter_pos_repetition(conversation: str, max_repeats=3) -> bool:
    turns = conversation.splitlines()
    turns = list(filter(lambda x: len(x.strip()) > 0, turns))
    for line in turns:
        speaker, _, utterance = line.partition(":")
        if not utterance.strip():
            continue  # Skip empty lines or lines without an utterance

        # Tokenize and POS-tag the utterance
        tokens = word_tokenize(utterance.strip())
        pos_tags = pos_tag(tokens)

        # Check for consecutive repetitions of specified POS tags
        consecutive_count = 0
        previous_pos = None
        for token, pos in pos_tags:
            if token in ['“', '”', '’', '‘', ']', '[']: 
                consecutive_count = 1
                previous_pos = None
                continue
            if pos == previous_pos:
                consecutive_count += 1
                if consecutive_count > max_repeats:
                    return False  # Excessive POS repetition detected
            else:
                consecutive_count = 1  # Reset count for new POS
            previous_pos = pos
    return True  # No excessive POS repetition detected


In [None]:
drop = []
for i, row in tqdm(mirror.iterrows(), total=len(mirror), desc="Basic Filtering"):
    conversation = row['dialogue'].strip()
    pos_repetition_result = filter_pos_repetition(conversation=conversation, max_repeats=3)
    if not pos_repetition_result:
        drop += [row['idx']]

In [None]:
write_jsonl(json_obj=drop, save_path=pjoin(POSTPROC_RESULT_DIR, "pos", "drop.jsonl"))

In [78]:
drop_results = [json.loads(q) for q in open(pjoin(POSTPROC_RESULT_DIR, "pos", 'drop.jsonl'), 'r')]
len(drop_results)

424

In [80]:
for idx in drop_results:
    write_line(path=pjoin(POSTPROC_RESULT_DIR, "drop_idx.jsonl"), entry={'type': 'repetitive pos',
                                                                     'idx': idx})

---

### **Safety Filtering**

**Post processing**

In [81]:
canary_result = [json.loads(q) for q in open(pjoin(POSTPROC_RESULT_DIR, "canary", "results.jsonl"), 'r')]

In [82]:
def parse_safety(safety_str):
    safety, reason = safety_str, None
    if ' ' in safety:
        safety, reason = safety.split(" ", 1)

    return safety, reason

In [None]:
entire_labels = {"__casual__", "__possibly_needs_caution__", "__probably_needs_caution__", "__needs_caution__", "__needs_intervention__"}

In [None]:
therapist_safety_dic = {}
drop_candidates_t = []

for row in tqdm(canary_result, total=len(canary_result), desc="Postprocessing Canary Results"):
    idx = row['idx']

    for utt in row['therapist_safety']:
        safety, reason = parse_safety(utt['safety'])
        if not safety in ["__needs_caution__", "__needs_intervention__"]: continue
        drop_candidates_t += [idx]

        if safety in therapist_safety_dic:
            therapist_safety_dic[safety] += [utt]
        else:
            therapist_safety_dic[safety] = [utt]

Postprocessing Canary Results:   0%|          | 0/41223 [00:00<?, ?it/s]

Postprocessing Canary Results: 100%|██████████| 41223/41223 [00:00<00:00, 294783.62it/s]


In [85]:
len(set(drop_candidates_t)), len(canary_result)

(452, 41223)

In [None]:
for idx in set(drop_candidates_t):
    write_line(path=pjoin(POSTPROC_RESULT_DIR, "drop_idx.jsonl"), 
               entry={'type': 'need intervention', 'idx': idx})

---

### **Copy-Paste Filtering**

In [27]:
import re

def split_into_sentences(text):
    sentences = re.split(r'(?<=[.!?])\s+', text)
    sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
    return sentences


In [None]:
prompt_content = ['personal_info', 'personality', 'reason_for_seeking_counseling', 'distorted_thought', 'cbt_plan']

In [None]:
drop = []
for i, row in tqdm(mirror.iterrows(), total=len(mirror)):
    for t, utt in enumerate(row['proc_dialogue']):
        for k in prompt_content:
            check_sents = split_into_sentences(row[k])
            for s in check_sents:
                if s in utt['statement']: 
                    if s.startswith(tuple([f"{n}."for n in range(11)])): continue
                    drop += [row['idx']]

100%|██████████| 41223/41223 [00:52<00:00, 781.64it/s] 


In [77]:
write_jsonl(json_obj=list(set(drop)), save_path=pjoin(POSTPROC_RESULT_DIR, "copy-paste", "drop.jsonl"))

In [None]:
for idx in set(drop):
    write_line(path=pjoin(POSTPROC_RESULT_DIR, "drop_idx.jsonl"), 
               entry={'type': 'copy-paste', 'idx': idx})