This is the implementation for the paper GraphLLM: Boosting Graph Reasoning Ability of Large Language Model.
- You may need a single 80G GPU to run the experiment. We experiment on CUDA 11.8 and torch 2.0.1.
- Setup up a new conda env and install necessary packages.
conda create -n graph_llm python=3.10 -y pip install -r requirements.txt
- To run the code, you need the checkpoint and tokenizer of LLaMA-2-7B, which you can access at Meta.
After downloading LLaMA-2-7B, soft link the checkpoint folder and the tokenizer folder to the folder of this repository:
ln -s /folder/of/LLaMA-2-7B/checkpoint ./LLaMA-7B-2 ln -s /folder/of/LLaMA-2-7B/tokenizer ./Llama-2-7b-hf
- Remember to replace the directory
/folder/of/LLaMA-2-7B/checkpointand/folder/of/LLaMA-2-7B/tokenizerwith actual directories! - The four graph reasoning datasets are available on Google Drive or Huggingface.
You may download it and place the zip file in the directory of this repository. And then run the command:
unzip dataset.zip -d ./dataset
- The directory structure should be:
. |- LLaMA-7B-2 | |- params.json | |- consolidated.00.pth | |- Llama-2-7b-hf | |- tokenizer.model | |- dataset |- sc |- mts |- sp |- bgm
Train and evaluate the model with default settings on graph reasoning datasets on GPU 0:
- Substructure Counting
./scripts/sc.sh
- Maximum Triplet Sum
./scripts/mts.sh
- Shortest Path
./scripts/sp.sh
- Bipartite Graph Matching
./scripts/bgm.sh
More hyperparameter settings are at config.py
Hyperparameter explanation:
-
--n_encoder_layersnumber of transformer layers of textual encoder -
--n_decoder_layersnumber of transformer layers of textual decoder -
--n_mp_layersnumber of graph transformer layers -
--adapter_dimhidden dimension of textual encoder/decoder and graph transformer -
--adapter_lennumber of prefix tokens per LLM layer -
--rrwpgraph positional encoding dimension -
--batch_sizebatch size in memory during training -
--grad_stepsgrad_step$\times$ batch_size = batch size for optimization -
--lrthe learning rate -
--num_epochsnumber of training epochs -
--warmup_epochsnumber of linear warmup epochs -
--wdweight decay