In [5]:
# =====================================================
# CPU-ONLY IMAGE CAPTION INFERENCE (TRAINING-MATCHED TOKENIZER)
# =====================================================

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization
import random

print("TensorFlow version:", tf.__version__)
print("GPUs visible:", tf.config.list_physical_devices("GPU"))

# =====================================================
# CONFIG (MUST MATCH TRAINING)
# =====================================================
MODEL_PATH = "/home/kavir/image_project/model_checkpoints/best_caption_model.keras"
CAPTIONS_FILE = "/home/kavir/image_project/captions.txt"
TEST_IMAGE_DIR = "/home/kavir/image_project/testImages"

IMAGE_SIZE = (224, 224)
SEQ_LENGTH = 20
VOCAB_SIZE = 5000
EMBED_DIM = 256

# =====================================================
# IMAGE LOADER (PNG/JPG SAFE)
# =====================================================
def load_image(path):
    img = tf.io.read_file(path)
    img = tf.io.decode_image(img, channels=3, expand_animations=False)  # jpg/png safe
    img = tf.image.resize(img, IMAGE_SIZE)
    img = tf.cast(img, tf.float32) / 255.0
    return img

# =====================================================
# CUSTOM TRANSFORMER LAYERS (MUST MATCH TRAINING)
# =====================================================
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, num_heads=4, ff_dim=256, **kwargs):
        super().__init__(**kwargs)
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation="relu"),
            layers.Dense(embed_dim)
        ])
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        attn_output = self.att(inputs, inputs)
        x = self.norm1(inputs + attn_output)
        return self.norm2(x + self.ffn(x))


class TransformerDecoder(layers.Layer):
    def __init__(self, vocab_size, embed_dim, num_heads=4, ff_dim=512, **kwargs):
        super().__init__(**kwargs)
        self.embed = layers.Embedding(vocab_size, embed_dim)
        self.att1 = layers.MultiHeadAttention(num_heads, embed_dim)
        self.att2 = layers.MultiHeadAttention(num_heads, embed_dim)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation="relu"),
            layers.Dense(embed_dim)
        ])
        self.norm1 = layers.LayerNormalization()
        self.norm2 = layers.LayerNormalization()
        self.norm3 = layers.LayerNormalization()
        self.out = layers.Dense(vocab_size)

    def call(self, x, enc_output):
        x = self.embed(x)
        x = self.norm1(x + self.att1(x, x))
        x = self.norm2(x + self.att2(x, enc_output))
        x = self.norm3(x + self.ffn(x))
        return self.out(x)

# =====================================================
# LOAD MODEL
# =====================================================
print("\nLoading model...")
model = tf.keras.models.load_model(
    MODEL_PATH,
    custom_objects={
        "TransformerEncoder": TransformerEncoder,
        "TransformerDecoder": TransformerDecoder
    },
    compile=False
)
print("✓ Model loaded")

# =====================================================
# REBUILD TOKENIZER (MATCH TRAINING EXACTLY)
# =====================================================
print("\nRebuilding tokenizer (MATCH TRAINING)...")

tokenizer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_sequence_length=SEQ_LENGTH,
    standardize="lower_and_strip_punctuation"
)

# 1) Build captions dict exactly like training
captions_dict = {}
malformed_count = 0

with open(CAPTIONS_FILE, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()

        if not line or line.startswith("image,caption"):
            continue

        parts = line.split(",", 1)
        if len(parts) < 2:
            malformed_count += 1
            continue

        img = parts[0].strip()
        cap = parts[1].strip()

        if not img:
            malformed_count += 1
            continue
        if len(cap) == 0:
            continue

        captions_dict.setdefault(img, []).append(cap)

print(f"✓ Total images with captions: {len(captions_dict)}")
print(f"✗ Malformed lines skipped: {malformed_count}")

# 2) Build caption_seqs exactly like training
caption_seqs = []
for img, caps in captions_dict.items():  # dict insertion order preserved
    for c in caps:
        caption_seqs.append(str("sos " + c + " eos"))

# 3) Adapt tokenizer exactly like training
adapt_dataset = tf.data.Dataset.from_tensor_slices(caption_seqs)
tokenizer.adapt(adapt_dataset)

vocab = tokenizer.get_vocabulary()
vocab_set = set(vocab)

print("✓ Tokenizer rebuilt (training-matched)")
print("✓ Vocabulary size:", len(vocab))

FALLBACK_WORD = "a" if "a" in vocab_set else vocab[1]

# =====================================================
# CAPTION GENERATION
# =====================================================
def generate_caption(image_path, max_length=20, min_length=3, debug_steps=5):
    img = load_image(image_path)
    img = tf.expand_dims(img, 0)

    caption = "sos"

    for i in range(max_length):
        seq = tokenizer([caption])
        preds = model([img, seq], training=False)

        next_id = int(tf.argmax(preds[0, -1]).numpy())
        next_word = vocab[next_id]

        # Debug: show first few predicted words
        if i < debug_steps:
            print(f"  step {i+1}: id={next_id} word={next_word}")

        # Prevent early stop / unknown too early
        if i < min_length and next_word in ("eos", "[UNK]"):
            next_word = FALLBACK_WORD

        if next_word == "eos":
            break

        if next_word == "[UNK]":
            next_word = FALLBACK_WORD

        caption += " " + next_word

    result = caption.replace("sos ", "").strip()

    if not result or result == "sos":
        return "(no caption generated)"
    return result

# =====================================================
# TEST
# =====================================================
print("\nRunning sanity test...")

if not os.path.exists(TEST_IMAGE_DIR):
    raise FileNotFoundError(f"Test image folder not found: {TEST_IMAGE_DIR}")

images = [
    os.path.join(TEST_IMAGE_DIR, f)
    for f in os.listdir(TEST_IMAGE_DIR)
    if f.lower().endswith((".jpg", ".jpeg", ".png"))
]

if len(images) == 0:
    raise RuntimeError(f"No .jpg/.jpeg/.png images found in: {TEST_IMAGE_DIR}")

for img in random.sample(images, min(5, len(images))):
    print("\nImage:", os.path.basename(img))
    print("Caption:", generate_caption(img))

print("\n✓ Inference finished (CPU only)")


TensorFlow version: 2.20.0
GPUs visible: []

Loading model...
✓ Model loaded

Rebuilding tokenizer (MATCH TRAINING)...
✓ Total images with captions: 8091
✗ Malformed lines skipped: 0


2026-01-27 20:13:58.926119: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


✓ Tokenizer rebuilt (training-matched)
✓ Vocabulary size: 5000

Running sanity test...

Image: image.png
  step 1: id=0 word=
  step 2: id=0 word=
  step 3: id=0 word=
  step 4: id=0 word=
  step 5: id=0 word=
Caption: (no caption generated)

✓ Inference finished (CPU only)
