# vAGI L-KAN train tren Modal.com (GPU A100 40GB)

Notebook nay chay train tu xa tren Modal, dung GPU A100 40GB.

## Tham khao chinh
- Modal quickstart + CLI: https://modal.com/docs/guide
- Modal Python SDK (`App`, `function`, `Image`, `Volume`): https://modal.com/docs/reference/modal.App
- Modal GPU docs (A100): https://modal.com/docs/guide/gpu
- Modal image from registry: https://modal.com/docs/reference/modal.Image#from_registry
- Modal volume docs: https://modal.com/docs/reference/modal.Volume

## Output
- Checkpoint duoc luu trong Modal Volume va co cell de tai ve local.


In [None]:
import getpass
import os
import pathlib
import shlex
import subprocess
import sys

def run(cmd: str, check: bool = True, env=None):
    print(f"$ {cmd}")
    completed = subprocess.run(
        cmd,
        shell=True,
        text=True,
        capture_output=True,
        env=env,
    )
    if completed.stdout:
        print(completed.stdout)
    if completed.stderr:
        print(completed.stderr)
    if check and completed.returncode != 0:
        raise RuntimeError(f"Command failed ({completed.returncode}): {cmd}")
    return completed

print("Python:", sys.version)


In [None]:
REPO_URL = "https://github.com/vietrix/vagi.git"
BRANCH = "main"
VOLUME_NAME = "vagi-lkan-models"
CHECKPOINT_NAME = "lkan-genesis-a10040gb.safetensors"
TRAIN_STEPS = 5_000
BATCH_SIZE = 32
SEQ_LEN = 64

token_id = os.environ.get("MODAL_TOKEN_ID", "").strip()
token_secret = os.environ.get("MODAL_TOKEN_SECRET", "").strip()
if not token_id:
    token_id = getpass.getpass("MODAL_TOKEN_ID: ").strip()
if not token_secret:
    token_secret = getpass.getpass("MODAL_TOKEN_SECRET: ").strip()

os.environ["MODAL_TOKEN_ID"] = token_id
os.environ["MODAL_TOKEN_SECRET"] = token_secret

print("Repo:", REPO_URL)
print("Branch:", BRANCH)
print("Volume:", VOLUME_NAME)
print("Checkpoint:", CHECKPOINT_NAME)
print("Train steps:", TRAIN_STEPS, "batch:", BATCH_SIZE, "seq_len:", SEQ_LEN)
print("MODAL_TOKEN_ID prefix:", token_id[:6] + "***")


In [None]:
run(f"{sys.executable} -m pip install -U pip modal")
run("modal --version")
run("modal token info")


In [None]:
modal_script = r'''
import os
import pathlib
import re
import shutil
import subprocess
import urllib.request

import modal

APP_NAME = "vagi-lkan-train"
MOUNT_MODELS = "/models"
DEFAULT_VOLUME = "vagi-lkan-models"

volume = modal.Volume.from_name(DEFAULT_VOLUME, create_if_missing=True)

image = (
    modal.Image.from_registry("nvidia/cuda:12.4.1-devel-ubuntu22.04", add_python="3.11")
    .apt_install("git", "curl", "build-essential", "pkg-config", "libssl-dev", "ca-certificates")
    .run_commands("curl https://sh.rustup.rs -sSf | sh -s -- -y --profile minimal")
)

app = modal.App(APP_NAME)

def run_cmd(cmd: str, cwd: pathlib.Path | None = None):
    print(f"$ {cmd}")
    subprocess.run(cmd, shell=True, check=True, cwd=str(cwd) if cwd else None)

def replace_const(text: str, name: str, value: str, as_str: bool = False) -> str:
    pattern = rf"const {name}: [^=]+ = [^;]+;"
    repl = f"const {name}: usize = {value};"
    if as_str:
        repl = f"const {name}: &str = \"{value}\";"
    updated, n = re.subn(pattern, repl, text, count=1)
    if n == 0:
        raise RuntimeError(f"Cannot find const {name} in train_lkan.rs")
    return updated

def patch_cargo_for_cuda(repo_root: pathlib.Path):
    cargo_toml = repo_root / "kernel" / "Cargo.toml"
    text = cargo_toml.read_text(encoding="utf-8")
    m = re.search(r"^candle-core\\s*=\\s*(.+)$", text, flags=re.MULTILINE)
    if not m:
        raise RuntimeError("Cannot find candle-core dependency")
    current = m.group(0)
    if 'features = ["cuda"]' in current:
        print("candle-core already has cuda feature")
        return
    vm = re.search(r"version\\s*=\\s*\"([^\"]+)\"", current) or re.search(r"\"([^\"]+)\"", current)
    if not vm:
        raise RuntimeError(f"Cannot parse candle-core version from: {current}")
    version = vm.group(1)
    new_line = f"candle-core = {{ version = \"{version}\", features = [\"cuda\"] }}"
    cargo_toml.write_text(text.replace(current, new_line), encoding="utf-8")
    print("patched", cargo_toml)

def patch_train_binary(repo_root: pathlib.Path, train_steps: int, batch_size: int, seq_len: int):
    train_rs = repo_root / "kernel" / "src" / "bin" / "train_lkan.rs"
    src = train_rs.read_text(encoding="utf-8")

    src = replace_const(src, "OUTPUT_PATH", "models/lkan-genesis.safetensors", as_str=True)
    src = replace_const(src, "TRAIN_STEPS", str(train_steps))
    src = replace_const(src, "BATCH_SIZE", str(batch_size))
    src = replace_const(src, "SEQ_LEN", str(seq_len))

    src = src.replace("hidden_dim: 192,", "hidden_dim: 128,")
    src = src.replace("in_dim: 192,", "in_dim: 128,")
    src = src.replace("out_dim: 192,", "out_dim: 128,")

    if "Device::new_cuda(0)" not in src:
        cpu_line = "let device = Device::Cpu;"
        device_block = """let device = match Device::new_cuda(0) {
        Ok(dev) => {
            println!(\"using CUDA device 0\");
            dev
        }
        Err(err) => {
            println!(\"CUDA unavailable ({err}), fallback to CPU\");
            Device::Cpu
        }
    };"""
        if cpu_line not in src:
            raise RuntimeError("Cannot patch device block in train_lkan.rs")
        src = src.replace(cpu_line, device_block)

    train_rs.write_text(src, encoding="utf-8")
    print("patched", train_rs)

def ensure_dataset(repo_root: pathlib.Path):
    data_path = repo_root / "data" / "input.txt"
    data_path.parent.mkdir(parents=True, exist_ok=True)
    if data_path.exists():
        return data_path
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    print("downloading", url)
    urllib.request.urlretrieve(url, data_path)
    return data_path

@app.function(
    image=image,
    gpu="A100-40GB",
    cpu=8.0,
    memory=32768,
    timeout=24 * 60 * 60,
    volumes={MOUNT_MODELS: volume},
)
def train(
    repo_url: str,
    branch: str,
    train_steps: int = 5000,
    batch_size: int = 32,
    seq_len: int = 64,
    checkpoint_name: str = "lkan-genesis-a10040gb.safetensors",
):
    os.environ["PATH"] = f"/root/.cargo/bin:/usr/local/cuda/bin:{os.environ.get('PATH', '')}"
    os.environ["CUDA_HOME"] = "/usr/local/cuda"
    os.environ["LD_LIBRARY_PATH"] = "/usr/local/cuda/lib64"

    run_cmd("nvidia-smi")
    run_cmd("rustc --version")
    run_cmd("cargo --version")

    workdir = pathlib.Path("/root/work")
    repo_root = workdir / "vagi"
    if repo_root.exists():
        shutil.rmtree(repo_root)
    workdir.mkdir(parents=True, exist_ok=True)
    run_cmd(f"git clone --depth 1 --branch {branch} {repo_url} {repo_root}")

    patch_cargo_for_cuda(repo_root)
    patch_train_binary(repo_root, train_steps=train_steps, batch_size=batch_size, seq_len=seq_len)
    ensure_dataset(repo_root)

    run_cmd("cargo run -p vagi-kernel --release --bin train_lkan", cwd=repo_root)

    candidates = [
        repo_root / "models" / "lkan-genesis.safetensors",
        repo_root / "models" / "lkan-gen2.safetensors",
    ]
    produced = next((p for p in candidates if p.exists()), None)
    if produced is None:
        raise RuntimeError("Checkpoint not found after training")

    target = pathlib.Path(MOUNT_MODELS) / checkpoint_name
    shutil.copy2(produced, target)
    volume.commit()

    return {
        "checkpoint": str(target),
        "size_mb": round(target.stat().st_size / (1024 * 1024), 2),
        "gpu": "A100-40GB",
    }

@app.local_entrypoint()
def main(
    repo_url: str = "https://github.com/vietrix/vagi.git",
    branch: str = "main",
    train_steps: int = 5000,
    batch_size: int = 32,
    seq_len: int = 64,
    checkpoint_name: str = "lkan-genesis-a10040gb.safetensors",
):
    result = train.remote(
        repo_url=repo_url,
        branch=branch,
        train_steps=train_steps,
        batch_size=batch_size,
        seq_len=seq_len,
        checkpoint_name=checkpoint_name,
    )
    print("TRAIN RESULT:", result)
    print("Saved to Modal Volume:", DEFAULT_VOLUME)
'''


In [None]:
script_path = pathlib.Path("modal_train_lkan.py")
script_path.write_text(modal_script.strip() + "\n", encoding="utf-8")
print("Wrote:", script_path.resolve())
print(script_path.read_text(encoding="utf-8")[:1000])


In [None]:
cmd = (
    "modal run modal_train_lkan.py "
    f"--repo-url {shlex.quote(REPO_URL)} "
    f"--branch {shlex.quote(BRANCH)} "
    f"--train-steps {TRAIN_STEPS} "
    f"--batch-size {BATCH_SIZE} "
    f"--seq-len {SEQ_LEN} "
    f"--checkpoint-name {shlex.quote(CHECKPOINT_NAME)}"
)
run(cmd)


In [None]:
run(f"modal volume ls {VOLUME_NAME} /")
run(
    f"modal volume get {VOLUME_NAME} /{CHECKPOINT_NAME} {CHECKPOINT_NAME}",
    check=False,
)


In [None]:
local_ckpt = pathlib.Path(CHECKPOINT_NAME)
if local_ckpt.exists():
    print("Downloaded:", local_ckpt.resolve())
    print("Size (MB):", round(local_ckpt.stat().st_size / (1024 * 1024), 2))
else:
    print("Checkpoint chua co local file. Ban co the tai tu Modal Volume tren dashboard.")


In [None]:
print("Done.")
print("Neu can train lau hon, tang TRAIN_STEPS roi chay lai cell modal run.")
print("Neu can checkpoint ten khac, doi CHECKPOINT_NAME.")
