In [5]:
# tests/test_block_roundtrip.py
import numpy as np
import tensorflow as tf
from model.block_tf import BlockEmbeddingTF, BlockFoldTF

def run_test():
    # small synthetic image
    B = 1
    H = 18
    W = 18
    C = 5
    img = np.random.rand(B, H, W, C).astype(np.float32)
    images = tf.convert_to_tensor(img)

    patch_h = 3
    patch_w = 3
    stride = 3

    be = BlockEmbeddingTF(patch_h=patch_h, patch_w=patch_w, stride_h=stride, stride_w=stride, padding='SAME')
    patches, info = be.extract(images)
    print("patches shape:", patches.shape, "info:", info)

    bf = BlockFoldTF()
    recon = bf.fold(patches, info, orig_H=H, orig_W=W)
    recon_np = recon.numpy()

    # compare only the central region that is within valid reconstruction (tolerance)
    diff = np.abs(recon_np - img)
    print("max abs diff:", diff.max())
    assert diff.max() < 1e-5 or diff.max() < 1e-3, "Round-trip error too large"
    print("Round-trip OK")

if __name__ == '__main__':
    run_test()


patches shape: (1, 36, 3, 3, 5) info: {'new_h': 6, 'new_w': 6, 'ph': 3, 'pw': 3, 'sh': 3, 'sw': 3, 'padding': 'SAME'}
max abs diff: 0.0
Round-trip OK


In [6]:
# tests/test_attention_tf.py
import numpy as np
import tensorflow as tf
from model.attention_tf import AttentionBlock as AttentionTF

def demo():
    tf.random.set_seed(0)
    B = 2
    psize = 3
    pstride = 3
    H = psize * pstride
    W = H
    C = 8
    x = tf.random.normal((B, H, W, C), dtype=tf.float32)

    # create attention module
    att = AttentionTF(embed_dim=C, patch_size=psize, patch_stride=pstride, proj_ratio=4, attn_drop=0.0)
    # build by calling once
    dummy = att(x, block_idx=tf.constant([0,1], dtype=tf.int32), match_vec=tf.constant([0,0,0,0], dtype=tf.float32), return_attn=False)

    # Case 1: match_vec all zeros -> AFB behavior (self suppressed)
    num_blocks_global = 10
    match_vec = tf.zeros([num_blocks_global], dtype=tf.float32)
    block_idx = tf.constant([0, 1], dtype=tf.int32)  # pretend these are indices into match_vec
    out_afb, attn_afb = att(x, block_idx=block_idx, match_vec=match_vec, return_attn=True, training=False)
    print("attn_afb shape:", attn_afb.shape)  # [B, N, N]
    # inspect diagonal mean
    diag_afb = tf.linalg.diag_part(attn_afb)  # [B, N]
    print("AFB diagonal mean:", tf.reduce_mean(diag_afb).numpy())

    # Case 2: match_vec has ones -> BFB behavior (for those batch samples)
    match_vec2 = tf.ones([num_blocks_global], dtype=tf.float32)
    out_bfb, attn_bfb = att(x, block_idx=block_idx, match_vec=match_vec2, return_attn=True, training=False)
    diag_bfb = tf.linalg.diag_part(attn_bfb)
    print("BFB diagonal mean:", tf.reduce_mean(diag_bfb).numpy())

    # Expect diag_afb mean << diag_bfb mean
    print("diag_afb mean:", float(tf.reduce_mean(diag_afb)))
    print("diag_bfb mean:", float(tf.reduce_mean(diag_bfb)))

if __name__ == '__main__':
    demo()


attn_afb shape: (2, 9, 9)
AFB diagonal mean: 0.0
BFB diagonal mean: 0.8201668
diag_afb mean: 0.0
diag_bfb mean: 0.820166826248169


In [7]:
# tests/test_gtb_net.py
import numpy as np
import tensorflow as tf
from model.gtblock_tf import NetTF

def run_test():
    B = 4
    psize = 3
    pstride = 3
    H = psize * pstride
    W = H
    in_chans = 8
    embed_dim = 16

    # create random blocks in channels-last
    x = tf.random.normal((B, H, W, in_chans), dtype=tf.float32)
    # dummy match_vec and block_idx
    num_blocks_global = 100
    match_vec = tf.zeros([num_blocks_global], dtype=tf.float32)
    block_idx = tf.constant([0,1,2,3], dtype=tf.int32)

    net = NetTF(in_chans=in_chans, embed_dim=embed_dim, patch_size=psize, patch_stride=pstride, mlp_ratio=2.0, proj_ratio=4)
    out = net(x, block_idx=block_idx, match_vec=match_vec, training=True)
    print("in shape:", x.shape, "out shape:", out.shape)
    assert out.shape == x.shape
    print("GTB/Net forward OK")

if __name__ == '__main__':
    run_test()


in shape: (4, 9, 9, 8) out shape: (4, 9, 9, 8)
GTB/Net forward OK


In [8]:
# tests/test_block_search_integration.py
import numpy as np, tensorflow as tf
from model.block_tf import BlockEmbeddingTF, BlockFoldTF
from model.block_search_tf import BlockSearchTF

def run_test():
    # synthetic image
    H = 18; W = 18; C = 4
    rng = np.random.RandomState(2)
    img = rng.rand(H, W, C).astype(np.float32)
    img_tf = tf.expand_dims(tf.convert_to_tensor(img), axis=0)  # [1,H,W,C]

    # extractor params (no-overlap assumption: stride == patch size)
    ph = 3; pw = 3; sh = 3; sw = 3
    be = BlockEmbeddingTF(patch_h=ph, patch_w=pw, stride_h=sh, stride_w=sw, padding='SAME')
    patches_orig, info = be.extract(img_tf)  # [1, N, ph, pw, C]
    N = patches_orig.shape[1]
    block_query = tf.reshape(patches_orig[0], (N, -1))  # [N, L]

    # simulate search_matrix == original reconstructions
    search_matrix = patches_orig[0].numpy()  # [N, ph, pw, C]
    # wrap as batched:
    search_batched = np.expand_dims(search_matrix, axis=0)  # [1,N,ph,pw,C]

    bf = BlockFoldTF()
    # fold back:
    recon = bf.fold(tf.convert_to_tensor(search_batched), info, orig_H=H, orig_W=W)  # [1,H,W,C]

    # run BlockSearchTF
    bs = BlockSearchTF(block_embedding=be, block_query=block_query)
    match_vec = bs.compute_match_vec_from_batched_search_matrix(search_batched, info, orig_H=H, orig_W=W)
    print("match_vec sum (should be N):", int(tf.reduce_sum(match_vec).numpy()), "N:", N)
    assert int(tf.reduce_sum(match_vec).numpy()) == int(N)
    print("BlockSearch integration test PASSED")

if __name__ == '__main__':
    run_test()


match_vec sum (should be N): 36 N: 36
BlockSearch integration test PASSED
