Note: This repository contains the complete architecture implementation with comprehensive test suite. Trained model weights are not included - users should train on their own datasets.
A knowledge-graph-enhanced transformer for code understanding, based on the GraphMERT architecture. (Currently working on getting some training data.)
What is GraphMERT? It's CodeBERT (RoBERTa trained on code) enhanced with graph structure. It learns from both the syntax of code (tokens) and its semantics (knowledge graph relations like "function X calls function Y").
- Leafy Chain Graphs: Innovative data structure linking code tokens to knowledge graph triples
- H-GAT Layer: Hierarchical Graph Attention for fusing text with graph structure
- Attention Decay Mask: Graph-distance-aware attention mechanism
- Dual Training: MLM (Masked Language Modeling) + MNM (Masked Node Modeling)
- Built on CodeBERT: Leverages existing code understanding, adds graph reasoning
Code Input → AST Parser → Knowledge Graph Triples
↓
Leafy Chain Graph
↓
[CodeBERT + H-GAT + Decay Mask]
↓
Graph-Enhanced Representations
- Base Model: CodeBERT (microsoft/codebert-base)
- Graph Layer: H-GAT fuses token embeddings with KG relation embeddings
- Attention Mask: Exponential decay based on graph distance
- Training: 60% MLM (predict masked tokens) + 40% MNM (predict masked relations)
When using CodeBERT-base (recommended):
- Hidden size: 768 (from CodeBERT)
- Layers: 12 (from CodeBERT)
- Attention heads: 12 (from CodeBERT)
- Total parameters: ~125M (CodeBERT) + H-GAT layer
Paper's medical model (trained from scratch):
- Hidden size: 512
- Layers: 12
- Attention heads: 8
- Total parameters: ~80M
Note: This implementation uses pretrained CodeBERT, so it inherits CodeBERT's architecture (768 hidden size).
# Clone the repository
git clone https://github.com/humanjesse/graphmert-codebert-base.git
cd graphmert-codebert-base
# Create virtual environment (recommended)
python -m venv venv
source venv/bin/activate # Linux/Mac (Windows: venv\Scripts\activate)
# Install dependencies
pip install -r requirements.txt
# Run comprehensive test suite (validates all 10 critical components)
python test_fixes.pyRequirements:
- Python 3.8+
- PyTorch 2.0+
- transformers, torch-geometric, networkx
- GPU recommended for training (CPU works but is slow)
python examples/quick_start.pyThis will:
- Parse sample code and extract knowledge graph triples
- Build leafy chain graphs
- Initialize GraphMERT from CodeBERT
- Show graph-enhanced vs. standard encoding
Prepare your code dataset (one sample per line or blank-line separated):
# Your data file: data/my_code.txt
def fibonacci(n):
if n <= 1:
return n
return fibonacci(n-1) + fibonacci(n-2)
class Stack:
def __init__(self):
self.items = []Train the model:
python train.py \
--data_path examples/sample_data.txt \
--num_epochs 25 \
--batch_size 32 \
--output_dir ./checkpointsfrom graphmert import GraphMERTModel
from transformers import RobertaTokenizer
# Load trained model
model = GraphMERTModel.from_pretrained("./checkpoints/checkpoint-epoch-25")
tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
# Encode code
code = "def hello(name): print(name)"
inputs = tokenizer(code, return_tensors="pt")
outputs = model(**inputs)
embeddings = outputs.last_hidden_state # Graph-enhanced representationsgraphmert/
├── graphmert/
│ ├── models/
│ │ ├── graphmert.py # Main GraphMERT model
│ │ ├── h_gat.py # Hierarchical Graph Attention layer
│ │ └── attention_mask.py # Graph-aware attention decay
│ ├── data/
│ │ ├── leafy_chain.py # Leafy chain graph data structure
│ │ ├── code_parser.py # AST parsing to extract triples
│ │ └── graph_builder.py # Build graphs from code
│ └── training/
│ ├── losses.py # MLM + MNM loss functions
│ └── trainer.py # Training pipeline
├── examples/
│ ├── quick_start.py # Demo script
│ └── sample_data.txt # Example code samples
├── configs/
│ └── default.yaml # Training configuration
├── train.py # Main training script
├── test_installation.py # Installation test
├── README.md # This file
└── ARCHITECTURE.md # Detailed architecture guide
## How It Works
### 1. Extract Knowledge Graph Triples
```python
Code: def hello(name): print(name)
Triples extracted:
(hello, parameter_of, name)
(hello, calls, print)
(print, uses, name)
Roots (tokens): ["def", "hello", "(", "name", ")", ":", ...]
Leaves (triples): [(hello, parameter_of, name), (hello, calls, print), ...]
Edges: token "hello" → connected to triples 0 and 1
token "name" → connected to triples 0 and 2
# Standard CodeBERT: Only sees tokens
embedding = encoder(["def", "hello", "(", "name", ...])
# GraphMERT: Sees tokens + their semantic relations
embedding = encoder(
tokens=["def", "hello", "(", "name", ...],
graph=[(hello, parameter_of, name), (hello, calls, print), ...]
)
# Result: "hello" embedding now includes information about its parameters and what it callsThe model is trained with two objectives:
Predict masked code tokens (standard BERT objective):
Input: def [MASK](name): print(name)
Target: "hello"
Predict masked graph relations (novel GraphMERT objective):
Input graph: (hello, [MASK], name)
Target: "parameter_of"
L = 0.6 * L_MLM + 0.4 * L_MNM
This teaches the model to understand BOTH code syntax AND semantic structure.
Edit configs/default.yaml to customize:
model:
base_model: "microsoft/codebert-base"
hidden_size: 512
num_layers: 12
num_attention_heads: 8
training:
num_epochs: 25
batch_size: 32
learning_rate: 0.0004
lambda_mlm: 0.6 # 60% MLM, 40% MNMpython train.py \
--data_path <path-to-code-samples> \
--output_dir ./checkpoints \
--num_epochs 25 \
--batch_size 32 \
--learning_rate 4e-4 \
--lambda_mlm 0.6 \
--use_wandb # Optional: log to Weights & Biases- Start small: Test on 1,000 samples before training on full dataset
- GPU recommended: Training on CPU is very slow (~100x slower)
- Adjust batch size: Reduce if you get OOM errors
- Use gradient accumulation: If you need larger effective batch size
- Monitor both losses: MLM and MNM should both decrease during training
Currently supports Python. To add JavaScript/Java/etc:
# In graphmert/data/code_parser.py
class JavaScriptParser:
def parse(self, code):
# Use a JS AST parser
# Extract triples
return triples# In graphmert/data/code_parser.py
class PythonTripleExtractor(ast.NodeVisitor):
def visit_YourNode(self, node):
self.triples.append(Triple(
head=...,
relation="your_new_relation",
tail=...
))# Try RoBERTa, GraphCodeBERT, or other compatible models
model = GraphMERTModel.from_codebert(
codebert_model_name="roberta-base" # or "huggingface/CodeBERTa-small-v1"
)The repository includes test_fixes.py - a comprehensive 1,283-line test suite validating:
- ✅ Hidden size matches CodeBERT (768)
- ✅ Attention decay formula (λ^GELU(√distance - p))
- ✅ H-GAT has no cross-token attention leakage
- ✅ Floyd-Warshall multi-hop distance computation
- ✅ Span masking with geometric distribution
- ✅ MNM (Masked Node Modeling) loss
- ✅ Combined MLM+MNM loss (μ=1)
- ✅ End-to-end forward pass with graphs
- ✅ Decay mask integration
- ✅ Shared relation embeddings
Run tests: python test_fixes.py
Q: Installation fails?
- Ensure Python 3.8+, install PyTorch first:
pip install torch
Q: No graph connections?
- Run
python examples/quick_start.pyto verify parsing - Ensure code has functions/classes (not just plain statements)
Q: Out of memory?
- Reduce
--batch_size(try 16 or 8) - Reduce
max_seq_lenin config
If you use GraphMERT in your research, please cite:
@article{graphmert2024,
title={GraphMERT: A Graph-Enhanced Transformer for Code Understanding},
year={2024}
}MIT License (or specify your license)
Contributions welcome! Please:
- Fork the repository
- Create a feature branch
- Submit a pull request
- Based on the GraphMERT paper (2024)
- Built on CodeBERT by Microsoft
- Inspired by Graph Attention Networks