# Representing Users and Files
#### Design Document: https://docs.google.com/document/d/1F84Nj3IQ-f_36bmmsOTuOo9u65gfD1WU5LKdnw8ShcY/edit?tab=t.x97j5jy1kop1#heading=h.1vu1g9fe3ujo
#### Optimizations:
- Re-ranking topics
- More information in file embeddings

# Dependencies

In [None]:
! pip install pandas
! pip install openai
! pip install sentence-transformers
! pip install torch

# Utilities/Constants

In [None]:
# Load CSV files representing database into pandas DataFrames
import pandas as pd
import os

CSV_DIR = "./csvs"

DFS = {}
for filename in os.listdir(CSV_DIR):
    path = CSV_DIR + "/" + filename
    try:
        name_no_ext = filename.split('.')[0]
        DFS[name_no_ext] = pd.read_csv(path)
    except Exception as e:
        print(f"Failed to read {path}: {e}")

# Print brief summary and show first few rows for each loaded dataframe
from IPython.display import display

for name, df in DFS.items():
    print(f"{name}: {df.shape}")
    display(df.head())

UNIQUE_OPS = DFS['resource_auditrecord']['operation'].unique()

In [None]:
from datetime import datetime, timezone

# Now for this session
NOW = datetime.now(timezone.utc)

# Cached embeddings directory
CACHE_DIR = "./pickle/"

In [None]:
# Cache any heavy embeddings or mappings we compute.
import os
import pickle
from datetime import datetime
import glob

def _most_recent_file(pattern):
  # Return the most recently modified file in cache_dir matching the glob pattern, or None if none found
  matches = glob.glob(os.path.join(CACHE_DIR, pattern))
  if not matches:
    return None
  return max(matches, key=os.path.getmtime)

def get_cache():
  audit_to_file_mapping_file_path = _most_recent_file("audit_to_file_mapping_*.pkl")
  topic_embeddings_file_path = _most_recent_file("topic_embeddings_bge-large-zh-v1.5_*.pkl")
  file_embeddings_file_path = _most_recent_file("file_embeddings_*.pkl")
  user_embeddings_file_path = _most_recent_file("user_embeddings_*.pkl")

  audit_to_file_mapping, embeddings_cache, files, users = {}, {}, {}, {}

  if audit_to_file_mapping_file_path and os.path.exists(audit_to_file_mapping_file_path):
    with open(audit_to_file_mapping_file_path, "rb") as p:
      audit_to_file_mapping = pickle.load(p)

  if topic_embeddings_file_path and os.path.exists(topic_embeddings_file_path):
    with open(topic_embeddings_file_path, "rb") as p:
      embeddings_cache = pickle.load(p)

  if file_embeddings_file_path and os.path.exists(file_embeddings_file_path):
    with open(file_embeddings_file_path, "rb") as p:
      files = pickle.load(p)

  if user_embeddings_file_path and os.path.exists(user_embeddings_file_path):
    with open(user_embeddings_file_path, "rb") as p:
      users = pickle.load(p)

  print(f"Audit to file mapping size: {len(audit_to_file_mapping)}")
  print(f"Topic embeddings cache size: {len(embeddings_cache) if embeddings_cache is not None else 0}")
  print(f"Files size: {len(files) if files is not None else 0}")
  print(f"Users size: {len(users) if users is not None else 0}")

  return audit_to_file_mapping, embeddings_cache, files, users


def save_to_cache(audit_to_file_mapping, embeddings_cache, files, users):
  # Ensure cache dir exists
  os.makedirs(CACHE_DIR, exist_ok=True)
  save_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
  with open(os.path.join(CACHE_DIR, f"topic_embeddings_bge-large-zh-v1.5_{save_time}.pkl"), "wb") as f:
    pickle.dump(embeddings_cache, f)
  with open(os.path.join(CACHE_DIR, f"file_embeddings_{save_time}.pkl"), "wb") as f:
    pickle.dump(files, f)
  with open(os.path.join(CACHE_DIR, f"user_embeddings_{save_time}.pkl"), "wb") as f:
    pickle.dump(users, f)
  with open(os.path.join(CACHE_DIR, f"audit_to_file_mapping_{save_time}.pkl"), "wb") as f:
    pickle.dump(audit_to_file_mapping, f)

AUDIT_TO_FILE_MAPPING, EMBEDDINGS_CACHE, FILES, USERS = get_cache()

In [None]:
# Database searching utilities
from functools import cache
from tqdm import tqdm

# Getting audit ids which are associated with a file we have information for
# This way we have a much smaller list of audit logs to parse when we are
# calculating user embeddings
@cache
def get_valid_audit_ids():
  rr, rrn, ra = DFS["resource_resource"], DFS["resource_resourcenode"], DFS["resource_auditrecord"]
  valid_resource_ids = set()
  print("Get valid resources")
  for row in tqdm(rr.iterrows(), total=len(rr)):
    row_info = row[1]
    id = row_info['id']

    if len(rr[rr['parent_id'] == id]) != 0:
      valid_resource_ids.add(id)

  valid_resource_node_ids = set()
  print("Get valid resource nodes")
  for row in tqdm(rrn.iterrows(), total=len(rrn)):
    row_info = row[1]
    if row_info['resource_id'] in valid_resource_ids:
      valid_resource_node_ids.add(row_info['id'])

  valid_audit_ids = set()
  ra.head()
  print("Get valid audit ids")
  for row in tqdm(ra.iterrows(), total=len(ra)):
    row_info = row[1]
    if row_info['audited_id'] in valid_resource_node_ids:
      valid_audit_ids.add(row_info['id'])
  return valid_audit_ids, valid_resource_node_ids, valid_resource_ids

def get_stream(audit_id, file_id, streams, timestamp, operation):
  # File has been changed in some way, we want to find the first STREAM that is after the audit log timestamp
  should_check_after = operation in ['MODIFIED', 'FILE_UPLOADED', 'RENAMED']
  
  # 1 stream, always pick it
  if len(streams) == 1:
    stream = streams.iloc[0]
  else:
    streams = streams.sort_values(by="timestamp")
    if should_check_after:
      filtered_streams = streams[streams['timestamp'] >= timestamp]
      if len(filtered_streams) == 0:
        # Fallback to filtering opposite way
        filtered_streams = streams[streams['timestamp'] <= timestamp]
        stream = filtered_streams.iloc[-1]
      else:
        stream = filtered_streams.iloc[0]
    else:
      filtered_streams = streams[streams['timestamp'] <= timestamp]
      if len(filtered_streams) == 0:
        # Fallback to filtering opposite way
        filtered_streams = streams[streams['timestamp'] >= timestamp]
        stream = filtered_streams.iloc[0]
      else:
        stream = filtered_streams.iloc[-1]
  return stream

def get_file_hash_as_of_audit(audit_id, operation, timestamp):
  rrn, rr = DFS['resource_resourcenode'], DFS['resource_resource']
  resource_node = rrn[rrn['id'] == audit_id]
  resource_id = resource_node['resource_id']

  # Assume only one resource id associated with audit id?
  if len(resource_id) == 0:
    raise ValueError(f"No resource id found. Audit id: {audit_id}")
  resource_id = resource_id.iloc[0]

  # Grab the resource reference by the audit log event
  resource = rr[rr['id'] == resource_id]
  if len(resource) == 0:
    raise ValueError(f"No resource found: Audit id: {audit_id}")

  # Work our way down to the relevant STREAM, or in other words actual data, relevant for this audit log event
  file_or_stream = resource.iloc[0]

  if file_or_stream['resource_type'] != "STREAM":
      # Must be a file
      assert file_or_stream['resource_type'] == "FILE"
      file = file_or_stream
      streams = rr[rr['parent_id'] == file['id']]
      if len(streams) is None:
        # If no streams, at least return file id
        return file['id'], False
      stream = get_stream(audit_id, file['id'], streams, timestamp, operation)
      hash_id = stream['hash_id']
  else:
      # If stream, directly get hash id
      hash_id = file_or_stream['hash_id']

  return hash_id, True

# User/File Embeddings

In [None]:
# Load embedding model
from sentence_transformers import SentenceTransformer

model_name = 'BAAI/bge-large-zh-v1.5'
model = SentenceTransformer(model_name)

In [None]:
# FILE/TOPIC embeddings (get topic embeddings as we calculate file embeddings)
import numpy as np
from tqdm import tqdm

def average_embeddings(embeddings):
  return np.mean(embeddings, axis=0)

def get_embedding_for_text(text: str):
  text = text.lower()

  if text in EMBEDDINGS_CACHE:
    return EMBEDDINGS_CACHE[text]
  
  EMBEDDINGS_CACHE[text] = model.encode(text, normalize_embeddings=True) # normalize for cosine similarity
  return EMBEDDINGS_CACHE[text]

def get_file_embeddings():
  files = {}
  rl = DFS['resource_label']
  rl = rl[rl['name'] == "topic"]

  print("Working through resource labels")
  for row in tqdm(rl.iterrows(), total=len(rl)):

    resource_info = row[1]
    file_info = files.setdefault(resource_info['hash_id'], {
      "labels": [],
      "embedding": []
    })

    if (resource_info["name"] != "topic"):
      continue

    file_info["labels"].append({
      "id": resource_info["id"],
      "name": resource_info["value"]
    })

  print("Average topic embeddings for files")
  for _, info in tqdm(files.items(), total=len(files)):
    info['embedding'] = average_embeddings([get_embedding_for_text(label['name']) for label in info['labels']])
  return files

In [None]:
# USER embeddings, for each time-weight the relevant file-embeddings. Create dict of {user_id: embedding}
import math
from tqdm import tqdm

def aggregate_file_embeddings_per_user(valid_audit_ids):
  users = {}
  ar = DFS['resource_auditrecord']

  ar = ar[ar['id'].isin(valid_audit_ids)]
  
  bad_resource_ids = set()
  good_resource_ids = set()

  print("Processing audit log records:")
  for row in tqdm(ar.iterrows(), total=len(ar)):
    
    row_info = row[1]
    if row_info['audited_id'] in bad_resource_ids:
      continue
    
    user_info = users.setdefault(row_info['user_id'], {
      "file_infos": []
    })
    time_of_operation = row_info["timestamp"]
    try:
      # Get most relevant file version as of audit time
      hash_id, found_streams = get_file_hash_as_of_audit(row_info['audited_id'], row_info['operation'], time_of_operation)
    except Exception as e:
      # print(f"Exception caught: {e}")
      bad_resource_ids.add(row_info['audited_id'])
      continue

    # Cache audit id to file mapping for later metadata processing, even when no streams found
    AUDIT_TO_FILE_MAPPING[row_info['id']] = hash_id

    if hash_id not in FILES:
      # print(f"Didn't find {hash_id} in file, audit id: {row_info['audited_id']}, must have failed topic extraction")
      continue

    good_resource_ids.add(row_info['audited_id'])

    # If successfully got hash, lookup in files table and add to user info
    file_info = FILES[hash_id]
    file_embedding = file_info["embedding"]
    user_info['file_infos'].append({
      "timestamp": datetime.fromisoformat(time_of_operation),
      "embedding": file_embedding
    })

  # print(f"Success: {len(good_resource_ids)}, Error: {len(bad_resource_ids)}")
  return users

from datetime import datetime

def get_user_embedding(user_info, tau=60*60*24*30): 
  '''
  1) Default tau to 1 month, in other words, we lose 37% of old information if 
     a month has passed since last indexing. Exponential decay.
  2) Unit of time is seconds
  '''

  user_embedding = np.array(user_info['file_infos'][0]['embedding'])
  embedding_time = user_info['file_infos'][0]['timestamp']

  # Build up the index recursively
  for file_info in user_info['file_infos'][1:]:
    # Get difference from last time and update time variable
    d_t = file_info['timestamp'] - embedding_time
    d_t = d_t.total_seconds()
    embedding_time = file_info['timestamp']

    assert d_t >= 0, f"User file access times must be ordered: from - {embedding_time}, to - {file_info['timestamp']}"

    # Apply the weighting
    # d_t and tau should both be in seconds
    alpha = 1 - math.exp(-d_t / tau)
    user_embedding = alpha * np.array(file_info['embedding']) + (1 - alpha) * user_embedding

  return user_embedding

def get_user_embeddings(users):
  # 2.1) Time weight the file embeddings
  print("Calculating user embeddings: ")
  for user, user_info in tqdm(users.items()):
    user_info['file_infos'] = sorted(user_info['file_infos'], key=lambda event: event['timestamp'])
    if len(user_info['file_infos']) == 0: 
      print(f"User has no files {user}")
      continue
    user_info["embedding"] = get_user_embedding(user_info)
  return users

In [None]:
len(AUDIT_TO_FILE_MAPPING)

In [None]:
USERS[list(USERS.keys())[0]]

In [None]:
FILES, USERS, AUDIT_TO_FILE_MAPPING, EMBEDDINGS_CACHE = None, None, dict(), EMBEDDINGS_CACHE
FILES = get_file_embeddings()

# This step is an optimization, don't bother processing audit ids associated with files that haven't been indexed
valid_audit_ids, _, _ = get_valid_audit_ids()
USERS_AGG = aggregate_file_embeddings_per_user(valid_audit_ids)

In [None]:
USERS = get_user_embeddings(USERS_AGG)

In [None]:
save_to_cache(AUDIT_TO_FILE_MAPPING, EMBEDDINGS_CACHE, FILES, USERS)

# Meta Embeddings

In [None]:
# Meta helpers
import math
from collections import Counter
from datetime import datetime, timezone, timedelta
from collections import Counter
import numpy as np

def get_interarrivals(timestamps):
  interarrivals = []
  if len(timestamps) <= 1:
    return []
  prev_timestamp = timestamps[0]
  for next_timestamp in timestamps[1:]:
    diff = (next_timestamp - prev_timestamp).total_seconds()
    assert diff >= 0, f"not ordered: start_time: {prev_timestamp}, next_time: {next_timestamp}"
    interarrivals.append(diff)
    prev_timestamp = next_timestamp
  return interarrivals

def get_mit(interarrivals):
  return sum(interarrivals) / len(interarrivals) if len(interarrivals) > 0 else -1

def get_sit(interarrivals, mean):
  deviations = [math.pow(i - mean, 2) for i in interarrivals]
  if len(deviations) == 0:
    return -1
  return math.sqrt(sum(deviations) / len(deviations))

def get_fft_info(buckets):
  # Calculate top 3 frequencies/amplitudes/phases of time series using FFT
  bucket_arr = np.array(buckets)
  fft_result = np.fft.fft(bucket_arr)
  fft_freqs = np.fft.fftfreq(len(bucket_arr), d=1)
  amplitudes = np.abs(fft_result)
  phases = np.angle(fft_result)
  # Ignore the zero frequency (DC component)
  indices = np.argsort(amplitudes[1:])[::-1][:3] + 1 if len(amplitudes) > 1 else []
  return [(fft_freqs[i], amplitudes[i], phases[i]) for rank, i in enumerate(indices)]

def get_burstiness(interarrivals, meta):
  # TODO
  pass

def get_time_based_meta(start, end, interval, meta, operation, time_str, times):
  # Sort datetimes so it's O(n) operation to create buckets
  buckets = []
  current = start
  idx = 0
  n = len(times)
  while current < end:
      next_bucket = current + interval
      count = 0
      # Count how many times fall into [current, next_bucket)
      while idx < n and times[idx] < next_bucket:
          if times[idx] >= current:
              count += 1
          idx += 1
      buckets.append(count)
      current = next_bucket
  
  for i, bucket in enumerate(buckets):
    meta['times'][f'{operation}_{time_str}_{i}'] = bucket
  fft_infos = get_fft_info(buckets)
  for i, info in enumerate(fft_infos):
    meta['times'][f'{operation}_{time_str}_freq_{i}'] = info[0]
    meta['times'][f'{operation}_{time_str}_amp_{i}'] = info[1]
    meta['times'][f'{operation}_{time_str}_phase_{i}'] = info[2]

def get_entropy(values, normalized=True):
  if not values:
    return 0.0
  counts = Counter(values)
  total = sum(counts.values())
  probs = [cnt / total for cnt in counts.values()]
  H = -sum(p * math.log2(p) for p in probs if p > 0)

  if not normalized:
      return H

  k = len(counts)
  if k <= 1:
      return 0.0
  return H / math.log2(k)

def setup_metadata_from_audit_log(metadata, row_info):
  # Add location information
  locs = metadata.setdefault('locs', list())
  locs.append(f'{row_info['geolocation']}-{row_info['client_ip']}')

  # Add timestamps for operations
  operation = row_info['operation']
  times = metadata.setdefault('times', {})

  all_operations = times.setdefault('all_ops', [])
  operation_times = times.setdefault(operation, [])
  event = {
     "timestamp": datetime.fromisoformat(row_info['timestamp']),
     "user": row_info["user_id"]
  }
  # Per operation
  operation_times.append(event)
  # All operations
  all_operations.append(event)

def fill_metadata_stats_for_file_or_users(metadata, users, num_top=3):
  sorted_ops = [o['timestamp'] for o in sorted(metadata['times']['all_ops'], key=lambda op: op['timestamp'])]
  interarrivals = get_interarrivals(sorted_ops)
  # mean interarrival time
  metadata['mit'] = sum(interarrivals) / len(interarrivals) if len(interarrivals) > 0 else -1
  # std interarrival time
  metadata['sit'] = get_sit(interarrivals, metadata['mit'])
  # recency
  metadata['recency'] = (NOW - sorted_ops[-1]).total_seconds()

  # user distribution entropy
  metadata['ud_entropy'] = get_entropy(users)
  # unique users
  metadata['unique_users'] = len(set(users))
  # top users
  counter = Counter(users)
  top_users = counter.most_common(num_top)
  for i, user in enumerate(top_users):
    metadata[f'top_user_{i}'] = user[0]
    metadata[f'top_user_{i}_per'] = user[1] / len(users)

  # unique locations
  metadata['unique_locs'] = len(set(metadata['locs']))

  # top locations
  counter = Counter(metadata['locs'])
  top_locs = counter.most_common(num_top)
  for i, loc in enumerate(top_locs):
    metadata[f'top_loc_{i}'] = loc[0]
    metadata[f'top_loc_{i}_per'] = loc[1] / len(metadata['locs'])

  # time meta
  for OP in UNIQUE_OPS:
    if not metadata.get("times"):
      break
    op_times = []
    if metadata["times"].get(OP):
      op_times = [t['timestamp'] for t in sorted([op for op in metadata["times"][OP]], key=lambda operation: operation["timestamp"])]

    
    # time buckets and FFT for specific operations
    get_time_based_meta(NOW - timedelta(days=7), NOW, timedelta(hours=1), metadata, OP, 'hour', op_times)
    get_time_based_meta(NOW - timedelta(weeks=4), NOW, timedelta(days=1), metadata, OP, 'day', op_times)
    get_time_based_meta(NOW - timedelta(weeks=52), NOW, timedelta(weeks=4), metadata, OP, 'month', op_times)

  # time buckets and FFT for 'all' operations
  get_time_based_meta(NOW - timedelta(days=7), NOW, timedelta(hours=1), metadata, "all_ops", 'hour', sorted_ops)
  get_time_based_meta(NOW - timedelta(weeks=4), NOW, timedelta(days=1), metadata, "all_ops", 'day', sorted_ops)
  get_time_based_meta(NOW - timedelta(weeks=52), NOW, timedelta(weeks=4), metadata, "all_ops", 'month', sorted_ops)

  # burstiness stats
  get_burstiness(interarrivals, metadata)

def clean_up_metadata(metadata):
  # Clean up metadata we don't need
  del metadata['locs']
  del metadata['times']['all_ops']
  for OP in UNIQUE_OPS:
    if metadata['times'].get(OP):
      del metadata['times'][OP]

In [None]:
# USER: Extract meta features from audit log
from tqdm import tqdm

def get_user_info_from_auditlog(users, ra):
  print("Iterate Audit Log")
  for row in tqdm(ra.iterrows(), total=len(ra)):
    row_info = row[1]
    uid = row_info['user_id']
    user_info = {
      "file_infos": [],
      "embedding": []
    }
    if uid in users:
      user_info: dict = users[uid]
    
    # Create metadata object for user if doesn't exist
    metadata = user_info.setdefault("metadata", {})
    
    setup_metadata_from_audit_log(metadata, row_info)


def get_meta_features_users(users, num_top_locs=3, limit=1000000000):
  ra = DFS['resource_auditrecord']
  ra = ra.head(limit)

  # Delete any existing metadata so we don't mix between runs
  for user, user_info in users.items():
    if user_info.get("metadata"):
      del user_info['metadata']

  get_user_info_from_auditlog(users, ra)

  print("Update Users Metadata")
  for user, user_info in tqdm(users.items(), total=len(users)):
    if not user_info.get('metadata'):
      # print(f"No metadata for file {file}")
      continue
    metadata = user_info['metadata']
    if not metadata.get('locs'):
      print(f"No activity for user {user}")
      continue

    fill_metadata_stats_for_file_or_users(metadata, [user])
    clean_up_metadata(metadata)

  return users

In [None]:
# FILE: Extract meta features from audit log
from collections import Counter

def get_file_info_from_auditlog(files, ra):
  print("Iterate Audit Log")
  for row in tqdm(ra.iterrows(), total=len(ra)):
    row_info = row[1]

    audit_log_id = row_info['id']
    if audit_log_id not in AUDIT_TO_FILE_MAPPING:
      # print("No file info available")
      continue
    file_hash = AUDIT_TO_FILE_MAPPING[row_info['id']]

    if file_hash not in files:
      # print(f"Didn't find file hash in files... adding file")
      files[file_hash] = {}
    file_info = files[file_hash]
    
    # Create metadata object for file if doesn't exist
    metadata = file_info.setdefault("metadata", {})

    setup_metadata_from_audit_log(metadata, row_info)

def get_meta_features_files(files, num_top_locs=3, limit=1000000000):
  ra = DFS['resource_auditrecord']
  ra = ra.head(limit)

  # Delete any existing metadata so we don't mix between runs
  for file, file_info in files.items():
    if file_info.get("metadata"):
      del file_info['metadata']

  get_file_info_from_auditlog(files, ra)

  print("Update Files Metadata")
  for file, file_info in tqdm(files.items(), total=len(files)):
    if not file_info.get('metadata'):
      # print(f"No metadata for file {file}")
      continue
    metadata = file_info['metadata']
    if not metadata.get('locs'):
      # print(f"No activity for file {file}")
      continue

    users = [op['user'] for op in metadata['times']['all_ops']]
    fill_metadata_stats_for_file_or_users(metadata, users)
    clean_up_metadata(metadata)
  return files

In [None]:
limit = 10000000000000
USERS = get_meta_features_users(USERS, limit=limit)
FILES = get_meta_features_files(FILES, limit=limit)

In [None]:
save_to_cache(AUDIT_TO_FILE_MAPPING, EMBEDDINGS_CACHE, FILES, USERS)

In [None]:
# Explore data

from tqdm import tqdm

ra = DFS['resource_auditrecord']
rrn = DFS['resource_resourcenode']
rr = DFS['resource_resource']
rl = DFS['resource_label']
rh = DFS['resource_hash']
# rrn[rrn['id'] == 2444061]
# rr[rr['id'] == 814033]
# rr[rr['parent_id'] == 814033]
# rl[rl['hash_id'] == 78048]
# ra.head()
# ra[ra['audited_id'] == 2444061.0]
# len(rr['id'].unique())
# rr[rr['resource_type'] == "FILE"]
# rrn = dfs['resource_resourcenode']
# len(rrn['resource_id'].unique())

# Get valid audit record ids preemptively
340 in set(AUDIT_TO_FILE_MAPPING.values())

index = 4000
# users_list = list(USERS.keys())
# user_name = users_list[index]
files_list = list(FILES.keys())
file_hash = files_list[index]
# USERS[user_name]['metadata']
count = 0
for hash, file in FILES.items():
  if file.get("metadata"):
    count += 1

print(count)
# file_hash


# Training

# Quality Control

# Semantic Similarity

# Anomaly Detection