In [1]:
# tests/test_block_roundtrip.py
import numpy as np
import tensorflow as tf
from model.block_tf import BlockEmbeddingTF, BlockFoldTF

def run_test():
    # small synthetic image
    B = 1
    H = 18
    W = 18
    C = 5
    img = np.random.rand(B, H, W, C).astype(np.float32)
    images = tf.convert_to_tensor(img)

    patch_h = 3
    patch_w = 3
    stride = 3

    be = BlockEmbeddingTF(patch_h=patch_h, patch_w=patch_w, stride_h=stride, stride_w=stride, padding='SAME')
    patches, info = be.extract(images)
    print("patches shape:", patches.shape, "info:", info)

    bf = BlockFoldTF()
    recon = bf.fold(patches, info, orig_H=H, orig_W=W)
    recon_np = recon.numpy()

    # compare only the central region that is within valid reconstruction (tolerance)
    diff = np.abs(recon_np - img)
    print("max abs diff:", diff.max())
    assert diff.max() < 1e-5 or diff.max() < 1e-3, "Round-trip error too large"
    print("Round-trip OK")

if __name__ == '__main__':
    run_test()


2025-11-13 11:03:45.100339: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-13 11:03:45.234007: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-13 11:03:48.205359: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


patches shape: (1, 36, 3, 3, 5) info: {'new_h': 6, 'new_w': 6, 'ph': 3, 'pw': 3, 'sh': 3, 'sw': 3, 'padding': 'SAME'}
max abs diff: 0.0
Round-trip OK


W0000 00:00:1763012030.204683   20983 gpu_device.cc:2342] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [2]:
# tests/test_attention_tf.py
import numpy as np
import tensorflow as tf
from model.attention_tf import AttentionBlock as AttentionTF

def demo():
    tf.random.set_seed(0)
    B = 2
    psize = 3
    pstride = 3
    H = psize * pstride
    W = H
    C = 8
    x = tf.random.normal((B, H, W, C), dtype=tf.float32)

    # create attention module
    att = AttentionTF(embed_dim=C, patch_size=psize, patch_stride=pstride, proj_ratio=4, attn_drop=0.0)
    # build by calling once
    dummy = att(x, block_idx=tf.constant([0,1], dtype=tf.int32), match_vec=tf.constant([0,0,0,0], dtype=tf.float32), return_attn=False)

    # Case 1: match_vec all zeros -> AFB behavior (self suppressed)
    num_blocks_global = 10
    match_vec = tf.zeros([num_blocks_global], dtype=tf.float32)
    block_idx = tf.constant([0, 1], dtype=tf.int32)  # pretend these are indices into match_vec
    out_afb, attn_afb = att(x, block_idx=block_idx, match_vec=match_vec, return_attn=True, training=False)
    print("attn_afb shape:", attn_afb.shape)  # [B, N, N]
    # inspect diagonal mean
    diag_afb = tf.linalg.diag_part(attn_afb)  # [B, N]
    print("AFB diagonal mean:", tf.reduce_mean(diag_afb).numpy())

    # Case 2: match_vec has ones -> BFB behavior (for those batch samples)
    match_vec2 = tf.ones([num_blocks_global], dtype=tf.float32)
    out_bfb, attn_bfb = att(x, block_idx=block_idx, match_vec=match_vec2, return_attn=True, training=False)
    diag_bfb = tf.linalg.diag_part(attn_bfb)
    print("BFB diagonal mean:", tf.reduce_mean(diag_bfb).numpy())

    # Expect diag_afb mean << diag_bfb mean
    print("diag_afb mean:", float(tf.reduce_mean(diag_afb)))
    print("diag_bfb mean:", float(tf.reduce_mean(diag_bfb)))

if __name__ == '__main__':
    demo()


W0000 00:00:1763033796.713866   25298 gpu_device.cc:2342] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


attn_afb shape: (2, 9, 9)
AFB diagonal mean: 0.0
BFB diagonal mean: 0.8623292
diag_afb mean: 0.0
diag_bfb mean: 0.8623291850090027


In [3]:
# tests/test_gtb_net.py
import numpy as np
import tensorflow as tf
from model.gtblock_tf import NetTF

def run_test():
    B = 4
    psize = 3
    pstride = 3
    H = psize * pstride
    W = H
    in_chans = 8
    embed_dim = 16

    # create random blocks in channels-last
    x = tf.random.normal((B, H, W, in_chans), dtype=tf.float32)
    # dummy match_vec and block_idx
    num_blocks_global = 100
    match_vec = tf.zeros([num_blocks_global], dtype=tf.float32)
    block_idx = tf.constant([0,1,2,3], dtype=tf.int32)

    net = NetTF(in_chans=in_chans, embed_dim=embed_dim, patch_size=psize, patch_stride=pstride, mlp_ratio=2.0, proj_ratio=4)
    out = net(x, block_idx=block_idx, match_vec=match_vec, training=True)
    print("in shape:", x.shape, "out shape:", out.shape)
    assert out.shape == x.shape
    print("GTB/Net forward OK")

if __name__ == '__main__':
    run_test()


in shape: (4, 9, 9, 8) out shape: (4, 9, 9, 8)
GTB/Net forward OK
