In [1]:
import numpy as np
from fenics import *
import matplotlib.pyplot as plt
from dolfin import *
import os
import copy
import torch
import pickle

# Create data storage class
class DataCollector:
    def __init__(self, V):
        self.V = V
        self.data = {
            'D': [],
            'u': [],
        }
        
    def collect_snapshot(self, d2, u):
        D_array = d2
        
        # Create an empty array for the solution
        u_array = np.zeros((nx, ny))
        
        # Get the mesh coordinates and solution values
        mesh_coordinates = mesh.coordinates()
        u_values = u.compute_vertex_values(mesh)
        
        # Initialize a counter array to count how many points map to each grid cell
        counter = np.zeros((nx, ny))
        
        # Map the values to the grid
        for i, coord in enumerate(mesh_coordinates):
            # Calculate indices with appropriate bounds checking
            x_idx = min(int(coord[0] * (nx - 1)), nx - 1)
            y_idx = min(int(coord[1] * (ny - 1)), ny - 1)
            
            # Accumulate values
            u_array[x_idx, y_idx] += u_values[i]
            counter[x_idx, y_idx] += 1
        
        # Average values where multiple points map to the same cell
        # Avoid division by zero
        mask = counter > 0
        u_array[mask] /= counter[mask]
        
        # Store arrays directly
        self.data['D'].append(D_array)
        self.data['u'].append(u_array)

    def save_to_file(self, filename):
        # Convert lists to numpy arrays
        D_data = np.array(self.data['D'])
        u_data = np.array(self.data['u'])
        
        # Save using numpy's savez
        np.savez(filename, D=D_data, u=u_data)
    
    def save_to_pkl(self, filename):
        """
        Dataset = [
            [X1, Y1, Theta1, Inputs_funcs1],
            [X2, Y2, Theta2, Inputs_funcs2],
            ...
        ]
        """
        # Create dataset structure
        dataset = []
        
        # Create mesh grid for X (input points)
        xx = np.linspace(0, 1, nx)
        yy = np.linspace(0, 1, ny)
        x_grid, y_grid = np.meshgrid(xx, yy)
        
        # For each sample
        for i in range(len(self.data['D'])):
            # Rotate the D and u fields 90 degrees clockwise
            # D_rotated = np.rot90(self.data['D'][i], k=1)
            # u_rotated = np.rot90(self.data['u'][i], k=1)

            # (x,y) -> (y, 1-x) for 90-degree clockwise rotation

            X = np.column_stack((x_grid.flatten(), y_grid.flatten())) 
            
            # Y: solution field values (N x 1)
            Y = self.data['D'][i].flatten()[:, np.newaxis]
            
            # Theta: global parameters
            Theta = np.array([])
            
            # Inputs_funcs: U field with coordinates
            u_values = self.data['u'][i].flatten()[:, np.newaxis]
            Inputs_funcs = (np.hstack((X, u_values)),)
            
            # Add this sample to the dataset
            dataset.append([X, Y, Theta, Inputs_funcs])
        
        # Save as pickle file
        with open(filename, 'wb') as f:
            pickle.dump(dataset, f)

def verify_mio_dataset(filename):
    """Verify that the dataset matches the required format."""
    with open(filename, 'rb') as f:
        dataset = pickle.load(f)
        
    print(f"Dataset contains {len(dataset)} samples")
    
    # Check first sample
    sample = dataset[0]
    X, Y, Theta, Inputs_funcs = sample
    u_with_coords = Inputs_funcs[0]
    u_values = u_with_coords[:, 2]
    
    print(f"X shape: {X.shape} (N x N_in)")
    print(f"Y shape: {Y.shape} (N x N_out)")
    print(f"Theta shape: {Theta.shape} (N_theta,)")
    print(f"Inputs_funcs: {len(Inputs_funcs)} functions")
    print(f"D range: [{u_values.min():.4f}, {u_values.max():.4f}]")
    
    for i, inp in enumerate(Inputs_funcs):
        print(f"  Input function {i+1} shape: {inp.shape}")
    
    return dataset

def GRF2d(Nx, Ny, alpha, tau):
    kx = np.fft.fftfreq(Nx,1/Nx) 
    ky = np.fft.rfftfreq(Ny,1/Ny) 
    Kx,Ky = np.meshgrid(kx,ky)
    K = np.sqrt(Kx**2+Ky**2).T 

    assert tau > 0, 'tau must be positive'
    lmbda = (K+tau)**(-alpha)  
    eta = (
        np.random.randn(*K.shape) + 1j*np.random.randn(*K.shape)
    ) 

    uhat = lmbda*eta
    u = np.fft.irfft2(uhat, norm='forward')
    return np.abs(u)

def create_diffusion_coefficient(V, d2):
    D = Function(V)
    d = D.vector()
    dof_coordinates = V.tabulate_dof_coordinates()
    x = dof_coordinates[:, 0]
    y = dof_coordinates[:, 1]

    grid_x = (x * N).astype(int)
    grid_y = (y * N).astype(int)
    grid_values = d2
    for i in range(len(x)):
        if grid_x[i] >= N:
            grid_x[i] = N-1
        if grid_y[i] >= N:
            grid_y[i] = N-1
        d[i] = grid_values[grid_x[i], grid_y[i]]
    
    return D

# Main simulation code
output_dir = "Transformer2.0Train10000"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Number of samples to generate
num_samples = 10000

# Initialize collector
nx = ny = 32
mesh = RectangleMesh(Point(0, 0), Point(1, 1), nx, ny)
V = FunctionSpace(mesh, 'P', 1)
data_collector = DataCollector(V)

# Generate multiple samples
for sample in range(num_samples):
    # Generate random diffusion field
    N = 32
    nx = ny = N
    alpha = 5
    tau = 2
    d_field = GRF2d(nx, ny, alpha, tau)
    d2 = copy.deepcopy(d_field)

    # Create mesh and function space
    mesh = RectangleMesh(Point(0, 0), Point(1, 1), nx, ny)
    V = FunctionSpace(mesh, 'P', 1)

    # Set diffusion coefficient
    D = create_diffusion_coefficient(V, d2)

    # Define boundary conditions
    def boundary_left(x, on_boundary):
        return on_boundary and near(x[0], 0, 1e-14)
    def boundary_right(x, on_boundary):
        return on_boundary and near(x[0], 1, 1e-14)
    def boundary_bottom(x, on_boundary):
        return on_boundary and near(x[1], 0, 1e-14)
    def boundary_top(x, on_boundary):
        return on_boundary and near(x[1], 1, 1e-14)
    
    bc_left = DirichletBC(V, Constant(1.0), boundary_left)
    bc_right = DirichletBC(V, Constant(0.0), boundary_right)
    bc_bottom = DirichletBC(V, Constant(1.0), boundary_bottom)
    bc_top = DirichletBC(V, Constant(0.0), boundary_top)
    
    bc = [bc_left, bc_right, bc_bottom, bc_top]

    # Define variational problem
    u = TrialFunction(V)
    v = TestFunction(V)
    f = Constant(0.0)
    
    a = inner(grad(D), grad(u))*v*dx - inner(grad(u), grad(D))*v*dx - inner(grad(u), D*grad(v))*dx
    L = f*v*dx

    # Solve
    u = Function(V)
    solve(a == L, u, bc)
    
    # Collect data
    data_collector.collect_snapshot(d2, u)

# Save the data in both formats
data_collector.save_to_file(f"{output_dir}/transformer2.0_train10000.npz")
data_collector.save_to_pkl(f"{output_dir}/transformer2.0_train10000.pkl")

dataset = verify_mio_dataset(f"{output_dir}/transformer2.0_train10000.pkl")

[hs394-dt-04.egr.duke.edu:2827474] shmem: mmap: an error occurred while determining whether or not /tmp/ompi.hs394-dt-04.1001/jf.0/1061158912/shared_mem_cuda_pool.hs394-dt-04 could be created.
[hs394-dt-04.egr.duke.edu:2827474] create_and_attach: unable to create shared memory BTL coordinating structure :: size 134217728 


Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational problem.
Solving linear variational p