In [31]:
#!/usr/bin/env python
# coding: utf-8
"""
PreemptiveNetworkClasses (restructured to mirror NetworkClasses.py)
------------------------------------------------------------------
- Mirrors the structure and naming of NetworkClasses.py (EventType/Event, Job, Queue, Server, Station,
  ArrivalProcess, SchedulingPolicy, Network).
- Adds **preemptive-resume** service: a running job can be preempted if the policy allows. We cancel
  the server's previously scheduled completion via an **event token** (eid). With Exp(μ) service,
  memorylessness implies we can re-sample when resuming; this is handled by simply starting service
  anew with a fresh sample.
- Fixes metrics and timing to match the non-preemptive file:
    * We accumulate **queue areas** and **service areas** (busy servers) for mean Lq and Ls.
    * We track sojourn times via exited_jobs list.
    * We expose run_warmup_and_measure(...) and summarize() with the same fields.
- Policies follow the same shape: a central **SchedulingPolicy** whose decide(...) method receives
  free servers grouped by station and returns a mapping of Server -> Queue to start. Preemption is
  opt-in via should_preempt(...).
"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Callable, Deque, Dict, List, Optional, Tuple
from collections import deque, defaultdict
import heapq, itertools, math

import numpy as np


# ---------------- Event model ----------------

class EventType(Enum):
    ARRIVAL = auto()
    DEPARTURE = auto()


@dataclass(order=True)
class Event:
    time: float
    order: int
    type: EventType = field(compare=False)
    payload: Dict[str, Any] = field(compare=False, default_factory=dict)


# ---------------- Core domain ----------------

@dataclass
class Job:
    id: int
    cls: str
    t_arrival: float
    q_arrival: float
    t_service_start: Optional[float] = None
    t_departure: Optional[float] = None
    # optional: path trace [(station_id, queue_id, t_enter, t_start, t_leave)]
    trace: List[Tuple[str, str, float, Optional[float], Optional[float]]] = field(default_factory=list)


class Queue:
    """
    FIFO queue by default. Replace with a priority queue if you need.
    """
    def __init__(self, station_id: str, queue_id: str):
        self.station_id = station_id
        self.queue_id = queue_id
        self._buf: Deque[Job] = deque()

    def push(self, job: Job, t: float) -> None:
        job.trace.append((self.station_id, self.queue_id, t, None, None))
        self._buf.append(job)

    def push_front(self, job: Job, t: float) -> None:
        # For preempted jobs if you want them at the head
        job.trace.append((self.station_id, self.queue_id, t, None, None))
        self._buf.appendleft(job)

    def pop(self) -> Optional[Job]:
        if not self._buf:
            return None
        return self._buf.popleft()

    def peek(self) -> Optional[Job]:
        return self._buf[0] if self._buf else None

    def peek_n(self, n: int) -> Optional[Job]:
        # n = 0 -> head, 1 -> second, ...
        if n < 0 or n >= len(self._buf):
            return None
        # deque supports O(n) indexing; n is tiny (<= #free servers), so this is fine
        return self._buf[n]

    def __len__(self) -> int:
        return len(self._buf)


class PreemptiveServer:
    """
    Single server with preemptive-resume semantics.
    We issue a unique 'eid' token for each scheduled completion. If preempted, we invalidate
    the token by changing self.eid_active; the stale completion is ignored in the handler.
    """
    def __init__(self, server_id: str, service_sampler: Callable[[Job], float]):
        self.server_id = server_id
        self.service_sampler = service_sampler
        self.busy: bool = False
        self.job: Optional[Job] = None
        self.t_busy_until: float = math.inf
        self.eid_active: Optional[int] = None  # completion token

    def is_idle(self) -> bool:
        return not self.busy

    def start_service(self, job: Job, t: float, schedule_departure: Callable[[float, Dict[str, Any]], int]) -> int:
        """Start service and return eid for departure token."""
        self.busy = True
        self.job = job
        job.t_service_start = t
        st = self.service_sampler(job)  # Exp(μ) sampler returns service time
        self.t_busy_until = t + st
        # schedule departure; store eid token
        eid = schedule_departure(self.t_busy_until, {"server": self, "station_id": None})  # station_id filled by station
        self.eid_active = eid
        return eid

    def preempt(self) -> Optional[Job]:
        """Preempt the current job if any; return the job to be requeued."""
        if not self.busy or self.job is None:
            return None
        # Invalidate current completion
        self.eid_active = None
        # On preemptive-resume with Exp service, we just return the job to queue; we will resample on resume.
        job, self.job = self.job, None
        self.busy = False
        self.t_busy_until = math.inf
        return job

    def complete_if_match(self, eid_token: int, t: float) -> Optional[Job]:
        """
        If the provided token matches active eid, complete the job and return it.
        Otherwise ignore (stale completion).
        """
        if self.eid_active is None or self.eid_active != eid_token:
            return None  # stale event
        self.eid_active = None
        self.busy = False
        self.t_busy_until = math.inf
        job, self.job = self.job, None
        if job is not None:
            job.t_departure = t
        return job


class Station:
    """
    Station with c preemptive servers and multiple queues.
    """
    def __init__(
        self,
        station_id: str,
        queues: Dict[str, Queue],
        servers: List[PreemptiveServer],
    ):
        self.station_id = station_id
        self.queues = queues         # queue_id -> Queue
        self.servers = servers       # list of PreemptiveServer

        # For area metrics
        self._ql_area: Dict[str, float] = defaultdict(float)  # by queue
        self._sl_area: float = 0.0  # number of busy servers area

    def free_servers(self) -> List[PreemptiveServer]:
        return [s for s in self.servers if s.is_idle()]

    def busy_servers(self) -> List[PreemptiveServer]:
        return [s for s in self.servers if not s.is_idle()]


# ---------------- Arrivals & Policies ----------------

class ArrivalProcess:
    """
    External arrival process targeting (station_id, queue_id).
    For exponential inter-arrivals: sampler() returns Exp(λ).
    """
    def __init__(
        self,
        target_station: str,
        target_queue: str,
        interarrival_sampler: Callable[[], float],
        job_class: str,
    ):
        self.station_id = target_station
        self.queue_id = target_queue
        self.interarrival_sampler = interarrival_sampler
        self.job_class = job_class

    def schedule_next(self, t_now: float) -> float:
        return t_now + float(self.interarrival_sampler())


class SchedulingPolicy:
    """
    Centralized policy: at a decision epoch (e.g., after arrivals or a departure),
    assign jobs to some/all free servers across the network.

    Must return a mapping: (server) -> (queue) from which to take the next job.
    The simulator will pop a job from that queue and start it on that server.

    Preemption: If your policy supports preemption, also implement:
        should_preempt(net, t, station: Station, server: PreemptiveServer, curr_qid: str, cand_qid: str) -> bool
    """
    allow_preemption: bool = False

    def decide(
        self,
        net: "PreemptiveNetwork",
        t: float,
        free_servers: Dict[str, List[PreemptiveServer]],
    ) -> Dict[PreemptiveServer, Queue]:
        raise NotImplementedError

    # Optional hook for preemption decisions
    def best_queue_for_station(self, net: "PreemptiveNetwork", st: Station, t: float) -> Optional[Queue]:
        """
        Choose a candidate queue for this station if it had a free server.
        Used to compare against currently running jobs for preemption.
        Default: None → no preemption.
        """
        return None

    def should_preempt(
        self,
        net: "PreemptiveNetwork",
        t: float,
        station: Station,
        server: PreemptiveServer,
        curr_qid: Optional[str],
        cand_qid: Optional[str],
    ) -> bool:
        return False


# ---------- Example policies (station-local heuristics) ----------

class FCFSQueueArrivalPolicy(SchedulingPolicy):
    """Station-local FCFS by buffer-arrival time (non-preemptive)."""
    allow_preemption = False

    def decide(self, net, t, free_servers):
        out: Dict[PreemptiveServer, Queue] = {}
        for st_id, srvs in free_servers.items():
            st = net.stations[st_id]
            used = defaultdict(int)
            for srv in srvs:
                best_q, best_t = None, float("inf")
                for q in st.queues.values():
                    j = q.peek_n(used[q.queue_id]) if hasattr(q, 'queue_id') else q.peek_n(used[q.queue_id])
                    # Fallback to simple peek for safety
                    j = j if j is not None else q.peek()
                    if j is None:
                        continue
                    if j.q_arrival < best_t:
                        best_t, best_q = j.q_arrival, q
                if best_q is not None and len(best_q) > 0:
                    out[srv] = best_q
                    used[best_q.queue_id] += 1
        return out


class LBFSStaticPriorityPolicy(SchedulingPolicy):
    """
    Last-Buffer-First-Served with static priority per queue_id (higher int = higher priority).
    Preemption optional via allow_preemption.
    """
    def __init__(self, priority_order: Dict[Tuple[str, str], int], allow_preemption: bool = False):
        self._prio = priority_order
        self.allow_preemption = allow_preemption

    def _rank(self, st_id: str, qid: str) -> int:
        return self._prio.get((st_id, qid), 0)

    def decide(self, net, t, free_servers):
        out: Dict[PreemptiveServer, Queue] = {}
        for st_id, srvs in free_servers.items():
            st = net.stations[st_id]
            for srv in srvs:
                best_q, best_r = None, -10**9
                for qid, q in st.queues.items():
                    if len(q) == 0:
                        continue
                    r = self._rank(st_id, qid)
                    if r > best_r:
                        best_r, best_q = r, q
                if best_q is not None:
                    out[srv] = best_q
        return out

    def best_queue_for_station(self, net, st: Station, t: float) -> Optional[Queue]:
        best_q, best_r = None, -10**9
        for qid, q in st.queues.items():
            if len(q) == 0:
                continue
            r = self._rank(st.station_id, qid)
            if r > best_r:
                best_r, best_q = r, q
        return best_q

    def should_preempt(self, net, t, station, server, curr_qid, cand_qid) -> bool:
        if not self.allow_preemption:
            return False
        if curr_qid is None or cand_qid is None:
            return False
        return self._rank(station.station_id, cand_qid) > self._rank(station.station_id, curr_qid)


class MaxWeightByQLenPolicy(SchedulingPolicy):
    """Preemptive MaxWeight using queue length; ties broken arbitrarily."""
    allow_preemption = True

    def _best_q(self, st: Station) -> Optional[Queue]:
        best_q, best_w = None, -1
        for q in st.queues.values():
            L = len(q)
            if L > best_w and L > 0:
                best_w, best_q = L, q
        return best_q

    def decide(self, net, t, free_servers):
        out: Dict[PreemptiveServer, Queue] = {}
        for st_id, srvs in free_servers.items():
            st = net.stations[st_id]
            for srv in srvs:
                q = self._best_q(st)
                if q is not None:
                    out[srv] = q
        return out

    def best_queue_for_station(self, net, st: Station, t: float) -> Optional[Queue]:
        return self._best_q(st)

    def should_preempt(self, net, t, station, server, curr_qid, cand_qid) -> bool:
        # Preempt if candidate queue is longer than the queue of the running job
        if not self.allow_preemption or cand_qid is None:
            return False
        if server.job is None:
            return False
        curr_q = station.queues[curr_qid] if curr_qid is not None else None
        cand_q = station.queues[cand_qid]
        curr_len = len(curr_q) if curr_q is not None else 0
        return len(cand_q) > curr_len


# ---------------- Simulator (Network) ----------------

class PreemptiveNetwork:
    """
    Continuous-time event-driven simulator with preemption support.
    """
    def __init__(
        self,
        stations: Dict[str, Station],
        arrivals: List[ArrivalProcess],
        policy: SchedulingPolicy,
        rng: Optional[np.random.Generator] = None,
    ):
        self.stations = stations
        self.arrivals = arrivals
        self.policy = policy
        self.rng = rng if rng is not None else np.random.default_rng(0)
        self._seeded: bool = False

        # Event list
        self.t = 0.0
        self._event_seq = itertools.count(1)
        self._evq: List[Event] = []

        # Metrics
        self.completed_jobs: int = 0
        self.exited_jobs: List[Job] = []
        self._measure_t0: float = 0.0

        # Build a map for quick queue lookup
        self._queue_map: Dict[Tuple[str, str], Queue] = {}
        for st_id, st in self.stations.items():
            for qid, q in st.queues.items():
                self._queue_map[(st_id, qid)] = q

    # ---- scheduling helpers ----
    def schedule(self, t: float, etype: EventType, payload: Dict[str, Any]) -> None:
        heapq.heappush(self._evq, Event(time=t, order=next(self._event_seq), type=etype, payload=payload))

    def _accumulate_areas(self, dt: float) -> None:
        if dt <= 0:
            return
        for st in self.stations.values():
            # queue areas
            for qid, q in st.queues.items():
                st._ql_area[qid] += len(q) * dt
            # service areas
            num_busy = sum(1 for s in st.servers if not s.is_idle())
            st._sl_area += num_busy * dt

    def reset_metrics(self) -> None:
        self.exited_jobs = []
        self.completed_jobs = 0
        self._measure_t0 = self.t
        for st in self.stations.values():
            st._ql_area = defaultdict(float)
            st._sl_area = 0.0

    # ---- arrivals, departures ----

    def _on_arrival(self, ap: ArrivalProcess) -> None:
        # realize a job and enqueue
        job = Job(id=self._next_job_id(), cls=ap.job_class, t_arrival=self.t, q_arrival=self.t)
        st = self.stations[ap.station_id]
        q = st.queues[ap.queue_id]
        q.push(job, self.t)

        # schedule next external arrival
        t_next = ap.schedule_next(self.t)
        self.schedule(t_next, EventType.ARRIVAL, {"ap": ap})

    def _on_departure(self, station_id: str, server: PreemptiveServer, eid: int) -> None:
        # complete if token matches
        job = server.complete_if_match(eid, self.t)
        if job is None:
            return  # stale completion due to preemption

        self.completed_jobs += 1

        # routing: by default, jobs leave the system
        self._on_job_exit(job)

    def _on_job_exit(self, job: Job) -> None:
        self.exited_jobs.append(job)

    # ---- policy application ----

    def _dispatch_and_preempt_if_needed(self) -> None:
        # 1) Dispatch to free servers
        free_by_station: Dict[str, List[PreemptiveServer]] = {
            st_id: st.free_servers() for st_id, st in self.stations.items()
        }
        assignments = self.policy.decide(self, self.t, free_by_station)

        # For each server to be started, pop from the assigned queue and start service
        for srv, q in assignments.items():
            if srv.is_idle() and len(q) > 0:
                job = q.pop()
                # Capture the station_id for departure payload
                st_id = None
                for sid, st in self.stations.items():
                    if srv in st.servers:
                        st_id = sid
                        break
                assert st_id is not None, "Server not found in any station"
                def sched_departure(t_complete: float, payload: Dict[str, Any]) -> int:
                    # attach station_id and a unique token to validate later
                    eid = next(self._event_seq)
                    payload = dict(payload)
                    payload["eid"] = eid
                    payload["station_id"] = st_id
                    payload["server"] = srv
                    heapq.heappush(self._evq, Event(time=t_complete, order=eid, type=EventType.DEPARTURE, payload=payload))
                    return eid
                eid = srv.start_service(job, self.t, sched_departure)

        # 2) Optionally preempt running servers if policy allows
        if getattr(self.policy, "allow_preemption", False):
            for st_id, st in self.stations.items():
                cand_q = self.policy.best_queue_for_station(self, st, self.t)
                cand_qid = cand_q.queue_id if cand_q is not None else None
                for srv in st.busy_servers():
                    curr_qid = None
                    if srv.job is not None:
                        # infer the queue id from the job's class mapping
                        # (by default we assume queue_id == job.cls; override as needed per topology)
                        curr_qid = self.cls_to_queue_id(st_id, srv.job.cls)
                    if self.policy.should_preempt(self, self.t, st, srv, curr_qid, cand_qid):
                        # Preempt and requeue current job
                        j = srv.preempt()
                        if j is not None:
                            # Put preempted job back to the FRONT of its queue (resume sooner)
                            qid = curr_qid if curr_qid is not None else self.cls_to_queue_id(st_id, j.cls)
                            st.queues[qid].push_front(j, self.t)
                        # Start the candidate
                        if cand_q is not None and len(cand_q) > 0:
                            j2 = cand_q.pop()
                            def sched_departure2(t_complete: float, payload: Dict[str, Any]) -> int:
                                eid2 = next(self._event_seq)
                                payload = dict(payload)
                                payload["eid"] = eid2
                                payload["station_id"] = st_id
                                payload["server"] = srv
                                heapq.heappush(self._evq, Event(time=t_complete, order=eid2, type=EventType.DEPARTURE, payload=payload))
                                return eid2
                            srv.start_service(j2, self.t, sched_departure2)

    # ---- run loop ----

    def run(self, until_time: float) -> None:
        # schedule initial external arrivals exactly once (if queue empty)
        if not self._evq:
            for ap in self.arrivals:
                self.schedule(ap.schedule_next(self.t), EventType.ARRIVAL, {"ap": ap})

        while self._evq and self.t < until_time:
            ev = heapq.heappop(self._evq)
            if ev.time > until_time:
                # accumulate to until_time and put event back
                self._accumulate_areas(until_time - self.t)
                self.t = until_time
                heapq.heappush(self._evq, ev)
                break

            # advance time & accumulate areas
            self._accumulate_areas(ev.time - self.t)
            self.t = ev.time

            if ev.type == EventType.ARRIVAL:
                self._on_arrival(ev.payload["ap"])
            elif ev.type == EventType.DEPARTURE:
                self._on_departure(ev.payload["station_id"], ev.payload["server"], ev.payload["eid"])
            else:
                raise RuntimeError(f"Unknown event type {ev.type}")

            # After each event, run scheduling decisions
            self._dispatch_and_preempt_if_needed()

    def _next_job_id(self) -> int:
        if not hasattr(self, "_job_id_seq"):
            self._job_id_seq = 0
        self._job_id_seq += 1
        return self._job_id_seq

    # ---------- Metrics / summaries ----------

    def mean_sojourn(self) -> float:
        if not self.exited_jobs:
            return float("nan")
        soj = [j.t_departure - j.t_arrival for j in self.exited_jobs]
        return float(np.mean(soj))

    def summarize(self) -> Dict[str, Any]:
        jobs = self.exited_jobs
        window = max(1e-12, self.t - self._measure_t0)
        if not jobs:
            return {
                "completed": self.completed_jobs,
                "mean_sojourn": self.mean_sojourn(),
                "p50": float("nan"),
                "p90": float("nan"),
                "p95": float("nan"),
                "mean_Lq_total": float("nan"),
                "mean_Ls_total": float("nan"),
                "throughput_per_time": float("nan"),
            }
        soj = np.array([j.t_departure - j.t_arrival for j in jobs], dtype=float)
        p50, p90, p95 = np.percentile(soj, [50, 90, 95]).tolist()
        Lq_total = 0.0
        Ls_total = 0.0
        for st in self.stations.values():
            Lq_total += sum(st._ql_area.values()) / window
            Ls_total += st._sl_area / window
        return {
            "completed": self.completed_jobs,
            "mean_sojourn": float(np.mean(soj)),
            "p50": float(p50),
            "p90": float(p90),
            "p95": float(p95),
            "mean_Lq_total": float(Lq_total),
            "mean_Ls_total": float(Ls_total),
            "throughput_per_time": self.completed_jobs / window,
        }

    def run_warmup_and_measure(self, warmup_time: float = 1_000.0, measure_time: float = 2_000.0) -> Dict[str, Any]:
        self.run(until_time=self.t + warmup_time)
        self.reset_metrics()
        self.run(until_time=self.t + measure_time)
        return self.summarize()

    # ---------- Extensibility hooks (override in concrete network) ----------

    def cls_to_queue_id(self, station_id: str, cls: str) -> str:
        """Default: queue id equals class label. Override per topology if needed."""
        return cls


# ---------------- Example topology: Extended Six-Class (preemptive) ----------------

class ExtendedSixClassNetwork(PreemptiveNetwork):
    """
    Re-entrant 'extended six-class' benchmark (minimal version):
      - L stations, each with classes 3(i-1)+{1,2,3} as queue IDs "Q1","Q2","Q3" repeated by station
      - Exponential services by class index modulo 3, with rates mu1, mu2, mu3
      - External Poisson arrivals to S1:Q1 and S1:Q3 at rate lam each
      - Routing:
          if station < L: class c -> (station+1, same class index)
          if station == L:
             class 1 -> (S1, class 2)  (re-enter to Q2)
             class 2,3 depart
    """
    def __init__(
        self,
        *,
        L: int = 2,
        lam: float = 0.09,
        mu1: float = 2.0,
        mu2: float = 2.0,
        mu3: float = 2.0,
        policy: Optional[SchedulingPolicy] = None,
        seed: int = 123,
    ):
        rng = np.random.default_rng(seed)

        # Build stations, queues, servers
        stations: Dict[str, Station] = {}
        # service samplers by queue_id name "Q1"/"Q2"/"Q3"
        def svc_sampler_for_qid(qid: str) -> Callable[[Job], float]:
            if qid.endswith("1"):
                return lambda job: rng.exponential(1.0 / mu1)
            if qid.endswith("2"):
                return lambda job: rng.exponential(1.0 / mu2)
            return lambda job: rng.exponential(1.0 / mu3)

        for i in range(1, L + 1):
            st_id = f"S{i}"
            queues = {f"Q{k}": Queue(st_id, f"Q{k}") for k in (1, 2, 3)}
            # single server per station by default
            servers = [PreemptiveServer(server_id=f"{st_id}-srv0", service_sampler=lambda j, st_id=st_id: svc_sampler_for_qid(self.cls_to_queue_id(st_id, j.cls))(j))]
            stations[st_id] = Station(station_id=st_id, queues=queues, servers=servers)

        # Arrivals: to S1:Q1 and S1:Q3 at rate lam each
        def exp_interarrival(rate: float, rng=rng) -> Callable[[], float]:
            assert rate > 0.0
            return lambda: rng.exponential(1.0 / rate)

        arrivals = [
            ArrivalProcess("S1", "Q1", exp_interarrival(lam, rng), job_class="1"),
            ArrivalProcess("S1", "Q3", exp_interarrival(lam, rng), job_class="3"),
        ]

        if policy is None:
            policy = MaxWeightByQLenPolicy()

        super().__init__(stations=stations, arrivals=arrivals, policy=policy, rng=rng)

    def _on_departure(self, station_id: str, server: PreemptiveServer, eid: int) -> None:
        # complete as usual
        job = server.complete_if_match(eid, self.t)
        if job is None:
            return  # stale event
        self.completed_jobs += 1

        # routing logic
        st_idx = int(station_id[1:])  # "S1" -> 1
        q_idx = int(self.cls_to_queue_id(station_id, job.cls)[-1])  # "Q2" -> 2
        if st_idx < len(self.stations):
            # send to next station, same class index
            next_sid = f"S{st_idx+1}"
            next_qid = f"Q{q_idx}"
            job.cls = str(q_idx)  # keep class index consistent with queue id
            job.q_arrival = self.t
            self.stations[next_sid].queues[next_qid].push(job, self.t)
        else:
            # last station: custom exits
            if q_idx == 1:
                # class 1 -> (S1, Q2)
                job.cls = "2"
                job.q_arrival = self.t
                self.stations["S1"].queues["Q2"].push(job, self.t)
            else:
                # class 2 or 3 exit
                self._on_job_exit(job)

    def cls_to_queue_id(self, station_id: str, cls: str) -> str:
        # For this network, queue id is "Q{cls}"
        return f"Q{cls}"


Dai–Gluzman Table 4 verification (means, time units)
λ per external stream = 0.064286 (→ ρ≈0.9 per station), odd stations μ=(1/8,1/2,1/4), even μ=(1/6,1/7,1)

--- 3L=6 classes (L=2) ---


AttributeError: 'LBFSStaticPriorityPolicy' object has no attribute 'decide'