In [None]:
# terminology:
# obs_mask:  1 = observed,   0 = hole
# hole_mask: 1 = hole,       0 = not hole (observed)


# Absolutely lovely state of typesafety in this godforsaken ecosystem. We will use underscores a lot ...
# I will try to minimize usage of tf.cast


# we place emphasis on determinism

# only change one parameter (optimizer/dataset/hyperparams/overall architecture)in experiments


from datetime import datetime
import time,pytz
import os, glob, subprocess, sys, uuid
from tqdm import tqdm

import psycopg2
from psycopg2.extras import execute_values
from psycopg2.extras import Json

conn = psycopg2.connect(
    dbname="mldb", user="navid", password="123", host="pg.network.navidmafi.com", port="443",
)
conn.autocommit = True
cur = conn.cursor()

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["PYTHONHASHSEED"] = "42"
os.environ["TF_DETERMINISTIC_OPS"] = "1"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "1"
import tensorflow as tf


config = tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.9
config.gpu_options.allow_growth = True
session = tf.compat.v1.InteractiveSession(config=config)

import keras
import random
import numpy as np
import matplotlib.pyplot as plt

GLOBAL_RANDOM_SEED = 42

random.seed(GLOBAL_RANDOM_SEED)
np.random.seed(GLOBAL_RANDOM_SEED)
tf.random.set_seed(GLOBAL_RANDOM_SEED)
tf.config.experimental.enable_op_determinism()
tf.config.threading.set_intra_op_parallelism_threads(1)
tf.config.threading.set_inter_op_parallelism_threads(1)
initializer = keras.initializers.GlorotUniform(seed=GLOBAL_RANDOM_SEED)

In [None]:
EPOCHS_TO_TRAIN=40
num_samples = 5000

BATCH_SIZE = 32


H = 64
W = 64
C = 3

P = 8
assert H == W
assert H % P == 0

h = 8

D_model = 1024
D_head = 128
D_fcn = 1024
num_layers = 8
N = (H * W) // (P * P)

MASK_MAX_SIZE = 20
MASK_MIN_SIZE = 20
assert MASK_MAX_SIZE < H/2

optimizer = keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-2)
# optimizer = keras.optimizers.Adam(learning_rate=1e-4)

# keras.mixed_precision.set_global_policy("mixed_bfloat16")
FLOAT = tf.float32

In [None]:
def viz_img(img):
    img = tf.cast(img, tf.float32)
    plt.imshow(tf.squeeze(img).numpy(), cmap="gray")
    plt.colorbar()
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    plt.text(
        0.99,
        0.01,
        timestamp,
        ha="right",
        va="bottom",
        transform=plt.gca().transAxes,
        fontsize=12,
        color="white",
        alpha=0.9,
        bbox=dict(boxstyle="square,pad=0.1", facecolor="black", alpha=0.3),
    )
    plt.show()


def viz_mask(mask):
    plt.imshow(tf.squeeze(mask).numpy(), cmap="gray", vmin=0, vmax=1)
    plt.colorbar()
    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    plt.text(
        0.99,
        0.01,
        timestamp,
        ha="right",
        va="bottom",
        transform=plt.gca().transAxes,
        fontsize=12,
        color="white",
        alpha=0.9,
        bbox=dict(boxstyle="square,pad=0.1", facecolor="black", alpha=0.3),
    )
    plt.show()


def viz_grid(batch: tf.Tensor, max: int = 4):
    batch_size: int = batch.shape[0]  # type: ignore
    num = min(batch_size, max)
    fig, axes = plt.subplots(nrows=1, ncols=num, figsize=(15, 15), dpi=300)
    if num == 1:
        axes = [axes]
    for i in range(num):
        # Original image
        axes[i].imshow(
            tf.clip_by_value(
                tf.cast(batch[i], dtype=tf.float32), 0, 1  # type: ignore
            ).numpy()  # type: ignore
        )
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()


def random_obs_mask(H, W, MASK_MIN_SIZE, MASK_MAX_SIZE):
    "-> [0,1]^{H x W x C}, dtype=FLOAT"
    w, h = [np.random.randint(MASK_MIN_SIZE, MASK_MAX_SIZE + 1) for _ in range(2)]
    x = np.random.randint(0, W - w + 1)
    y = np.random.randint(0, H - h + 1)

    # Start with all ones (observed pixels)
    mask = tf.ones((H, W), dtype=FLOAT)

    ys = tf.range(y, y + h)
    xs = tf.range(x, x + w)
    yy, xx = tf.meshgrid(ys, xs, indexing="ij")
    indices = tf.stack([yy, xx], axis=-1)
    indices = tf.reshape(indices, (-1, 2))

    # Set the hole region to zero
    updates = tf.zeros((h * w,), dtype=FLOAT)
    mask = tf.tensor_scatter_nd_update(mask, indices, updates)

    return tf.expand_dims(mask, -1)


def load_and_validate(file_path, ds_shape_advertised):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_image(
        img, channels=ds_shape_advertised[2], expand_animations=False
    )
    img = tf.divide(tf.cast(img, dtype=FLOAT), 255.0)
    is_valid = tf.reduce_all(tf.equal(tf.shape(img), tf.constant(ds_shape_advertised)))

    return img, is_valid

In [None]:
ds_shape_advertised = (512, 512, 3)
dataset_path = "/mnt/Data/ML/datasets/portraits"


all_files = [
    os.path.join(dataset_path, f)
    for f in os.listdir(dataset_path)
    if f.endswith((".jpg", ".png"))
]
random.shuffle(all_files)
selected_files = all_files[:num_samples]
dataset = tf.data.Dataset.from_tensor_slices(selected_files)
dataset = dataset.map(lambda img: load_and_validate(img, ds_shape_advertised))
dataset = dataset.filter(lambda img, is_valid: is_valid)
dataset = dataset.map(lambda img, is_valid: img)
dataset = dataset.map(lambda img: tf.image.resize(img, (H, W)))

# tf.print(next(iter(dataset.take(1)))[0].dtype)
# with tf.device("/cpu:0"):
valid_count = dataset.reduce(
    tf.constant(0, dtype=tf.int32), lambda x, _: x + 1
).numpy()  # type: ignore

print(f"Valid images count: {valid_count}")
assert valid_count, "Everything's gone"

masks = [
    random_obs_mask(H, W, MASK_MIN_SIZE, MASK_MAX_SIZE) for _ in range(valid_count)
]
mask_ds = tf.data.Dataset.from_tensor_slices(masks)

ds = tf.data.Dataset.zip((dataset, mask_ds))


train_count = int(valid_count * 0.8)
test_count = int(valid_count * 0.1)
val_count = valid_count - train_count - test_count

train_ds = ds.take(train_count).batch(BATCH_SIZE)
test_ds = ds.skip(train_count).take(test_count).batch(BATCH_SIZE)
val_ds = ds.skip(train_count + test_count).take(val_count).batch(BATCH_SIZE)

train_batches = -(train_count // -BATCH_SIZE)
test_batches = -(test_count // -BATCH_SIZE)
val_batches = -(val_count // -BATCH_SIZE)

print("Trn,Tst,Val:", train_count, test_count, val_count)
print("Trn,Tst,Val batches:", train_batches, test_batches, val_batches)
viz_grid(next(iter(train_ds.take(1).map(lambda img, mask: img))), 8)  # type: ignore
viz_grid(next(iter(train_ds.take(1).map(lambda img, mask: mask))), 8)  # type: ignore

In [None]:
def extract_patches(image: tf.Tensor) -> tf.Tensor:
    "R^{BS x H x W x C} -> R^{BS x N x P^2 x C}"

    patches: tf.Tensor = tf.image.extract_patches(
        images=image,
        sizes=[1, P, P, 1],
        strides=[1, P, P, 1],
        rates=[1, 1, 1, 1],
        padding="VALID",
    )
    BS, H_prime, W_prime, _ = tf.unstack(tf.shape(patches))

    # Reshape patches to [BS, H' * W', P*P, C]
    patches = tf.reshape(patches, [BS, H_prime * W_prime, P * P, -1])

    return patches


def patches_to_imgs(patches: tf.Tensor) -> tf.Tensor:
    "R^{BS x N x P.P.C} -> R^{BS x H x W x C}"
    BS = tf.shape(patches)[0]
    grid_size = H // P  # same as W // P
    patches = tf.reshape(patches, [BS, grid_size, grid_size, P, P, C])
    patches = tf.transpose(patches, perm=[0, 1, 3, 2, 4, 5])
    image = tf.reshape(patches, [BS, grid_size * P, grid_size * P, C])

    return image


sample = tf.expand_dims(next(iter(dataset.take(1))), 0)
tf.assert_equal(patches_to_imgs(extract_patches(sample)), sample)


def create_attention_mask(obs_mask_bool: tf.Tensor):
    "R^{BS x H x W} -> R^{BS x N x N}"
    # TF does not support native min pooling.
    # The mask shown is OBSERVATION MASK meaning 0 means missing.

    BS = tf.shape(obs_mask_bool)[0]
    mask_pooled = tf.nn.max_pool2d(
        tf.cast(tf.logical_not(obs_mask_bool), dtype=tf.int8),
        ksize=[P, P],
        strides=[P, P],
        padding="VALID",
    )
    mask_pooled = tf.logical_not(tf.cast(mask_pooled, tf.bool))
    # viz_mask(mask_pooled)
    mask_pooled = tf.reshape(mask_pooled, [BS, N])
    mask_expanded = tf.expand_dims(mask_pooled, axis=1)  # (BS, 1, N)
    mask_expanded = tf.tile(mask_expanded, [1, N, 1])  # (BS, N, N)
    A = tf.where(
        mask_expanded,
        tf.constant(0.0, dtype=FLOAT),  # zero penanly
        tf.constant(-float("inf"), dtype=FLOAT),  # inf penalty
    )
    return A


def apply_obsv_mask(image: tf.Tensor, obs_mask_float: tf.Tensor) -> tf.Tensor:
    return tf.multiply(image, obs_mask_float)


def reconstruct(
    original: tf.Tensor, reconstruct: tf.Tensor, obs_mask_float: tf.Tensor
) -> tf.Tensor:
    obs_mask_bool = tf.cast(obs_mask_float, dtype=tf.bool)
    return tf.add(
        tf.multiply(tf.cast(obs_mask_bool, FLOAT), original),
        tf.multiply(tf.cast(tf.logical_not(obs_mask_bool), FLOAT), reconstruct),
    )


sample = next(iter(train_ds.unbatch().take(1)))
print(sample[0].dtype)
viz_img(sample[0])
# viz_mask(sample[1])

In [None]:
commonDense = {"dtype": FLOAT, "kernel_initializer": initializer}


class PatchEmbedding(keras.layers.Layer):
    def __init__(self):
        super().__init__(dtype=FLOAT)
        self.proj = keras.layers.Dense(D_model, **commonDense)  # (P² * C) -> D_model

    def build(self, input_shape):
        self.positional_embedding = self.add_weight(
            shape=(N, D_model), initializer=initializer, name="pos_embed"
        )

    def call(self, patches_flat: tf.Tensor):
        # R^{BS x N x (P^2 . C)} -> R^{BS x N x D_model}
        X = self.proj(patches_flat)
        X += self.positional_embedding
        return X


class MultiHeadAttention(keras.layers.Layer):
    def __init__(self):
        super().__init__(dtype=FLOAT)
        # Project to h * D_head dimensions
        self.W_Q = keras.layers.Dense(h * D_head, **commonDense)
        self.W_K = keras.layers.Dense(h * D_head, **commonDense)
        self.W_V = keras.layers.Dense(h * D_head, **commonDense)
        # Project back to D_model
        self.W_O = keras.layers.Dense(D_model, **commonDense)

    def call(self, X, A):
        """
        X: R^{BS x N x D_model}, A: R^{BS x N x N} -> R^{BS x N x D_model}
        """

        # In the standard implementation, each head has its own separate projection matrices. However, a common optimization is to project the input into h * D_head dimensions (which is D_model) with a single large projection, then split into h heads. So, if D_model = h * D_head, then using a Dense(D_model) for Q, K, V and then splitting into h heads each of D_head is equivalent to having h separate projections. This is a standard approach because it's more efficient to compute all heads in parallel with a single matrix multiplication rather than h separate ones.
        # So the optimal way is to use combined projections.
        Q = self.W_Q(X)  # (BS, N, h * D_head)
        K = self.W_K(X)  # (BS, N, h * D_head)
        V = self.W_V(X)  # (BS, N, h * D_head)

        Q = tf.reshape(Q, (-1, N, h, D_head))  # (BS, N, h, D_head)
        K = tf.reshape(K, (-1, N, h, D_head))
        V = tf.reshape(V, (-1, N, h, D_head))

        # Transpose for attention computation
        Q = tf.transpose(Q, [0, 2, 1, 3])  # (BS, h, N, D_head)
        K = tf.transpose(K, [0, 2, 1, 3])
        V = tf.transpose(V, [0, 2, 1, 3])
        # scaled dot-product attention
        attn_scores = tf.matmul(Q, K, transpose_b=True)  # (BS, h, N, N)
        attn_scores /= tf.math.sqrt(
            tf.cast(D_head, attn_scores.dtype)
        )  # scale by sqrt(D_head)

        A = tf.expand_dims(A, 1)  # (BS, 1, N, N)
        attn_scores += A  # Broadcast to all heads

        attn_weights = tf.nn.softmax(attn_scores, axis=-1)  # (BS, h, N, N)

        output = tf.matmul(attn_weights, V)  # (BS, h, N, D_head)
        output = tf.transpose(output, [0, 2, 1, 3])  # (BS, N, h, D_head)
        output = tf.reshape(output, (-1, N, h * D_head))  # (BS, N, h * D_head)
        output = self.W_O(output)  # (BS, N, D_model)
        return output


class KMultiHeadAttention(keras.layers.Layer):
    def __init__(self):
        super().__init__(dtype=FLOAT)
        self.mha = keras.layers.MultiHeadAttention(
            num_heads=h,
            key_dim=D_head,
            output_shape=D_model,
            use_bias=True,
            dtype=FLOAT,
        )

    def call(self, X, attn_mask):
        "R^{BS x N x D_model}, R^{BS x N x N} -> R^{BS x N x D_model}"
        return self.mha(X, X, attention_mask=attn_mask)


class TransformerBlock(keras.layers.Layer):
    def __init__(self):
        super().__init__(dtype=FLOAT)
        self.attn = KMultiHeadAttention()
        self.norm1 = keras.layers.LayerNormalization(dtype=FLOAT)
        self.norm2 = keras.layers.LayerNormalization(dtype=FLOAT)
        self.ffn = keras.Sequential(
            [
                keras.layers.Dense(
                    D_fcn, activation="relu", **commonDense
                ),  # Switched to ReLU
                keras.layers.Dense(D_model, **commonDense),
            ]
        )

    def call(self, X, A):
        "R^{BS x N x D_model} -> R^{BS x N x D_model}"

        A = tf.cast(A, dtype=tf.bool)
        # NEW : pre norm blocks
        X_norm = self.norm1(X)
        X_attn = self.attn(X_norm, A)
        X = X + X_attn
        X_norm2 = self.norm2(X)
        X_ffn = self.ffn(X_norm2)
        X = X + X_ffn
        return X


class Refiner(keras.layers.Layer):
    def __init__(self):
        super().__init__(dtype=FLOAT)
        self.conv1 = keras.layers.Conv2D(64, 3, padding="same", activation="relu")
        self.conv2 = keras.layers.Conv2D(
            128, 3, dilation_rate=2, padding="same", activation="relu"
        )
        self.conv3 = keras.layers.Conv2D(3, 3, padding="same")

    def call(self, X):
        "R^{BS x H x W x C} -> R^{BS x H x W x C}"
        residual = X
        X = self.conv1(X)
        X = self.conv2(X)
        X = self.conv3(X)
        return X + residual


class TransformerDecoder(keras.layers.Layer):
    "R^{BS x N x D_model} -> R^{BS x N x P*P*C}"

    def __init__(self):
        super().__init__(dtype=FLOAT)
        self.proj1 = keras.layers.Dense(D_model, activation="gelu", **commonDense)
        self.proj2 = keras.layers.Dense(P * P * C, activation="sigmoid", **commonDense)

    def call(self, X):
        BS = tf.shape(X)[0]
        X = self.proj1(X)
        X = self.proj2(X)
        return tf.reshape(X, (BS, N, P, P, C))


class ImageInpaintingTransformer(keras.Model):
    """outputs R ^ {BS x H x W x C}"""

    def __init__(self):
        super().__init__(dtype=FLOAT)
        self.embed = PatchEmbedding()
        self.transformer_blocks = [TransformerBlock() for _ in range(num_layers)]
        self.decoder = TransformerDecoder()
        self.refiner = Refiner()

    def build(self, input_shape):
        BS = input_shape[0]
        # dummy_images = tf.zeros((BS, H, W, C), dtype=FLOAT)  # THIS WASTED 40 MINUTES
        self.call(*next(iter(val_ds.take(1))))
        self.built = True

    def call(self, image, obs_mask_float):
        image = tf.multiply(image, obs_mask_float)
        obs_mask_bool = tf.cast(obs_mask_float, dtype=tf.bool)
        # viz_img(image[0])
        patches = extract_patches(image)
        AttnMask = create_attention_mask(obs_mask_bool)
        # viz_img(AttnMask[0])

        BS = tf.shape(patches)[0]
        patches_flat = tf.reshape(patches, [BS, N, P**2 * C])
        # tf.print(tf.shape(patches_flat))
        X = self.embed(patches_flat)
        for block in self.transformer_blocks:
            X = block(X, AttnMask)

        decoded = self.decoder(X)
        return patches_to_imgs(decoded)


model = ImageInpaintingTransformer()
model.build((BATCH_SIZE, H, W, C))
# model.compute_output_shape((BATCH_SIZE, H, W, C))
model.summary()

run_id = uuid.uuid4()
print(run_id)
session_epochs = 0
session_steps = 0


def costfunc(
    y_true: tf.Tensor, y_pred: tf.Tensor, obs_mask_float: tf.Tensor
) -> tf.RaggedTensor:
    errors = tf.square(tf.subtract(y_true, y_pred))
    hole_mask = 1.0 - obs_mask_float  # type: ignore
    
    hole_loss = tf.reduce_sum(errors * hole_mask) / (tf.reduce_sum(hole_mask) + 1e-8)
    valid_loss = tf.reduce_sum(errors * obs_mask_float) / (tf.reduce_sum(obs_mask_float) + 1e-8)

    return 2 * hole_loss + 1 * valid_loss


def train_step(img: tf.Tensor, mask: tf.Tensor) -> tf.RaggedTensor:
    with tf.GradientTape() as tape:
        pred = model(img, mask, training=True)
        loss = costfunc(img, pred, mask)
    gradients = tape.gradient(loss, model.trainable_variables)
    if gradients is None:
        raise RuntimeError()
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


def val_step(image: tf.Tensor, mask: tf.Tensor) -> tf.RaggedTensor:
    reconstructed_img = model(image, mask, training=False)
    loss = costfunc(image, reconstructed_img, mask)
    return loss

import inspect
import textwrap
loss_source = textwrap.dedent(inspect.getsource(costfunc))
cur.execute(
    "INSERT INTO runs (run_id, started_at, hyperparams) VALUES (%s, %s, %s)",
    (
        str(run_id),
        datetime.now(tz=pytz.UTC),
        Json(
            {
                "num_samples": num_samples,
                "optimizer": optimizer.get_config(),
                "D_model": D_model,
                "D_head": D_head,
                "D_fcn": D_fcn,
                "H": H,
                "W": W,
                'P' : P,
                "heads": h,
                'Float': str(FLOAT),
                "num_layers": num_layers,
                "batch_size": BATCH_SIZE,
                "loss": loss_source
            }
        ),
    ),
)

In [None]:
# name = "decoder_unet_1.keras"
# model.load_weights(name)
# model.save(name)

In [None]:
assert isinstance(session_epochs, int)
assert isinstance(session_steps, int)
print(
    f"Starting training for {EPOCHS_TO_TRAIN - session_epochs} "
    f"epochs (already did {session_epochs})"
)

best_val_loss = float("inf")
best_epoch = -1
for _ in range(EPOCHS_TO_TRAIN - session_epochs):
    epoch_loss = 0.0
    steps = 0
    pbar = tqdm(
        train_ds,
        desc=f"Epoch {session_epochs+1}",
        unit="step",
        total=train_batches,
    )
    for image_batch, mask_batch in pbar:
        training_step_loss = train_step(image_batch, mask_batch).numpy()
        epoch_loss += training_step_loss
        steps += 1
        session_steps += 1
        pbar.set_postfix(loss=f"{float(training_step_loss):.4f}")
    train_loss = epoch_loss / steps
    cur.execute(
        "INSERT INTO metrics (run_id, epoch, phase, value, logged_at, metric) VALUES (%s, %s, %s, %s, %s, %s)",
        (
            str(run_id),
            int(session_epochs+1),
            "train",
            float(train_loss),
            datetime.now(tz=pytz.UTC),
            "train",
        ),
    )

    val_loss_total = 0.0
    val_steps = 0
    pbar_val = tqdm(
        val_ds,
        desc=f"Epoch {session_epochs+1} Validation",
        unit="step",
        total=val_batches,
    )
    for val_image_batch, val_mask_batch in pbar_val:
        val_step_loss = val_step(val_image_batch, val_mask_batch).numpy()
        val_loss_total += val_step_loss
        val_steps += 1
        pbar_val.set_postfix(loss=f"{float(val_step_loss):.4f}")

    val_loss = val_loss_total / val_steps

    cur.execute(
        "INSERT INTO metrics (run_id, epoch, phase, value, logged_at, metric) VALUES (%s, %s, %s, %s, %s, %s)",
        (
            str(run_id),
            int(session_epochs+1),
            "val",
            float(val_loss),
            datetime.now(tz=pytz.UTC),
            "val",
        ),
    )

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = session_epochs + 1
        model.save("best_run.keras")
    print(
        f"Epoch {session_epochs+1} Summary:\n Steps:{steps} | Train Loss = {train_loss:.4f} | Validation Loss = {val_loss:.4f}"
    )
    session_epochs += 1

In [None]:
img, obs_mask_float = next(iter(train_ds.take(1)))
viz_grid(img)
viz_grid(apply_obsv_mask(img, obs_mask_float))
model_out = model(img, obs_mask_float)
reconstructed = reconstruct(img, model(img, obs_mask_float), obs_mask_float)
# viz_grid(reconstructed)
viz_grid(model_out)

# viz_img(model_out[0])
# viz_img(img[0])

In [None]:
# Qualitative Eval
# visualize_unbatched_dataset(test_ds, 5)


# img = tf.image.decode_image(
#     tf.io.read_file("/home/navid/Dev/PaperTex/impl/naruto")
#     , dtype=tf.float32)
# img = tf.image.resize_with_crop_or_pad(img, H, W)
# img = tf.expand_dims(img, 0)
# tf.print(tf.shape(img))
# obvmask = tf.expand_dims(random_visibility_mask(),0)
# tf.print(tf.shape(obvmask))