In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
from matplotlib import pyplot as plt
import pathlib

import acoustic_no.cno.cno_layers as cno_layers
from acoustic_no.cno.cno_model import CNOModel
from acoustic_no.data import ShuffledAcousticDataset

In [None]:
# Setup device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

In [None]:
cno_activation = cno_layers.ActivationLayer("relu")

In [None]:
x = torch.randn(1, 3, 32, 32, device=device)  # Create a random input tensor
# Clamp the input to the range [0, 1]
x = torch.clamp(x, 0, 1)
x_cno = cno_activation(x)  # Apply the activation layer
x_base = F.relu(x)  # Apply the base activation function
# Show the difference
fig, ax = plt.subplots(1, 3, figsize=(10, 5))
ax[0].imshow(x[0].cpu().detach().numpy().transpose(1, 2, 0))  # Convert to numpy and transpose for plotting
ax[0].set_title("Input Image")
ax[1].imshow(x_cno[0].cpu().detach().numpy().transpose(1, 2, 0))  # Convert to numpy and transpose for plotting
ax[1].set_title("CNO Activation")
ax[2].imshow(x_base[0].cpu().detach().numpy().transpose(1, 2, 0))  # Convert to numpy and transpose for plotting
ax[2].set_title("Base Activation (ReLU)")
plt.show()  # Display the images

# Load Dataset

In [None]:
# Load the dataset
dataset = ShuffledAcousticDataset(
    dataset_dir=pathlib.Path("../resources/dataset/processed/training"),
)
print(f"Dataset size: {len(dataset)}")

In [None]:
# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
train_dataset = Subset(dataset, range(train_size))
val_dataset = Subset(dataset, range(train_size, len(dataset)))
# Use a random subset of the dataset for training
N_TRAIN, N_VAL = 1024, 16
train_dataset = Subset(train_dataset, range(N_TRAIN))
val_dataset = Subset(val_dataset, range(N_VAL))
# Create a data loader
train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=10,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=10,
)

# A Simple CNO Model

In [None]:
model = CNOModel(
    input_channels=193,
    hidden_channels=[64, 64],
    layer_sizes=[2, 2],
    output_channels=64
)
model.to(device)  # Move the model to the appropriate device
model

In [None]:
# Initialize training parameters
num_epochs = 8
learning_rate = 0.001
criterion = nn.MSELoss()  # Use Mean Squared Error loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for i, batch in enumerate(train_loader):
        inputs, targets = batch["x"], batch["y"]  # Get inputs and targets from the batch
        inputs, targets = inputs.to(device), targets.to(device)  # Move data to device
        optimizer.zero_grad()  # Zero the gradients
        outputs = model(inputs)  # Forward pass
        loss = criterion(outputs, targets)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
        running_loss += loss.item()
        
        if (i + 1) % 10 == 0:  # Print every 10 batches
            print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / (i + 1):.4f}")

    print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {running_loss / len(train_loader):.4f}")

In [None]:
# Inference
model.eval()
eval_data = val_dataset[0]
x = eval_data["x"]
p = eval_data["y"]
v = eval_data["v"]
a = eval_data["a"]
with torch.no_grad():
    pred = model(x.unsqueeze(0).to(device))

# Plot the results
fig, ax = plt.subplots(1, 3, figsize=(12, 6))
ax[0].imshow(pred[0, -1].cpu().numpy(), cmap='viridis', vmin=-10, vmax=10)
ax[0].set_title("Predicted Pressure")
ax[1].imshow(p[-1].cpu().numpy(), cmap='viridis', vmin=-10, vmax=10)
ax[1].set_title("Ground Truth Pressure")
ax[2].imshow(pred[0, -1].cpu().numpy() - p[-1].cpu().numpy(), cmap='viridis', vmin=-10, vmax=10)
ax[2].set_title("Difference")
plt.tight_layout()
plt.show()
