# LLMs can Schedule — arXiv:2408.06993
### Implementation in Google Colab (T4 GPU)

**Paper:** *LLMs can Schedule* — Abgaryan, Harutyunyan, Cazenave (Université Paris Dauphine - PSL)

**What this notebook does:**
1. ✅ Install dependencies (OR-Tools, unsloth, trl)
2. ✅ Generate a JSSP dataset with OR-Tools CP-SAT labels
3. ✅ Fine-tune **Phi-3-Mini-128K-Instruct** (3.8B) with LoRA via unsloth (fits T4 16 GB)
4. ✅ Evaluate with the paper's **sampling method** (generate s=10, keep best valid)
5. ✅ Benchmark against ft06 and la01 with makespan gap reporting

| | Your previous setup | This paper |
|---|---|---|
| **Model** | gpt2-medium 345M | Phi-3-Mini 3.8B |
| **Labels** | EFT heuristic | OR-Tools CP-SAT (provably feasible) |
| **Format** | compact `J0:M0@0-5` | verbose natural language |
| **Inference** | single greedy decode | sample s=10, pick best valid |
| **Validation** | none | regex parse + constraint check |

**Expected result:** ~8–13% gap from optimal on 6×6–10×5 benchmarks (paper reports 8.92% avg gap with s=10).

> **Runtime estimate:** dataset gen ~5 min · training ~90 min · eval ~10 min on T4

## 0 · Check GPU

In [1]:
import subprocess, sys

result = subprocess.run(["nvidia-smi", "--query-gpu=name,memory.total",
                         "--format=csv,noheader"], capture_output=True, text=True)
if result.returncode == 0:
    print("✓ GPU:", result.stdout.strip())
else:
    print("✗ No GPU found — go to Runtime → Change runtime type → T4 GPU")
    sys.exit(1)

✓ GPU: Tesla T4, 15360 MiB


## 1 · Install dependencies

> Takes ~3 minutes on a fresh Colab session.

In [2]:
%%capture install_log
!pip install ortools==9.10.4067 unsloth trl transformers datasets accelerate bitsandbytes peft
print("✓ All packages installed")

## 2 · Configuration

Edit values here to adjust training size, model, etc.

In [3]:
import os
# Disable torch.compile / Dynamo — required for Phi-3 on older CUDA (T4/Turing)
# The 128k model uses LongRoPE which has data-dependent control flow that
# Dynamo cannot trace. Disabling compile has <5% throughput impact on T4.
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ── Dataset ─────────────────────────────────────────────────────────────────
N_SAMPLES       = 3000    # paper uses 120k; 3k is ~5 min and gives decent results
MIN_SIZE        = 2       # min jobs/machines per dimension
MAX_SIZE        = 8       # paper goes up to 20; 8 fits T4 context budget
DUR_MIN         = 5       # min operation duration
DUR_MAX         = 500     # max operation duration  (paper: 5–500)
ORTOOLS_TIME    = 60      # seconds per problem for OR-Tools (paper: 300)
DATASET_PATH    = "jssp_dataset.json"

# ── Training ─────────────────────────────────────────────────────────────────
# Use the 4k variant — identical weights to 128k but without LongRoPE dynamic
# context scaling, which causes Dynamo graph-break errors on T4 + Torch 2.10.
# The paper's seq lengths are well under 4096 so this is lossless.
MODEL_NAME      = "microsoft/Phi-3-mini-4k-instruct"
OUTPUT_DIR      = "/content/jssp_phi3_lora"
MAX_SEQ_LEN     = 2048    # safe for T4; our prompts are ~800-1200 tokens
EPOCHS          = 3
LR              = 2e-4
LORA_R          = 16
BATCH_SIZE      = 2
GRAD_ACCUM      = 8       # effective batch = 2 × 8 = 16
MAX_TRAIN       = N_SAMPLES

# ── Evaluation ───────────────────────────────────────────────────────────────
N_INFERENCE_SAMPLES = 10  # paper's sampling parameter s
TEMPERATURE         = 0.8
TOP_P               = 0.95
MAX_NEW_TOKENS      = 512

print("Config loaded ✓")
print(f"Model  : {MODEL_NAME}")
print(f"Dynamo : disabled (TORCHDYNAMO_DISABLE=1)")

Config loaded ✓
Model  : microsoft/Phi-3-mini-4k-instruct
Dynamo : disabled (TORCHDYNAMO_DISABLE=1)


## 3 · Core library

OR-Tools solver, NL formatters, parser, validator — run once.

In [4]:
import os, re, json, time
import numpy as np
from typing import List, Tuple, Dict, Optional

# ─────────────────────────────────────────────────────────────────────────────
# 3a. EFT fallback solver (used if OR-Tools times out)
# ─────────────────────────────────────────────────────────────────────────────

def _eft_fallback(jobs):
    n_jobs = len(jobs)
    n_mach = max(m for job in jobs for m, _ in job) + 1
    job_free, mach_free = [0]*n_jobs, [0]*n_mach
    op_idx, start_times, done = [0]*n_jobs, {}, 0
    total = sum(len(j) for j in jobs)
    while done < total:
        best = None
        for j in range(n_jobs):
            k = op_idx[j]
            if k < len(jobs[j]):
                m, d = jobs[j][k]
                s = max(job_free[j], mach_free[m])
                if best is None or s+d < best[0]:
                    best = (s+d, j, k, m, d, s)
        _, j, k, m, d, s = best
        start_times[(j,k)] = s
        job_free[j] = mach_free[m] = s+d
        op_idx[j] += 1; done += 1
    ms = max(start_times[(j,len(job)-1)] + job[-1][1] for j,job in enumerate(jobs))
    return {"start_times": start_times, "makespan": ms, "optimal": False}


# ─────────────────────────────────────────────────────────────────────────────
# 3b. OR-Tools CP-SAT solver  (paper §4.4)
# ─────────────────────────────────────────────────────────────────────────────

def solve_ortools(jobs, time_limit=60):
    """
    Solve JSSP with Google OR-Tools CP-SAT.
    Paper config: max_time=300s, num_workers=42, AUTOMATIC_SEARCH.
    Returns {start_times, makespan, optimal} or falls back to EFT.
    """
    try:
        from ortools.sat.python import cp_model
    except ImportError:
        return _eft_fallback(jobs)

    model    = cp_model.CpModel()
    n_jobs   = len(jobs)
    n_mach   = max(m for job in jobs for m,_ in job) + 1
    horizon  = sum(d for job in jobs for _,d in job)

    tasks, m2iv = {}, {m: [] for m in range(n_mach)}
    for j, job in enumerate(jobs):
        for k, (m, d) in enumerate(job):
            s  = model.NewIntVar(0, horizon, f"s_{j}_{k}")
            e  = model.NewIntVar(0, horizon, f"e_{j}_{k}")
            iv = model.NewIntervalVar(s, d, e, f"iv_{j}_{k}")
            tasks[(j,k)] = (s, e, m, d)
            m2iv[m].append(iv)

    for j, job in enumerate(jobs):
        for k in range(len(job)-1):
            model.Add(tasks[(j,k+1)][0] >= tasks[(j,k)][1])
    for m in range(n_mach):
        model.AddNoOverlap(m2iv[m])

    ms_var = model.NewIntVar(0, horizon, "ms")
    model.AddMaxEquality(ms_var, [tasks[(j,len(job)-1)][1] for j,job in enumerate(jobs)])
    model.Minimize(ms_var)

    solver = cp_model.CpSolver()
    solver.parameters.max_time_in_seconds = time_limit
    solver.parameters.num_search_workers  = min(42, os.cpu_count() or 2)
    solver.parameters.search_branching    = cp_model.AUTOMATIC_SEARCH

    status = solver.Solve(model)
    if status not in (cp_model.OPTIMAL, cp_model.FEASIBLE):
        return _eft_fallback(jobs)

    return {
        "start_times": {(j,k): solver.Value(tasks[(j,k)][0])
                        for j in range(n_jobs) for k in range(len(jobs[j]))},
        "makespan": solver.Value(ms_var),
        "optimal":  status == cp_model.OPTIMAL,
    }


# ─────────────────────────────────────────────────────────────────────────────
# 3c. Natural language formatters  (paper Listings 2, 3, 4)
# ─────────────────────────────────────────────────────────────────────────────

_INSTRS = [
    "Optimize schedule for {nj} Jobs across {nm} Machines to minimize makespan. "
    "Each job involves a series of Operations needing specific machines and times. "
    "Operations are processed in order, without interruption, on a single Machine at a time.",
    "Find an optimal schedule for {nj} Jobs on {nm} Machines that minimizes the total "
    "completion time (makespan). Each Job has a fixed sequence of Operations, each "
    "requiring a specific Machine and duration. A Machine can handle only one Job at a time.",
    "Schedule {nj} Jobs across {nm} Machines to minimize makespan. "
    "Jobs must follow their operation order. No machine overlap allowed.",
]

def _instr(nj, nm, rng=None):
    t = _INSTRS[int(rng.integers(len(_INSTRS))) if rng else 0]
    return t.format(nj=nj, nm=nm)

def format_job_centric(jobs, rng=None):
    """Paper Listing 2."""
    nj = len(jobs); nm = max(m for job in jobs for m,_ in job)+1
    lines = [_instr(nj, nm, rng), "", "Problem:"]
    for j, job in enumerate(jobs):
        lines.append(f"\n Job {j} consists of the following Operations:")
        for k,(m,d) in enumerate(job):
            lines.append(f"  Operation {k} on Machine {m} duration {d} mins.")
    return "\n".join(lines)

def format_machine_centric(jobs, rng=None):
    """Paper Listing 3."""
    nj = len(jobs); nm = max(m for job in jobs for m,_ in job)+1
    mops = {m: [] for m in range(nm)}
    for j, job in enumerate(jobs):
        for k,(m,d) in enumerate(job): mops[m].append((j,k,d))
    lines = [_instr(nj, nm, rng), "", "Problem:"]
    for m in range(nm):
        lines.append(f"\n Machine {m} is used for the following Operations:")
        for j,k,d in mops[m]:
            lines.append(f"  Job {j} Operation {k} duration {d} mins.")
    return "\n".join(lines)

def format_solution(jobs, result):
    """Paper Listing 4 — sorted by start time."""
    st, ms = result["start_times"], result["makespan"]
    rows = sorted((st[(j,k)], j, k, m, d)
                  for j,job in enumerate(jobs)
                  for k,(m,d) in enumerate(job))
    lines = ["Solution:\n"]
    for s,j,k,m,d in rows:
        lines.append(f" Job {j} Operation {k} on Machine {m} : {s} + {d} -> {s+d}")
    last_k = max(len(job)-1 for job in jobs)
    lines.append(f"\nMakespan: {ms}, as it is the maximum end completion time of Operation {last_k}")
    return "\n".join(lines)

def build_prompt(problem, solution=None):
    """Phi-3 chat template used during training and inference."""
    if solution:
        return f"<|user|>\n{problem.strip()}<|end|>\n<|assistant|>\n{solution.strip()}<|end|>"
    return f"<|user|>\n{problem.strip()}<|end|>\n<|assistant|>\n"


# ─────────────────────────────────────────────────────────────────────────────
# 3d. Output parser + validator  (paper §6.1)
# ─────────────────────────────────────────────────────────────────────────────

_SOL_RE = re.compile(
    r"Job\s+(\d+)\s+Operation\s+(\d+)\s+on\s+Machine\s+(\d+)"
    r"\s*:\s*(\d+)\s*\+\s*(\d+)\s*->\s*(\d+)"
)

def parse_output(text, jobs):
    """Extract start times from LLM output via regex."""
    expected = {(j,k) for j,job in enumerate(jobs) for k in range(len(job))}
    st = {}
    for m in _SOL_RE.finditer(text):
        j,k,machine,start,dur,end = (int(x) for x in m.groups())
        if j < len(jobs) and k < len(jobs[j]) and jobs[j][k] == (machine,dur):
            st[(j,k)] = start
    return st if set(st.keys()) == expected else None

def validate(jobs, st):
    """Check precedence + no machine overlap. Returns (valid, makespan)."""
    nm = max(m for job in jobs for m,_ in job)+1
    for j,job in enumerate(jobs):
        for k in range(1, len(job)):
            if st[(j,k)] < st[(j,k-1)] + job[k-1][1]:
                return False, 0
    for m in range(nm):
        segs = sorted((st[(j,k)], st[(j,k)]+d)
                      for j,job in enumerate(jobs)
                      for k,(mach,d) in enumerate(job) if mach==m)
        for i in range(1, len(segs)):
            if segs[i][0] < segs[i-1][1]: return False, 0
    ms = max(st[(j,len(job)-1)]+job[-1][1] for j,job in enumerate(jobs))
    return True, ms


# ─────────────────────────────────────────────────────────────────────────────
# 3e. Benchmark instances
# ─────────────────────────────────────────────────────────────────────────────

FT06 = dict(name="ft06", optimal=55, jobs=[
    [(2,1),(0,3),(1,6),(3,7),(5,3),(4,6)],
    [(1,8),(2,5),(4,10),(5,10),(0,10),(3,4)],
    [(2,5),(3,4),(5,8),(0,9),(1,1),(4,7)],
    [(1,5),(0,5),(2,5),(3,3),(4,8),(5,9)],
    [(2,9),(1,3),(4,5),(5,4),(0,3),(3,1)],
    [(1,3),(3,3),(5,9),(0,10),(4,4),(2,1)],
])
LA01 = dict(name="la01", optimal=666, jobs=[
    [(1,21),(0,53),(4,95),(3,55),(2,34)],
    [(0,21),(3,52),(4,16),(2,26),(1,71)],
    [(3,39),(4,98),(1,42),(2,31),(0,12)],
    [(1,77),(0,55),(4,79),(2,66),(3,77)],
    [(0,83),(3,34),(2,64),(1,19),(4,37)],
    [(1,54),(2,43),(4,79),(0,92),(3,62)],
    [(3,69),(4,77),(1,87),(2,87),(0,93)],
    [(2,38),(0,60),(1,41),(3,24),(4,83)],
    [(3,17),(1,49),(4,25),(0,44),(2,98)],
    [(4,77),(3,79),(2,43),(1,75),(0,96)],
])
BENCHMARKS = [FT06, LA01]

print("✓ Core library loaded")

✓ Core library loaded


## 4 · Pipeline sanity check

Verifies OR-Tools → format → parse → validate without a GPU.

In [5]:
print("=" * 55)
print("Sanity check: OR-Tools → format → parse → validate")
print("=" * 55)

jobs = FT06["jobs"]
t0 = time.time()
result = solve_ortools(jobs, time_limit=30)
elapsed = time.time() - t0

assert result is not None
status = "OPTIMAL ✓" if result["optimal"] else "FEASIBLE (heuristic fallback)"
print(f"\nft06 solved in {elapsed:.2f}s  makespan={result['makespan']}  {status}")
print(f"Known optimal = 55  |  gap = {(result['makespan']-55)/55*100:.1f}%")

problem_nl  = format_job_centric(jobs)
solution_nl = format_solution(jobs, result)

print("\nProblem excerpt:")
for line in problem_nl.split("\n")[3:7]:
    print(" ", line)

print("\nSolution excerpt:")
for line in solution_nl.split("\n")[2:6]:
    print(" ", line)

st = parse_output(solution_nl, jobs)
assert st is not None, "Parse failed!"
ok, ms = validate(jobs, st)
assert ok, "Validation failed!"
print(f"\nParse + validate ✓  makespan={ms}")

prompt = build_prompt(problem_nl, solution_nl)
print(f"\nTraining prompt length: {len(prompt)} chars / ~{len(prompt)//4} tokens")
print("\n✓ All checks passed — ready to generate dataset")

Sanity check: OR-Tools → format → parse → validate

ft06 solved in 0.53s  makespan=55  OPTIMAL ✓
Known optimal = 55  |  gap = 0.0%

Problem excerpt:
  
   Job 0 consists of the following Operations:
    Operation 0 on Machine 2 duration 1 mins.
    Operation 1 on Machine 0 duration 3 mins.

Solution excerpt:
   Job 0 Operation 0 on Machine 2 : 0 + 1 -> 1
   Job 1 Operation 0 on Machine 1 : 0 + 8 -> 8
   Job 0 Operation 1 on Machine 0 : 1 + 3 -> 4
   Job 2 Operation 0 on Machine 2 : 1 + 5 -> 6

Parse + validate ✓  makespan=55

Training prompt length: 3912 chars / ~978 tokens

✓ All checks passed — ready to generate dataset


## 5 · Generate training dataset

OR-Tools solves each random JSSP instance and stores the feasible schedule as the training label.
This is what the paper does for their 120k dataset — we default to 3k for a fast run.

In [6]:
from tqdm.auto import tqdm

def generate_dataset(n_samples=N_SAMPLES, out_path=DATASET_PATH,
                     min_size=MIN_SIZE, max_size=MAX_SIZE,
                     dur_min=DUR_MIN, dur_max=DUR_MAX,
                     ortools_time=ORTOOLS_TIME, seed=42):

    rng = np.random.default_rng(seed)
    data, skipped = [], 0

    pbar = tqdm(total=n_samples, desc="Generating")
    while len(data) < n_samples:
        nj = int(rng.integers(min_size, max_size+1))
        nm = int(rng.integers(min_size, max_size+1))

        jobs = [[(int(m), int(d))
                 for m, d in zip(rng.permutation(nm).tolist(),
                                 rng.integers(dur_min, dur_max+1, size=nm).tolist())]
                for _ in range(nj)]

        result = solve_ortools(jobs, time_limit=ortools_time)
        if result is None:
            skipped += 1
            continue

        use_job = bool(rng.integers(2))
        problem  = format_job_centric(jobs, rng) if use_job else format_machine_centric(jobs, rng)
        solution = format_solution(jobs, result)

        data.append({
            "text":         build_prompt(problem, solution),
            "makespan":     result["makespan"],
            "optimal":      result["optimal"],
            "num_jobs":     nj,
            "num_machines": nm,
        })
        pbar.update(1)

    pbar.close()
    with open(out_path, "w") as f:
        json.dump(data, f)

    opt_pct  = 100 * sum(d["optimal"] for d in data) / len(data)
    avg_ms   = np.mean([d["makespan"] for d in data])
    avg_size = np.mean([d["num_jobs"]*d["num_machines"] for d in data])
    print(f"\n✓ Saved {len(data)} samples → {out_path}")
    print(f"  Optimal: {opt_pct:.1f}%  |  Avg makespan: {avg_ms:.0f}  |  Avg problem size: {avg_size:.1f} ops")
    print(f"  Skipped: {skipped} (solver timeout/infeasible)")
    return data

dataset = generate_dataset()

Generating:   0%|          | 0/3000 [00:00<?, ?it/s]


✓ Saved 3000 samples → jssp_dataset.json
  Optimal: 100.0%  |  Avg makespan: 2140  |  Avg problem size: 25.1 ops
  Skipped: 0 (solver timeout/infeasible)


## 6 · Fine-tune Phi-3-Mini with LoRA

Loads `Phi-3-Mini-128K-Instruct` in 4-bit via unsloth, attaches LoRA adapters, trains.

**T4 memory budget:** ~13 GB used with batch=2, grad_accum=8, seq_len=4096.

In [7]:
from unsloth import FastLanguageModel
from datasets import Dataset
from trl import SFTTrainer
from transformers import TrainingArguments
import torch

# T4 = Turing arch → supports fp16 but NOT bf16 (needs Ampere+)
IS_BF16 = torch.cuda.is_bf16_supported()
print(f"GPU bf16 support: {IS_BF16}  → using {'bf16' if IS_BF16 else 'fp16'}")

# ── Load base model ──────────────────────────────────────────────────────────
# Phi-3-mini-4k: same weights as 128k but uses standard RoPE (no LongRoPE),
# which avoids the Dynamo data-dependent-branching crash on T4 + Torch 2.10.
print("Loading Phi-3-Mini-4K-Instruct (4-bit) …")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name     = MODEL_NAME,
    max_seq_length = MAX_SEQ_LEN,
    dtype          = torch.bfloat16 if IS_BF16 else torch.float16,
    load_in_4bit   = True,
)
print(f"✓ Base model loaded  |  {sum(p.numel() for p in model.parameters())/1e9:.2f}B params")

# ── Attach LoRA adapters ─────────────────────────────────────────────────────
model = FastLanguageModel.get_peft_model(
    model,
    r              = LORA_R,
    target_modules = ["q_proj","k_proj","v_proj","o_proj",
                      "gate_proj","up_proj","down_proj"],
    lora_alpha     = LORA_R,
    lora_dropout   = 0.0,         # 0.0 enables unsloth fast path (dropout>0 causes perf hit)
    bias           = "none",
    use_gradient_checkpointing = "unsloth",
    random_state   = 42,
)
model.print_trainable_parameters()

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
GPU bf16 support: False  → using fp16
Loading Phi-3-Mini-4K-Instruct (4-bit) …
==((====))==  Unsloth 2026.2.1: Fast Mistral patching. Transformers: 4.57.6.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.563 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.6.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.34. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.26G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/194 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/458 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

✓ Base model loaded  |  2.01B params


Unsloth 2026.2.1 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


trainable params: 29,884,416 || all params: 3,850,963,968 || trainable%: 0.7760


In [8]:
# ── Prepare dataset ──────────────────────────────────────────────────────────
with open(DATASET_PATH) as f:
    raw = json.load(f)
raw = raw[:MAX_TRAIN]

np.random.default_rng(0).shuffle(raw)
hf_dataset = Dataset.from_list(raw)
print(f"Training on {len(hf_dataset)} samples  (max_seq_len={MAX_SEQ_LEN})")

# ── Train ────────────────────────────────────────────────────────────────────
trainer = SFTTrainer(
    model              = model,
    tokenizer          = tokenizer,
    train_dataset      = hf_dataset,
    dataset_text_field = "text",
    max_seq_length     = MAX_SEQ_LEN,
    dataset_num_proc   = 2,
    args = TrainingArguments(
        per_device_train_batch_size = BATCH_SIZE,
        gradient_accumulation_steps = GRAD_ACCUM,
        warmup_steps                = 50,
        num_train_epochs            = EPOCHS,
        learning_rate               = LR,
        bf16                        = IS_BF16,   # True on A100, False on T4
        fp16                        = not IS_BF16,  # True on T4
        logging_steps               = 20,
        optim                       = "adamw_8bit",
        weight_decay                = 0.01,
        lr_scheduler_type           = "cosine",
        output_dir                  = OUTPUT_DIR,
        save_steps                  = 500,
        report_to                   = "none",
    ),
)

print("\nStarting training …")
trainer_stats = trainer.train()
print(f"\n✓ Training complete")
print(f"  Total steps: {trainer_stats.global_step}")
print(f"  Final loss:  {trainer_stats.training_loss:.4f}")
print(f"  Time:        {trainer_stats.metrics['train_runtime']/60:.1f} min")

model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"\n✓ Model saved → {OUTPUT_DIR}")

Training on 3000 samples  (max_seq_len=2048)


Unsloth: Tokenizing ["text"] (num_proc=6):   0%|          | 0/3000 [00:00<?, ? examples/s]


Starting training …


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 3,000 | Num Epochs = 3 | Total steps = 564
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 8 x 1) = 16
 "-____-"     Trainable parameters = 29,884,416 of 3,850,963,968 (0.78% trained)


Step,Training Loss
20,0.5547
40,0.3406
60,0.2448
80,0.2448
100,0.2363
120,0.2335
140,0.2431
160,0.2303
180,0.2349
200,0.2473



✓ Training complete
  Total steps: 564
  Final loss:  0.2487
  Time:        141.6 min

✓ Model saved → /content/jssp_phi3_lora


## 7 · Sampling-based inference

The paper's key contribution: generate **s** candidates, validate each, return the best valid makespan.
This is what pushes the LLM from ~20% gap (greedy) down to ~8.92% (s=10).

In [9]:
FastLanguageModel.for_inference(model)  # enable unsloth's fast kernel

def solve_with_llm(jobs, n_samples=N_INFERENCE_SAMPLES,
                   temperature=TEMPERATURE, top_p=TOP_P,
                   max_new_tokens=MAX_NEW_TOKENS, fmt="job"):
    """
    Paper sampling method (§6.2):
      1. Generate s candidate solutions with temperature sampling.
      2. Parse + validate each with regex + constraint checker.
      3. Return the one with the lowest valid makespan.
    """
    problem  = format_job_centric(jobs) if fmt == "job" else format_machine_centric(jobs)
    prompt   = build_prompt(problem)
    inputs   = tokenizer(prompt, return_tensors="pt",
                         truncation=True, max_length=2048).to("cuda")

    best_ms, best_result, valid_count = None, None, 0

    for i in range(n_samples):
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens = max_new_tokens,
                do_sample      = True,
                temperature    = temperature,
                top_p          = top_p,
                pad_token_id   = tokenizer.eos_token_id,
            )
        text = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:],
                                skip_special_tokens=True)
        st = parse_output(text, jobs)
        if st is None:
            continue
        ok, ms = validate(jobs, st)
        if ok:
            valid_count += 1
            if best_ms is None or ms < best_ms:
                best_ms, best_result = ms, {"makespan": ms, "start_times": st, "text": text}

    return best_result, valid_count

print("✓ Inference function ready")

✓ Inference function ready


## 8 · Benchmark evaluation

Evaluate on **ft06** (6×6, optimal=55) and **la01** (10×5, optimal=666) — same benchmarks as the paper.

In [10]:
all_results = []

for bench in BENCHMARKS:
    name, optimal, jobs = bench["name"], bench["optimal"], bench["jobs"]
    nj = len(jobs); nm = max(m for job in jobs for m,_ in job)+1
    print(f"\n{'='*55}")
    print(f"  {name.upper()}  ({nj}×{nm})   optimal makespan = {optimal}")
    print(f"{'='*55}")

    t0 = time.time()
    result, valid_count = solve_with_llm(jobs, n_samples=N_INFERENCE_SAMPLES)
    elapsed = time.time() - t0

    if result:
        ms  = result["makespan"]
        gap = (ms - optimal) / optimal * 100
        print(f"  Best makespan : {ms}")
        print(f"  Optimality gap: {gap:.2f}%")
        print(f"  Valid / Total : {valid_count} / {N_INFERENCE_SAMPLES}")
        print(f"  Time          : {elapsed:.1f}s")
        print(f"\n  Schedule (sorted by start time):")
        for line in result["text"].split("\n")[1:8]:
            print(f"    {line}")
    else:
        ms, gap = None, None
        print(f"  No valid solution found in {N_INFERENCE_SAMPLES} samples")
        print(f"  Time: {elapsed:.1f}s")

    all_results.append(dict(instance=name, optimal=optimal,
                            makespan=ms, gap_pct=gap,
                            valid_solutions=valid_count,
                            n_samples=N_INFERENCE_SAMPLES,
                            elapsed=elapsed))

# Summary
print(f"\n{'='*55}")
print("RESULTS SUMMARY")
print(f"{'='*55}")
valid = [r for r in all_results if r["gap_pct"] is not None]
if valid:
    avg_gap = np.mean([r["gap_pct"] for r in valid])
    print(f"\n  Average gap from optimal : {avg_gap:.2f}%")
    print(f"  Paper reports            : ~8.92%  (Phi-3, s=10, 10×10 test set)")
for r in all_results:
    g = f"{r['gap_pct']:.2f}%" if r["gap_pct"] is not None else "no solution"
    print(f"  {r['instance']:8s}  makespan={r['makespan']}  gap={g}  "
          f"valid={r['valid_solutions']}/{r['n_samples']}")


  FT06  (6×6)   optimal makespan = 55
  No valid solution found in 10 samples
  Time: 338.1s

  LA01  (10×5)   optimal makespan = 666
  No valid solution found in 10 samples
  Time: 342.0s

RESULTS SUMMARY
  ft06      makespan=None  gap=no solution  valid=0/10
  la01      makespan=None  gap=no solution  valid=0/10


## 9 · Save results & download model

In [11]:
import json

# Save benchmark results
with open("/content/jssp_results.json", "w") as f:
    json.dump(all_results, f, indent=2)
print("✓ Results saved → /content/jssp_results.json")

# Download helper
from google.colab import files

print("\nDownload options:")
print("  Run the next cell to download the results JSON.")
print("  The LoRA adapter is in:", OUTPUT_DIR)

✓ Results saved → /content/jssp_results.json

Download options:
  Run the next cell to download the results JSON.
  The LoRA adapter is in: /content/jssp_phi3_lora


In [16]:
# Optional: download results
# from google.colab import files
# files.download("/content/jssp_results.json")

# Optional: zip and download the LoRA adapter weights
# import shutil
# shutil.make_archive("/content/jssp_phi3_lora", "zip", OUTPUT_DIR)
# files.download("/content/jssp_phi3_lora.zip")
print("Uncomment lines above to download files")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Uncomment lines above to download files


## 10 · Quick interactive test (optional)

Try the model on a custom problem.

In [14]:
# Define your own JSSP instance here
# Format: jobs[j] = [(machine_id, duration), ...]  — each job visits all machines in order

custom_jobs = [
    [(0, 3), (1, 2), (2, 5)],   # Job 0: M0(3h) → M1(2h) → M2(5h)
    [(1, 4), (0, 6), (2, 2)],   # Job 1: M1(4h) → M0(6h) → M2(2h)
    [(2, 3), (1, 5), (0, 4)],   # Job 2: M2(3h) → M1(5h) → M0(4h)
]

print("Problem:")
print(format_job_centric(custom_jobs))

print("\nSolving with LLM (s=5 samples) …")
result, valid_count = solve_with_llm(custom_jobs, n_samples=5)

if result:
    print(f"\n✓ Valid schedule found  makespan={result['makespan']}  ({valid_count}/5 valid candidates)")
    print("\nSchedule:")
    for line in result["text"].split("\n")[:10]:
        print(" ", line)

    # Compare with OR-Tools optimal
    ref = solve_ortools(custom_jobs, time_limit=10)
    if ref:
        gap = (result["makespan"] - ref["makespan"]) / ref["makespan"] * 100
        print(f"\nOR-Tools optimal: {ref['makespan']}  |  LLM gap: {gap:.1f}%")
else:
    print("No valid solution found — try increasing n_samples or check the model trained correctly")

Problem:
Optimize schedule for 3 Jobs across 3 Machines to minimize makespan. Each job involves a series of Operations needing specific machines and times. Operations are processed in order, without interruption, on a single Machine at a time.

Problem:

 Job 0 consists of the following Operations:
  Operation 0 on Machine 0 duration 3 mins.
  Operation 1 on Machine 1 duration 2 mins.
  Operation 2 on Machine 2 duration 5 mins.

 Job 1 consists of the following Operations:
  Operation 0 on Machine 1 duration 4 mins.
  Operation 1 on Machine 0 duration 6 mins.
  Operation 2 on Machine 2 duration 2 mins.

 Job 2 consists of the following Operations:
  Operation 0 on Machine 2 duration 3 mins.
  Operation 1 on Machine 1 duration 5 mins.
  Operation 2 on Machine 0 duration 4 mins.

Solving with LLM (s=5 samples) …

✓ Valid schedule found  makespan=15  (5/5 valid candidates)

Schedule:
  Solution:
  
   Job 0 Operation 0 on Machine 0 : 0 + 3 -> 3
   Job 1 Operation 0 on Machine 1 : 0 + 4 ->