# GIRNet — Inference\n\nMinimal notebook to run inference on test mixtures.\n\n> Assumes you trained and saved the full model to `results/girnet_model.h5`.\n> If you saved **weights only**, see the second code cell.

In [None]:
import numpy as np\nfrom tensorflow.keras.models import load_model\nfrom spektral.layers import GATConv\nfrom GAT import GraphAttentionLayer  # custom layer used in GIRNet\n\nMODEL_PATH = "results/girnet_model.h5"     # change if needed\nX_TEST     = "CM_Xtest.npy"                # change if needed\nOUT_NPY    = "Y_hat.npy"\n\n# Load full model\nmodel = load_model(\n    MODEL_PATH,\n    custom_objects={"GraphAttentionLayer": GraphAttentionLayer, "GATConv": GATConv}\n)\n\n# Load test mixtures\nX = np.load(X_TEST)\nif X.ndim == 3:  # [N, C, T] -> [N, C, T, 1]\n    X = X[..., None]\n\n# Predict\nY_hat = model.predict(X, batch_size=1)\nnp.save(OUT_NPY, Y_hat)\nprint("Saved:", OUT_NPY, Y_hat.shape)

### If you only saved weights\nRebuild the model and load weights:

In [None]:
import numpy as np\nfrom GAT import build_girnet_model\n\nWEIGHTS = "GAT_weights.h5"\nDIM     = 1024   # <- same as used for training (--dim)\nC       = 4      # <- number of channels used in training\nTHRESH  = 0.6    # <- same as used for training\n\nmodel = build_girnet_model(input_shape=DIM, channels=C, threshold=THRESH)\nmodel.load_weights(WEIGHTS)\n\nX = np.load("CM_Xtest.npy")\nif X.ndim == 3:\n    X = X[..., None]\nY_hat = model.predict(X, batch_size=1)\nnp.save("Y_hat.npy", Y_hat)\nprint("Saved: Y_hat.npy", Y_hat.shape)

*(Optional)* Save WAVs using `soundfile` after any post-processing you choose.

In [None]:
# Optional: write first example to WAV per channel\nimport soundfile as sf\nimport os\n\nsr = 22050  # or 44100, depending on your data\nos.makedirs("pred_wavs", exist_ok=True)\n\n# Y_hat shape typically [N, C, T, 1]\ny = np.squeeze(Y_hat[0], axis=-1)  # [C, T]\nfor ch in range(y.shape[0]):\n    sf.write(f"pred_wavs/example_ch{ch}.wav", y[ch], sr)\nprint("Wrote WAVs to pred_wavs/")