# Representing Users and Files
#### Design Document: https://docs.google.com/document/d/1F84Nj3IQ-f_36bmmsOTuOo9u65gfD1WU5LKdnw8ShcY/edit?tab=t.x97j5jy1kop1#heading=h.1vu1g9fe3ujo
#### Optimizations:
- Re-ranking topics
- More information in file embeddings

# Dependencies

In [None]:
! pip install pandas
! pip install openai
! pip install sentence-transformers
! pip install torch
! pip install faiss

# Data

In [None]:
# Load CSV files into pandas DataFrames
import pandas as pd
import os

csv_dir = "./csvs"

dfs = {}
for filename in os.listdir(csv_dir):
    path = csv_dir + "/" + filename
    try:
        name_no_ext = filename.split('.')[0]
        dfs[name_no_ext] = pd.read_csv(path)
    except Exception as e:
        print(f"Failed to read {path}: {e}")

In [None]:
# Print brief summary and show first few rows for each loaded dataframe
from IPython.display import display

for name, df in dfs.items():
    print(f"{name}: {df.shape}")
    display(df.head())

# Topic Embeddings

In [None]:
# Load embedding model
from sentence_transformers import SentenceTransformer

model_name = 'BAAI/bge-large-zh-v1.5'
model = SentenceTransformer(model_name)

In [None]:
# 0.5) Cache any embeddings we calculate. Keep in seperate cell for safety
embeddings_cache = {}

In [None]:
# 1) For each file, get its topics. Create dict of {file_hash_id: {labels: [{label_name: str, label_embedding: [int]}], embedding: int}]
# 1.1) Create topic embedding with embedding model, store in cache
# 1.2) Average those embeddings to get the file representation
import numpy as np
from tqdm import tqdm

def average_embeddings(embeddings):
  return np.mean(embeddings, axis=0)

def get_embedding_for_text(text: str):
  text = text.lower()

  if text in embeddings_cache:
    return embeddings_cache[text]
  
  embeddings_cache[text] = model.encode(text, normalize_embeddings=True) # normalize for cosine similarity
  return embeddings_cache[text]

def get_file_embeddings(limit=None):
  files = {}
  rl = dfs['resource_label']
  rl = rl[rl['name'] == "topic"]

  print("Number of resources to work through: ", len(rl))
  for row in tqdm(rl.iterrows()):

    # Break out early if we want
    if limit is not None:
      limit -= 1
      if limit <= 0:
        break

    resource_info = row[1]
    file_info = files.setdefault(resource_info['hash_id'], {
      "labels": [],
      "embedding": []
    })

    if (resource_info["name"] != "topic"):
      continue

    file_info["labels"].append({
      "id": resource_info["id"],
      "name": resource_info["value"]
    })

  print("Number of files to work through")
  for _, info in tqdm(files.items()):
    info['embedding'] = average_embeddings([get_embedding_for_text(label['name']) for label in info['labels']])
  return files

files = get_file_embeddings()

In [None]:
# 1.3) Store our results
import pickle
from datetime import datetime

def store_results():
  with open("pickle/topic_embeddings_bge-large-zh-v1.5", "wb") as f:
    pickle.dump(embeddings_cache, f)
  with open("pickle/file_embeddings_" + str(datetime.now()), "wb") as f:
    pickle.dump(files, f)

In [None]:
ra = dfs['resource_auditrecord']
rrn = dfs['resource_resourcenode']
rr = dfs['resource_resource']
rl = dfs['resource_label']
# rrn[rrn['id'] == 2444061]
# rr[rr['id'] == 814033]
rl[rl['hash_id'] == 78048]

In [None]:
# 2) For each user, time-weight the relevant file-embeddings. Create dict of {user_id: embedding}
import math
from tqdm import tqdm

def get_stream(audit_id, file_id, streams, timestamp, operation):
    # File has been changed in some way, we want to find the first STREAM that is after the audit log timestamp
    should_check_after = operation in ['MODIFIED', 'FILE_UPLOADED', 'RENAMED']
    if len(streams) == 0:
      raise ValueError(f"No streams for audit id: {audit_id}, file id: {file_id} ")
    
    # 1 stream, always pick it
    if len(streams) == 1:
      stream = streams.iloc[0]
    else:
      streams = streams.sort_values(by="timestamp")
      if should_check_after:
        filtered_streams = streams[streams['timestamp'] >= timestamp]
        if len(filtered_streams) == 0:
          # Fallback to filtering opposite way
          filtered_streams = streams[streams['timestamp'] <= timestamp]
          stream = filtered_streams.iloc[-1]
        else:
          stream = filtered_streams.iloc[0]
      else:
        filtered_streams = streams[streams['timestamp'] <= timestamp]
        if len(filtered_streams) == 0:
          # Fallback to filtering opposite way
          filtered_streams = streams[streams['timestamp'] >= timestamp]
          stream = filtered_streams.iloc[0]
        else:
          stream = filtered_streams.iloc[-1]
    return stream

def get_file_hash_as_of_audit(audit_id, operation, timestamp):
  rrn, rr = dfs['resource_resourcenode'], dfs['resource_resource']
  resource_node = rrn[rrn['id'] == audit_id]
  resource_id = resource_node['resource_id']

  # Should only be one...
  if len(resource_id) != 1:
      raise ValueError(f"Found more than one resource with ID: {resource_id}. Audit id: {audit_id}")
  resource_id = resource_id.iloc[0]  # Use iloc[0] instead of .item()

  # Grab the resource reference by the audit log event
  resource = rr[rr['id'] == resource_id]
  if len(resource) != 1:
      raise ValueError(f"Resource length greater than 1: {resource}. Audit id: {audit_id}")

  # Work our way down to the relevant STREAM, or in other words actual data, relevant for this audit log event
  file_or_stream = resource.iloc[0]  # Use iloc[0] instead of .item()

  if file_or_stream['resource_type'] != "STREAM":
      # Must be a file
      assert file_or_stream['resource_type'] == "FILE"
      file = file_or_stream
      streams = rr[rr['parent_id'] == file['id']]
      stream = get_stream(audit_id, file['id'], streams, timestamp, operation)
      if stream is None:
          return None
      hash_id = stream['hash_id']
  else:
      hash_id = file_or_stream['hash_id']

  return hash_id

def get_user_embedding(user, tau=60*60*24*30): 
  '''
  1) Default tau to 1 month, in other words, we lose 37% of old information if 
     a month has passed since last indexing. Exponential decay.
  2) Unit of time is seconds
  '''
  assert len(user['file_infos']) > 0

  user_embedding = np.array(user['file_infos'][0]['embedding'])
  embedding_time = user['file_infos'][0]['timestamp']

  # Build up the index recursively
  for file_info in user['file_infos'][1:]:
    # Get difference from last time and update time variable
    d_t = datetime.fromisoformat(file_info['timestamp']) - datetime.fromisoformat(embedding_time)
    d_t = d_t.total_seconds()
    embedding_time = file_info['timestamp']

    assert d_t > 0, "User file access times must be ordered"

    # Apply the weighting
    # d_t and tau should both be in seconds
    alpha = 1 - math.exp(-d_t / tau)
    user_embedding = alpha * np.array(file_info['embedding']) + (1 - alpha) * user_embedding

  return user_embedding

def get_user_embeddings():
  users = {}
  ar = dfs['resource_auditrecord']
  print("Processing audit log records:")
  for row in tqdm(ar.iterrows(), total=len(ar)):
    row_info = row[1]
    user_info = users.setdefault(row_info['user_id'], {
      "file_infos": []
    })
    time_of_operation = row_info["timestamp"]
    try:
      hash_id = get_file_hash_as_of_audit(row_info['audited_id'], row_info['operation'], time_of_operation)
    except Exception as e:
      print(f"Exception caught: {e}")
      continue

    if hash_id not in files:
      print(f"Didn't find {hash_id} in file, must have failed topic extraction")
      continue
    file_info = files[hash_id]
    file_embedding = file_info["embedding"]
    user_info['file_infos'].append({
      "timestamp": time_of_operation,
      "embedding": file_embedding
    })

  # 2.1) Time weight the embeddings
  print("Getting user embeddings: ")
  for user, user_info in tqdm(users.items()):
    user_info["embedding"] = get_user_embedding(user_info)
  return users

# users = get_user_embeddings()

In [153]:
user_1 = {'file_infos': [{'timestamp': '2025-09-28 01:29:31.77+00', 'embedding': [ 1, 1, 1 ]}, {'timestamp': '2025-10-28 01:29:31.77+00', 'embedding': [ 0, 0, 0 ]}]}
get_user_embedding(user_1)

array([0.36787944, 0.36787944, 0.36787944])

# Meta Embeddings

# User/File Embeddings

# Semantic Similarity

# Training

# Quality Control

# Anomaly Detection