Symbol dynamics predictor

This notebook learns a deterministic rule from the challenge API using:
- Period detection for shape and color (KMP minimal period)
- Variable-order Markov (context) model as fallback

It minimizes API calls by caching past observations/predictions in a local log and only observing when needed. All intermediate steps are small and inspectable, and results are saved to disk.

Install requirements. This ensures the notebook runs in clean environments.

In [1]:
import sys, subprocess
def pip_install(pkgs):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + pkgs)
pip_install(["requests", "pandas", "matplotlib"])

Configuration: API base URL, token, and log/output file paths. The base URL and token are provided by the organizers. Update if needed.

In [2]:
import os, json, time, datetime as dt
import requests
import pandas as pd
from collections import Counter, defaultdict
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

BASE_URL = "https://challenge.gcp.katana-labs.com"
TOKEN = "f7ce2595765d1d3ce463a3825ddd67e3616bbe4bbe812fde17957d8fab7eb1f2"

LOG_JSONL = "symbol_dynamics_log.jsonl"
PRED_CSV = "predictions_log.csv"
STATS_CSV = "symbol_stats.csv"
STATS_PNG = "symbol_stats.png"

session = requests.Session()
session.headers.update({"Content-Type": "application/json"})
TIMEOUT = 30

Symbol utilities: define the 9-symbol vocabulary and helpers for encoding/decoding, and for splitting shape/color.

In [3]:
shapes = ["circle", "triangle", "square"]
colors = ["red", "green", "blue"]
all_symbols = [f"{s}_{c}" for s in shapes for c in colors]
symbol_to_id = {s: i for i, s in enumerate(all_symbols)}
id_to_symbol = {i: s for s, i in symbol_to_id.items()}

def split_symbol(sym: str) -> Tuple[str, str]:
    s, c = sym.split("_")
    return s, c

def shape_id(sym: str) -> int:
    s, _ = split_symbol(sym)
    return shapes.index(s)

def color_id(sym: str) -> int:
    _, c = split_symbol(sym)
    return colors.index(c)

def encode_symbol(sym: str) -> int:
    return symbol_to_id[sym]

def decode_symbol(idx: int) -> str:
    return id_to_symbol[idx]

Minimal period estimation using the KMP prefix function. This works for incomplete periods too (no need that the sample length is a multiple of the period).

In [4]:
def kmp_prefix_function(arr: List[int]) -> List[int]:
    pi = [0] * len(arr)
    for i in range(1, len(arr)):
        j = pi[i-1]
        while j > 0 and arr[i] != arr[j]:
            j = pi[j-1]
        if arr[i] == arr[j]:
            j += 1
        pi[i] = j
    return pi

def minimal_period_prefix(arr: List[int]) -> Optional[int]:
    if len(arr) < 2:
        return None
    pi = kmp_prefix_function(arr)
    p = len(arr) - pi[-1]
    if p < len(arr):
        # verify the candidate period p covers all positions
        for i in range(p, len(arr)):
            if arr[i] != arr[i - p]:
                return None
        return p
    return None

@dataclass
class PeriodModel:
    shape_period: Optional[int]
    color_period: Optional[int]
    shape_pattern: Optional[List[int]]
    color_pattern: Optional[List[int]]

    def predict(self, t: int) -> Optional[str]:
        if self.shape_period is None or self.color_period is None:
            return None
        s = shapes[self.shape_pattern[t % self.shape_period]]
        c = colors[self.color_pattern[t % self.color_period]]
        return f"{s}_{c}"

def fit_period_model(symbols: List[str]) -> PeriodModel:
    shp = [shape_id(s) for s in symbols]
    col = [color_id(s) for s in symbols]
    sp = minimal_period_prefix(shp)
    cp = minimal_period_prefix(col)
    shape_pattern = shp[:sp] if sp else None
    color_pattern = col[:cp] if cp else None
    return PeriodModel(sp, cp, shape_pattern, color_pattern)

Variable-order context model (n-gram with backoff). Keeps small deterministic next-symbol maps, using majority when not unique. This is a compact fallback if there is a finite-order rule not captured by pure periodicity.

In [5]:
class ContextModel:
    def __init__(self, max_order: int = 30):
        self.max_order = max_order
        self.counts: List[Dict[Tuple[int, ...], Counter]] = [defaultdict(Counter) for _ in range(self.max_order+1)]
        self.det_next: List[Dict[Tuple[int, ...], int]] = [dict() for _ in range(self.max_order+1)]

    def fit(self, seq: List[int]):
        self.counts = [defaultdict(Counter) for _ in range(self.max_order+1)]
        self.det_next = [dict() for _ in range(self.max_order+1)]
        n = len(seq)
        for i in range(n-1):
            up = min(self.max_order, i+1)
            for k in range(1, up+1):
                ctx = tuple(seq[i-k+1:i+1])
                nxt = seq[i+1]
                self.counts[k][ctx][nxt] += 1
        self._refresh()

    def _refresh(self):
        for k in range(1, self.max_order+1):
            d = {}
            for ctx, cnt in self.counts[k].items():
                d[ctx] = cnt.most_common(1)[0][0]
            self.det_next[k] = d

    def predict_next_id(self, history_ids: List[int]) -> Optional[int]:
        L = min(self.max_order, len(history_ids))
        for k in range(L, 0, -1):
            ctx = tuple(history_ids[-k:])
            nxt = self.det_next[k].get(ctx)
            if nxt is not None:
                return nxt
        return None

def fit_models(observed_symbols: List[str], max_order: int = 30) -> Tuple[PeriodModel, ContextModel]:
    period_model = fit_period_model(observed_symbols)
    ctx = ContextModel(max_order=max_order)
    ctx.fit([encode_symbol(s) for s in observed_symbols])
    return period_model, ctx

def period_available(pm: PeriodModel) -> bool:
    return pm.shape_period is not None and pm.color_period is not None

Logging helpers: we store every API interaction in a JSONL file. We also provide utilities to load the log and reconstruct observed and virtual histories, plus the latest server position. This minimizes re-observation across runs and API calls overall.

In [6]:
def append_log(entry: Dict):
    entry = dict(entry)
    entry["ts"] = dt.datetime.utcnow().isoformat() + "Z"
    with open(LOG_JSONL, "a", encoding="utf-8") as f:
        f.write(json.dumps(entry) + "\n")

def load_log() -> List[Dict]:
    if not os.path.exists(LOG_JSONL):
        return []
    entries = []
    with open(LOG_JSONL, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                entries.append(json.loads(line))
            except Exception:
                pass
    return entries

def reconstruct_histories(entries: List[Dict]) -> Tuple[List[str], List[str], int]:
    observed_history: List[str] = []
    virtual_history: List[str] = []
    latest_pos = 0
    for e in entries:
        latest_pos = e.get("position", latest_pos)
        if e.get("type") == "observe":
            syms = e.get("symbols", [])
            observed_history.extend(syms)
            virtual_history.extend(syms)
        elif e.get("type") == "predict":
            preds = e.get("predictions", [])
            # Extend virtual history with our past predictions to keep time indexing consistent
            virtual_history.extend(preds)
    return observed_history, virtual_history, latest_pos

def ensure_pred_csv_header():
    if not os.path.exists(PRED_CSV):
        pd.DataFrame(columns=["timestamp","position","streak","correct","predictions"]).to_csv(PRED_CSV, index=False)

ensure_pred_csv_header()

API wrappers: small helpers to call /observe and /predict, with logging to disk and minimal overhead. We always trust the server-reported position to stay in sync across sessions/runs. We append all responses to the JSONL log for future reuse and minimal calls later on.

In [7]:
def _url(path: str) -> str:
    return BASE_URL.rstrip("/") + path

def api_observe() -> Tuple[List[str], int]:
    r = session.post(_url("/observe"), json={"token": TOKEN}, timeout=TIMEOUT)
    r.raise_for_status()
    data = r.json()
    syms = data.get("symbols", [])
    pos = data.get("position", 0)
    append_log({"type": "observe", "symbols": syms, "position": pos})
    return syms, pos

def api_predict(preds: List[str]) -> Tuple[int, int, int]:
    payload = {"token": TOKEN, "predictions": preds}
    r = session.post(_url("/predict"), json=payload, timeout=TIMEOUT)
    r.raise_for_status()
    data = r.json()
    correct = int(data.get("correct", 0))
    streak = int(data.get("streak", 0))
    pos = int(data.get("position", 0))
    append_log({"type": "predict", "predictions": preds, "correct": correct, "streak": streak, "position": pos})
    # CSV for quick tracking
    row = {"timestamp": dt.datetime.utcnow().isoformat()+"Z", "position": pos, "streak": streak, "correct": correct, "predictions": json.dumps(preds)}
    pd.DataFrame([row]).to_csv(PRED_CSV, mode="a", header=False, index=False)
    return correct, streak, pos

Load existing log and reconstruct histories. We reuse past observations to train models and past predictions to keep a virtual timeline, avoiding unnecessary observation calls. If the log is empty or too short, we will make a few observations next to bootstrap the models efficiently but minimally.

In [8]:
entries = load_log()
observed_history, virtual_history, current_pos = reconstruct_histories(entries)
# Minimal bootstrap target (in symbols) if we have no data
MIN_BOOTSTRAP = 60
if len(observed_history) < MIN_BOOTSTRAP:
    needed = (MIN_BOOTSTRAP - len(observed_history) + 9) // 10
    for _ in range(needed):
        syms, current_pos = api_observe()
        observed_history.extend(syms)
        virtual_history.extend(syms)
    # Refresh entries if needed later
    entries = load_log()
    observed_history, virtual_history, current_pos = reconstruct_histories(entries)

  entry["ts"] = dt.datetime.utcnow().isoformat() + "Z"


Fit the period model and the context model. Both are retrained from observed history only to avoid compounding errors from unconfirmed predictions. The period model can predict using absolute position t, while the context model uses recent history as fallback when periodicity is not fully detected or is too long to capture from bootstrap data.

In [9]:
period_model, context_model = fit_models(observed_history, max_order=30)
periodic = period_available(period_model)
period_info = {
    "shape_period": period_model.shape_period,
    "color_period": period_model.color_period,
    "periodic": periodic
}
with open("period_model_info.json", "w", encoding="utf-8") as f:
    json.dump(period_info, f, indent=2)

Save quick statistics about observed data for inspection: counts and a simple plot. This helps verify we are seeing structured data and stores artifacts for later review without re-running the API calls.

In [10]:
import matplotlib.pyplot as plt

sym_counts = Counter(observed_history)
df_counts = pd.DataFrame({"symbol": list(sym_counts.keys()), "count": list(sym_counts.values())}).sort_values("symbol")
df_counts.to_csv(STATS_CSV, index=False)

plt.figure(figsize=(9, 3))
plt.bar(df_counts["symbol"], df_counts["count"], color="steelblue")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig(STATS_PNG, dpi=150)
plt.close()  # ensure no inline display remains after saving

Batch prediction function: combine periodic predictor (if available) with context fallback. We use a virtual history for roll-out across predictions to provide context, while periodic prediction uses absolute positions (server positions) for accuracy when the process is periodic. Virtual history is updated with predicted symbols to maintain continuity across rounds without extra API calls. We only observe again if streak falls below a threshold to re-sync/update models minimally.

In [11]:
def predict_symbol_at(t_abs: int, virt_hist: List[str]) -> str:
    if period_available(period_model):
        ps = period_model.predict(t_abs)
        if ps is not None:
            return ps
    # context fallback
    hid = [encode_symbol(s) for s in virt_hist]
    nxt_id = context_model.predict_next_id(hid)
    if nxt_id is not None:
        return decode_symbol(nxt_id)
    # global frequency fallback
    if virt_hist:
        mc = Counter(virt_hist).most_common(1)[0][0]
        return mc
    return "circle_red"

def predict_batch(start_t_abs: int, virt_hist: List[str], n: int = 10) -> List[str]:
    preds = []
    for i in range(n):
        sym = predict_symbol_at(start_t_abs + i, virt_hist)
        preds.append(sym)
        virt_hist.append(sym)  # teacher-forcing rollout
    return preds

Play loop with minimal calls: we predict a few batches and only observe if our streak drops below a threshold. All interactions are logged to disk. You can re-run this cell later to continue from the same state without re-observing, thanks to the log reuse. Adjust rounds and thresholds if desired.

Note: If the process is periodic in both shape and color, predictions should achieve perfect streaks after a short bootstrap, making extra observations unnecessary.

In [12]:
# Parameters for minimal-call prediction loop
ROUNDS = 5
OBSERVE_ON_LOW_STREAK = True
LOW_STREAK_THRESHOLD = 6

# Ensure virtual history is aligned with last known server position
entries = load_log()
observed_history, virtual_history, current_pos = reconstruct_histories(entries)

for r in range(ROUNDS):
    preds = predict_batch(start_t_abs=current_pos, virt_hist=virtual_history, n=10)
    correct, streak, current_pos = api_predict(preds)
    # Optionally observe to re-sync/update models if performance drops
    if OBSERVE_ON_LOW_STREAK and streak < LOW_STREAK_THRESHOLD:
        syms, current_pos = api_observe()
        observed_history.extend(syms)
        # Update models with new ground-truth observations only
        period_model, context_model = fit_models(observed_history, max_order=30)
        periodic = period_available(period_model)
        # keep virtual history consistent with time: extend with observed true symbols as well
        virtual_history.extend(syms)
    # Save brief round summary alongside JSONL/CSV (already appended)
    with open("last_run_summary.txt", "w", encoding="utf-8") as f:
        f.write(f"round={r+1}, streak={streak}, correct={correct}, position={current_pos}, periodic={periodic}\n")

  entry["ts"] = dt.datetime.utcnow().isoformat() + "Z"
  row = {"timestamp": dt.datetime.utcnow().isoformat()+"Z", "position": pos, "streak": streak, "correct": correct, "predictions": json.dumps(preds)}
