In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [14]:
# # ===================================================
# # Building Blocks
# # ===================================================
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x): return self.block(x)

# ===================================================
# Model Variants
# ===================================================

class PixelMLP(nn.Module):
    def __init__(self, input_dim=6144, latent_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 1024), nn.ReLU(),
            nn.Linear(1024, 256), nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.ReLU(),
            nn.Linear(256, 1024), nn.ReLU(),
            nn.Linear(1024, input_dim)
        )

    def forward(self, x):
        # (B, 6144)
        z = self.encoder(x)
        y = self.decoder(z)
        return y, z


class PatchCNN(nn.Module):
    def __init__(self, in_ch=6144, latent_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, 256, 3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128), nn.ReLU(),
            nn.Linear(128, 256), nn.ReLU()
        )
        self.out_conv = nn.Conv2d(256, in_ch, 3, padding=1)

    def forward(self, x):
        # (B, C=6144, 16, 16)
        z = self.encoder(x)
        B = x.shape[0]
        feat = self.decoder(z).view(B, 256, 1, 1)
        y = self.out_conv(feat.expand(B, 256, x.shape[2], x.shape[3]))
        return y, z

# ------------------------------
# UNet AutoEncoder with dynamic latent layers
# ------------------------------
class UNetAutoEncoder(nn.Module):
    def __init__(self, in_ch=6144, latent_dim=128, latent_patch=False):
        super().__init__()
        self.in_ch = in_ch
        self.latent_patch = latent_patch
        self.latent_dim = latent_dim

        # Encoder
        self.inc = DoubleConv(in_ch, 64)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(64, 128))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(128, 256))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(256, 512))

        # Latent layers (created dynamically)
        self.to_latent = None
        self.from_latent = None

    def _build_latent_layers(self, x4_shape):
        B, C, Hs, Ws = x4_shape
        flat_dim = C * Hs * Ws

        self.to_latent = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flat_dim, self.latent_dim),
            nn.ReLU(inplace=True)
        )
        self.from_latent = nn.Sequential(
            nn.Linear(self.latent_dim, flat_dim),
            nn.ReLU(inplace=True)
        )

        # Decoder
        self.up1 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv1 = DoubleConv(256, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = DoubleConv(128, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv3 = DoubleConv(64, 64)
        self.outc = nn.Conv2d(64, self.in_ch, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        if self.to_latent is None:
            self._build_latent_layers(x4.shape)

        z = self.to_latent(x4)

        if self.latent_patch:
            B = x.shape[0]
            z_reshaped = z.view(B, self.latent_dim, 25, 25)
            out = torch.randn_like(x)
            return out, z_reshaped
        else:
            B, _, Hs, Ws = x4.shape
            x_dec = self.from_latent(z).view(B, 512, Hs, Ws)
            x_dec = self.conv1(self.up1(x_dec))
            x_dec = self.conv2(self.up2(x_dec))
            x_dec = self.conv3(self.up3(x_dec))
            out = self.outc(x_dec)
            return out, z

# ===================================================
# Unified Wrapper
# ===================================================
class UnifiedReflectanceAutoEncoder(nn.Module):
    def __init__(self, mode="pixel", **kwargs):
        super().__init__()
        if mode == "pixel":
            self.model = PixelMLP(**kwargs)
        elif mode == "patch":
            self.model = PatchCNN(**kwargs)
        elif mode == "unet_patch":
            self.model = UNetAutoEncoder(latent_patch=False, **kwargs)
        elif mode == "unet_full":
            self.model = UNetAutoEncoder(latent_patch=False, **kwargs)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    def forward(self, x): return self.model(x)


# ===================================================
# Demo
# ===================================================
if __name__ == "__main__":
    cfgs = {
        # "pixel": (torch.rand(2, 6144), {}),
        # "patch": (torch.rand(2, 6144, 16, 16), {}),
        # "unet_patch": (torch.rand(2, 6144, 100, 100), {}),
        "unet_full": (torch.rand(2, 6144, 100, 100), {}),
    }

    for mode, (x, extra) in cfgs.items():
        print(f"\n--- {mode.upper()} ---")
        model = UnifiedReflectanceAutoEncoder(mode=mode, latent_dim=128, **extra)
        y, z = model(x)
        print(f"Input: {tuple(x.shape)},  Latent: {tuple(z.shape)},  Output: {tuple(y.shape)}")



--- UNET_FULL ---
Input: (2, 6144, 100, 100),  Latent: (2, 128),  Output: (2, 6144, 96, 96)


In [17]:
# ===================================================
# CONFIGURATION (all parameters centralized here)
# ===================================================
CONFIG = {
    "pixel": {
        "input_dim": 6144,
        "latent_dim": 128,
        "hidden_layers": [1024, 256]
    },
    "patch": {
        "in_ch": 6144,
        "latent_dim": 128,
        "conv_channels": [256, 128],
        "kernel_size": 3,
        "padding": 1
    },
    "unet": {
        "in_ch": 6144,
        "latent_dim": 128,
        "encoder_channels": [64, 128, 256, 512],
        "kernel_size": 3,
        "padding": 1
    }
}


# ===================================================
# Building Blocks
# ===================================================
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, padding=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size, padding=padding),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


# ===================================================
# Model Variants
# ===================================================
class PixelMLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        input_dim = cfg["input_dim"]
        latent_dim = cfg["latent_dim"]
        hidden = cfg["hidden_layers"]

        # Encoder
        enc_layers = []
        prev = input_dim
        for h in hidden:
            enc_layers += [nn.Linear(prev, h), nn.ReLU()]
            prev = h
        enc_layers.append(nn.Linear(prev, latent_dim))
        self.encoder = nn.Sequential(*enc_layers)

        # Decoder
        dec_layers = [nn.Linear(latent_dim, hidden[-1]), nn.ReLU()]
        for i in reversed(range(len(hidden) - 1)):
            dec_layers += [nn.Linear(hidden[i + 1], hidden[i]), nn.ReLU()]
        dec_layers += [nn.Linear(hidden[0], input_dim)]
        self.decoder = nn.Sequential(*dec_layers)

    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        return y, z


class PatchCNN(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        in_ch = cfg["in_ch"]
        latent_dim = cfg["latent_dim"]
        conv_ch = cfg["conv_channels"]
        k = cfg["kernel_size"]
        p = cfg["padding"]

        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, conv_ch[0], k, padding=p), nn.ReLU(),
            nn.Conv2d(conv_ch[0], conv_ch[1], k, padding=p), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(conv_ch[1], latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, conv_ch[1]), nn.ReLU(),
            nn.Linear(conv_ch[1], conv_ch[0]), nn.ReLU()
        )
        self.out_conv = nn.Conv2d(conv_ch[0], in_ch, k, padding=p)

    def forward(self, x):
        z = self.encoder(x)
        B = x.shape[0]
        feat = self.decoder(z).view(B, -1, 1, 1)
        y = self.out_conv(feat.expand(B, feat.shape[1], x.shape[2], x.shape[3]))
        return y, z


class UNetAutoEncoder(nn.Module):
    def __init__(self, cfg, latent_patch=False):
        super().__init__()
        self.in_ch = cfg["in_ch"]
        self.latent_dim = cfg["latent_dim"]
        self.latent_patch = latent_patch
        enc_ch = cfg["encoder_channels"]
        k = cfg["kernel_size"]
        p = cfg["padding"]

        # Encoder
        self.inc = DoubleConv(self.in_ch, enc_ch[0], k, p)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(enc_ch[0], enc_ch[1], k, p))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(enc_ch[1], enc_ch[2], k, p))
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(enc_ch[2], enc_ch[3], k, p))

        # Latent layers built dynamically
        self.to_latent = None
        self.from_latent = None

    def _build_latent_layers(self, x4_shape):
        B, C, H, W = x4_shape
        flat_dim = C * H * W
        self.to_latent = nn.Sequential(nn.Flatten(), nn.Linear(flat_dim, self.latent_dim), nn.ReLU(True))
        self.from_latent = nn.Sequential(nn.Linear(self.latent_dim, flat_dim), nn.ReLU(True))

        self.up1 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv1 = DoubleConv(256, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = DoubleConv(128, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv3 = DoubleConv(64, 64)
        self.outc = nn.Conv2d(64, self.in_ch, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        if self.to_latent is None:
            self._build_latent_layers(x4.shape)

        z = self.to_latent(x4)

        if self.latent_patch:
            return torch.randn_like(x), z.view(x.shape[0], self.latent_dim, 25, 25)

        B, _, H, W = x4.shape
        x_dec = self.from_latent(z).view(B, 512, H, W)
        x_dec = self.conv1(self.up1(x_dec))
        x_dec = self.conv2(self.up2(x_dec))
        x_dec = self.conv3(self.up3(x_dec))
        return self.outc(x_dec), z


# ===================================================
# Unified Wrapper
# ===================================================
class UnifiedReflectanceAutoEncoder(nn.Module):
    def __init__(self, mode="pixel", cfg=None):
        super().__init__()
        if cfg is None:
            cfg = CONFIG  # fallback to global config
        if mode == "pixel":
            self.model = PixelMLP(cfg["pixel"])
        elif mode == "patch":
            self.model = PatchCNN(cfg["patch"])
        elif mode == "unet_patch":
            self.model = UNetAutoEncoder(cfg["unet"], latent_patch=True)
        elif mode == "unet_full":
            self.model = UNetAutoEncoder(cfg["unet"], latent_patch=False)
        else:
            raise ValueError(f"Unknown mode: {mode}")

    def forward(self, x):
        return self.model(x)


# ===================================================
# Demo
# ===================================================
if __name__ == "__main__":
    cfg = CONFIG

    test_inputs = {
        # "pixel": torch.rand(2, cfg["pixel"]["input_dim"]),
        # "patch": torch.rand(2, cfg["patch"]["in_ch"], 16, 16),
        "unet_full": torch.rand(2, cfg["unet"]["in_ch"], 104, 104)
    }

    for mode, x in test_inputs.items():
        print(f"\n--- {mode.upper()} ---")
        model = UnifiedReflectanceAutoEncoder(mode, cfg)
        y, z = model(x)
        print(f"Input: {tuple(x.shape)},  Latent: {tuple(z.shape)},  Output: {tuple(y.shape)}")



--- UNET_FULL ---
Input: (2, 6144, 104, 104),  Latent: (2, 128),  Output: (2, 6144, 104, 104)


In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# --------------------------------------------------
# CONFIG
# --------------------------------------------------
config = {
    "model": {
        "in_channels": 96,
        "latent_dim": 128,
        "encoder_channels": [64, 128, 256, 512],
        "decoder_channels": [512, 256, 128, 64],
        "kernel_size": 3,
        "padding": 1,
        "spatial_latent": True,  # 🔁 toggle between spatial or global latent
    },
    "training": {
        "epochs": 3,
        "lr": 1e-3,
        "batch_size": 1,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
    },
    "data": {
        "N": 4,
        "C": 96,
        "H": 112,
        "W": 112,
    },
}


# --------------------------------------------------
# Double Conv
# --------------------------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3, padding=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size, padding=padding),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


# --------------------------------------------------
# Encoder
# --------------------------------------------------
class UNetEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        c = cfg["encoder_channels"]
        in_ch = cfg["in_channels"]
        latent_dim = cfg["latent_dim"]
        k = cfg["kernel_size"]
        p = cfg["padding"]
        self.spatial_latent = cfg["spatial_latent"]

        self.enc1 = DoubleConv(in_ch, c[0], k, p)
        self.enc2 = DoubleConv(c[0], c[1], k, p)
        self.enc3 = DoubleConv(c[1], c[2], k, p)
        self.enc4 = DoubleConv(c[2], c[3], k, p)

        if not self.spatial_latent:
            self.fc_mu = nn.Linear(c[3], latent_dim)

    def forward(self, x):
        B, _, H, W = x.shape
        x = self.enc1(x)
        x = F.max_pool2d(x, 2)
        x = self.enc2(x)
        x = F.max_pool2d(x, 2)
        x = self.enc3(x)
        x = F.max_pool2d(x, 2)
        x = self.enc4(x)
        x = F.max_pool2d(x, 2)

        _, C, H_out, W_out = x.shape
        if self.spatial_latent:
            return x, (H_out, W_out)
        else:
            x = F.adaptive_avg_pool2d(x, 1).view(B, -1)
            z = self.fc_mu(x)
            return z, (H_out, W_out)


# --------------------------------------------------
# Decoder
# --------------------------------------------------
class UNetDecoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        c = cfg["decoder_channels"]
        out_ch = cfg["in_channels"]
        latent_dim = cfg["latent_dim"]
        self.spatial_latent = cfg["spatial_latent"]

        self.fc = None  # initialized dynamically in forward

        self.up4 = nn.ConvTranspose2d(c[0], c[1], 2, stride=2)
        self.dec4 = DoubleConv(c[1], c[1])
        self.up3 = nn.ConvTranspose2d(c[1], c[2], 2, stride=2)
        self.dec3 = DoubleConv(c[2], c[2])
        self.up2 = nn.ConvTranspose2d(c[2], c[3], 2, stride=2)
        self.dec2 = DoubleConv(c[3], c[3])
        self.up1 = nn.ConvTranspose2d(c[3], out_ch, 2, stride=2)

    def forward(self, z, latent_hw, first_ch):
        if not self.spatial_latent:
            # create fc layer on first forward to match shape
            if self.fc is None:
                self.fc = nn.Linear(z.size(1), first_ch * latent_hw[0] * latent_hw[1]).to(z.device)
            B = z.size(0)
            x = self.fc(z).view(B, first_ch, *latent_hw)
        else:
            x = z

        x = self.dec4(self.up4(x))
        x = self.dec3(self.up3(x))
        x = self.dec2(self.up2(x))
        out = self.up1(x)
        return out


# --------------------------------------------------
# Combined
# --------------------------------------------------
class ReflectanceAutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.encoder = UNetEncoder(cfg)
        self.decoder = UNetDecoder(cfg)

    def forward(self, x):
        z, latent_hw = self.encoder(x)
        first_ch = self.cfg["decoder_channels"][0]
        out = self.decoder(z, latent_hw, first_ch)
        return out, z


# --------------------------------------------------
# Data + Training
# --------------------------------------------------
def get_random_dataloader(cfg):
    dcfg = cfg["data"]
    N, C, H, W = dcfg["N"], dcfg["C"], dcfg["H"], dcfg["W"]
    bs = cfg["training"]["batch_size"]
    data = torch.rand(N, C, H, W)
    ds = TensorDataset(data, data)
    return DataLoader(ds, batch_size=bs, shuffle=True)


def train_autoencoder(model, dataloader, cfg):
    tcfg = cfg["training"]
    device = tcfg["device"]
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=tcfg["lr"])
    loss_fn = nn.MSELoss()

    for epoch in range(tcfg["epochs"]):
        total_loss = 0.0
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            recon, z = model(x)
            loss = loss_fn(recon, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{tcfg['epochs']} | Avg Loss: {total_loss/len(dataloader):.6f}")

    x, _ = next(iter(dataloader))
    x = x.to(device)
    recon, z = model(x)
    print("\nEncoder/Decoder test:")
    print("  Input :", x.shape)
    print("  Latent:", (z.shape if isinstance(z, torch.Tensor) else 'tuple'))
    print("  Recon :", recon.shape)


# --------------------------------------------------
# Run
# --------------------------------------------------
if __name__ == "__main__":
    model = ReflectanceAutoEncoder(config["model"])
    loader = get_random_dataloader(config)
    train_autoencoder(model, loader, config)


Epoch 1/3 | Avg Loss: 0.321875
Epoch 2/3 | Avg Loss: 0.224959
Epoch 3/3 | Avg Loss: 0.154078

Encoder/Decoder test:
  Input : torch.Size([1, 96, 112, 112])
  Latent: torch.Size([1, 512, 7, 7])
  Recon : torch.Size([1, 96, 112, 112])


In [33]:
import os, glob, re
import torch
import numpy as np
import imageio.v3 as iio
from torch.utils.data import Dataset, DataLoader
from concurrent.futures import ThreadPoolExecutor


# ==================================================
# Utility functions
# ==================================================
pat = re.compile(r'(\d+)[_-](\d+)\.exr$')  # matches "row_col.exr" or "row-col.exr"

def exr_sort_key(path):
    m = pat.search(os.path.basename(path))
    return (int(m.group(1)), int(m.group(2))) if m else (10**9, 10**9)


def _read_exr(cfg, path):
    """Helper function for threaded EXR loading with slicing."""
    r0, r1 = cfg["row_slice"]
    c0, c1 = cfg["col_slice"]

    img = iio.imread(path).astype(np.float32)  # (H, W, 3)
    img = img[r0:r1, c0:c1, :]                 # crop region
    return torch.from_numpy(img).permute(2, 0, 1)  # (3, H, W)


def load_OLAT(cfg, file_paths, num_threads=8):
    """Load and stack OLAT images using multithreading."""
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        imgs = list(executor.map(lambda p: _read_exr(cfg, p), file_paths))
    return torch.cat(imgs, dim=0)  # (3*N, H, W)


# ==================================================
# Dataset
# ==================================================
class OLATDataset(Dataset):
    def __init__(self, cfg, num_threads=8):
        self.cfg = cfg
        self.data, self.names = [], []

        for folder in cfg["folders"]:
            data_dir = os.path.join(cfg["root_dir"], folder, "olat")
            file_paths = sorted(glob.glob(os.path.join(data_dir, "*.exr")), key=exr_sort_key)
            if not file_paths:
                print(f"⚠️ No EXRs found in {data_dir}")
                continue

            stacked = load_OLAT(cfg, file_paths, num_threads=num_threads)
            self.data.append(stacked)
            self.names.append(folder)

        print(f"✅ Loaded {len(self.data)} OLAT samples using {num_threads} threads")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Return same tensor as input & target for autoencoder training
        return self.data[idx], self.data[idx]


# ==================================================
# Dataloader factory (integrates with existing config)
# ==================================================
def get_dataloader(cfg, num_threads=8):
    dataset = OLATDataset(cfg, num_threads=num_threads)
    bs = cfg.get("batch_size", 1)
    return DataLoader(dataset, batch_size=bs, shuffle=False), dataset


if __name__ == "__main__":
    # Dataset config
    data_cfg = {
        "root_dir": "/home/gmh72/3DReconstruction/Blender_Rendering/data",
        "folders": ["diffuse_suzanne_white", "roger"],
        "row_slice": [88, 200],
        "col_slice": [88, 200],
        "batch_size": 1,
    }

    dataloader, dataset = get_dataloader(data_cfg, num_threads=8)

    # Detect input channels automatically
    in_channels = dataset[0][0].shape[0]
    print(f"Detected input channels: {in_channels}")

    # Merge into model config
    config["model"]["in_channels"] = in_channels

    model = ReflectanceAutoEncoder(config["model"])
    train_autoencoder(model, dataloader, config)


✅ Loaded 2 OLAT samples using 8 threads
Detected input channels: 6144
Epoch 1/3 | Avg Loss: 0.000416
Epoch 2/3 | Avg Loss: 0.000142
Epoch 3/3 | Avg Loss: 0.000067

Encoder/Decoder test:
  Input : torch.Size([1, 6144, 112, 112])
  Latent: torch.Size([1, 512, 7, 7])
  Recon : torch.Size([1, 6144, 112, 112])
