Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ jobs:
- name: Run tests
env:
TEST_DATA_DIR: ${{ github.workspace }}/tests/test_files
run: uv run pytest tests/ --cov=src --cov-report=html --cov-report=term-missing --cov-fail-under=80
run: uv run pytest tests/ --cov=src --cov-report=html --cov-report=term-missing --cov-fail-under=70
Comment thread
vratins marked this conversation as resolved.
Comment thread
marcuscollins marked this conversation as resolved.
Comment thread
vratins marked this conversation as resolved.
33 changes: 26 additions & 7 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,24 @@ def load_config(run_dir: Path) -> dict:
return config


def _extract_dataset_filter_config(config: dict) -> dict:
"""Extract dataset filter params from training config with fallback to defaults."""
return {
"max_com_dist": config.get("max_com_dist", 25.0),
"max_clash_fraction": config.get("max_clash_fraction", 0.05),
"clash_dist": config.get("clash_dist", 2.0),
"interface_dist_threshold": config.get("interface_dist_threshold", 4.0),
"min_water_residue_ratio": config.get("min_water_residue_ratio", 0.6),
"edia_dir": config.get("edia_dir"),
"max_protein_dist": config.get("max_protein_dist", 5.0),
"min_edia": config.get("min_edia", 0.4),
"max_bfactor_zscore": config.get("max_bfactor_zscore", 1.5),
"filter_by_distance": config.get("filter_by_distance", True),
"filter_by_edia": config.get("filter_by_edia", True),
"filter_by_bfactor": config.get("filter_by_bfactor", True),
}


def build_model_from_config(config: dict, device: torch.device) -> nn.Module:
"""
Build model architecture from training configuration.
Expand All @@ -221,8 +239,7 @@ def build_model_from_config(config: dict, device: torch.device) -> nn.Module:
- encoder_type: "gvp", "slae", or "esm"
- hidden_s, hidden_v: Hidden dimensions for scalars/vectors
- flow_layers: Number of flow layers
- For SLAE: slae_dim (default 128)
- For ESM: esm_dim (default 1536)
- For cached encoders: embedding_dim and embedding_key="embedding"
device: Device to place model on

Returns:
Expand All @@ -243,11 +260,9 @@ def build_model_from_config(config: dict, device: torch.device) -> nn.Module:
"encoder_ckpt": config.get("encoder_ckpt"),
}

# Add encoder-specific dimension (use 'or' to handle None values)
if encoder_type == "slae":
encoder_config["slae_dim"] = config.get("slae_dim") or 128
elif encoder_type == "esm":
encoder_config["esm_dim"] = config.get("esm_dim") or 1536
if encoder_type in {"slae", "esm"}:
encoder_config["embedding_key"] = "embedding"
encoder_config["embedding_dim"] = config.get("embedding_dim")
Comment thread
vratins marked this conversation as resolved.

encoder = build_encoder(encoder_config, device)

Expand Down Expand Up @@ -431,6 +446,9 @@ def main():
"geometry_cache_name", "geometry"
)

# Extract dataset filter config from training config for consistency
filter_config = _extract_dataset_filter_config(config)

dataset = ProteinWaterDataset(
pdb_list_file=args.pdb_list,
processed_dir=args.processed_dir,
Expand All @@ -439,6 +457,7 @@ def main():
include_mates=include_mates,
geometry_cache_name=geometry_cache_name,
preprocess=True,
**filter_config,
)

logger.info(f"Found {len(dataset)} PDB entries")
Expand Down
62 changes: 34 additions & 28 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,18 +222,12 @@ def parse_args():
p.add_argument("--k_pw", type=int, default=16)
p.add_argument("--k_ww", type=int, default=16)

# optional encoder-specific overrides
# optional cached-embedding override
p.add_argument(
"--slae_dim",
"--embedding_dim",
type=int,
default=None,
help="Optional SLAE embedding dimension override",
)
p.add_argument(
"--esm_dim",
type=int,
default=None,
help="Optional ESM embedding dimension override",
help="Optional cached embedding dimension override for SLAE/ESM encoders",
)

# training
Expand Down Expand Up @@ -318,7 +312,10 @@ def parse_args():
p.add_argument("--wandb_project", type=str, default="water-flow")
p.add_argument("--wandb_dir", type=str, default="/home/srivasv/wandb_logs")
p.add_argument("--device", type=str, default="cuda")
return p.parse_args()
args = p.parse_args()
if args.encoder_type == "gvp" and args.embedding_dim is not None:
p.error("--embedding_dim is only valid for cached encoders: slae or esm")
return args


def _extract_quality_config(args: argparse.Namespace) -> dict:
Expand Down Expand Up @@ -425,15 +422,18 @@ def _required_embedding_field(encoder_type: str) -> str | None:
encoder_type: Encoder identifier ('gvp', 'slae', or 'esm')

Returns:
Field name string (e.g., 'slae_embedding') or None if encoder doesn't need embeddings
Field name string (e.g., 'embedding') or None if encoder doesn't need embeddings
"""
if encoder_type == "slae":
return "slae_embedding"
if encoder_type == "esm":
return "esm_embedding"
if encoder_type in {"slae", "esm"}:
return "embedding"
return None


def _uses_cached_embeddings(encoder_type: str) -> bool:
"""Return whether the selected encoder consumes cached protein embeddings."""
return _required_embedding_field(encoder_type) is not None


def _resolve_embedding_dim(
sample_data,
encoder_type: str,
Expand All @@ -460,7 +460,15 @@ def _resolve_embedding_dim(
raise ValueError(
f"Selected encoder '{encoder_type}' requires protein.{field}, "
f"but it is missing from dataset samples. "
f"Expected cache at {field.split('_')[0]}/<cache_key>.pt under --processed_dir."
f"Expected cached embeddings in data['protein'].embedding from "
f"--processed_dir/{encoder_type}/<cache_key>.pt."
)

embedding_type = sample_data["protein"].get("embedding_type")
if embedding_type is not None and embedding_type != encoder_type:
raise ValueError(
f"Selected encoder '{encoder_type}' requires protein.embedding_type="
f"'{encoder_type}', but sample data has '{embedding_type}'."
)

inferred_dim = int(sample_data["protein"][field].shape[-1])
Expand All @@ -484,8 +492,8 @@ def resolve_encoder_config(args, sample_data, node_scalar_in: int):
Returns:
dict: Encoder configuration ready for build_encoder(), e.g.:
- GVP: {"encoder_type": "gvp", "hidden_s": 256, "hidden_v": 64, ...}
- SLAE: {"encoder_type": "slae", "slae_dim": 128, ...}
- ESM: {"encoder_type": "esm", "esm_dim": 1536, ...}
- SLAE: {"encoder_type": "slae", "embedding_key": "embedding", "embedding_dim": 128, ...}
- ESM: {"encoder_type": "esm", "embedding_key": "embedding", "embedding_dim": 1536, ...}
"""
encoder_config = {
"encoder_type": args.encoder_type,
Expand All @@ -496,13 +504,10 @@ def resolve_encoder_config(args, sample_data, node_scalar_in: int):
"encoder_ckpt": args.encoder_ckpt,
}

if args.encoder_type == "slae":
encoder_config["slae_dim"] = _resolve_embedding_dim(
sample_data, "slae", args.slae_dim
)
elif args.encoder_type == "esm":
encoder_config["esm_dim"] = _resolve_embedding_dim(
sample_data, "esm", args.esm_dim
if _uses_cached_embeddings(args.encoder_type):
encoder_config["embedding_key"] = "embedding"
encoder_config["embedding_dim"] = _resolve_embedding_dim(
sample_data, args.encoder_type, args.embedding_dim
)

return encoder_config
Expand All @@ -514,8 +519,9 @@ def log_encoder_sample_stats(sample_data: HeteroData, encoder_type: str) -> None
if field is None:
return
emb = sample_data["protein"][field]
embedding_type = sample_data["protein"].get("embedding_type", "unknown")
logger.info(
f"{field} shape={tuple(emb.shape)} "
f"{field} type={embedding_type} shape={tuple(emb.shape)} "
f"mean={emb.mean():.4f} std={emb.std():.4f} min={emb.min():.4f} max={emb.max():.4f}"
)

Expand Down Expand Up @@ -995,8 +1001,8 @@ def main():
logger.info(f"Trainable parameters: {trainable_params:,}")
logger.info(f"Total parameters: {total_params:,}")

# quick forward pass sanity check for embedding-based encoders
if args.encoder_type in {"slae", "esm"}:
# quick forward pass sanity check for cached embedding encoders
if _uses_cached_embeddings(args.encoder_type):
logger.info(f"Testing forward pass with {args.encoder_type.upper()}...")
model.eval()
batch = next(iter(train_loader)).to(device)
Expand Down
20 changes: 20 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,23 @@
"U": "SEC",
"O": "PYL",
}

# Element vocabulary for atom types in protein structures
ELEMENT_VOCAB = [
"C",
"N",
"O",
"S",
"P",
"SE",
"MG",
"ZN",
"CA",
"FE",
"NA",
"K",
"CL",
"F",
"BR",
]
ELEM_IDX = {e: i for i, e in enumerate(ELEMENT_VOCAB)}
Loading