# Base Model Offline Training & Visualization
Architecture: 4 → 64 → 64 → 32 → 4

In [None]:
"""Unified Training Utilities for Neural Network-Based Flow Field Compression"""

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import pyarrow.csv as pv
from torcheval.metrics import PeakSignalNoiseRatio
from torchmetrics.image import StructuralSimilarityIndexMeasure
import matplotlib.pyplot as plt
import time
import os
import json


class FlowFieldDataset(Dataset):
    def __init__(self, filepath):
        print(f"Loading dataset from {filepath}")
        read_options = pv.ReadOptions(
            column_names=['x', 'y', 'z', 't', 'Vx', 'Vy', 'Pressure', 'TKE']
        )
        table = pv.read_csv(filepath, read_options=read_options)
        data = table.to_pandas().values
        self.inputs = data[:, :4].astype(np.float32)
        self.targets = data[:, 4:].astype(np.float32)
        print("Applying min-max normalization to [0, 1] range")
        self.input_min = self.inputs.min(axis=0)
        self.input_max = self.inputs.max(axis=0)
        self.input_range = self.input_max - self.input_min
        self.input_range[self.input_range == 0] = 1.0
        self.inputs = (self.inputs - self.input_min) / self.input_range
        self.target_min = self.targets.min(axis=0)
        self.target_max = self.targets.max(axis=0)
        self.target_range = self.target_max - self.target_min
        self.target_range[self.target_range == 0] = 1.0
        self.targets = (self.targets - self.target_min) / self.target_range
        print(f"Dataset loaded: {len(self)} samples")
        print(f"Input shape: {self.inputs.shape}, Target shape: {self.targets.shape}")

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return torch.from_numpy(self.inputs[idx]), torch.from_numpy(self.targets[idx])

    def denormalize_input(self, normalized):
        if isinstance(normalized, torch.Tensor):
            normalized = normalized.cpu().numpy()
        return normalized * self.input_range + self.input_min

    def denormalize_target(self, normalized):
        if isinstance(normalized, torch.Tensor):
            normalized = normalized.cpu().numpy()
        return normalized * self.target_range + self.target_min


class BaseCompressor(nn.Module):
    """4 -> 64 -> 64 -> 32 -> 4"""
    def __init__(self):
        super(BaseCompressor, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(4, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 4)
        )

    def forward(self, x):
        return self.model(x)


class LargeCompressor(nn.Module):
    """4 -> 128 -> 128 -> 64 -> 4"""
    def __init__(self):
        super(LargeCompressor, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 4)
        )

    def forward(self, x):
        return self.model(x)


def compute_psnr_ssim(predictions, targets, device):
    predictions = predictions.to(device)
    targets = targets.to(device)

    # PSNR
    psnr_metric = PeakSignalNoiseRatio().to(device)
    psnr_metric.update(predictions, targets)
    psnr = psnr_metric.compute().item()

    # SSIM — reshape to (N, 1, 1, 4) then permute to (1, 1, N, 4)
    pred_ssim = predictions.view(-1, 1, 1, predictions.shape[1])
    target_ssim = targets.view(-1, 1, 1, targets.shape[1])
    pred_ssim = pred_ssim.permute(1, 2, 0, 3)
    target_ssim = target_ssim.permute(1, 2, 0, 3)

    ssim_metric = StructuralSimilarityIndexMeasure(
        gaussian_kernel=False, kernel_size=1
    ).to(device)
    ssim_metric.update(pred_ssim, target_ssim)
    ssim = ssim_metric.compute().item()

    return psnr, ssim


def compute_relative_error(predictions, targets):
    target_norm = torch.norm(targets)
    error = torch.norm(predictions - targets)
    return (error / target_norm * 100).item()


def train_model(model, train_loader, dataset, device, epochs, model_name, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    metrics = {'loss': [], 'psnr': [], 'ssim': [], 'relative_error': [], 'time_per_epoch': []}
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\nTraining: {model_name}")
    print(f"Training samples: {len(dataset)}")
    print(f"Epochs: {epochs}")
    print(f"Device: {device}")
    print(f"Model parameters: {total_params:,}")
    print(f"Model size: {total_params * 4 / 1024:.2f} KB\n")
    for epoch in range(epochs):
        epoch_start = time.time()
        model.train()
        epoch_loss = 0.0
        all_predictions = []
        all_targets = []
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            all_predictions.append(outputs.detach())
            all_targets.append(targets)
        epoch_loss /= len(train_loader)
        metrics['loss'].append(epoch_loss)
        all_predictions = torch.cat(all_predictions, dim=0)
        all_targets = torch.cat(all_targets, dim=0)
        psnr, ssim = compute_psnr_ssim(all_predictions, all_targets, device)
        rel_error = compute_relative_error(all_predictions, all_targets)
        metrics['psnr'].append(psnr)
        metrics['ssim'].append(ssim)
        metrics['relative_error'].append(rel_error)
        epoch_time = time.time() - epoch_start
        metrics['time_per_epoch'].append(epoch_time)
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1}/{epochs}: Loss={epoch_loss:.6f}, "
                  f"PSNR={psnr:.2f} dB, SSIM={ssim:.4f}, "
                  f"RE={rel_error:.2f}%, Time={epoch_time:.2f}s")
    model_path = os.path.join(output_dir, f'{model_name}_minmax.pth')
    torch.save(model.state_dict(), model_path)
    print(f"\nModel saved: {model_path}")
    norm_params = {
        'input_min': dataset.input_min.tolist(),
        'input_max': dataset.input_max.tolist(),
        'input_range': dataset.input_range.tolist(),
        'target_min': dataset.target_min.tolist(),
        'target_max': dataset.target_max.tolist(),
        'target_range': dataset.target_range.tolist()
    }
    norm_path = os.path.join(output_dir, f'{model_name}_normalization.json')
    with open(norm_path, 'w') as f:
        json.dump(norm_params, f, indent=2)
    print(f"Normalization parameters saved: {norm_path}")
    print(f"\nTraining completed")
    print(f"Final Loss: {metrics['loss'][-1]:.6f}")
    print(f"Final PSNR: {metrics['psnr'][-1]:.2f} dB")
    print(f"Final SSIM: {metrics['ssim'][-1]:.4f}")
    print(f"Final Relative Error: {metrics['relative_error'][-1]:.2f}%")
    print(f"PSNR improvement: {metrics['psnr'][0]:.2f} \u2192 {metrics['psnr'][-1]:.2f} dB "
          f"({metrics['psnr'][-1]-metrics['psnr'][0]:+.2f} dB)\n")
    return metrics


def export_metrics_csv(metrics, output_path):
    df = pd.DataFrame(metrics)
    df['epoch'] = range(1, len(df) + 1)
    df = df[['epoch', 'loss', 'psnr', 'ssim', 'relative_error', 'time_per_epoch']]
    df.to_csv(output_path, index=False)
    print(f"Metrics saved: {output_path}")


print("Utilities loaded successfully.")

## Training

In [None]:
DATA_FILE = "/kaggle/input/ml-test-loader-original-data-csv/ML_test_loader_original_data.csv"
OUTPUT_DIR = "/kaggle/working/results2/base_model_offline"
EPOCHS = 150
BATCH_SIZE = 512

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

dataset = FlowFieldDataset(DATA_FILE)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

model = BaseCompressor().to(device)

metrics = train_model(
    model=model,
    train_loader=train_loader,
    dataset=dataset,
    device=device,
    epochs=EPOCHS,
    model_name='base_model',
    output_dir=OUTPUT_DIR
)

export_metrics_csv(metrics, f"{OUTPUT_DIR}/base_model_metrics.csv")

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
epochs_range = range(1, len(metrics['loss']) + 1)

axes[0, 0].plot(epochs_range, metrics['loss'], 'b-', linewidth=2)
axes[0, 0].set_xlabel('Epoch', fontweight='bold')
axes[0, 0].set_ylabel('MSE Loss', fontweight='bold')
axes[0, 0].set_title('Training Loss (Normalized)', fontweight='bold')
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(epochs_range, metrics['psnr'], 'g-', linewidth=2)
axes[0, 1].set_xlabel('Epoch', fontweight='bold')
axes[0, 1].set_ylabel('PSNR (dB)', fontweight='bold')
axes[0, 1].set_title('Peak Signal-to-Noise Ratio', fontweight='bold')
axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(epochs_range, metrics['ssim'], 'purple', linewidth=2)
axes[1, 0].set_xlabel('Epoch', fontweight='bold')
axes[1, 0].set_ylabel('SSIM', fontweight='bold')
axes[1, 0].set_title('Structural Similarity Index', fontweight='bold')
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].set_ylim([0, 1.05])

axes[1, 1].plot(epochs_range, metrics['relative_error'], 'r-', linewidth=2)
axes[1, 1].set_xlabel('Epoch', fontweight='bold')
axes[1, 1].set_ylabel('Relative Error (%)', fontweight='bold')
axes[1, 1].set_title('Reconstruction Error', fontweight='bold')
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('Base Model (64-64-32) - Offline Training',
             fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/base_model_training_progress.png", dpi=300, bbox_inches='tight')
print(f"Training plot saved: {OUTPUT_DIR}/base_model_training_progress.png")
plt.show()

## Visualization

In [None]:
MODEL_DIR = "/kaggle/working/results2/base_model_offline"
VIZ_OUTPUT_DIR = "/kaggle/working/results2/base_model_offline_viz"
TIMESTEP = 0.0396

os.makedirs(VIZ_OUTPUT_DIR, exist_ok=True)

print("Base model offline visualization")
print(f"Device: {device}")

model.load_state_dict(torch.load(f"{MODEL_DIR}/base_model_minmax.pth", map_location=device))
model.eval()

print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters")

print(f"Generating predictions for {len(dataset):,} points")
all_inputs = torch.from_numpy(dataset.inputs).to(device)
with torch.no_grad():
    all_predictions = model(all_inputs).cpu()

all_targets = torch.from_numpy(dataset.targets)

psnr, ssim = compute_psnr_ssim(all_predictions, all_targets, device)
print(f"\nEvaluation metrics (normalized space):")
print(f"PSNR: {psnr:.2f} dB")
print(f"SSIM: {ssim:.4f}")

predictions_denorm = dataset.denormalize_target(all_predictions.numpy())
targets_denorm = dataset.denormalize_target(all_targets.numpy())
coords_denorm = dataset.denormalize_input(dataset.inputs)

timestep_mask = np.abs(coords_denorm[:, 3] - TIMESTEP) < 1e-6
x = coords_denorm[timestep_mask, 0]
y = coords_denorm[timestep_mask, 1]
pred_t = predictions_denorm[timestep_mask]
target_t = targets_denorm[timestep_mask]
errors_t = np.abs(target_t - pred_t)

print(f"Visualizing {len(x):,} points at timestep {TIMESTEP}")

In [None]:
feature_names = ['Vx', 'Vy', 'Pressure', 'TKE']
feature_indices = [0, 1, 2, 3]

fig, axes = plt.subplots(4, 3, figsize=(18, 20))

for row, (idx, name) in enumerate(zip(feature_indices, feature_names)):
    original = target_t[:, idx]
    predicted = pred_t[:, idx]
    error = errors_t[:, idx]

    sc1 = axes[row, 0].scatter(x, y, c=original, cmap='jet', s=0.5, alpha=0.8)
    axes[row, 0].set_title(f'Original: {name}', fontweight='bold')
    axes[row, 0].set_aspect('equal')
    axes[row, 0].grid(True, alpha=0.3)
    plt.colorbar(sc1, ax=axes[row, 0])

    sc2 = axes[row, 1].scatter(x, y, c=predicted, cmap='jet', s=0.5, alpha=0.8)
    axes[row, 1].set_title(f'Prediction: {name}', fontweight='bold')
    axes[row, 1].set_aspect('equal')
    axes[row, 1].grid(True, alpha=0.3)
    plt.colorbar(sc2, ax=axes[row, 1])

    sc3 = axes[row, 2].scatter(x, y, c=error, cmap='hot', s=0.5, alpha=0.8)
    axes[row, 2].set_title(f'Error: {name}', fontweight='bold')
    axes[row, 2].set_aspect('equal')
    axes[row, 2].grid(True, alpha=0.3)
    plt.colorbar(sc3, ax=axes[row, 2])

plt.suptitle(f'Base Model (64-64-32) - PSNR: {psnr:.2f} dB - Timestep: {TIMESTEP}',
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(f"{VIZ_OUTPUT_DIR}/base_flow_visualization.png", dpi=150, bbox_inches='tight')
print(f"Visualization saved: {VIZ_OUTPUT_DIR}/base_flow_visualization.png")
plt.show()

viz_metrics = {
    'model': 'base_model_minmax.pth',
    'architecture': '4-64-64-32-4',
    'parameters': sum(p.numel() for p in model.parameters()),
    'psnr_db': float(psnr),
    'ssim': float(ssim)
}

with open(f"{VIZ_OUTPUT_DIR}/evaluation_metrics.json", 'w') as f:
    json.dump(viz_metrics, f, indent=2)

print("Base model visualization completed")