In [1]:
import os
import glob
from tqdm import tqdm

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import silence_tensorflow.auto
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

tf.debugging.enable_check_numerics()

import keras
import random
import numpy as np
import matplotlib.pyplot as plt

E0000 00:00:1738499132.209200  399033 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738499132.215984  399033 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = 1
config.gpu_options.per_process_gpu_memory_fraction = 1
session = tf.compat.v1.InteractiveSession(config=config)

I0000 00:00:1738499136.556610  399033 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3794 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


In [3]:
H = 256
W = 256
C = 3
h = 8

P = 16     
assert(H == W)
assert(H % P == 0)

D_model = 1024
D_head = 128
D_fcn = 2048   
num_layers = 4

N = (H * W) // (P * P)
BS = 16

In [4]:
FLOAT = tf.float32

In [5]:
DS_SHAPE = (512,512,3)
def load_and_validate(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_image(img, channels=C, expand_animations=False)
    img = tf.divide(tf.cast(img, dtype=FLOAT), 255.0)
    is_valid = tf.reduce_all(tf.equal(tf.shape(img), tf.constant(DS_SHAPE)))

    return img, is_valid


dataset_path = "/mnt/Data/ML/datasets/portraits"
num_samples = 1000


all_files = [
    os.path.join(dataset_path, f)
    for f in os.listdir(dataset_path)
    if f.endswith((".jpg", ".png"))
]
random.shuffle(all_files)
selected_files = all_files[:num_samples]
dataset = tf.data.Dataset.from_tensor_slices(selected_files)
dataset = dataset.map(load_and_validate)
dataset = dataset.filter(lambda img, is_valid: is_valid)  # Keep valid images
dataset = dataset.map(lambda img, is_valid: img)  # remove unused feature
dataset = dataset.map(lambda img: tf.image.resize(img, (H,W))) 
print(f"Total files: {len(selected_files)}")

with tf.device('/cpu:0'):
    valid_count = dataset.reduce(tf.constant(0, dtype=tf.int32), lambda x, _: x + 1).numpy()

print(f"Valid images count: {valid_count}")
assert valid_count, "Everything's gone"

I0000 00:00:1738499136.724807  399033 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3794 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


Total files: 1000
Valid images count: 1000


In [6]:
def viz_img(img):
    img = tf.cast(img, tf.float32)
    plt.imshow(tf.squeeze(img).numpy(), cmap="gray")
    plt.colorbar()
    plt.show()


def viz_mask(mask):
    plt.imshow(tf.squeeze(mask).numpy(), cmap="gray", vmin=0, vmax=1)
    plt.colorbar()
    plt.show()

In [7]:
def prepare_sample(image):
    mask = random_visibility_mask()
    return image, mask



def random_visibility_mask():
    x1 = tf.random.uniform(shape=(), minval=0, maxval=W - 100, dtype=tf.int32)
    y1 = tf.random.uniform(shape=(), minval=0, maxval=H - 100, dtype=tf.int32)
    x2 = tf.random.uniform(shape=(), minval=x1 + 50, maxval=W + 1, dtype=tf.int32)
    y2 = tf.random.uniform(shape=(), minval=y1 + 50, maxval=H + 1, dtype=tf.int32)
    # tf.print(x1,x2,y1,y2)

    mask = tf.ones((H, W), dtype=tf.bool)
    mask = tf.tensor_scatter_nd_update(
        mask,
        indices=tf.stack(
            [
                tf.repeat(tf.range(y1, y2), x2 - x1),
                tf.tile(tf.range(x1, x2), [y2 - y1]),
            ],
            axis=-1,
        ),
        updates=tf.zeros([(y2 - y1) * (x2 - x1)], dtype=tf.bool),
    )
    return tf.expand_dims(mask, -1)  # expand channel wise

ds_masks = dataset.map(prepare_sample)
train_count = int(valid_count * 0.8)
test_count = int(valid_count * 0.1)
val_count = valid_count - train_count - test_count

train_ds = ds_masks.take(train_count).batch(BS)
test_ds = ds_masks.skip(train_count).take(test_count).batch(BS)
val_ds = ds_masks.skip(train_count + test_count).take(val_count).batch(BS)
print(train_count, test_count,val_count)

800 100 100


In [None]:

# def mask_area(mask):
#     return tf.reduce_sum(tf.cast(mask, tf.int32))


def extract_patches(image: tf.Tensor) -> tf.Tensor:
    "R^{BS x H x W x C} -> R^{BS x N x P^2 x C}"
    # print(image.dtype)

    patches: tf.Tensor = tf.image.extract_patches(
        images=image,  # Add batch dim
        sizes=[1, P, P, 1],  # Patch size
        strides=[1, P, P, 1],  # Step size
        rates=[1, 1, 1, 1],  # No dilation
        padding="VALID",
    )
    BS, H_prime, W_prime, _ = tf.unstack(tf.shape(patches))

    # Reshape patches to [BS, H' * W', P*P, C]
    patches = tf.reshape(patches, [BS, H_prime * W_prime, P * P, -1])

    return patches


def patches_to_imgs(patches: tf.Tensor) -> tf.Tensor:
    "R^{BS x N x P^2 x C} -> R^{BS x H x W x C}"
    BS = tf.shape(patches)[0]
    grid_size = H // P  # same as W // P

    patches = tf.reshape(patches, [BS, grid_size, grid_size, P, P, C])
    patches = tf.transpose(patches, perm=[0, 1, 3, 2, 4, 5])

    image = tf.reshape(patches, [BS, grid_size * P, grid_size * P, C])

    return image


def process_mask(obvmask: tf.Tensor):
    "R^{BS x H x W} -> tuple[R^{BS x N}, R^{BS x N x N}]"
    # TF does not support min pooling. the mask shown here is OBSERVATION MASK meaning 0 means missing. the inpaint mask is a negation of that

    # viz_mask(mask)
    BS = tf.shape(obvmask)[0]
    mask_pooled = tf.nn.max_pool2d(
        tf.cast(
            tf.logical_not(obvmask), dtype=tf.int8
        ),  # insane shit happaned here with mask_inverted
        ksize=[P, P],
        strides=[P, P],
        padding="VALID",
    )
    mask_pooled = tf.logical_not(tf.cast(mask_pooled, tf.bool))
    # viz_mask(mask_pooled)
    mask_pooled = tf.reshape(mask_pooled, [BS, N])
    mask_expanded = tf.expand_dims(mask_pooled, axis=1)  # (BS, 1, N)
    mask_expanded = tf.tile(mask_expanded, [1, N, 1])  # (BS, N, N)
    A = tf.where(
        mask_expanded,
        tf.constant(0.0, dtype=FLOAT),  # zero penanly
        tf.constant(-float("inf"), dtype=FLOAT),  # inf penalty
    )
    # tf.print(tf.shape(A))

    return mask_pooled, A

In [8]:
class PatchEmbedding(keras.layers.Layer):
    def __init__(self):
        super().__init__(dtype=FLOAT)  
        self.proj = keras.layers.Dense(D_model, dtype=FLOAT)  # (P² * C) -> D_model

    def build(self, input_shape):
        positions = tf.range(N, dtype=FLOAT)
        positions = tf.expand_dims(positions, 1)  # (N, 1)
        i = tf.range(D_model//2, dtype=FLOAT)
        div_term = tf.exp(
        (2.0 * i) * (-tf.math.log(10000.0) / D_model)
        )
        angles = positions * div_term  # (N, D_model//2)
        sin_terms = tf.sin(angles)
        cos_terms = tf.cos(angles)
            
        self.positional_embedding = tf.reshape(tf.stack([sin_terms, cos_terms], axis=-1), [N, D_model])
        # tf.print(self.positional_embedding)

    def call(self, patches_flat: tf.Tensor):
        # R^{BS x N x (P^2 . C)} -> R^{BS x N x D_model}
        X = self.proj(patches_flat) 
        X += self.positional_embedding 
        return X


class MultiHeadAttention(keras.layers.Layer):
    def __init__(self):
        super().__init__(dtype=FLOAT)
        # Project to h * D_head dimensions
        self.W_Q = keras.layers.Dense(h * D_head, dtype=FLOAT)
        self.W_K = keras.layers.Dense(h * D_head, dtype=FLOAT)
        self.W_V = keras.layers.Dense(h * D_head, dtype=FLOAT)
        # Project back to D_model
        self.W_O = keras.layers.Dense(D_model, dtype=FLOAT)

    def call(self, X, A):
        # X: R^{BS x N x D_model}
        # A: R^{BS x N x N}
        # returns: R^{BS x N x D_model}

        # In the standard implementation, each head has its own separate projection matrices. However, a common optimization is to project the input into h * D_head dimensions (which is D_model) with a single large projection, then split into h heads. So, if D_model = h * D_head, then using a Dense(D_model) for Q, K, V and then splitting into h heads each of D_head is equivalent to having h separate projections. This is a standard approach because it's more efficient to compute all heads in parallel with a single matrix multiplication rather than h separate ones.
        # So the optimal way is to use combined projections.
        Q = self.W_Q(X)  # (BS, N, h * D_head)
        K = self.W_K(X)  # (BS, N, h * D_head)
        V = self.W_V(X)  # (BS, N, h * D_head)
        

        Q = tf.reshape(Q, (-1, N, h, D_head))  # (BS, N, h, D_head)
        K = tf.reshape(K, (-1, N, h, D_head))
        V = tf.reshape(V, (-1, N, h, D_head))

        # Transpose for attention computation
        Q = tf.transpose(Q, [0, 2, 1, 3])  # (BS, h, N, D_head)
        K = tf.transpose(K, [0, 2, 1, 3])
        V = tf.transpose(V, [0, 2, 1, 3])
        # scaled dot-product attention
        attn_scores = tf.matmul(Q, K, transpose_b=True)  # (BS, h, N, N)
        attn_scores /= tf.math.sqrt(
            tf.cast(D_head, attn_scores.dtype)
        )  # scale by sqrt(D_head)

        A = tf.expand_dims(A, 1)  # (BS, 1, N, N)
        attn_scores += A  # Broadcast to all heads

        attn_weights = tf.nn.softmax(attn_scores, axis=-1)  # (BS, h, N, N)

        output = tf.matmul(attn_weights, V)  # (BS, h, N, D_head)
        output = tf.transpose(output, [0, 2, 1, 3])  # (BS, N, h, D_head)
        output = tf.reshape(output, (-1, N, h * D_head))  # (BS, N, h * D_head)
        output = self.W_O(output)  # (BS, N, D_model)
        return output


class TransformerBlock(keras.layers.Layer):
    def __init__(self):
        super().__init__(dtype=FLOAT)
        self.attn = MultiHeadAttention()
        self.norm1 = keras.layers.LayerNormalization(dtype=FLOAT)
        self.norm2 = keras.layers.LayerNormalization(dtype=FLOAT)
        self.ffn = keras.Sequential(
            [
                keras.layers.Dense(D_fcn, activation="gelu", dtype=FLOAT),
                keras.layers.Dense(D_model, dtype=FLOAT),
                # keras.layers.Dropout(0.1, dtype=FLOAT),
            ],
        )

    def call(self, X, A):
        "R^{N x D_model} -> R^{N x D_model}"
        X = self.norm1(X + self.attn(X, A))
        X = self.norm2(X + self.ffn(X))
        return X


class Decoder(keras.layers.Layer):
    def __init__(self):
        super().__init__(dtype=FLOAT)
        self.proj = keras.layers.Dense(P * P * C, dtype=FLOAT)

    def call(self, X):
        "R^{BS x N x D_model} -> R^{BS x N x P x P x C}"
        BS = tf.shape(X)[0]
        X = self.proj(X)
        X = tf.reshape(X, (BS, N, P, P, C))
        return X


class ImageInpaintingTransformer(keras.Model):
    def __init__(self):
        super().__init__(dtype=FLOAT)
        self.embed = PatchEmbedding()
        self.transformer_blocks = [TransformerBlock() for _ in range(num_layers)]
        self.decoder = Decoder()

    def build(self, input_shape):
        BS = input_shape[0]
        # dummy_images = tf.zeros((BS, H, W, C), dtype=FLOAT)  # THIS WASTED 40 MINUTES
        self.call(*next(iter(val_ds.take(1))))
        self.built = True

    def call(self, image, obvmask):
        image = tf.multiply(image, tf.cast(obvmask,FLOAT))
        # viz_img(image[0])
        patches = extract_patches(image)
        mask_pooled, A = process_mask(obvmask)
        # tf.print(tf.shape(A))
        # viz_img(A)

        BS = tf.shape(patches)[0]
        patches_flat = tf.reshape(patches, [BS, N, P**2 * C])
        # tf.print(tf.shape(patches_flat))
        X = self.embed(patches_flat)
        for block in self.transformer_blocks:
            X = block(X, A)
        reconstructed_patches = self.decoder(X)  # R^{BS x N x P x P x C}
        return patches_to_imgs(reconstructed_patches)

model = ImageInpaintingTransformer()
model.build((BS, H, W, C))
model.summary()

InvalidArgumentError: {{function_node __wrapped__CheckNumericsV2_device_/job:localhost/replica:0/task:0/device:GPU:0}} 

!!! Detected Infinity or NaN in output 0 of eagerly-executing op "SelectV2" (# of outputs: 1) !!!
  dtype: <dtype: 'float32'>
  shape: (16, 256, 256)
  # of -Inf elements: 305920

  Input tensors (3):
         0: tf.Tensor(
[[[ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  ...
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]]

 [[ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  ...
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]]

 [[ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  ...
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]]

 ...

 [[ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  ...
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]]

 [[ True  True  True ... False False  True]
  [ True  True  True ... False False  True]
  [ True  True  True ... False False  True]
  ...
  [ True  True  True ... False False  True]
  [ True  True  True ... False False  True]
  [ True  True  True ... False False  True]]

 [[ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  ...
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]]], shape=(16, 256, 256), dtype=bool)
         1: tf.Tensor(0.0, shape=(), dtype=float32)
         2: tf.Tensor(-inf, shape=(), dtype=float32)

 : Tensor had -Inf values [Op:CheckNumericsV2] name: 

In [9]:
# # Tests patching and depatching

# img = next(iter(dataset.take(1)))

# viz_img(img)
# patched = extract_patches(tf.expand_dims(img ,0))
# tf.print(tf.shape(patched))

# recreated_img = patches_to_imgs(patched)
# viz_img(recreated_img)

In [10]:
# model.load_weights("best_run.keras")

In [11]:
session_epochs = 0

In [None]:
@tf.function
def costfunc(y_true: tf.Tensor, y_pred: tf.Tensor, obsvmask: tf.Tensor):
    errors = tf.abs(tf.subtract(y_true, y_pred))
    inpaintmask = tf.cast(tf.logical_not(obsvmask), FLOAT)
    masked_errors = tf.multiply(errors, inpaintmask)
    sum_masked_errors = tf.reduce_sum(masked_errors)
    area = tf.reduce_sum(inpaintmask)
    masked_loss = sum_masked_errors / (area + keras.backend.epsilon())

    global_loss = tf.reduce_mean(errors)

    return masked_loss + global_loss

optimizer = keras.optimizers.Adam(learning_rate=1e-5 , clipvalue=1.0)

@tf.function
def train_step(image: tf.Tensor, mask: tf.Tensor):
    with tf.GradientTape() as tape:
        reconstructed_img = model(image, mask)  # N x P x P x C
        loss = costfunc(image, reconstructed_img, mask)
        tf.debugging.check_numerics(loss, "Loss contains NaN or Inf.") 
        # if (tf.math.is_nan(loss)):
            # raise Exception("Divergence")
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


@tf.function
def val_step(image: tf.Tensor, mask: tf.Tensor):
    reconstructed_img = model(image, mask, training = False)
    loss = costfunc(image, reconstructed_img, mask)
    return loss


epochs = 50
print("Starting training")

best_val_loss = float("inf")
best_epoch = -1
for _ in range(epochs):
    epoch_loss = 0.0
    steps = 0
    pbar = tqdm(train_ds, desc=f"Epoch {session_epochs+1}", unit="batch", total=train_count // BS)
    for image_batch, mask_batch in pbar:
        loss = train_step(image_batch, mask_batch)
        epoch_loss += loss
        steps += 1
        # Dynamically update the tqdm bar without spamming stdout
        pbar.set_postfix(loss=f"{loss:.4f}")
    train_loss = epoch_loss / steps

    val_loss_total = 0.0
    val_steps = 0
    pbar_val = tqdm(val_ds, desc=f"Epoch {session_epochs+1} Validation", unit="batch", total=val_count // BS)
    for val_image_batch, val_mask_batch in pbar_val:
        loss = val_step(val_image_batch, val_mask_batch)
        val_loss_total += loss
        val_steps += 1
        pbar_val.set_postfix(loss=f"{loss:.4f}")
    avg_val_loss = val_loss_total / val_steps
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_epoch = session_epochs + 1
        model.save("best_run.keras")
    print(f"Epoch {session_epochs+1} Summary: Train Loss = {train_loss:.4f} | Validation Loss = {avg_val_loss:.4f}")
    session_epochs+=1

In [15]:
# model.save("b4.keras")
# model.load_weights("best_run.keras")

In [29]:
def apply_obsv_mask(image: tf.Tensor, obvmask: tf.Tensor) -> tf.Tensor:
    return tf.multiply(image, tf.cast(obvmask, FLOAT))


def viz_grid(batch: tf.Tensor):
    batch_size: int = batch.shape[0]  # type: ignore
    fig, axes = plt.subplots(nrows=1, ncols=batch_size, figsize=(15, 15), dpi=300)
    if batch_size == 1:
        axes = [axes]
    for i in range(batch_size):
        # Original image
        axes[i].imshow(tf.cast(batch[i], dtype=tf.float32).numpy())  # type: ignore
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()


def reconstruct(original: tf.Tensor, reconstruct: tf.Tensor, obvmask: tf.Tensor):
    return tf.add(
        tf.multiply(tf.cast(obvmask, FLOAT), original),
        tf.multiply(tf.cast(tf.logical_not(obvmask), FLOAT), reconstruct),
    )

In [None]:
# Qualitative Eval
# visualize_unbatched_dataset(test_ds, 5)


# img = tf.image.decode_image(
#     tf.io.read_file("/home/navid/Dev/PaperTex/impl/naruto")
#     , dtype=tf.float32)
# img = tf.image.resize_with_crop_or_pad(img, H, W)
# img = tf.expand_dims(img, 0)
# tf.print(tf.shape(img))
# obvmask = tf.expand_dims(random_visibility_mask(),0)
# tf.print(tf.shape(obvmask))


img, obvmask = next(iter(test_ds.take(1)))
viz_grid(img)
viz_grid(apply_obsv_mask(img, obvmask))
model_out = model(img, obvmask)
# reconstructed = reconstruct(img, model(img, obvmask), obvmask)
viz_grid(model_out)

# viz_img(model_out[0])
# viz_img(img[0])