# Shared code block

In [None]:
import os
import cv2
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
)
from tensorflow.keras.optimizers import Adam


IMAGE_SIZE = (256, 256)
BASE_PREP = "/Users/amayakof/Desktop/2025_autumn/deep_learning/SIS/3/project/data/preprocessed"
MODEL_DIR = "/Users/amayakof/Desktop/2025_autumn/deep_learning/SIS/3/project/models/upd"

os.makedirs(MODEL_DIR, exist_ok=True)


# ------------------ LOADER ------------------
def load_images_from_folder(folder):
    imgs = []
    files = sorted([
        f for f in os.listdir(folder)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ])

    for fname in files:
        path = os.path.join(folder, fname)

        img = cv2.imread(path, cv2.IMREAD_COLOR)
        if img is None:
            print("Skipping:", path)
            continue

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, IMAGE_SIZE)
        img = img.astype("float32") / 255.0

        imgs.append(img)

    return np.array(imgs)


# U-net

In [None]:


# ------------------ U-NET ARCHITECTURE ------------------
def build_unet(input_shape=(256, 256, 3)):
    inp = Input(shape=input_shape)

    # Encoder
    c1 = Conv2D(32, 3, activation="relu", padding="same")(inp)
    c1 = Conv2D(32, 3, activation="relu", padding="same")(c1)
    p1 = MaxPooling2D()(c1)

    c2 = Conv2D(64, 3, activation="relu", padding="same")(p1)
    c2 = Conv2D(64, 3, activation="relu", padding="same")(c2)
    p2 = MaxPooling2D()(c2)

    c3 = Conv2D(128, 3, activation="relu", padding="same")(p2)
    c3 = Conv2D(128, 3, activation="relu", padding="same")(c3)
    p3 = MaxPooling2D()(c3)

    bn = Conv2D(256, 3, activation="relu", padding="same")(p3)
    bn = Conv2D(256, 3, activation="relu", padding="same")(bn)

    # Decoder
    u3 = UpSampling2D()(bn)
    u3 = Concatenate()([u3, c3])
    d3 = Conv2D(128, 3, activation="relu", padding="same")(u3)
    d3 = Conv2D(128, 3, activation="relu", padding="same")(d3)

    u2 = UpSampling2D()(d3)
    u2 = Concatenate()([u2, c2])
    d2 = Conv2D(64, 3, activation="relu", padding="same")(u2)
    d2 = Conv2D(64, 3, activation="relu", padding="same")(d2)

    u1 = UpSampling2D()(d2)
    u1 = Concatenate()([u1, c1])
    d1 = Conv2D(32, 3, activation="relu", padding="same")(u1)
    d1 = Conv2D(32, 3, activation="relu", padding="same")(d1)

    out = Conv2D(3, 3, activation="sigmoid", padding="same")(d1)

    return Model(inp, out)

In [None]:


# ------------------ LOSS FUNCTIONS ------------------

def mae_basic(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred))

def nightvis_weighted_mae(y_true, y_pred):
    """Make bright areas more important (fixes model black collapse)"""
    lum = tf.reduce_mean(y_true, axis=-1, keepdims=True)
    weights = 1 + 4 * lum
    return tf.reduce_mean(weights * tf.abs(y_true - y_pred))



# ------------------ SHARED TRAIN FUNCTION ------------------

def train_style(style_name, loss_fn, lr=1e-4, epochs=200, batch_size=2):
    print(f"\n=== Training style: {style_name} ===")

    input_folder = f"{BASE_PREP}/input"
    style_folder = f"{BASE_PREP}/style/{style_name}"

    X = load_images_from_folder(input_folder)
    Y = load_images_from_folder(style_folder)

    if len(X) != len(Y):
        raise ValueError(f"Image count mismatch for style {style_name}")

    X_train, X_val, y_train, y_val = train_test_split(
        X, Y, test_size=0.2, random_state=42
    )

    model = build_unet()
    model.compile(
        optimizer=Adam(lr),
        loss=loss_fn,
        metrics=["mae"]
    )

    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=epochs,
        batch_size=batch_size,
        verbose=1
    )

    out_path = f"{MODEL_DIR}/autoencoder_{style_name}_upd.keras"
    model.save(out_path)
    print(f"Saved model: {out_path}")

    return history

In [None]:
# ---- Train BLUR ----
train_style(
    style_name="blur",
    loss_fn=mae_basic,
    lr=1e-4,
    epochs=150,
    batch_size=2
)


In [None]:
# ---- Train POSTER ----
train_style(
    style_name="poster",
    loss_fn=mae_basic,
    lr=5e-5,
    epochs=200,
    batch_size=2
)


In [None]:
# ---- Train OUTLINE ----
train_style(
    style_name="outline",
    loss_fn=mae_basic,
    lr=1e-4,
    epochs=200,
    batch_size=2
)


In [None]:
# ---- Train NIGHT VISION (hard style!) ----
train_style(
    style_name="night_vis",
    loss_fn=nightvis_weighted_mae,
    lr=1e-4,
    epochs=350,
    batch_size=2
)
