In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import List, Tuple
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.optim import Adam
from tqdm import tqdm

In [2]:

class DataPreprocessor:
    def __init__(self, location_mask, measurement_mask, dataset, downsample_factor=30):
        """
        Args:
            location_mask (np.ndarray): (H, W, 2) coordinates.
            measurement_mask (np.ndarray): (H, W) or (H, W, 1), indicating measurement locations.
            dataset (np.ndarray): (N, H, W), N = number of samples.
            downsample_factor (int): Number of samples to average together.
        """
        self.location_mask = location_mask
        self.measurement_mask = measurement_mask
        self.dataset = dataset
        self.downsample_factor = downsample_factor

    def downsample(self):
        """Downsample by averaging over 'downsample_factor' samples."""
        n_samples = self.dataset.shape[0]
        n_groups = n_samples // self.downsample_factor  # ✅ define this
        ds_data = []

        for i in range(n_groups):
            start = i * self.downsample_factor
            end = start + self.downsample_factor
            avg_sample = np.mean(self.dataset[start:end], axis=0)
            ds_data.append(avg_sample)

        return np.stack(ds_data)
    
    def get_prepared_data(self):
        X = self.downsample()
        return X


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple

class ConditionalVAE(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 loc_channels: int = 2,
                 mask_channels: int = 1,
                 latent_dim: int = 16,
                 hidden_dims: List[int] = [32, 64, 128, 256, 512],
                 img_size: int = 40):
        super().__init__()

        self.latent_dim = latent_dim
        self.img_size = img_size
        self.in_channels = in_channels
        self.loc_channels = loc_channels
        self.mask_channels = mask_channels

        # ---------------- Encoder ----------------
        enc_in_channels = in_channels + loc_channels + mask_channels
        self.hidden_dims_enc = hidden_dims
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(enc_in_channels, h_dim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU()
                )
            )
            enc_in_channels = h_dim
        self.encoder = nn.Sequential(*modules)

        # Compute encoder output sizes dynamically
        encoder_sizes = self.compute_encoder_sizes(img_size, hidden_dims)
        self.final_spatial = encoder_sizes[-1]
        flat_size = hidden_dims[-1] * self.final_spatial**2

        # Latent space
        self.fc_mu = nn.Linear(flat_size, latent_dim)
        self.fc_var = nn.Linear(flat_size, latent_dim)

        # ---------------- Decoder ----------------
        self.decoder_input = nn.Linear(
            latent_dim + (loc_channels + mask_channels) * img_size * img_size,
            flat_size
        )

        self.hidden_dims_dec = hidden_dims[::-1]

        dec_modules = []
        for i in range(len(self.hidden_dims_dec) - 1):
            dec_modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        self.hidden_dims_dec[i],
                        self.hidden_dims_dec[i+1],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=0  # always 0, we will pad later
                    ),
                    nn.BatchNorm2d(self.hidden_dims_dec[i+1]),
                    nn.LeakyReLU()
                )
            )
        self.decoder = nn.Sequential(*dec_modules)

        self.final_mu = nn.Conv2d(self.hidden_dims_dec[-1], out_channels=in_channels, kernel_size=3, padding=1)
        self.final_logvar = nn.Conv2d(self.hidden_dims_dec[-1], out_channels=in_channels, kernel_size=3, padding=1)

    # ---------------- Static methods ----------------
    @staticmethod
    def compute_encoder_sizes(img_size, hidden_dims, kernel_size=3, stride=2, padding=1):
        sizes = []
        size = img_size
        for _ in hidden_dims:
            size = (size + 2*padding - kernel_size) // stride + 1
            sizes.append(size)
        return sizes

    # ---------------- Forward methods ----------------
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        result = self.encoder(x)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        return mu, log_var

    def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps * std + mu

    def decode(self, z: torch.Tensor, loc_mask: torch.Tensor, meas_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Combine conditioning info
        cond = torch.cat([loc_mask, meas_mask], dim=1)
        cond_flat = cond.view(cond.size(0), -1)
        dec_input = torch.cat([z, cond_flat], dim=1)

        result = self.decoder_input(dec_input)
        result = result.view(-1, self.hidden_dims_dec[0], self.final_spatial, self.final_spatial)
        result = self.decoder(result)

        # Zero-pad to ensure output is img_size x img_size
        _, _, h, w = result.shape
        pad_h = self.img_size - h
        pad_w = self.img_size - w
        if pad_h > 0 or pad_w > 0:
            # pad = (left, right, top, bottom)
            result = F.pad(result, (0, pad_w, 0, pad_h))

        mu_pred = self.final_mu(result)
        logvar_pred = torch.clamp(self.final_logvar(result), -10, 10)
        return mu_pred, logvar_pred

    def forward(self, x: torch.Tensor, loc_mask: torch.Tensor, meas_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # Tile location mask if needed
        if loc_mask.dim() == 3:
            loc_mask = loc_mask.unsqueeze(0).repeat(x.size(0), 1, 1, 1)

        enc_in = torch.cat([x, loc_mask, meas_mask], dim=1)
        mu, log_var = self.encode(enc_in)
        z = self.reparameterize(mu, log_var)
        mu_pred, logvar_pred = self.decode(z, loc_mask, meas_mask)
        return mu_pred, logvar_pred, mu, log_var

    # ---------------- Loss function ----------------
    def loss_function(self, x: torch.Tensor, mu_pred: torch.Tensor, logvar_pred: torch.Tensor,
                      mu: torch.Tensor, log_var: torch.Tensor, kld_weight: float = 1e-3) -> dict:
        const = torch.log(torch.tensor(2.0 * torch.pi, device=x.device))
        recon_var = torch.exp(logvar_pred)
        nll_element = 0.5 * ((x - mu_pred)**2 / recon_var + logvar_pred + const)
        nll_loss = torch.mean(torch.sum(nll_element, dim=[1,2,3]))

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))
        total_loss = nll_loss + kld_weight * kld_loss
        return {"loss": total_loss, "NLL": nll_loss, "KLD": kld_loss}


In [None]:
""" N, H, W = 1500, 40, 40
location_mask = np.random.rand(H, W, 2)
measurement_mask = np.random.randint(0, 2, (H, W))
dataset = np.random.rand(N, H, W)

prep = DataPreprocessor(location_mask, measurement_mask, dataset, downsample_factor=30)
X = prep.get_prepared_data()  # shape (N//30, H, W)

# Convert to torch tensors
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)         # (batch, 1, H, W)
loc_mask = torch.tensor(location_mask, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)  # (1, 2, H, W)
meas_mask = torch.tensor(measurement_mask, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # (1, 1, H, W)

# Repeat masks for all samples in batch
loc_mask = loc_mask.repeat(X.size(0), 1, 1, 1)
meas_mask = meas_mask.repeat(X.size(0), 1, 1, 1)

# Create Dataset + DataLoader
dataset = TensorDataset(X, loc_mask, meas_mask)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_data, val_data = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
val_loader = DataLoader(val_data, batch_size=8, shuffle=False)

# ===============================================================
# 2️⃣ Initialize model + optimizer
# ===============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dims=[32, 64, 128, 256, 512]

model = ConditionalVAE(
    in_channels=1,
    loc_channels=2,
    mask_channels=1,
    latent_dim=16,
    img_size=H,
    hidden_dims=hidden_dims,
).to(device)

optimizer = Adam(model.parameters(), lr=1e-3)
epochs = 50
kld_weight = 1e-3  # you can anneal this if needed

# ===============================================================
# 3️⃣ Training loop
# ===============================================================
for epoch in range(epochs):
    model.train()
    train_loss, val_loss = 0.0, 0.0

    for x_batch, loc_batch, mask_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
        x_batch = x_batch.to(device)
        loc_batch = loc_batch.to(device)
        mask_batch = mask_batch.to(device)

        optimizer.zero_grad()
        mu_pred, logvar_pred, mu, log_var = model(x_batch,loc_batch, mask_batch)
        losses = model.loss_function(x_batch, mu_pred, logvar_pred, mu, log_var, kld_weight)
        loss = losses["loss"]

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x_batch.size(0)

    # Validation loop
    model.eval()
    with torch.no_grad():
        for x_batch, loc_batch, mask_batch in val_loader:
            x_batch = x_batch.to(device)
            loc_batch = loc_batch.to(device)
            mask_batch = mask_batch.to(device)
            mu_pred, logvar_pred, mu, log_var = model(x_batch, loc_batch, mask_batch)
            losses = model.loss_function(x_batch, mu_pred, logvar_pred, mu, log_var, kld_weight)
            val_loss += losses["loss"].item() * x_batch.size(0)

    train_loss /= len(train_loader.dataset)
    val_loss /= len(val_loader.dataset)
    print(f"Epoch [{epoch+1}/{epochs}] | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

# ===============================================================
# 4️⃣ Save model
# ===============================================================
torch.save(model.state_dict(), "sivae_sinr.pt")
print("✅ Training complete! Model saved.") """

                                                         

Epoch [1/50] | Train Loss: 1638.8605 | Val Loss: 1629.9093


                                                         

Epoch [2/50] | Train Loss: 1516.1191 | Val Loss: 1558.7402


                                                         

Epoch [3/50] | Train Loss: 1415.8355 | Val Loss: 1479.9607


                                                         

Epoch [4/50] | Train Loss: 1310.7165 | Val Loss: 1388.1142


                                                         

Epoch [5/50] | Train Loss: 1204.3517 | Val Loss: 1371.9483


                                                         

Epoch [6/50] | Train Loss: 1103.7463 | Val Loss: 1357.9556


                                                         

Epoch [7/50] | Train Loss: 1023.3861 | Val Loss: 1312.6687


                                                         

Epoch [8/50] | Train Loss: 964.4285 | Val Loss: 1306.8110


                                                         

KeyboardInterrupt: 

In [5]:
# [CELL 4 - REPLACEMENT]

N, H, W = 1500, 40, 40
location_mask = np.random.rand(H, W, 2)
# This original measurement_mask is now just a dummy placeholder for the DataPreprocessor,
# we will not use it in the model.
measurement_mask = np.random.randint(0, 2, (H, W)) 
dataset = np.random.rand(N, H, W)

prep = DataPreprocessor(location_mask, measurement_mask, dataset, downsample_factor=30)
X = prep.get_prepared_data()  # shape (N//30, H, W)

# ===============================================================
# 1️⃣ Data Setup with Dynamic Masking
# ===============================================================

# Convert to torch tensors
X_data = torch.tensor(X, dtype=torch.float32).unsqueeze(1) # (batch, 1, H, W)

# Create a *single* base location mask (coordinate grid)
# We'll repeat this mask per-batch in the training loop
loc_mask_base = torch.tensor(location_mask, dtype=torch.float32).permute(2, 0, 1) # (2, H, W)

# --- MODIFICATION START ---

# 1. Define the 21 measurement locations (as (y, x) coordinates)
# !!! NOTE: Replace this random selection with your *actual* 21 coordinates !!!
all_locations_flat = np.arange(H * W)
np.random.shuffle(all_locations_flat)
train_loc_indices_flat = all_locations_flat[:21]
val_loc_indices_flat = train_loc_indices_flat[:9] # Use first 9 for fixed validation

my_train_coords = [
    (5, 10), (5, 12), (8, 20), (8, 22), (10, 30), (10, 32),
    (15, 5), (15, 8), (17, 15), (17, 18), (20, 25), (20, 28),
    (25, 10), (25, 12), (28, 20), (28, 22), (30, 30), (30, 32),
    (35, 5), (35, 8), (38, 15)
]

# 2. Define your 9 fixed validation locations (a subset of the 21)
# !!! Replace these with your actual 9 validation coordinates !!!
my_val_coords = [
    (5, 10), (8, 20), (10, 30), (15, 5), (17, 15), 
    (20, 25), (25, 10), (28, 20), (30, 30)
]
# For Training
train_coords_y = torch.tensor([y for y, x in my_train_coords], dtype=torch.long)
train_coords_x = torch.tensor([x for y, x in my_train_coords], dtype=torch.long)
num_train_locs = len(train_coords_y) # This must be 21

# For Validation
val_coords_y = torch.tensor([y for y, x in my_val_coords], dtype=torch.long)
val_coords_x = torch.tensor([x for y, x in my_val_coords], dtype=torch.long)
num_train_locs = len(train_coords_y) # 21

# 2. Define k-sampling parameters
k_choices = torch.tensor([5, 9, 15, 18])
k_probs = torch.tensor([0.133, 0.6, 0.133, 0.133])

# --- MODIFICATION END ---

# Create Dataset (ONLY with X, the ground truth)
# The loc_mask and meas_mask will be generated on-the-fly
dataset = TensorDataset(X_data) 
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_data, val_data = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
val_loader = DataLoader(val_data, batch_size=8, shuffle=False)

# ===============================================================
# 2️⃣ Initialize model + optimizer
# ===============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dims=[32, 64, 128, 256, 512]

model = ConditionalVAE(
    in_channels=1,
    loc_channels=2,
    mask_channels=1,
    latent_dim=16,
    img_size=H,
    hidden_dims=hidden_dims,
).to(device)

optimizer = Adam(model.parameters(), lr=1e-3)
epochs = 50
kld_weight = 1e-3

# --- MODIFICATION: Move base tensors to device ---
loc_mask_base = loc_mask_base.to(device)
train_coords_y = train_coords_y.to(device)
train_coords_x = train_coords_x.to(device)
val_coords_y = val_coords_y.to(device)
val_coords_x = val_coords_x.to(device)
k_choices = k_choices.to(device)
k_probs = k_probs.to(device)
# ---

# ===============================================================
# 3️⃣ Training loop
# ===============================================================
for epoch in range(epochs):
    model.train()
    train_loss, val_loss = 0.0, 0.0

    # --- MODIFICATION: Loader only yields x_batch ---
    for (x_batch,) in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
        x_batch = x_batch.to(device) # This is the ground truth
        batch_size = x_batch.size(0)

        # 1. Create location mask batch (repeat the base grid)
        loc_batch = loc_mask_base.unsqueeze(0).repeat(batch_size, 1, 1, 1)
        
        # 2. Create dynamic measurement mask batch
        # Sample k for each item in the batch
        k_samples_indices = torch.multinomial(k_probs, num_samples=batch_size, replacement=True)
        k_values = k_choices[k_samples_indices]
        
        # Create masks one by one (since k varies)
        mask_batch = torch.zeros(batch_size, 1, H, W, device=device)
        for i in range(batch_size):
            k = k_values[i].item()
            # Randomly pick k indices from the 21 available
            indices_to_pick = torch.randperm(num_train_locs, device=device)[:k]
            # Get the y, x coordinates for these k points
            y = train_coords_y[indices_to_pick]
            x = train_coords_x[indices_to_pick]
            # Set the mask
            mask_batch[i, 0, y, x] = 1.0
        
        # 3. Create masked input
        x_input_batch = x_batch * mask_batch # Encoder sees only observed values

        optimizer.zero_grad()
        
        # Model takes masked input, but loss compares to ground truth
        mu_pred, logvar_pred, mu, log_var = model(x_input_batch, loc_batch, mask_batch)
        losses = model.loss_function(x_batch, mu_pred, logvar_pred, mu, log_var, kld_weight) # Loss uses original x_batch
        loss = losses["loss"]

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x_batch.size(0)

    # Validation loop
    model.eval()
    
    # --- MODIFICATION: Create the *fixed* 9-point validation mask *once* ---
    val_mask_base = torch.zeros(1, 1, H, W, device=device)
    val_mask_base[0, 0, val_coords_y, val_coords_x] = 1.0
    
    with torch.no_grad():
        # --- MODIFICATION: Loader only yields x_batch ---
        for (x_batch,) in val_loader: 
            x_batch = x_batch.to(device)
            batch_size = x_batch.size(0)
            
            # 1. Create loc_batch
            loc_batch = loc_mask_base.unsqueeze(0).repeat(batch_size, 1, 1, 1)
            
            # 2. Create mask_batch (repeat the fixed val mask)
            mask_batch = val_mask_base.repeat(batch_size, 1, 1, 1)
            
            # 3. Create masked input
            x_input_batch = x_batch * mask_batch
            
            mu_pred, logvar_pred, mu, log_var = model(x_input_batch, loc_batch, mask_batch)
            losses = model.loss_function(x_batch, mu_pred, logvar_pred, mu, log_var, kld_weight)
            val_loss += losses["loss"].item() * x_batch.size(0)

    train_loss /= len(train_loader.dataset)
    val_loss /= len(val_loader.dataset)
    print(f"Epoch [{epoch+1}/{epochs}] | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

# ===============================================================
# 4️⃣ Save model
# ===============================================================
#torch.save(model.state_dict(), "sivae_sinr.pt")
print("✅ Training complete! Model saved.")

                                                         

Epoch [1/50] | Train Loss: 1655.3141 | Val Loss: 1654.7728


                                                         

Epoch [2/50] | Train Loss: 1541.3323 | Val Loss: 1584.3346


                                                         

Epoch [3/50] | Train Loss: 1439.6963 | Val Loss: 1482.5812


                                                         

Epoch [4/50] | Train Loss: 1337.1976 | Val Loss: 1363.5209


                                                         

Epoch [5/50] | Train Loss: 1233.0296 | Val Loss: 1248.8597


                                                         

Epoch [6/50] | Train Loss: 1137.2128 | Val Loss: 1138.6981


                                                         

Epoch [7/50] | Train Loss: 1047.0307 | Val Loss: 1069.8155


                                                         

Epoch [8/50] | Train Loss: 990.3392 | Val Loss: 1082.3796


                                                         

Epoch [9/50] | Train Loss: 957.6909 | Val Loss: 1068.3966


                                                  

KeyboardInterrupt: 