In [1]:
import sys
sys.path.append('..')

In [2]:
from utils import dump_jsonl, load_jsonl
import requests
import time
from tqdm.notebook import tqdm



In [3]:
from itertools import groupby

def get_conversation(row):
    conversation = ""
    users = {}
    for m in row["messages"]:
        # if "user_id" not in m:
        # print(m)
        conversation += f"U{m['user_id']}: {m['text']} \n"
#         conversation.append((users[m['user_id']], m['text']))
    
    return conversation

In [4]:
def has_cycle(graph_edges):
    """
    Detects if a directed graph contains a cycle.
    
    Args:
        graph_edges (dict): Format {(u, v): w} where u, v are nodes and w is weight.
        
    Returns:
        bool: True if a cycle exists, False otherwise.
    """
    
    # 1. Preprocess: Convert input dict to an Adjacency List
    adj_list = defaultdict(list)
    all_nodes = set()
    
    for (u, v), w in graph_edges.items():
        adj_list[u].append(v)
        all_nodes.add(u)
        all_nodes.add(v)
    
    # 2. Prepare tracking sets
    visited = set()          # Nodes visited at any point in time
    recursion_stack = set()  # Nodes currently in the current DFS path

    # 3. DFS Helper Function
    def dfs(node):
        # If node is in the current recursion stack, we found a back-edge (Cycle)
        if node in recursion_stack:
            return True
        
        # If node was already processed in a previous DFS call, skip it
        if node in visited:
            return False

        # Mark node as visited and add to current stack
        visited.add(node)
        recursion_stack.add(node)

        # Visit neighbors
        for neighbor in adj_list[node]:
            if dfs(neighbor):
                return True
        
        # Backtrack: remove node from current stack before returning
        recursion_stack.remove(node)
        return False

    # 4. Iterate over all nodes (handles disconnected graphs)
    for node in all_nodes:
        if node not in visited:
            if dfs(node):
                return True
                
    return False


In [5]:
from collections import defaultdict
from itertools import groupby

def reAnnotatedUserId(conv, verbose=False):
    if len(conv["user_ids"])==2:
        return conv
        
    edgeCount = defaultdict(int)
    highestWeight = -1
    mostCommonEdge = None
    for i in range(1, len(conv["messages"])):
        prev = conv["messages"][i-1]
        curr = conv["messages"][i]
    
        u1 = prev["user_id"]
        u2 = curr["user_id"]
        if u1==u2:
            continue
        edge = (min(u1, u2), max(u1, u2))
        edgeCount[edge] += 1

        if edgeCount[edge] > highestWeight:
            highestWeight = edgeCount[edge]
            mostCommonEdge = edge
    
    if has_cycle(edgeCount):
        print("ERROR")

    newUserId = {}
    userA, userB = mostCommonEdge
    for u1, u2 in edgeCount:
        if edgeCount[(u1, u2)] < 3:
            continue
            
        if u1==userA and u2==userB:
            continue

        if u1 in newUserId:
            u1 = newUserId[u1]

        if u2 in newUserId:
            u2 = newUserId[u2]
    
            
        if userA in [u1, u2]:
            if userA==u1:
                newUserId[u2] = userB
            else:
                newUserId[u1] = userB
        elif userB in [u1, u2]:
            if userB==u1:
                newUserId[u2] = userA
            else:
                newUserId[u1] = userA
        else:
            continue

    if len(edgeCount)==2 and len(conv["user_ids"]) ==3 and len(newUserId)==0:
        e = list(edgeCount.keys())
        e1 = set(e[0])
        e2 = set(e[1])

        commonId = e1&e2
        newUserId[list(e1-commonId)[0]] = list(e2-commonId)[0]
        
        
    for m in conv["messages"]:
        if m["user_id"] in newUserId:
            m["user_id"] = newUserId[m["user_id"]]

    conv["user_ids"] = set([m["user_id"] for m in conv["messages"]])

    return conv

# reAnnotatedUserId(conv);

In [6]:
import json
with open("./raw_data/re_annotated_user_ids.json") as w:
    newUserIds = json.load(w)
# newUserIds

In [9]:
nError = 0
annotated_conversations = load_jsonl("./raw_data/annotated_conersations.jsonl")
for conv in annotated_conversations:
    conv["messages"].sort(key=lambda x: x["date_created"], reverse=False)
    for m in conv["messages"]:
        if str(m["user_id"]) in newUserIds:
            m["user_id"] = int(newUserIds[str(m["user_id"])])
            
    conv["user_ids"] = set([m["user_id"] for m in conv["messages"]])
    if len(conv["user_ids"]) > 2:
        conv = reAnnotatedUserId(conv, verbose=True)
        if len(conv["user_ids"]) >2:
            print(conv["user_ids"])
            nError += 1
            # break

    conv["userA_id"] = conv["messages"][0]["user_id"]
    conv["userA_name"] = None
    conv["userB_id"] = list(conv["user_ids"]-{conv["userA_id"]})[0]
    conv["userB_name"] = None
    conv["revisit"] = len(conv["user_ids"])==2

    
    
    conv["conversations"] = get_conversation(conv)
    # conv["user_ids"] = None
    # conv["messages"] = None
nError

Loaded 1234 records from ./raw_data/annotated_conersations.jsonl


0

In [10]:
newUserIds={}
for conv in annotated_conversations:
    if len(conv["user_ids"]) <= 2:
        continue

    
    if "new_user_ids" in conv:
        parts = conv["new_user_ids"].split("&")
        for uids in parts:
            uids = uids.split(",")
            mn = min(uids)
            for u in uids:
                if u==mn:
                    continue
                newUserIds[u] = mn
newUserIds

{}

In [11]:
from IPython.display import clear_output

nError = 0
for conv in annotated_conversations:
    if len(conv["user_ids"]) <= 2:
        continue
        
    if "new_user_ids" in conv:
        continue

    nError += 1
    # print("UIDs", conv["user_ids"])
    # print(conv["conversations"])
    # newUserId = input()
    # if newUserId!="":
    #     conv["new_user_ids"] = newUserId
    # clear_output(wait=True)
n

0

In [12]:
# lines = text.split("\n")
# newUserIds = {}
# newIds = defaultdict(set)
# for line in lines:
#     if "UIDs" in line:
#         print(newIds)
#         if "A" in newIds:
#             mn = min(list(newIds["A"]))
#             for uid in newIds["A"]:
#                 newUserIds[uid] = mn
            
#             mn = min(list(newIds["B"]))
#             for uid in newIds["B"]:
#                 newUserIds[uid] = mn
            
#         newIds = defaultdict(set)
        
#         continue

#     if ":" not in line:
#         continue
        
#     u, idx = line.split(":")[0].split(" ")
#     idx = idx[1:]
#     newIds[u].add(idx)

# if "A" in newIds:
#     mn = min(list(newIds["A"]))
#     for uid in newIds["A"]:
#         newUserIds[uid] = mn
    
#     mn = min(list(newIds["B"]))
#     for uid in newIds["B"]:
#         newUserIds[uid] = mn
    