# GEARS 训练

已内置 `sys.path` 修复：
1. 将 **GEARS 根目录**（包含 `gears/` 子目录的父目录）加入 `sys.path`。
2. 将 **model 根目录**（包含 `load.py` 的目录）加入 `sys.path`。


In [1]:
# 环境检查
import sys, torch
print('[Info] Python', sys.version)
print('[Info] PyTorch', torch.__version__)
print('[Info] CUDA available =', torch.cuda.is_available())
if torch.cuda.is_available():
    print('[Info] CUDA device count =', torch.cuda.device_count())
    print('[Info] Current device =', torch.cuda.current_device())
    print('[Info] Device name =', torch.cuda.get_device_name(torch.cuda.current_device()))

[Info] Python 3.8.18 | packaged by conda-forge | (default, Dec 23 2023, 17:21:28) 
[GCC 12.3.0]
[Info] PyTorch 2.4.1+cu118
[Info] CUDA available = True
[Info] CUDA device count = 7
[Info] Current device = 0
[Info] Device name = NVIDIA L40


将 GEARS 与 model 根目录加入 `sys.path`，以支持 `from gears...` 和 `from load import ...`。

In [2]:
import sys, os

# === 请按需修改为你的实际路径 ===
GEARS_ROOT = "/home/mjin/scFoundation-main/GEARS"   # 父目录，里面应有子目录 `gears/`
MODEL_ROOT = "/home/mjin/scFoundation-main/model"   # 目录内应有 `load.py`

for p in (GEARS_ROOT, MODEL_ROOT):
    if not os.path.isdir(p):
        raise RuntimeError(f"目录不存在: {p}")
    if p not in sys.path:
        sys.path.insert(0, p)
print('[OK] sys.path patched. Top entries:', sys.path[:3])

[OK] sys.path patched. Top entries: ['/home/mjin/scFoundation-main/model', '/home/mjin/scFoundation-main/GEARS', '/home/mjin/anaconda3/envs/hast/lib/python38.zip']


## 0.1) 路径与模块可见性自检

In [3]:
import importlib
try:
    import gears
    print('[OK] gears package at:', gears.__file__)
except Exception as e:
    print('[ERR] cannot import gears:', e)

try:
    load = importlib.import_module('load')
    print('[OK] load.py at:', load.__file__)
except Exception as e:
    print('[ERR] cannot import load.py:', e)

  from .autonotebook import tqdm as notebook_tqdm


loading embedding from scfoundation
torch.Size([3638, 1049, 512])
[OK] gears package at: /home/mjin/scFoundation-main/GEARS/gears/__init__.py
[OK] load.py at: /home/mjin/scFoundation-main/model/load.py


## 1) 参数区（在这里改你的路径与超参）

In [4]:
import torch
# ===== 必改/可改参数 =====
DATA_ROOT = "/home/mjin/scFoundation-main/GEARS/data"  # 必须包含 gene2go.pkl
DATASET_NAME = None  # 可选："norman" | "adamson" | "dixit"；若不为 None，将自动下载到 DATA_ROOT 下
DATASET_PATH = "/home/mjin/scFoundation-main/GEARS/data/"  # 本地加载：目录(含 norman.h5ad 或 perturb_processed.h5ad) 或 .h5ad 文件

# 数据划分/loader
SPLIT = "simulation"
SEED = 3
TRAIN_GENE_SET_SIZE = 0.75
BATCH_SIZE = 32
TEST_BATCH_SIZE = 32

# 训练超参
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HIDDEN_SIZE = 64
ACCUMULATION_STEPS = 5
HIGHRES = 0
SINGLECELL_MODEL_PATH = None  # 若使用 scFoundation 预训练编码器，这里填 ckpt 路径，否则保持 None
EPOCHS = 20
LR = 1e-3
WEIGHT_DECAY = 5e-4
RESULT_DIR = "/home/mjin/scFoundation-main/GEARS/result/scfoundation-norman"

print("[Params] device=", DEVICE)

[Params] device= cuda


## 2) 导入依赖（已修复路径）

In [5]:
from gears.pertdata import PertData
from gears.gears import GEARS
ALLOWED_NAMES = {"norman", "adamson", "dixit"}
print("[Info] Imports ok.")

[Info] Imports ok.


## 3) 基本检查

In [6]:
gene2go_pkl = os.path.join(DATA_ROOT, "gene2go.pkl")
if not os.path.exists(gene2go_pkl):
    raise FileNotFoundError(f"缺少 {gene2go_pkl}，PertData.__init__ 会读取它。请把 gene2go.pkl 放到 DATA_ROOT 下。")
print("[Check] gene2go.pkl 存在：", gene2go_pkl)

[Check] gene2go.pkl 存在： /home/mjin/scFoundation-main/GEARS/data/gene2go.pkl


## 4) 构造 PertData 并加载数据

In [7]:
# gi_go=False 表示不在装载阶段用 GI/GO 过滤；不影响 GEARS 是否尝试构建 GO 图
pert_data = PertData(data_path=DATA_ROOT, gi_go=False)

if DATASET_NAME is not None:
    name = str(DATASET_NAME).lower()
    if name not in ALLOWED_NAMES:
        raise ValueError(f"DATASET_NAME 仅支持 {ALLOWED_NAMES}，否则请将 DATASET_NAME 设为 None 并使用 DATASET_PATH")
    pert_data.load(data_name=name)
    print(f"[INFO] 自动下载数据集: {name}")
else:
    if DATASET_PATH is None:
        raise ValueError("必须提供 DATASET_NAME 或 DATASET_PATH 其一")
    # 允许目录或 .h5ad；若只有 norman.h5ad 则内部会生成 perturb_processed.h5ad
    if (isinstance(DATASET_PATH, str) and os.path.isdir(DATASET_PATH)):
        if not os.path.exists(os.path.join(DATASET_PATH, "norman.h5ad")) and not os.path.exists(os.path.join(DATASET_PATH, "perturb_processed.h5ad")):
            print(f"[WARN] 目录 {DATASET_PATH} 下未发现 norman.h5ad 或 perturb_processed.h5ad，若失败请检查路径")
    pert_data.load(data_path=DATASET_PATH)
    print(f"[INFO] 使用本地数据路径: {DATASET_PATH}")

print("[OK] 数据加载完成。")

read /home/mjin/scFoundation-main/GEARS/data/gene2go.pkl
/home/mjin/scFoundation-main/GEARS/data/


These perturbations are not in the GO graph and is thus not able to make prediction for...


(80506, 1049)


[]
Local copy of pyg dataset is detected. Loading...
Done!


[INFO] 使用本地数据路径: /home/mjin/scFoundation-main/GEARS/data/
[OK] 数据加载完成。


## 5) 划分与 DataLoader

In [8]:
pert_data.prepare_split(split=SPLIT, seed=SEED, train_gene_set_size=TRAIN_GENE_SET_SIZE)
pert_data.get_dataloader(batch_size=BATCH_SIZE, test_batch_size=TEST_BATCH_SIZE)
print("[OK] 划分 & DataLoader 完成。")

Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:6
combo_seen1:48
combo_seen2:18
unseen_single:25
Done!
Creating dataloaders....
Done!


[OK] 划分 & DataLoader 完成。


## 6) GEARS 实例化与初始化
此处会间接导入 `modules/encoders.py`，而该文件内会 `from load import *`，因此我们已将 `MODEL_ROOT` 放入 `sys.path`。

In [9]:
gears_model = GEARS(pert_data, device=DEVICE)
gears_model.model_initialize(
    hidden_size=HIDDEN_SIZE,
    model_type='emb',
    load_path=SINGLECELL_MODEL_PATH,
    finetune_method='random',
    accumulation_steps=ACCUMULATION_STEPS,
    highres=HIGHRES
)
print("[OK] GEARS 初始化完成。")

Use accumulation steps: 5
Use mode: v1
Use higres: 0
No G_go
/home/mjin/scFoundation-main/GEARS/data/go.csv
[OK] GEARS 初始化完成。


## 7) 训练
（注意：训练会花一定时间，运行前请确认显卡/内存资源）

In [10]:
os.makedirs(RESULT_DIR, exist_ok=True)
print("[Run] 开始训练... 输出目录:", RESULT_DIR)
import time
start = time.time()
gears_model.train(
    epochs=EPOCHS,
    result_dir=RESULT_DIR,
    lr=LR,
    weight_decay=WEIGHT_DECAY
)
print(f"[Done] 训练完成，用时 {time.time()-start:.1f}s")

Start Training...


[Run] 开始训练... 输出目录: /home/mjin/scFoundation-main/GEARS/result/scfoundation-norman


Epoch 1 Step 1 Train Loss: 0.5088
Epoch 1 Step 1001 Train Loss: 0.3875
Pert: AHR+FEV PCC: 0.5892 MSE: 0.0228 PCC_DE: 0.9368 MSE_DE: 0.3550
Pert: AHR+KLF1 PCC: 0.3280 MSE: 0.0176 PCC_DE: 0.5509 MSE_DE: 0.2905
Pert: AHR+ctrl PCC: 0.4533 MSE: 0.0132 PCC_DE: 0.6958 MSE_DE: 0.0558
Pert: ARID1A+ctrl PCC: 0.3983 MSE: 0.0130 PCC_DE: 0.6466 MSE_DE: 0.1976
Pert: ARRDC3+ctrl PCC: 0.3440 MSE: 0.0064 PCC_DE: 0.8473 MSE_DE: 0.0065
Pert: ATL1+ctrl PCC: 0.3910 MSE: 0.0132 PCC_DE: 0.8464 MSE_DE: 0.2417
Pert: BCL2L11+BAK1 PCC: 0.0367 MSE: 0.0218 PCC_DE: 0.5275 MSE_DE: 0.0282
Pert: BPGM+ZBTB1 PCC: 0.5661 MSE: 0.0181 PCC_DE: 0.9272 MSE_DE: 0.1760
Pert: CBL+CNN1 PCC: 0.6643 MSE: 0.0228 PCC_DE: 0.9330 MSE_DE: 0.4676
Pert: CBL+UBASH3B PCC: 0.6734 MSE: 0.0182 PCC_DE: 0.9594 MSE_DE: 0.2792
Pert: CDKN1A+ctrl PCC: 0.5766 MSE: 0.0103 PCC_DE: 0.8952 MSE_DE: 0.0591
Pert: CDKN1B+CDKN1A PCC: 0.6001 MSE: 0.0224 PCC_DE: 0.9218 MSE_DE: 0.0563
Pert: CDKN1C+CDKN1A PCC: 0.5987 MSE: 0.0206 PCC_DE: 0.9384 MSE_DE: 0.0618
Pert

KeyboardInterrupt: 

## 8) 保存当前参数到文件

In [None]:
import json, pathlib
params = {
    'DATA_ROOT': DATA_ROOT,
    'DATASET_NAME': DATASET_NAME,
    'DATASET_PATH': DATASET_PATH,
    'SPLIT': SPLIT,
    'SEED': SEED,
    'TRAIN_GENE_SET_SIZE': TRAIN_GENE_SET_SIZE,
    'BATCH_SIZE': BATCH_SIZE,
    'TEST_BATCH_SIZE': TEST_BATCH_SIZE,
    'DEVICE': DEVICE,
    'HIDDEN_SIZE': HIDDEN_SIZE,
    'ACCUMULATION_STEPS': ACCUMULATION_STEPS,
    'HIGHRES': HIGHRES,
    'SINGLECELL_MODEL_PATH': SINGLECELL_MODEL_PATH,
    'EPOCHS': EPOCHS,
    'LR': LR,
    'WEIGHT_DECAY': WEIGHT_DECAY,
    'RESULT_DIR': RESULT_DIR,
}
cfg_path = os.path.join(RESULT_DIR, 'run_params.json')
pathlib.Path(RESULT_DIR).mkdir(parents=True, exist_ok=True)
with open(cfg_path, 'w') as f:
    json.dump(params, f, indent=2, ensure_ascii=False)
print('[OK] 参数已保存到', cfg_path)