In [4]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

# Create a grid of points with reduced size
n_points = 200
x = np.linspace(-2, 2, n_points)
y = np.linspace(-2, 2, n_points)
x_grid, y_grid = np.meshgrid(x, y)

# Define the complex 2D function
def complex_2d_function(x, y):
    return np.sin(5 * x) * np.cos(5 * y) * np.exp(-(x**2 + y**2))

# Generate the function values
f_xy = complex_2d_function(x_grid, y_grid)

# Prepare training data
x_flat = x_grid.flatten()
y_flat = y_grid.flatten()
z_flat = f_xy.flatten()

# Create the input features as pairs of (x, y)
train_input = np.vstack((x_flat, y_flat)).T
train_label = z_flat

# Convert to PyTorch tensors
train_input_tensor = torch.tensor(train_input, dtype=torch.float32)
train_label_tensor = torch.tensor(train_label, dtype=torch.float32).view(-1, 1)

# Create DataLoader for batch processing
batch_size = 64
train_dataset = TensorDataset(train_input_tensor, train_label_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Define the MLP model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(2, 32)
        self.fc2 = nn.Linear(32, 32)
        self.fc3 = nn.Linear(32, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize the model, loss function and optimizer
model = MLP()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 200

model.train()
for epoch in tqdm(range(num_epochs)):
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
    
    if (epoch+1) % 20 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Evaluation
model.eval()
with torch.no_grad():
    predictions = model(train_input_tensor.to(device)).cpu().numpy()

# Reshape predictions to the original grid shape
predictions_grid = predictions.reshape(x_grid.shape)

# Plotting the original and predicted function
fig = plt.figure(figsize=(12, 5))

# Original function
ax1 = fig.add_subplot(121, projection='3d')
ax1.plot_surface(x_grid, y_grid, f_xy, cmap='viridis')
ax1.set_title('Original Function')
ax1.set_xlabel('x')
ax1.set_ylabel('y')

# Predicted function
ax2 = fig.add_subplot(122, projection='3d')
ax2.plot_surface(x_grid, y_grid, predictions_grid, cmap='viridis')
ax2.set_title('MLP Predicted Function')
ax2.set_xlabel('x')
ax2.set_ylabel('y')

plt.show()

# Heatmap of the original and predicted function
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.imshow(f_xy, cmap='viridis', extent=[-2, 2, -2, 2])
plt.colorbar()
plt.title('Heatmap of Original Function')
plt.xlabel('x')
plt.ylabel('y')

plt.subplot(1, 2, 2)
plt.imshow(predictions_grid, cmap='viridis', extent=[-2, 2, -2, 2])
plt.colorbar()
plt.title('Heatmap of MLP Predicted Function')
plt.xlabel('x')
plt.ylabel('y')

plt.show()


KeyboardInterrupt: 