In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple

# ---------------- Conditional VAE ----------------
class ConditionalVAE(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 loc_channels: int = 2,
                 mask_channels: int = 1,
                 latent_dim: int = 16,
                 hidden_dims: List[int] = [32, 64, 128, 256],
                 img_size: int = 40,
                 coord_decode: bool = True):
        super().__init__()

        self.latent_dim = latent_dim
        self.img_size = img_size
        self.in_channels = in_channels
        self.loc_channels = loc_channels
        self.mask_channels = mask_channels
        self.coord_decode = coord_decode

        # ---------------- Encoder ----------------
        enc_in_channels = in_channels + loc_channels + mask_channels
        self.hidden_dims_enc = hidden_dims
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(enc_in_channels, h_dim, kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU()
                )
            )
            enc_in_channels = h_dim
        self.encoder = nn.Sequential(*modules)

        # Compute encoder output sizes dynamically
        encoder_sizes = self.compute_encoder_sizes(img_size, hidden_dims)
        self.final_spatial = encoder_sizes[-1]
        flat_size = hidden_dims[-1] * self.final_spatial**2

        # Latent space
        self.fc_mu = nn.Linear(flat_size, latent_dim)
        self.fc_var = nn.Linear(flat_size, latent_dim)

        # ---------------- CNN Decoder (optional) ----------------
        if not coord_decode:
            self.decoder_input = nn.Linear(
                latent_dim + (loc_channels + mask_channels) * img_size * img_size,
                flat_size
            )
            self.hidden_dims_dec = hidden_dims[::-1]
            dec_modules = []
            for i in range(len(self.hidden_dims_dec)-1):
                dec_modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(
                            self.hidden_dims_dec[i],
                            self.hidden_dims_dec[i+1],
                            kernel_size=3,
                            stride=2,
                            padding=1,
                            output_padding=0
                        ),
                        nn.BatchNorm2d(self.hidden_dims_dec[i+1]),
                        nn.LeakyReLU()
                    )
                )
            self.decoder = nn.Sequential(*dec_modules)
            self.final_mu = nn.Conv2d(self.hidden_dims_dec[-1], out_channels=in_channels, kernel_size=3, padding=1)
            self.final_logvar = nn.Conv2d(self.hidden_dims_dec[-1], out_channels=in_channels, kernel_size=3, padding=1)
        else:
            # ---------------- Coordinate-based decoder ----------------
            self.coord_mlp = nn.Sequential(
                nn.Linear(latent_dim + 2, 128),
                nn.ReLU(),
                nn.Linear(128, 128),
                nn.ReLU(),
                nn.Linear(128, 2)  # predict mu and logvar
            )

    # ---------------- Static methods ----------------
    @staticmethod
    def compute_encoder_sizes(img_size, hidden_dims, kernel_size=3, stride=2, padding=1):
        sizes = []
        size = img_size
        for _ in hidden_dims:
            size = (size + 2*padding - kernel_size) // stride + 1
            sizes.append(size)
        return sizes

    # ---------------- Forward methods ----------------
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        result = self.encoder(x)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        return mu, log_var

    def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return eps * std + mu

    def decode(self, z: torch.Tensor, loc_mask: torch.Tensor = None, meas_mask: torch.Tensor = None,
               coords: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.coord_decode:
            # coords shape: [B, N_points, 2], z: [B, latent_dim]
            B, N, _ = coords.shape
            z_exp = z.unsqueeze(1).repeat(1, N, 1)  # [B, N, latent_dim]
            mlp_input = torch.cat([z_exp, coords], dim=-1)  # [B, N, latent+2]
            out = self.coord_mlp(mlp_input)  # [B, N, 2]
            mu_pred = out[..., 0:1]
            logvar_pred = out[..., 1:2]
            return mu_pred, logvar_pred
        else:
            # CNN decode (full grid)
            cond = torch.cat([loc_mask, meas_mask], dim=1)
            cond_flat = cond.view(cond.size(0), -1)
            dec_input = torch.cat([z, cond_flat], dim=1)
            result = self.decoder_input(dec_input)
            result = result.view(-1, self.hidden_dims_dec[0], self.final_spatial, self.final_spatial)
            result = self.decoder(result)

            # Zero-pad to img_size
            _, _, h, w = result.shape
            pad_h = self.img_size - h
            pad_w = self.img_size - w
            if pad_h > 0 or pad_w > 0:
                result = F.pad(result, (0, pad_w, 0, pad_h))

            mu_pred = self.final_mu(result)
            logvar_pred = torch.clamp(self.final_logvar(result), -10, 10)
            return mu_pred, logvar_pred

    def forward(self, x: torch.Tensor, loc_mask: torch.Tensor = None, meas_mask: torch.Tensor = None,
                coords: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # Tile location mask if needed
        if loc_mask is not None and loc_mask.dim() == 3:
            loc_mask = loc_mask.unsqueeze(0).repeat(x.size(0), 1, 1, 1)

        enc_in = x
        if loc_mask is not None and meas_mask is not None:
            enc_in = torch.cat([x, loc_mask, meas_mask], dim=1)

        mu, log_var = self.encode(enc_in)
        z = self.reparameterize(mu, log_var)
        mu_pred, logvar_pred = self.decode(z, loc_mask, meas_mask, coords)
        return mu_pred, logvar_pred, mu, log_var

    # ---------------- Loss function ----------------
    def loss_function(self, x: torch.Tensor, mu_pred: torch.Tensor, logvar_pred: torch.Tensor,
                      mu: torch.Tensor, log_var: torch.Tensor, meas_mask: torch.Tensor = None,
                      kld_weight: float = 1e-3) -> dict:
        const = torch.log(torch.tensor(2.0 * torch.pi, device=x.device))
        recon_var = torch.exp(logvar_pred)
        nll_element = 0.5 * ((x - mu_pred)**2 / recon_var + logvar_pred + const)

        if meas_mask is not None:
            # mask over valid measurements
            valid_mask = meas_mask.bool()
            nll_element = nll_element * valid_mask
            nll_loss = torch.mean(torch.sum(nll_element, dim=list(range(1, nll_element.dim()))) /
                                  (valid_mask.sum(dim=list(range(1, valid_mask.dim()))) + 1e-8))
        else:
            nll_loss = torch.mean(torch.sum(nll_element, dim=list(range(1, nll_element.dim()))))

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))
        total_loss = nll_loss + kld_weight * kld_loss
        return {"loss": total_loss, "NLL": nll_loss, "KLD": kld_loss}


In [None]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split, TensorDataset
from torch.optim import Adam
from tqdm import tqdm

# -----------------------------
# 1Ô∏è‚É£ Define constants
# -----------------------------
H, W = 40, 40

# -----------------------------
# 2Ô∏è‚É£ Load SINR measurements
# -----------------------------
df_dataset = pd.read_csv('combined_sinr.csv')  # 30 sensors
sinr_cols = [col for col in df_dataset.columns if col.startswith('SINR_')]
N = len(sinr_cols)
num_samples = df_dataset.shape[0]
print(f"Loaded SINR data: {num_samples} samples, {N} sensor columns")

sinr_data = df_dataset[sinr_cols].to_numpy()  # (num_samples, 30)
sinr_data_transposed = sinr_data.T  # (30, num_samples)

# -----------------------------
# 3Ô∏è‚É£ Load location grid
# -----------------------------
df_grid = pd.read_csv('original_1600_grid_points.csv')  # 40x40 grid
y_values = df_grid['y'].to_numpy()
x_values = df_grid['x'].to_numpy()
location_values = np.stack([y_values, x_values], axis=-1).reshape(H, W, 2)
loc_mask_base = torch.tensor(location_values, dtype=torch.float32).permute(2, 0, 1)  # (2, H, W)

# -----------------------------
# 4Ô∏è‚É£ Load sensor mask
# -----------------------------
df_mask = pd.read_csv('mask.csv')  # 1600 rows
FLAG_COL = 'mask'
boolean_sensor_mask = (df_mask[FLAG_COL] == 1).to_numpy()  # (1600,)
num_sensors_from_mask = boolean_sensor_mask.sum()
print(f"Number of measured sensors from mask: {num_sensors_from_mask}")

# Map 1D mask to 2D grid indices
all_y_pixel_indices = np.arange(H * W) // W
all_x_pixel_indices = np.arange(H * W) % W
train_coords_y = torch.tensor(all_y_pixel_indices[boolean_sensor_mask], dtype=torch.long)
train_coords_x = torch.tensor(all_x_pixel_indices[boolean_sensor_mask], dtype=torch.long)
num_train_locs = len(train_coords_y)
print(f"Number of training coordinates: {num_train_locs}")

# -----------------------------
# 5Ô∏è‚É£ Build dataset images
# -----------------------------
dataset_images = np.zeros((N, H, W), dtype=np.float32)
for i in range(N):
    sample_values = sinr_data_transposed[:, i]  # (30,)
    dataset_images[i, train_coords_y, train_coords_x] = sample_values

dataset_images = torch.tensor(dataset_images, dtype=torch.float32)  # (N, H, W)
sensor_mask = dataset_images != 0  # boolean mask of measured points

# -----------------------------
# 6Ô∏è‚É£ Normalize sensor values only
# -----------------------------
sensor_vals = dataset_images[sensor_mask]
mean = sensor_vals.mean()
std = sensor_vals.std()
dataset_norm = torch.zeros_like(dataset_images)
dataset_norm[sensor_mask] = (dataset_images[sensor_mask] - mean) / (std + 1e-6)
X_data = dataset_norm.unsqueeze(1)  # (N, 1, H, W)
global_normalization_mean = mean
global_normalization_std = std

# -----------------------------
# 7Ô∏è‚É£ Load valid coordinates CSV
# -----------------------------
df_valid_coords = pd.read_csv('valid_coords.csv')  # columns: y, x
valid_coords = torch.tensor(df_valid_coords[['y','x']].to_numpy(), dtype=torch.long)  # (num_valid, 2)

# -----------------------------
# 8Ô∏è‚É£ Define k-sampling parameters
# -----------------------------
# You said you'll define these
k_choices = torch.tensor([5, 9, 15, 18])
k_probs = torch.tensor([0.133, 0.6, 0.133, 0.133])

# -----------------------------
# 9Ô∏è‚É£ Create train/val splits
# -----------------------------
dataset = TensorDataset(X_data)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_data, val_data = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
val_loader = DataLoader(val_data, batch_size=8, shuffle=False)

# -----------------------------
# üîü Initialize model + optimizer
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loc_mask_base = loc_mask_base.to(device)
train_coords_y = train_coords_y.to(device)
train_coords_x = train_coords_x.to(device)
valid_coords = valid_coords.to(device)
k_choices = k_choices.to(device)
k_probs = k_probs.to(device)
X_data = X_data.to(device)

hidden_dims = [32, 64, 128, 256, 512]

model = ConditionalVAE(
    in_channels=1,
    loc_channels=2,
    mask_channels=1,
    latent_dim=16,
    img_size=H,
    hidden_dims=hidden_dims,
    coord_decode=True
).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
epochs = 50
kld_weight = 1e-3

# -----------------------------
# 1Ô∏è‚É£1Ô∏è‚É£ Training Loop
# -----------------------------
for epoch in range(epochs):
    model.train()
    train_loss = 0.0

    for (x_batch,) in tqdm(train_loader, leave=False):
        x_batch = x_batch.to(device)
        batch_size = x_batch.size(0)

        loc_batch = loc_mask_base.unsqueeze(0).repeat(batch_size, 1, 1, 1)

        # Sample k points
        k_samples_indices = torch.multinomial(k_probs, num_samples=batch_size, replacement=True)
        k_values = k_choices[k_samples_indices]

        mask_batch = torch.zeros(batch_size, 1, H, W, device=device)
        for i in range(batch_size):
            k = k_values[i].item()
            indices_to_pick = torch.randperm(num_train_locs, device=device)[:k]
            y = train_coords_y[indices_to_pick]
            x = train_coords_x[indices_to_pick]
            mask_batch[i, 0, y, x] = 1.0

        x_input_batch = x_batch * mask_batch

        mu_pred, logvar_pred, mu, log_var = model(
            x_input_batch, loc_batch, mask_batch,
            coords=valid_coords.unsqueeze(0).repeat(batch_size, 1, 1)
        )

        losses = model.loss_function(
            x_batch, mu_pred, logvar_pred, mu, log_var,
            meas_mask=mask_batch, kld_weight=kld_weight
        )
        loss = losses["loss"]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x_batch.size(0)

    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    val_mask_base = torch.zeros(1, 1, H, W, device=device)
    # Define your validation coordinates (example: pick 9 points from train_coords)
    val_coords_y = train_coords_y[:9]
    val_coords_x = train_coords_x[:9]
    val_mask_base[0, 0, val_coords_y, val_coords_x] = 1.0

    with torch.no_grad():
        for (x_batch,) in val_loader:
            x_batch = x_batch.to(device)
            batch_size = x_batch.size(0)
            loc_batch = loc_mask_base.unsqueeze(0).repeat(batch_size, 1, 1, 1)
            mask_batch = val_mask_base.repeat(batch_size, 1, 1, 1)
            x_input_batch = x_batch * mask_batch

            mu_pred, logvar_pred, mu, log_var = model(
                x_input_batch, loc_batch, mask_batch,
                coords=valid_coords.unsqueeze(0).repeat(batch_size, 1, 1)
            )
            losses = model.loss_function(
                x_batch, mu_pred, logvar_pred, mu, log_var,
                meas_mask=mask_batch, kld_weight=kld_weight
            )
            val_loss += losses["loss"].item() * x_batch.size(0)

    val_loss /= len(val_loader.dataset)
    print(f"Epoch [{epoch+1}/{epochs}] | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

print("‚úÖ Training complete!")

In [None]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    # Pick a batch from validation set (or entire val_loader)
    for (x_batch,) in val_loader:
        x_batch = x_batch.to(device)
        batch_size = x_batch.size(0)
        loc_batch = loc_mask_base.unsqueeze(0).repeat(batch_size, 1, 1, 1)
        mask_batch = val_mask_base.repeat(batch_size, 1, 1, 1)
        x_input_batch = x_batch * mask_batch

        # Predict at valid coordinates
        mu_pred, logvar_pred, _, _ = model(
            x_input_batch, loc_batch, mask_batch,
            coords=valid_coords.unsqueeze(0).repeat(batch_size, 1, 1)
        )

        # mu_pred: predicted mean at valid coordinates
        # logvar_pred: predicted log-variance at valid coordinates
        pred_mean = mu_pred.cpu()
        pred_std = torch.exp(0.5 * logvar_pred).cpu()
        break  # take only first batch for simplicity


In [None]:
# valid_coords: (num_valid, 2)
pred_grid = torch.zeros((batch_size, H, W))
for i in range(batch_size):
    y_coords = valid_coords[:,0]
    x_coords = valid_coords[:,1]
    pred_grid[i, y_coords, x_coords] = pred_mean[i, 0, :, :].flatten()  # flatten if needed

# Inverse normalization
pred_grid_real = pred_grid * (global_normalization_std + 1e-6) + global_normalization_mean


In [None]:
metrics = []
for i in range(batch_size):
    y_coords = train_coords_y  # measured points
    x_coords = train_coords_x
    y_coords_val = val_coords_y  # use validation measured points if desired
    x_coords_val = val_coords_x

    # Ground truth
    gt = X_data[i,0,y_coords,x_coords].cpu() * (global_normalization_std + 1e-6) + global_normalization_mean
    # Prediction
    pred = pred_grid_real[i, y_coords, x_coords]

    mse = mean_squared_error(gt, pred)
    mae = mean_absolute_error(gt, pred)
    r2 = r2_score(gt, pred)
    metrics.append((mse, mae, r2))

avg_mse = np.mean([m[0] for m in metrics])
avg_mae = np.mean([m[1] for m in metrics])
avg_r2 = np.mean([m[2] for m in metrics])

print(f"Validation Metrics (measured points only):")
print(f"  MSE: {avg_mse:.4f}, MAE: {avg_mae:.4f}, R¬≤: {avg_r2:.4f}")


In [None]:
plt.figure(figsize=(6,6))
plt.imshow(pred_grid_real[0].numpy(), origin='lower', cmap='viridis')
plt.scatter(train_coords_x.cpu(), train_coords_y.cpu(), color='red', marker='x', label='Measured sensors')
plt.colorbar(label='SINR (dB)')
plt.legend()
plt.title('Predicted SINR map (sample 0)')
plt.show()


In [None]:
import matplotlib.pyplot as plt

# -----------------------------
# 1Ô∏è‚É£ Set model to eval
# -----------------------------
model.eval()
batch_size = X_data.size(0)

# -----------------------------
# 2Ô∏è‚É£ Prepare location + measurement masks
# -----------------------------
loc_batch = loc_mask_base.unsqueeze(0).repeat(batch_size, 1, 1, 1)

# Use all measured points as mask for evaluation
eval_mask = torch.zeros(batch_size, 1, H, W, device=device)
for i in range(batch_size):
    eval_mask[i, 0, train_coords_y, train_coords_x] = 1.0

# -----------------------------
# 3Ô∏è‚É£ Forward pass to predict at valid coordinates
# -----------------------------
with torch.no_grad():
    mu_pred, logvar_pred, _, _ = model(
        X_data * eval_mask,
        loc_batch,
        eval_mask,
        coords=valid_coords.unsqueeze(0).repeat(batch_size, 1, 1)
    )

# -----------------------------
# 4Ô∏è‚É£ Convert predictions to original scale
# -----------------------------
# mu_pred shape: (batch_size, 1, num_valid)
mu_pred = mu_pred.squeeze(1)  # (batch_size, num_valid)
pred_sinr = mu_pred * global_normalization_std + global_normalization_mean  # inverse normalization

# Optional: std/uncertainty
std_pred = torch.exp(0.5 * logvar_pred.squeeze(1)) * global_normalization_std

# -----------------------------
# 5Ô∏è‚É£ Compute metrics at measured points
# -----------------------------
from sklearn.metrics import mean_squared_error, mean_absolute_error

# Gather ground truth at valid_coords for comparison
y_vals = valid_coords[:, 0]
x_vals = valid_coords[:, 1]

true_vals = torch.zeros(batch_size, len(valid_coords), device=device)
for i in range(batch_size):
    true_vals[i] = X_data[i, 0, y_vals, x_vals] * global_normalization_std + global_normalization_mean

# Compute metrics
mse = mean_squared_error(true_vals.cpu().numpy().flatten(), pred_sinr.cpu().numpy().flatten())
mae = mean_absolute_error(true_vals.cpu().numpy().flatten(), pred_sinr.cpu().numpy().flatten())
print(f"Evaluation Metrics | MSE: {mse:.4f} | MAE: {mae:.4f}")

# -----------------------------
# 6Ô∏è‚É£ Visualize predictions vs ground truth (example for first sample)
# -----------------------------
sample_idx = 0
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.scatter(x_vals.cpu(), y_vals.cpu(), c=true_vals[sample_idx].cpu(), cmap='viridis', s=60)
plt.colorbar(label='Ground truth SINR')
plt.title('Ground truth SINR at valid coordinates')

plt.subplot(1,2,2)
plt.scatter(x_vals.cpu(), y_vals.cpu(), c=pred_sinr[sample_idx].cpu(), cmap='viridis', s=60)
plt.colorbar(label='Predicted SINR')
plt.title('Predicted SINR at valid coordinates')

plt.show()

# -----------------------------
# 7Ô∏è‚É£ Optional: Visualize uncertainty
# -----------------------------
plt.figure(figsize=(6,5))
plt.scatter(x_vals.cpu(), y_vals.cpu(), c=std_pred[sample_idx].cpu(), cmap='inferno', s=60)
plt.colorbar(label='Predicted std (uncertainty)')
plt.title('Prediction Uncertainty at valid coordinates')
plt.show()
