In [2]:
!pip install torch

Collecting torch
  Downloading torch-2.9.1-cp310-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting filelock (from torch)
  Using cached filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec>=0.8.5 (from torch)
  Using cached fsspec-2025.10.0-py3-none-any.whl.metadata (10 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.9.1-cp310-none-macosx_11_0_arm64.whl (74.5 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.5/74.5 MB[0m [31m3.5 MB/s[0m  [33m0:00:21[0m[0m eta [36m0:00:01[0m[36m0:00:01[0m
[?25hUsing cached fsspec-2025.10.0-py3-none-any.whl (200 kB)
Using cached networkx-3.4.2-py3-none-any.whl (1.7 MB)
Using cached sympy-1.14.0-p

In [3]:
!pip install numpy

Collecting numpy
  Using cached numpy-2.2.6-cp310-cp310-macosx_14_0_arm64.whl.metadata (62 kB)
Using cached numpy-2.2.6-cp310-cp310-macosx_14_0_arm64.whl (5.3 MB)
Installing collected packages: numpy
Successfully installed numpy-2.2.6


In [7]:
import torch
import numpy as np
from typing import List, Dict, Tuple, Optional
from collections import deque


In [8]:
class LossEfficiencyTracker:
    """Tracks Loss (Utility) and Length (Cost) to compute Value-per-Watt scores"""
    
    def __init__(self, dataset_size: int, alpha: float = 1.0, beta: float = 0.5, decay: float = 0.9):
        self.dataset_size = dataset_size
        self.alpha = alpha # Importance of high loss
        self.beta = beta   # Importance of short length
        self.decay = decay
        
        # Initialize scores to 1.0 so all samples have equal chance initially
        self.efficiency_scores = np.ones(dataset_size, dtype=np.float32)
        
    def update_batch_outcomes(self, indices: List[int], losses: List[float], lengths: List[int]):
        """
        Update scores based on the actual training result.
        Equation: Score = Loss^alpha / Length^beta
        """
        for idx, loss, length in zip(indices, losses, lengths):
            if 0 <= idx < self.dataset_size:
                # Avoid division by zero
                safe_len = max(1, length)
                
                # Calculate Value-per-Watt
                # High Loss = Good (Learn more)
                # High Length = Bad (Costs more)
                new_score = (loss ** self.alpha) / (safe_len ** self.beta)
                
                # Update with moving average to keep history stable
                self.efficiency_scores[idx] = (
                    self.decay * self.efficiency_scores[idx] + 
                    (1 - self.decay) * new_score
                )
    
    def get_probabilities(self) -> np.ndarray:
        """Get normalized sampling probabilities"""
        scores = self.efficiency_scores
        # Softmax or simple normalization - simple normalization is faster
        total_score = scores.sum()
        if total_score > 0:
            return scores / total_score
        return np.ones_like(scores) / len(scores)

In [9]:
class EnergyAwareSampler:
    """
    Samples data based on 'Value-per-Watt' (Loss/Length).
    """
    def __init__(self, dataset, energy_monitor, base_batch_size=32):
        self.dataset_size = len(dataset)
        self.energy_monitor = energy_monitor
        self.base_batch_size = base_batch_size
        
        # Replace Gradient tracker with Loss tracker
        self.tracker = LossEfficiencyTracker(self.dataset_size)
        
        # Internal state
        self.epoch_indices = list(range(self.dataset_size))
        self.used_indices = set()

    def update_batch_outcomes(self, indices, losses, lengths):
        """Pass feedback from Trainer to Tracker"""
        self.tracker.update_batch_outcomes(indices, losses, lengths)

    def __iter__(self):
        """
        Standard PyTorch Sampler iterator.
        1. Calculate Probabilities based on history.
        2. Sample without replacement.
        """
        # 1. Get probabilities from our Tracker (The "Brain")
        probs = self.tracker.get_probabilities()
        
        # 2. Reset for new epoch
        self.used_indices.clear()
        remaining_indices = list(range(self.dataset_size))
        
        # 3. Yield batches
        while len(remaining_indices) > 0:
            batch_size = self.base_batch_size # (Or get from adaptive batcher)
            
            # Normalize probabilities for ONLY the remaining indices
            current_probs = probs[remaining_indices]
            current_probs = current_probs / current_probs.sum()
            
            # Select indices
            selected_indices = np.random.choice(
                remaining_indices, 
                size=min(len(remaining_indices), batch_size), 
                replace=False, 
                p=current_probs
            )
            
            # Yield indices for this batch
            yield from selected_indices
            
            # Remove used
            for idx in selected_indices:
                remaining_indices.remove(idx)