In [6]:
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Any, Callable, Deque, Dict, List, Optional, Tuple
from collections import deque
import heapq
import math
import numpy as np
from collections import defaultdict

# -------- Events --------

class EventType(Enum):
    ARRIVAL = auto()
    DEPARTURE = auto()

@dataclass(order=True)
class Event:
    time: float
    order: int
    type: EventType = field(compare=False)
    payload: Dict[str, Any] = field(compare=False, default_factory=dict)
    is_valid: bool = field(compare=False, default=True) # NEW: Flag for cancellation


# -------- Core domain --------

@dataclass
class Job:
    id: int
    cls: str
    t_arrival: float
    q_arrival: float
    t_service_start: Optional[float] = None
    t_departure: Optional[float] = None
    # optional: path trace [(station_id, queue_id, t_enter, t_start, t_leave)]
    trace: List[Tuple[str, str, float, Optional[float], Optional[float]]] = field(default_factory=list)


class Queue:
    """
    FIFO queue by default. Replace with a priority queue if you need.
    """
    def __init__(self, station_id: str, queue_id: str):
        self.station_id = station_id
        self.queue_id = queue_id
        self._buf: Deque[Job] = deque()

    def push(self, job: Job, t: float) -> None:
        job.trace.append((self.station_id, self.queue_id, t, None, None))
        self._buf.append(job)

    def pop(self) -> Optional[Job]:
        if not self._buf:
            return None
        return self._buf.popleft()

    def __len__(self) -> int:
        return len(self._buf)

    def peek(self) -> Optional[Job]:
        return self._buf[0] if self._buf else None

    def peek_n(self, n: int) -> Optional[Job]:
        # n = 0 -> head, 1 -> second, ...
        if n < 0 or n >= len(self._buf):
            return None
        # deque supports O(n) indexing; n is tiny (<= #free servers), so this is fine
        return self._buf[n]


class Server:
    """
    Single server (non-preemptive). Station owns multiple of these.
    """
    #service sampler is a function that maps a job to a float.
    #def svc1(job):  # e.g., class-specific service rates
    #return rng.exponential(1/ (1.5 if job.cls=="A" else 1.0))
    def __init__(self, server_id: str, service_sampler: Callable[[Job], float]):
        self.server_id = server_id
        self.service_sampler = service_sampler
        self.busy: bool = False
        self.job: Optional[Job] = None
        self.t_busy_until: float = math.inf
        self.departure_event: Optional[Event] = None # NEW: Link to the departure event

    def start_service(self, job: Job, t: float) -> float:
        """Start service and return departure time."""
        assert not self.busy
        self.busy = True
        self.job = job
        job.t_service_start = t
        # mark trace start time
        if job.trace and job.trace[-1][3] is None:
            st, qid, t_enter, _, t_leave = job.trace[-1]
            job.trace[-1] = (st, qid, t_enter, t, t_leave)

        s = max(0.0, float(self.service_sampler(job)))
        dep_time = t + s
        self.t_busy_until = dep_time
        return dep_time

    def complete(self, t: float):
        assert self.busy and self.job is not None
        job = self.job
        # mark trace leave time for this stage
        if job.trace and job.trace[-1][4] is None:
            st, qid, t_enter, t_start, _ = job.trace[-1]
            job.trace[-1] = (st, qid, t_enter, t_start, t)
        self.busy = False
        self.job = None
        self.t_busy_until = math.inf
        return job

class Station:
    """
    Station with c servers and multiple queues.
    """
    def __init__(
        self,
        station_id: str,
        queues: Dict[str, Queue],
        servers: List[Server],
    ):
        self.station_id = station_id
        self.queues = queues         # queue_id -> Queue
        self.servers = servers       # list of Server

    def free_servers(self) -> List[Server]:
        return [s for s in self.servers if not s.busy]

    def total_queue_len(self) -> int:
        return sum(len(q) for q in self.queues.values())

    def queue_lengths(self) -> Dict[str, int]:
        return {qid: len(q) for qid, q in self.queues.items()}


# -------- Arrivals and Routing --------

class ArrivalProcess:
    """
    External arrival process targeting (station_id, queue_id).
    For exponential inter-arrivals: sampler() returns Exp(λ).
    """
    def __init__(
        self,
        target_station: str,
        target_queue: str,
        interarrival_sampler: Callable[[], float],
        job_class: str,
    ):
        self.station_id = target_station
        self.queue_id = target_queue
        self.interarrival_sampler = interarrival_sampler
        self.job_class = job_class
        self.next_time: float = 0.0

    def schedule_next(self, t_now: float) -> float:
        self.next_time = t_now + max(0.0, float(self.interarrival_sampler()))
        return self.next_time


# -------- Policy interface --------

class SchedulingPolicy:
    """
    Centralized policy: at a decision epoch (e.g., after arrivals or a departure),
    assign jobs to some/all free servers across the network.

    Must return a mapping: (server) -> (queue) from which to take the next job.
    The simulator will pop a job from that queue and start it on that server.
    """
    def decide(
        self,
        net: NetworkLike,                  # duck-typed: see Network below
        t: float,
        free_servers: List[Tuple[Station, Server]],
    ) -> Dict[Server, Queue]:
        raise NotImplementedError

# -------- Network (orchestrator) --------

class Network:
    def __init__(
        self,
        stations: Dict[str, Station],
        arrivals: List[ArrivalProcess],
        router: Router,
        policy: SchedulingPolicy,
        rng: Optional[np.random.Generator] = None,
    ):
        self.stations = stations
        self.arrivals = arrivals
        self.router = router
        self.policy = policy
        self.rng = rng or np.random.default_rng(0)

        self.t: float = 0.0
        self._event_q: List[Event] = []
        self._eid: int = 0
        self._job_id_seq: int = 0
        self.completed_jobs: int = 0
        self.sum_sojourn: float = 0.0
        self.exited_jobs: list = []
        self._seeded: bool = False

    def schedule(self, time: float, etype: EventType, payload: Dict[str, Any]) -> Event:
        self._eid += 1
        ev = Event(time=time, order=self._eid, type=etype, payload=payload)
        heapq.heappush(self._event_q, ev)
        return ev

    def run(self, until_time: Optional[float] = None, until_jobs: Optional[int] = None) -> None:
        if not self._seeded:
            for ap in self.arrivals:
                t_next = ap.schedule_next(self.t)
                self.schedule(t_next, EventType.ARRIVAL, {"ap": ap})
            self._seeded = True

        while self._event_q:
            if until_time is not None and self._event_q[0].time > until_time:
                break

            ev = heapq.heappop(self._event_q)
            
            if not ev.is_valid:
                continue

            if until_jobs is not None and self.completed_jobs >= until_jobs:
                heapq.heappush(self._event_q, ev)
                break

            dt = ev.time - self.t
            if dt > 0:
                for st in self.stations.values():
                    if not hasattr(st, "_ql_area"): st._ql_area = {qid: 0.0 for qid in st.queues}
                    if not hasattr(st, "_sl_area"): st._sl_area = 0.0
                    for qid, q in st.queues.items():
                        st._ql_area[qid] += len(q) * dt
                    num_busy_servers = sum(1 for srv in st.servers if srv.busy)
                    st._sl_area += num_busy_servers * dt
            
            self.t = ev.time

            # Use the event's order as its ID for departure validation
            current_event_id = ev.order

            if ev.type == EventType.ARRIVAL:
                self._on_arrival(ev.payload["ap"])
            elif ev.type == EventType.DEPARTURE:
                self._on_departure(ev.payload["station_id"], ev.payload["server_idx"], current_event_id)
            
            self._decision_epoch()

    def _on_arrival(self, ap: ArrivalProcess) -> None:
        job = Job(id=self._next_job_id(), cls=ap.job_class, t_arrival=self.t, q_arrival=self.t)
        st = self.stations[ap.station_id]
        q = st.queues[ap.queue_id]
        q.push(job, self.t)
        t_next = ap.schedule_next(self.t)
        self.schedule(t_next, EventType.ARRIVAL, {"ap": ap})
    
    def _on_departure(self, station_id: str, server_idx: int, event_id: int) -> None:
        st = self.stations[station_id]
        srv = st.servers[server_idx]
        
        # Stale event check: if the server's active departure event doesn't match this one, ignore it
        if not srv.busy or not srv.departure_event or srv.departure_event.order != event_id:
             return
        
        job = srv.complete(self.t)
        srv.departure_event = None
        
        nxt = self.router.route(job, station_id, self.t)
        if nxt is None:
            job.t_departure = self.t
            self.completed_jobs += 1
            self.sum_sojourn += (job.t_departure - job.t_arrival)
            self.exited_jobs.append(job) 
        else:
            next_st_id, next_q_id = nxt
            # Find the queue a job belongs to based on its class
            # This is a simplification and may need adjustment for complex routings
            preempted_job_q_id = f"Q{job.cls}"
            if preempted_job_q_id in self.stations[next_st_id].queues:
                job.q_arrival = self.t
                self.stations[next_st_id].queues[preempted_job_q_id].push(job, self.t)
            else: # Fallback for more complex cases
                self.stations[next_st_id].queues[next_q_id].push(job, self.t)

    def _decision_epoch(self) -> None:
        assignments = self.policy.decide(self, self.t, self.stations)

        for st_id, st in self.stations.items():
            for i, srv in enumerate(st.servers):
                ideal_queue = assignments.get(srv)
                ideal_job = ideal_queue.peek() if ideal_queue else None

                if srv.busy:
                    current_job = srv.job
                    if not ideal_job or (current_job and ideal_job.id == current_job.id):
                        continue
                    
                    preempted_job = srv.complete(self.t)
                    if srv.departure_event:
                        srv.departure_event.is_valid = False
                    srv.departure_event = None
                    
                    source_queue_id = f"Q{preempted_job.cls}"
                    if source_queue_id in st.queues:
                         st.queues[source_queue_id].push(preempted_job, self.t)
                    else:
                        print(f"Warning: Could not find queue for preempted job class {preempted_job.cls}")

                if not srv.busy and ideal_queue and len(ideal_queue) > 0:
                    job_to_start = ideal_queue.pop()
                    dep_time = srv.start_service(job_to_start, self.t)
                    dep_event = self.schedule(dep_time, EventType.DEPARTURE, {"station_id": st_id, "server_idx": i})
                    srv.departure_event = dep_event

    def _next_job_id(self) -> int:
        self._job_id_seq += 1
        return self._job_id_seq

    def queue_length(self, station_id: str, queue_id: str) -> int:
        return len(self.stations[station_id].queues[queue_id])

    def station_queue_lengths(self, station_id: str) -> Dict[str, int]:
        st = self.stations[station_id]
        return {qid: len(q) for qid, q in st.queues.items()}

    def total_queue_lengths(self) -> Dict[Tuple[str, str], int]:
        out = {}
        for sid, st in self.stations.items():
            for qid, q in st.queues.items():
                out[(sid, qid)] = len(q)
        return out

    def mean_sojourn(self) -> float:
        return self.sum_sojourn / self.completed_jobs if self.completed_jobs else float("nan")
# -------- Utility samplers (Exp arrivals/services) --------

def exp_interarrival(rate: float, rng: np.random.Generator) -> Callable[[], float]:
    assert rate > 0
    return lambda: rng.exponential(1.0 / rate)

def exp_service(mu: float, rng: np.random.Generator) -> Callable[[Job], float]:
    assert mu > 0
    return lambda job: rng.exponential(1.0 / mu)


In [7]:
class LBFSPolicy(SchedulingPolicy):
    """
    Last-Buffer First-Serve (LBFS) Policy for a Preemptive System.
    At each decision epoch, this policy identifies the non-empty queue with the
    highest class index at each station. It then assigns all servers at that
    station to this single highest-priority queue.
    """
    def decide(self, net: Network, t: float, stations: Dict[str, Station]) -> Dict[Server, Queue]:
        assignments: Dict[Server, Queue] = {}
        for st_id, st in stations.items():
            best_q: Optional[Queue] = None
            max_cls_id = -1

            # Find the single highest-priority non-empty queue for the station
            for q in st.queues.values():
                if len(q) > 0:
                    try:
                        cls_id = int(q.queue_id.replace("Q", ""))
                        if cls_id > max_cls_id:
                            max_cls_id = cls_id
                            best_q = q
                    except ValueError:
                        continue
            
            # If a priority queue was found, assign all servers at this station to it
            if best_q:
                for srv in st.servers:
                    assignments[srv] = best_q
                        
        return assignments

In [8]:
class FCFSPolicy(SchedulingPolicy):
    """
    First-Come, First-Serve (FCFS) Policy - System-Wide for a Preemptive System.
    At each decision epoch, for each station, this policy finds all available
    jobs (at the head of each queue), sorts them by their original network
    entry time (t_arrival), and assigns servers to the oldest available jobs.
    """
    def decide(self, net: Network, t: float, stations: Dict[str, Station]) -> Dict[Server, Queue]:
        assignments: Dict[Server, Queue] = {}
        for st_id, st in stations.items():
            # Find all available jobs at the station (job, and its queue)
            available_jobs = []
            for q in st.queues.values():
                if len(q) > 0:
                    job = q.peek()
                    available_jobs.append((job, q))
            
            # Sort the jobs by their original system entry time (oldest first)
            available_jobs.sort(key=lambda item: item[0].t_arrival)
            
            # Assign servers to the highest-priority (oldest) jobs
            for i, srv in enumerate(st.servers):
                if i < len(available_jobs):
                    # Assign this server to the i-th oldest job's queue
                    job, queue = available_jobs[i]
                    assignments[srv] = queue
        return assignments

In [9]:
class ExtendedSixClassNetwork(Network):
    """
    "Extended Six-Class Queueing Network" from Figure 4
    of Dai and Gluzman (2022).
    """
    def __init__(self,
                 policy: SchedulingPolicy,
                 *,
                 L: int, # Number of stations
                 seed: int = 0,
                 ):
        if not L >= 2:
            raise ValueError("L must be an integer >= 2 for this network.")
        rng = np.random.default_rng(seed)
        lam = 9.0 / 140.0
        mu_rates = { 
            1: 1.0 / 8.0, 2: 1.0 / 2.0, 3: 1.0 / 4.0,
            4: 1.0 / 6.0, 5: 1.0 / 7.0, 6: 1.0 / 1.0,
        }

        def service_sampler(job: Job) -> float:
            class_idx = int(job.cls)
            key = (class_idx - 1) % 6 + 1
            rate = mu_rates[key]
            return rng.exponential(1.0 / rate)

        stations: Dict[str, Station] = {}
        for i in range(1, L + 1):
            sid = f"S{i}"
            station_queues: Dict[str, Queue] = {}
            for k in range(1, 4):
                class_id = 3 * (i - 1) + k
                qid = f"Q{class_id}"
                station_queues[qid] = Queue(sid, qid)
            station_servers = [Server(f"{sid}-s0", service_sampler)]
            stations[sid] = Station(sid, station_queues, station_servers)

        arrival_sampler = lambda: rng.exponential(1.0 / lam)
        ap1 = ArrivalProcess("S1", "Q1", arrival_sampler, job_class="1")
        ap3 = ArrivalProcess("S1", "Q3", arrival_sampler, job_class="3")

        class _Router:
            def route(self, job: Job, station_id: str, t: float) -> Optional[Tuple[str, str]]:
                station_num = int(station_id.replace("S", ""))
                class_num = int(job.cls)
                if station_num < L:
                    next_class = class_num + 3
                    next_station = station_num + 1
                    job.cls = str(next_class)
                    return (f"S{next_station}", f"Q{next_class}")
                elif station_num == L:
                    if class_num == 3 * (L - 1) + 1:
                        job.cls = "2"
                        return ("S1", "Q2")
                    else:
                        return None
                return None

        super().__init__(
            stations=stations,
            arrivals=[ap1, ap3],
            router=_Router(),
            policy=policy,
            rng=rng,
        )
        self._params = dict(L=L, lam=lam, mu_rates=mu_rates)

    # --- Metrics and Experiment Helpers ---

    def run_and_get_batch_means_stats(
        self,
        warmup_time: float,
        num_batches: int,
        batch_duration: float
    ) -> Dict[str, Any]:
        """
        Runs a simulation with warmup and uses the batch means method to
        get a stable estimate of the mean number of jobs in the system.
        """
        print(f"Running warmup for {warmup_time:.0f} time units...")
        self.run(until_time=warmup_time)
        print("Warmup complete. Starting batch means measurement...")

        batch_means = []
        
        for i in range(num_batches):
            # Reset the area counters for queues and servers at the start of the batch
            for st in self.stations.values():
                if hasattr(st, "_ql_area"):
                    st._ql_area = {qid: 0.0 for qid in st.queues}
                if hasattr(st, "_sl_area"): # NEW
                    st._sl_area = 0.0       # NEW

            t_batch_start = self.t
            self.run(until_time=t_batch_start + batch_duration)
            
            # Calculate the mean jobs for this batch (queues + servers)
            total_area_this_batch = 0
            for st in self.stations.values():
                if hasattr(st, "_ql_area"):
                    total_area_this_batch += sum(st._ql_area.values())
                if hasattr(st, "_sl_area"): # NEW
                    total_area_this_batch += st._sl_area # NEW
            
            mean_jobs_this_batch = total_area_this_batch / batch_duration
            batch_means.append(mean_jobs_this_batch)
            
            # Optional: uncomment to see progress
            # print(f"Batch {i+1}/{num_batches} complete. Mean jobs: {mean_jobs_this_batch:.3f}")

        # Calculate statistics over the batch means
        mean_of_means = np.mean(batch_means)
        std_of_means = np.std(batch_means, ddof=1)
        
        # 95% CI half-width using z=1.96 (appropriate for num_batches >= 30)
        ci_half_width = 1.96 * (std_of_means / np.sqrt(num_batches))

        print("Measurement complete.")
        return {
            "mean_jobs_in_system": mean_of_means,
            "ci_half_width": ci_half_width,
            "std_dev_of_batch_means": std_of_means,
            "num_batches": num_batches,
        }

In [10]:
def run6Class(NUM_STATIONS, seed):
    
    print(f"--- Simulating the {3*NUM_STATIONS}-Class Network with LBFS Policy (using Batch Means) ---")
    
    lbfs_policy = LBFSPolicy()
    network = ExtendedSixClassNetwork(policy=lbfs_policy, L=NUM_STATIONS, seed=seed)
    results = network.run_and_get_batch_means_stats(
        warmup_time=10000.0,
        num_batches=50,
        batch_duration=100000.0
    )

    print("\n--- Final Simulation Results ---")
    print(f"Mean number of jobs in system: {results['mean_jobs_in_system']:.3f}")
    print(f"95% Confidence Interval: +/- {results['ci_half_width']:.3f}")
    print(f"Result: {results['mean_jobs_in_system']:.3f} ± {results['ci_half_width']:.3f}")
    print("\n")
    print(f"--- Simulating the {3*NUM_STATIONS}-Class Network with FIFO Policy (using Batch Means) ---")
    
    fifo_policy = FCFSPolicy()
    network = ExtendedSixClassNetwork(policy=fifo_policy, L=NUM_STATIONS, seed=seed)
    results = network.run_and_get_batch_means_stats(
        warmup_time=10000.0,
        num_batches=50,
        batch_duration=100000.0
    )

    print("\n--- Final Simulation Results ---")
    print(f"Mean number of jobs in system: {results['mean_jobs_in_system']:.3f}")
    print(f"95% Confidence Interval: +/- {results['ci_half_width']:.3f}")
    print(f"Result: {results['mean_jobs_in_system']:.3f} ± {results['ci_half_width']:.3f}")

run6Class(2,5)

--- Simulating the 6-Class Network with LBFS Policy (using Batch Means) ---
Running warmup for 10000 time units...
Warmup complete. Starting batch means measurement...
Measurement complete.

--- Final Simulation Results ---
Mean number of jobs in system: 14.284
95% Confidence Interval: +/- 0.608
Result: 14.284 ± 0.608


--- Simulating the 6-Class Network with FIFO Policy (using Batch Means) ---
Running warmup for 10000 time units...
Warmup complete. Starting batch means measurement...
Measurement complete.

--- Final Simulation Results ---
Mean number of jobs in system: 18.496
95% Confidence Interval: +/- 0.755
Result: 18.496 ± 0.755
