In [None]:
!pip install sentence-transformers

In [None]:
import json
from datetime import datetime
import textwrap
import re
from sentence_transformers import SentenceTransformer, util
import hashlib
import os
import glob

model = SentenceTransformer('all-mpnet-base-v2')
SIMILARITY_THRESHOLD = 0.15
EMBEDDING_TRACKER_DICT = {}
ALL_RELEVANT_MEMORIES = {}
GLOBAL_LINES = []  # Global list to store lines with their types


def format_timestamp(timestamp):
    """Convert a timestamp to a human-readable date and time format."""
    return datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S UTC')


def is_only_whitespace(s):
    # Regex pattern to match only whitespace characters (spaces, tabs, newlines, etc.)
    pattern = r'^\s*$'

    # Test if the string matches the pattern
    return bool(re.match(pattern, s))


def pretty_print_message(message, indent=0, width=80, cutoff_date=None, prev_message=None):
    if indent+4 > width:
        raise ValueError("Indent cannot exceed the width.")

    line_types = []
    if message:
        """Pretty print a single message with text wrapping and newline preservation."""
        author_role = message.get('author', {}).get('role', 'unknown')
        content_parts = message.get('content', {}).get('parts', [])
        content_searchText = message.get('content', {}).get('text', "")
        content_browsingResult = message.get('content', {}).get('result', "")
        content_memories = message.get('content', {}).get('model_set_context', None)
        model_slug = message.get('metadata', {}).get('model_slug', None)
        recipient = message.get('recipient', "")  #will be "bio" for new memory creation

        if prev_message:
            prev_message_content_parts = prev_message.get('content', {}).get('parts', [])
        else:
            prev_message_content_parts = []

        # Skip empty system messages
        if author_role == 'system' and not any(content_parts):
            return

        timestamp = message.get('create_time')

        # Determine the prefix code based on the author role
        if author_role == 'user' or len(content_browsingResult) > 0 or content_memories:
            line_code = '[P]'
        elif author_role == 'assistant' and not content_searchText and (recipient=="all"):
            line_code = '[C]'
        else:
            line_code = '[N]'

        # Prepare the prefix (timestamp, author role, and model slug)
        if timestamp:
            prefix = f"[{format_timestamp(timestamp)}] ({author_role.capitalize()})"
        else:
            prefix = f"({author_role.capitalize()})"

        if model_slug:
            prefix += f" - Model: {model_slug}"

        prefix += ':'

        # Add the prefix line to the global list
        GLOBAL_LINES.append(('[N]', ' ' * indent + prefix))

        # Wrap and append the message content, respecting newlines
        for part in content_parts:
            if isinstance(part, list):
                for paragraph in part.splitlines():
                    if is_only_whitespace(paragraph):
                        paragraph = ""
                    wrapped_text = textwrap.fill(paragraph, width=width,
                                                initial_indent=' ' * (indent + 4),
                                                subsequent_indent=' ' * (indent + 4))
                    for line in wrapped_text.splitlines():
                        GLOBAL_LINES.append((line_code, line))
            else:
                wrapped_text = textwrap.fill(str(part), width=width,
                                            initial_indent=' ' * (indent + 4),
                                            subsequent_indent=' ' * (indent + 4))
                for line in wrapped_text.splitlines():
                    GLOBAL_LINES.append((line_code, line))

        if content_memories and cutoff_date:
            filtered_memories = split_and_filter_memories(content_memories, cutoff_date)
            subject_message = " ".join(prev_message_content_parts).strip()
            relevant_memories = filter_memories_by_similarity(subject_message, filtered_memories, threshold=SIMILARITY_THRESHOLD)
            if relevant_memories:
                memory_info = f"Memories with similarity above {SIMILARITY_THRESHOLD} to last message:"
            else:
                memory_info = f"[Note from collator.ipynb - No memories found with similarity above {SIMILARITY_THRESHOLD} and more recent than {cutoff_date}.]"
            wrapped_text = textwrap.fill(memory_info, width=width,
                                          initial_indent=' ' * (indent + 4),
                                          subsequent_indent=' ' * (indent + 4))
            for line in wrapped_text.splitlines():
                GLOBAL_LINES.append(('[N]', line))
            for memory in relevant_memories:
                ALL_RELEVANT_MEMORIES[memory] = True
                wrapped_text = textwrap.fill(memory, width=width,
                                            initial_indent=' ' * (indent + 4),
                                            subsequent_indent=' ' * (indent + 4))
                for line in wrapped_text.splitlines():
                    GLOBAL_LINES.append((line_code, line))
        if content_searchText:
            wrapped_text = textwrap.fill(content_searchText, width=width,
                                        initial_indent=' ' * (indent + 4),
                                        subsequent_indent=' ' * (indent + 4))
            for line in wrapped_text.splitlines():
                GLOBAL_LINES.append((line_code, line))
        if content_browsingResult:
            wrapped_text = textwrap.fill(content_browsingResult, width=width,
                                        initial_indent=' ' * (indent + 4),
                                        subsequent_indent=' ' * (indent + 4))
            for line in wrapped_text.splitlines():
                GLOBAL_LINES.append((line_code, line))


def split_and_filter_memories(memories, cutoff_date):
    """Split the memories string into individual entries and filter them by cutoff_date."""
    memory_entries = re.split(r'\n(?=\d{2}\. \[\d{4}-\d{2}-\d{2}\])', memories)

    filtered_memories = []

    for entry in memory_entries:
        match = re.search(r'\[(\d{4}-\d{2}-\d{2})\]', entry)
        if match:
            entry_date_str = match.group(1)
            entry_date = datetime.strptime(entry_date_str, '%Y-%m-%d')
            if entry_date >= cutoff_date:
                filtered_memories.append(entry)

    return filtered_memories


def compute_similarity(embedding1, embedding2):
    """Compute cosine similarity between two embeddings."""
    return util.pytorch_cos_sim(embedding1, embedding2).item()


def get_message_embedding(message_text):
    """Generate an SBERT embedding and SHA-256 hash for the given text."""
    sha256_hash = hashlib.sha256(message_text.encode('utf-8')).hexdigest()
    if sha256_hash in EMBEDDING_TRACKER_DICT:
        return EMBEDDING_TRACKER_DICT[sha256_hash], sha256_hash
    embedding = model.encode(message_text, convert_to_tensor=True)
    EMBEDDING_TRACKER_DICT[sha256_hash] = embedding
    return embedding, sha256_hash


def filter_memories_by_similarity(user_message, memories, threshold=0.7):
    """Filter memories based on their similarity to the user message."""
    user_embedding, user_hash = get_message_embedding(user_message)
    EMBEDDING_TRACKER_DICT[user_hash] = user_embedding

    # Process memories and filter based on similarity
    relevant_memories = []
    for memory in memories:
        memory_embedding, memory_hash = get_message_embedding(memory)
        EMBEDDING_TRACKER_DICT[memory_hash] = memory_embedding
        similarity = compute_similarity(user_embedding, memory_embedding)
        if similarity >= threshold:
            relevant_memories.append(memory)

    return relevant_memories


def traverse_conversation(node_id, mapping, indent=0, width=80, first_call=True, second_call=False, cutoff_date=None, prev_message=None):
    """Recursively traverse and print the conversation tree with text wrapping."""
    node = mapping.get(node_id)
    first_timestamp = None

    if not node:
        print(f"Node with ID {node_id} not found.")
        return None

    message = node.get('message')
    if message:
        pretty_print_message(message, indent, width, cutoff_date, prev_message)
        first_timestamp = message.get('create_time')

    # Traverse children nodes if any
    children = node.get('children', [])
    if first_call:  # Do not increase indent for the first call (System message)
        for child_id in children:
            temp_ts = traverse_conversation(child_id, mapping, indent, width, first_call=False, second_call=True, cutoff_date=cutoff_date, prev_message=message)
            if not first_timestamp:
                first_timestamp = temp_ts
    elif second_call:  # Do not increase indent for the second call (System message)
        for child_id in children:
            temp_ts = traverse_conversation(child_id, mapping, indent, width, first_call=False, second_call=False, cutoff_date=cutoff_date, prev_message=message)
            if not first_timestamp:
                first_timestamp = temp_ts
    else:  # Increase indent for all subsequent messages
        for child_id in children:
            temp_ts = traverse_conversation(child_id, mapping, indent + 4, width, first_call=False, second_call=False, cutoff_date=cutoff_date, prev_message=message)
            if not first_timestamp:
                first_timestamp = temp_ts
    return first_timestamp


def find_root_node(mapping):
    """Find the root node of the conversation tree."""
    child_to_parent = {child_id: parent_id for parent_id, node in mapping.items() for child_id in node.get('children', [])}
    root_id = None

    for node_id in mapping:
        if node_id not in child_to_parent:
            root_id = node_id
            break

    return root_id

def parse_chat_transcripts(directory_path, output_file2, width=80, cutoff_date=None):
    """Parse and pretty print the conversations from all JSON files in the directory."""
    intermediate_data2 = {}
    intermediate_data_linetypes = {}

    for file_path in glob.glob(os.path.join(directory_path, '*.json')):
        with open(file_path, 'r') as f:
            chat_data = json.load(f)

        mapping = chat_data[0].get('mapping', {})
        root_id = find_root_node(mapping)

        print(file_path)
        first_timestamp = traverse_conversation(root_id, mapping, width=width, cutoff_date=cutoff_date)
        print(first_timestamp)
        print()

        if GLOBAL_LINES and first_timestamp:
            # Store the lines and their timestamp
            sha256_hash = hashlib.sha256("".join([line for _, line in GLOBAL_LINES]).encode('utf-8')).hexdigest()
            intermediate_data2[(sha256_hash, first_timestamp)] = list(GLOBAL_LINES)  # Store a copy of the list of tuples
            GLOBAL_LINES.clear()  # Clear the global list for the next file

    # Sort the intermediate data by the timestamp
    sorted_intermediate_data2 = sorted(intermediate_data2.items(), key=lambda item: item[0][1])

    # Determine the number of digits for the line numbers
    total_lines2 = sum(len(lines) for (_, _), lines in sorted_intermediate_data2)
    max_line_number_digits2 = len(str(total_lines2))

    # Write the sorted pretty-printed messages to the output file with line numbers
    with open(output_file2, 'w') as output_f:
        current_line = 1
        for (hash_key, ts), lines in sorted_intermediate_data2:
            for line_code, line in lines:
                # Prepend the line number and line type to each line, with spaces for alignment
                line_number_str = f"{current_line:>{max_line_number_digits2}}"
                try:
                    output_f.write(f"{line_number_str} {line_code} | {line}\n")
                except:
                    output_f.write(f"{line_number_str} X | {line}\n")
                current_line += 1


# Provide the directory path and output file path
directory_path = '/content/json_files'
output_file2 = '/content/collated-unredacted.txt'
cutoff_date_str = '2020-07-01'
cutoff_date = datetime.strptime(cutoff_date_str, '%Y-%m-%d')
parse_chat_transcripts(directory_path, output_file2, width=340, cutoff_date=cutoff_date)

for k, v in ALL_RELEVANT_MEMORIES.items():
    print(k)
    print("--------")
