Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,23 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
await env_setter.set_env.call(env_vars)


def get_nccl_env_vars() -> dict[str, str]:
"""Get NCCL environment variables by detecting network interfaces."""
if "NCCL_SOCKET_IFNAME" in os.environ and "NCCL_IB_DISABLE" in os.environ:
return {}

try:
interfaces = os.listdir("/sys/class/net/")
ib_interfaces = [i for i in interfaces if i.startswith("ib")]

return {
"NCCL_SOCKET_IFNAME": ",".join(ib_interfaces) if ib_interfaces else "^lo",
"NCCL_IB_DISABLE": "0" if ib_interfaces else "1",
}
except Exception:
return {"NCCL_SOCKET_IFNAME": "^lo", "NCCL_IB_DISABLE": "1"}


class GpuManager:
"""Tracks and assigns GPU devices on a host.

Expand Down Expand Up @@ -347,11 +364,16 @@ async def get_proc_mesh(
if with_gpus:
if not addr or not port:
addr, port = await get_remote_info(host_mesh)
gpu_ids = gpu_manager.get_gpus(num_procs)
gpu_ids: list[str] = gpu_manager.get_gpus(num_procs)

# Set PyTorch distributed environment variables
env_vars["MASTER_ADDR"] = addr
env_vars["MASTER_PORT"] = port

# Get NCCL-specific environment variables
nccl_vars = await get_nccl_env_vars()
env_vars.update(nccl_vars)

# Set the PTD world size
world_size = num_procs * (num_hosts or 1)
env_vars["WORLD_SIZE"] = str(world_size)
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/test_titan_fwd_vs_hf_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def compare_logits(
hf_val = hf_logits_cpu[pos].item()
diff_val = abs_diff[pos].item()
print(
f" {i+1}. Position {pos}: titan={titan_val:.6f}, hf={hf_val:.6f}, diff={diff_val:.6f}"
f" {i + 1}. Position {pos}: titan={titan_val:.6f}, hf={hf_val:.6f}, diff={diff_val:.6f}"
)

return metrics
Expand Down Expand Up @@ -242,12 +242,12 @@ def compare_probabilities(
zip(titan_top_k.values, titan_top_k.indices)
):
token = tokenizer.decode([token_id.item()])
print(f" {i+1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
print(f" {i + 1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")

print("\nHugging Face Top-K:")
for i, (prob, token_id) in enumerate(zip(hf_top_k.values, hf_top_k.indices)):
token = tokenizer.decode([token_id.item()])
print(f" {i+1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")
print(f" {i + 1}. '{token}' (id={token_id.item()}): {prob.item():.6f}")

# Calculate overlap in top-k predictions
titan_top_tokens = set(titan_top_k.indices.tolist())
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/datasets/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file):
# But should contain the same set of IDs
assert set(first_epoch_ids) == set(
range(SMALL_DATASET_SIZE)
), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_ids}"
), f"First epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {first_epoch_ids}"
assert set(second_epoch_ids) == set(
range(SMALL_DATASET_SIZE)
), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_ids}"
), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {second_epoch_ids}"

def test_epoch_tracking(self, dataset_factory, small_dataset_file):
"""Test that epoch number is correctly tracked across dataset restarts."""
Expand Down
Loading