In [24]:
import numpy as np
import pandas as pd
import random
import torch

In [25]:
# Set the random seed for reproducibility
RANDOM_STATE = 0
N_JOBS = 8
torch.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)

In [26]:
HYPERPARAMETERS = {
    'Input Preprocessing' : {
        'Mask Proportions' : [0.1, 0.2, 0.4, 0.8],
        'Batch Size' : 4
    },
    'Input Embedding' : {
        'Surface Embedding' : {
            'Grid Dimension' : 3,
            'Channels Dimension' : 8,
        },
        'Pre-Encoder' : {
            'Branch Channels Dimension' : 4,
            'Number of Blocks' : 2,
        }
    },
}

## Dataset

In [27]:
aapl_googl_data = pd.read_csv('volatility_surface_AAPL_GOOGL_2013_01_2013_06.csv', parse_dates=True, index_col=[0, 1], date_format="ISO8601")
aapl_googl_data

Unnamed: 0_level_0,Unnamed: 1_level_0,Log Moneyness,Time to Maturity,Implied Volatility,Market Return,Market Volatility,Treasury Rate
Datetime,Symbol,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
2013-01-02,AAPL,-0.316688,0.007937,0.3726,0.025086,14.680000,0.055
2013-01-02,AAPL,-0.316688,0.007937,0.6095,0.025086,14.680000,0.055
2013-01-02,AAPL,-0.304266,0.007937,0.3726,0.025086,14.680000,0.055
2013-01-02,AAPL,-0.304266,0.007937,0.6095,0.025086,14.680000,0.055
2013-01-02,AAPL,-0.291996,0.007937,0.3726,0.025086,14.680000,0.055
...,...,...,...,...,...,...,...
2013-06-28,GOOGL,0.427518,2.253968,0.2430,-0.004299,16.860001,0.030
2013-06-28,GOOGL,0.434898,2.253968,0.2383,-0.004299,16.860001,0.030
2013-06-28,GOOGL,0.434898,2.253968,0.2426,-0.004299,16.860001,0.030
2013-06-28,GOOGL,0.442224,2.253968,0.2402,-0.004299,16.860001,0.030


In [28]:
import gc
from joblib_progress import joblib_progress
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.pipeline import Pipeline
from joblib import Parallel, delayed

def implied_volatility_surface_datasets(
    options_market_data, 
    proportions, 
    n_jobs=1,
    random_state=0,
    n_chunks=1
):
    def mask_surface(
        date, 
        symbol, 
        surface, 
        rng
    ):
        def mask_surface_with_proportion(
            surface_data, 
            proportion, 
        ):
            n_clusters = int(np.ceil(1 / proportion))
            points_coordinates = surface_data['points_coordinates']
            points_volatilities = surface_data['points_volatilities']

            # Create the clustering pipeline
            pipeline = Pipeline([
                ('scaler', StandardScaler()),
                ('kmeans', KMeans(n_clusters=n_clusters, random_state=random_state, n_init='auto'))
            ])
            
            # Fit the pipeline to the data points
            labels = pipeline.fit_predict(points_coordinates)
            
            single_surface_datasets = []
            for cluster in range(n_clusters):
                cluster_indices = np.where(labels == cluster)[0]
                num_to_mask = int(np.ceil(len(cluster_indices) * proportion))
                masked_indices = rng.choice(cluster_indices, size=num_to_mask, replace=False)
                
                for idx in masked_indices:
                    unmasked_indices = np.setdiff1d(cluster_indices, masked_indices)

                    single_surface_datasets.append({
                        'Datetime': surface_data['datetime'],
                        'Symbol': surface_data['symbol'],
                        'Market Features': surface_data['market_features'],
                        'Input Surface': {
                            'Log Moneyness': points_coordinates[unmasked_indices, 0],
                            'Time to Maturity': points_coordinates[unmasked_indices, 1],
                            'Implied Volatility': points_volatilities[unmasked_indices]
                        },
                        'Query Point': {
                            'Log Moneyness': points_coordinates[idx, 0],
                            'Time to Maturity': points_coordinates[idx, 1]
                        },
                        'Target Volatility': points_volatilities[idx]
                    })

            return single_surface_datasets
        
        surface_data = {
            'datetime': date,
            'symbol': symbol,
            'points_coordinates': surface[['Log Moneyness', 'Time to Maturity']].values,
            'points_volatilities': surface['Implied Volatility'].values,
            'market_features': {
                'Market Return': surface['Market Return'].values[0],
                'Market Volatility': surface['Market Volatility'].values[0],
                'Treasury Rate': surface['Treasury Rate'].values[0]
            }
        }
        
        datasets = []
        for proportion in proportions:
            datasets.extend(mask_surface_with_proportion(surface_data, proportion))

        return datasets

    rng = np.random.default_rng(random_state)
    all_surfaces = list(options_market_data.groupby(level=['Datetime', 'Symbol']))
    n_surfaces = len(all_surfaces)
    
    # Split the array into 'n_chunks' chunks
    chunks = np.array_split(range(n_surfaces), n_chunks)
    # Initialize the list to hold all results
    surface_datasets = []
    # Process each chunk sequentially
    with joblib_progress("Surfaces...", total=n_surfaces): 
        for chunk in chunks:
            # Process the current chunk in parallel
            output = Parallel(n_jobs=n_jobs)(
                delayed(mask_surface)(date, symbol, surface, rng)
                for (date, symbol), surface in [all_surfaces[i] for i in chunk]
            )
            # Extend the overall results with the current chunk's results
            surface_datasets.extend(output)
            gc.collect()  

    # Flatten the list of lists into a single list of datasets
    return [item for sublist in surface_datasets for item in sublist]

aapl_googl_dataset = implied_volatility_surface_datasets(
    aapl_googl_data,
    HYPERPARAMETERS['Input Preprocessing']['Mask Proportions'],
    n_jobs=N_JOBS,
    random_state=RANDOM_STATE,
    n_chunks=4
)

Output()

In [29]:
# import pickle

# with open('aapl_googl_dataset.pickle', 'wb') as handle:
#     pickle.dump(aapl_googl_dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)

# with open('aapl_googl_dataset.pickle', 'rb') as handle:
#     aapl_googl_dataset_ = pickle.load(handle)


In [30]:
len(aapl_googl_dataset)

863511

In [31]:
aapl_googl_dataset[0]

{'Datetime': Timestamp('2013-01-02 00:00:00'),
 'Symbol': 'AAPL',
 'Market Features': {'Market Return': 0.0250861159586972,
  'Market Volatility': 14.68000030517578,
  'Treasury Rate': 0.0549999997019767},
 'Input Surface': {'Log Moneyness': array([-0.74747141, -0.72842322, -0.72842322, -0.70973108, -0.69138194,
         -0.69138194, -0.67336344, -0.67336344, -0.63827212, -0.63827212,
         -0.62117768, -0.62117768, -0.60437057, -0.60437057, -0.58784126,
         -0.58784126, -0.57158074, -0.5555804 , -0.5555804 , -0.53983205,
         -0.53983205, -0.52432786, -0.52432786, -0.50906039, -0.50906039,
         -0.49402251, -0.49402251, -0.47920742, -0.47920742, -0.46460862,
         -0.46460862, -0.45021989, -0.45021989, -0.43603525, -0.43603525,
         -0.42204901, -0.42204901, -0.40825569, -0.40825569, -0.39465004,
         -0.39465004, -0.74747141, -0.74747141, -0.72842322, -0.70973108,
         -0.70973108, -0.69138194, -0.69138194, -0.67336344, -0.67336344,
         -0.65566386

In [32]:
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import Dataset

class IVSurfaceDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data_point = self.data[idx]

        # Convert each component of the data point into tensors as appropriate
        return {
            'Datetime': data_point['Datetime'],
            'Symbol': data_point['Symbol'],
            'Market Features': {
                'Market Return': torch.tensor(data_point['Market Features']['Market Return'], dtype=torch.float32),
                'Market Volatility': torch.tensor(data_point['Market Features']['Market Volatility'], dtype=torch.float32),
                'Treasury Rate': torch.tensor(data_point['Market Features']['Treasury Rate'], dtype=torch.float32),
            },
            'Input Surface': {
                'Log Moneyness': torch.tensor(data_point['Input Surface']['Log Moneyness'], dtype=torch.float32),
                'Time to Maturity': torch.tensor(data_point['Input Surface']['Time to Maturity'], dtype=torch.float32),
                'Implied Volatility': torch.tensor(data_point['Input Surface']['Implied Volatility'], dtype=torch.float32),
            },
            'Query Point': {
                'Log Moneyness': torch.tensor(data_point['Query Point']['Log Moneyness'], dtype=torch.float32),
                'Time to Maturity': torch.tensor(data_point['Query Point']['Time to Maturity'], dtype=torch.float32),
            },
            'Target Volatility': torch.tensor(data_point['Target Volatility'], dtype=torch.float32),
        }

    def collate_fn(batch):
        # Organize batch data by structuring as a dictionary with batched components
        batched_data = {
            'Datetime': [item['Datetime'] for item in batch],
            'Symbol': [item['Symbol'] for item in batch],
            'Market Features': {
                'Market Return': default_collate([item['Market Features']['Market Return'] for item in batch]),
                'Market Volatility': default_collate([item['Market Features']['Market Volatility'] for item in batch]),
                'Treasury Rate': default_collate([item['Market Features']['Treasury Rate'] for item in batch]),
            },
            'Input Surface': {
                'Log Moneyness': [item['Input Surface']['Log Moneyness'] for item in batch],
                'Time to Maturity': [item['Input Surface']['Time to Maturity'] for item in batch],
                'Implied Volatility': [item['Input Surface']['Implied Volatility'] for item in batch],
            },
            'Query Point': {
                'Log Moneyness': default_collate([item['Query Point']['Log Moneyness'] for item in batch]),
                'Time to Maturity': default_collate([item['Query Point']['Time to Maturity'] for item in batch]),
            },
            'Target Volatility': default_collate([item['Target Volatility'] for item in batch]),
        }

        return batched_data



aapl_googl_data_loader = DataLoader(
    IVSurfaceDataset(aapl_googl_dataset), 
    batch_size=HYPERPARAMETERS['Input Preprocessing']['Batch Size'], 
    shuffle=True, 
    num_workers=0, 
    collate_fn=IVSurfaceDataset.collate_fn
)

# Fetch one batch from the DataLoader
batch = next(iter(aapl_googl_data_loader))
batch

{'Datetime': [Timestamp('2013-03-19 00:00:00'),
  Timestamp('2013-03-19 00:00:00'),
  Timestamp('2013-04-08 00:00:00'),
  Timestamp('2013-03-27 00:00:00')],
 'Symbol': ['GOOGL', 'AAPL', 'GOOGL', 'AAPL'],
 'Market Features': {'Market Return': tensor([-0.0024, -0.0024,  0.0063, -0.0006]),
  'Market Volatility': tensor([14.3900, 14.3900, 13.1900, 13.1500]),
  'Treasury Rate': tensor([0.0700, 0.0700, 0.0550, 0.0800])},
 'Input Surface': {'Log Moneyness': [tensor([-0.1766, -0.1766, -0.1620, -0.1334, -0.1194, -0.1194, -0.1056, -0.1056,
           -0.0853, -0.0786, -0.0653, -0.0653, -0.0588, -0.0588, -0.0523, -0.0458,
           -0.0394, -0.0394, -0.0330, -0.0330, -0.0266, -0.0266, -0.0203, -0.0141,
           -0.0078, -0.0078, -0.0016, -0.0016,  0.0106,  0.0106,  0.0167,  0.0228,
            0.0228,  0.0288,  0.0288,  0.0347,  0.0347,  0.0407,  0.0407,  0.0524,
            0.0583,  0.0641,  0.0698,  0.0756,  0.0756,  0.0813,  0.0813,  0.0869,
            0.0869,  0.0926,  0.0926,  0.0982,  0

## Input Embedding

### Surface Embedding

#### Components

In [33]:
import torch
import torch.nn as nn

class SurfaceBatchNorm(nn.Module):
    def __init__(self, num_features=1, eps=1e-5, momentum=0.1):
        super(SurfaceBatchNorm, self).__init__()
        self.log_moneyness_bn = nn.BatchNorm1d(num_features, eps, momentum)
        self.time_to_maturity_bn = nn.BatchNorm1d(num_features, eps, momentum)
        self.implied_volatility_bn = nn.BatchNorm1d(num_features, eps, momentum)
        self.market_return_bn = nn.BatchNorm1d(num_features, eps, momentum)
        self.market_volatility_bn = nn.BatchNorm1d(num_features, eps, momentum)
        self.treasury_rate_bn = nn.BatchNorm1d(num_features, eps, momentum)

    def forward(self, batch):
        # Concatenate all tensors from the Input Surface into one tensor for each feature
        input_surface_log_moneyness = torch.cat([x for x in batch['Input Surface']['Log Moneyness']])
        input_surface_time_to_maturity = torch.cat([x for x in batch['Input Surface']['Time to Maturity']])
        input_surface_implied_volatility = torch.cat([x for x in batch['Input Surface']['Implied Volatility']])

        # Concatenate Input Surface tensors with Query Point tensors
        total_log_moneyness = torch.cat([input_surface_log_moneyness, batch['Query Point']['Log Moneyness']])
        total_time_to_maturity = torch.cat([input_surface_time_to_maturity, batch['Query Point']['Time to Maturity']])

        # Normalize Log Moneyness and Time to Maturity
        norm_log_moneyness = self.log_moneyness_bn(total_log_moneyness.unsqueeze(1)).squeeze(1)
        norm_time_to_maturity = self.time_to_maturity_bn(total_time_to_maturity.unsqueeze(1)).squeeze(1)

        # Normalize Implied Volatility (only from Input Surface)
        norm_implied_volatility = self.implied_volatility_bn(input_surface_implied_volatility.unsqueeze(1)).squeeze(1)

        # Split the normalized results back to corresponding structures
        input_surface_sizes = [len(x) for x in batch['Input Surface']['Log Moneyness']]
        total_input_size = sum(input_surface_sizes)

        # Normalizing Market Features
        market_features = batch['Market Features']
        norm_market_return = self.market_return_bn(market_features['Market Return'].unsqueeze(1)).squeeze(1)
        norm_market_volatility = self.market_volatility_bn(market_features['Market Volatility'].unsqueeze(1)).squeeze(1)
        norm_treasury_rate = self.treasury_rate_bn(market_features['Treasury Rate'].unsqueeze(1)).squeeze(1)

        # Reconstructing the batch with normalized data
        output = {
            'Datetime': batch['Datetime'],
            'Symbol': batch['Symbol'],
            'Market Features': {
                'Market Return': norm_market_return,
                'Market Volatility': norm_market_volatility,
                'Treasury Rate': norm_treasury_rate
            },
            'Input Surface': {
                'Log Moneyness': list(torch.split(norm_log_moneyness[:total_input_size], input_surface_sizes)),
                'Time to Maturity': list(torch.split(norm_time_to_maturity[:total_input_size], input_surface_sizes)),
                'Implied Volatility': list(torch.split(norm_implied_volatility, input_surface_sizes))
            },
            'Query Point': {
                'Log Moneyness': norm_log_moneyness[total_input_size:],
                'Time to Maturity': norm_time_to_maturity[total_input_size:]
            },
            'Target Volatility': batch['Target Volatility']
        }

        return output

# Usage
surfacebatchnorm = SurfaceBatchNorm()
processed_batch = surfacebatchnorm(batch)
processed_batch

{'Datetime': [Timestamp('2013-03-19 00:00:00'),
  Timestamp('2013-03-19 00:00:00'),
  Timestamp('2013-04-08 00:00:00'),
  Timestamp('2013-03-27 00:00:00')],
 'Symbol': ['GOOGL', 'AAPL', 'GOOGL', 'AAPL'],
 'Market Features': {'Market Return': tensor([-0.5515, -0.5515,  1.2702, -0.1672], grad_fn=<SqueezeBackward1>),
  'Market Volatility': tensor([ 0.9997,  0.9997, -0.9669, -1.0325], grad_fn=<SqueezeBackward1>),
  'Treasury Rate': tensor([ 0.1320,  0.1320, -1.4519,  1.1879], grad_fn=<SqueezeBackward1>)},
 'Input Surface': {'Log Moneyness': [tensor([-0.8688, -0.8688, -0.8136, -0.7055, -0.6526, -0.6526, -0.6005, -0.6005,
           -0.5235, -0.4982, -0.4481, -0.4481, -0.4233, -0.4233, -0.3987, -0.3742,
           -0.3499, -0.3499, -0.3257, -0.3257, -0.3017, -0.3017, -0.2779, -0.2541,
           -0.2306, -0.2306, -0.2072, -0.2072, -0.1608, -0.1608, -0.1378, -0.1149,
           -0.1149, -0.0922, -0.0922, -0.0696, -0.0696, -0.0472, -0.0472, -0.0027,
            0.0194,  0.0413,  0.0631,  0.084

In [34]:
import torch
import torch.nn as nn
import numpy as np

# class ParametricContinuousKernel(nn.Module):
#     def __init__(self, input_dim, hidden_dim, hidden_layers, output_dim=1, dropout_prob=0.1):
#         super(ParametricContinuousKernel, self).__init__()
#         layers = []
#         current_dim = input_dim
#         for _ in range(hidden_layers):
#             layers.append(nn.Linear(current_dim, hidden_dim))
#             layers.append(nn.GELU())
#             layers.append(nn.Dropout(dropout_prob))
#             current_dim = hidden_dim
#         layers.append(nn.Linear(hidden_dim, output_dim))
#         self.net = nn.Sequential(*layers)

#     def forward(self, x):
#         return self.net(x)

class EllipticalRBFKernel(nn.Module):
    def __init__(self, input_dim):
        super(EllipticalRBFKernel, self).__init__()
        # Initialize the bandwidth parameters for each dimension
        # We use log-space parameterization for stability in optimization (exp to ensure positivity)
        self.log_bandwidth = nn.Parameter(torch.zeros(input_dim))  # Initialized to exp(0) = 1

    def forward(self, distances):
        # Scale the distances by the bandwidths
        # torch.exp(self.log_bandwidth) converts log bandwidth back to the standard scale
        scaled_distances = distances / torch.exp(self.log_bandwidth)

        # Compute the RBF kernel output using the scaled distances
        # The RBF kernel formula exp(-0.5 * (scaled distance)^2)
        kernel_values = torch.exp(-0.5 * torch.sum(scaled_distances ** 2, dim=-1))

        return kernel_values

class SurfaceContinuousKernelEmbedding(nn.Module):
    def __init__(self, grid_dim):
        super(SurfaceContinuousKernelEmbedding, self).__init__()
        self.grid_dim = grid_dim
        self.kernel = EllipticalRBFKernel(input_dim=2)

        # Create a regular grid in (0, 1)x(0, 1), excluding 0 and 1
        grid_points = torch.linspace(1 / (grid_dim + 1), 1 - 1 / (grid_dim + 1), grid_dim)
        mesh_x, mesh_y = torch.meshgrid(grid_points, grid_points, indexing='ij')
        self.grid_points = torch.stack([mesh_x.flatten(), mesh_y.flatten()], dim=-1)
        self.grid_points = torch.erfinv(2 * self.grid_points - 1) * np.sqrt(2)  # inverse CDF of normal

    def forward(self, input_surface_batch):
        batch_size = len(input_surface_batch['Log Moneyness'])
        batch_embedded_surfaces = []

        for i in range(batch_size):
            # Extract the coordinates and implied volatilities for each surface in the batch
            surface_coords = torch.stack([
                input_surface_batch['Log Moneyness'][i], 
                input_surface_batch['Time to Maturity'][i]
            ], dim=-1)
            surface_ivs = input_surface_batch['Implied Volatility'][i]

            # Initialize the output grid for the current surface
            embedded_surface = torch.zeros((self.grid_dim, self.grid_dim), dtype=torch.float32, device=surface_coords.device)

            # Compute the convolution for each point on the output grid
            for idx, grid_point in enumerate(self.grid_points):
                # Calculate the distance from each input point to the current grid point
                point_differences = surface_coords - grid_point

                # Apply the parametric kernel to these differences
                kernel_outputs = self.kernel(point_differences)

                # Compute the weighted sum of IVs based on the kernel outputs
                embedded_surface[idx // self.grid_dim, idx % self.grid_dim] = (kernel_outputs * surface_ivs).sum()

            # Append the encoded surface for this input surface to the batch list
            batch_embedded_surfaces.append(embedded_surface)

        # Stack all encoded surfaces to form a batch tensor
        return torch.stack(batch_embedded_surfaces)


# Example of initializing and using this module
grid_dim = HYPERPARAMETERS['Input Embedding']['Surface Embedding']['Grid Dimension']
# kernel_hidden_dim = HYPERPARAMETERS['Input Embedding']['Surface Embedding']['Kernel Hidden Layer Dimension']
# kernel_hidden_layers = HYPERPARAMETERS['Input Embedding']['Surface Embedding']['Kernel Hidden Layer Count']
# kernel_dropout_prob = HYPERPARAMETERS['Input Embedding']['Surface Embedding']['Kernel Dropout Probability']

continuous_kernel_embedding = SurfaceContinuousKernelEmbedding(grid_dim=grid_dim)
continuous_kernel_embedding_batch = continuous_kernel_embedding(processed_batch['Input Surface'])
continuous_kernel_embedding_batch

tensor([[[-153.7712, -177.7758, -159.2608],
         [-245.6186, -271.3016, -232.3041],
         [-253.6777, -276.9804, -233.8199]],

        [[   6.2822,    6.6540,    5.2081],
         [  19.3516,   20.0094,   14.9506],
         [  40.7884,   41.6916,   30.4323]],

        [[  47.4555,  -12.7430,  -64.3593],
         [  35.0965,  -41.7696, -101.9358],
         [  13.6700,  -58.6713, -110.4599]],

        [[ 267.1636,  204.8168,  101.9104],
         [ 215.7052,  162.4388,   78.6083],
         [ 130.2385,   96.9473,   46.0434]]], grad_fn=<StackBackward0>)

In [35]:
import torch
import torch.nn as nn

class SurfaceProjectionEmbedding(nn.Module):
    def __init__(self, in_channels, d_embedding, grid_dim):
        super(SurfaceProjectionEmbedding, self).__init__()
        # Initialize the 1x1 convolution layer
        self.conv1x1 = nn.Conv2d(in_channels, d_embedding, kernel_size=1)
        # Initialize layer normalization across the channel, height, and width dimensions
        self.layer_norm = nn.LayerNorm([d_embedding, grid_dim, grid_dim])  # Normalizes across (channels, height, width)

    def forward(self, x):
        # Ensure x has dimensions: (batch_size, channels, height, width)
        # Add a channel dimension if necessary
        if x.dim() == 3:  # assuming x has dimensions (batch_size, height, width)
            x = x.unsqueeze(1)  # add channel dimension
        # Apply the 1x1 convolution to project the input to a higher dimensional space
        x = self.conv1x1(x)
        # Normalize the features across each channel, maintaining the spatial dimensions
        x = self.layer_norm(x)
        return x
    
d_embedding = HYPERPARAMETERS['Input Embedding']['Surface Embedding']['Channels Dimension']  # Desired number of output channels
torch.manual_seed(RANDOM_STATE)
# Create the module
projection_embedding = SurfaceProjectionEmbedding(1, d_embedding, grid_dim)   
projection_embedding_batch = projection_embedding(continuous_kernel_embedding_batch)
projection_embedding_batch

tensor([[[[-7.3848e-02, -7.2383e-02, -7.3513e-02],
          [-6.8243e-02, -6.6675e-02, -6.9055e-02],
          [-6.7751e-02, -6.6328e-02, -6.8963e-02]],

         [[-7.5281e-01, -8.5778e-01, -7.7682e-01],
          [-1.1545e+00, -1.2668e+00, -1.0962e+00],
          [-1.1897e+00, -1.2916e+00, -1.1029e+00]],

         [[ 9.4675e-01,  1.1078e+00,  9.8358e-01],
          [ 1.5630e+00,  1.7353e+00,  1.4737e+00],
          [ 1.6171e+00,  1.7734e+00,  1.4838e+00]],

         [[ 8.3842e-01,  9.8243e-01,  8.7136e-01],
          [ 1.3894e+00,  1.5435e+00,  1.3096e+00],
          [ 1.4378e+00,  1.5776e+00,  1.3187e+00]],

         [[ 3.9251e-01,  4.6788e-01,  4.0975e-01],
          [ 6.8089e-01,  7.6153e-01,  6.3909e-01],
          [ 7.0620e-01,  7.7936e-01,  6.4385e-01]],

         [[-4.2406e-01, -4.7653e-01, -4.3606e-01],
          [-6.2484e-01, -6.8098e-01, -5.9573e-01],
          [-6.4245e-01, -6.9339e-01, -5.9904e-01]],

         [[-6.1034e-02, -5.7156e-02, -6.0147e-02],
          [-4.6199e

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class SurfacePositionalEncoding(nn.Module):
    def __init__(self, grid_dim, d_embedding):
        super(SurfacePositionalEncoding, self).__init__()
        self.grid_dim = grid_dim
        self.d_embedding = d_embedding
        
        # Create a regular grid in (0, 1)x(0, 1), excluding 0 and 1
        grid_points = torch.linspace(1 / (grid_dim + 1), 1 - 1 / (grid_dim + 1), grid_dim)
        mesh_x, mesh_y = torch.meshgrid(grid_points, grid_points, indexing='ij')
        self.grid_points = torch.stack([mesh_x.flatten(), mesh_y.flatten()], dim=-1)
        self.grid_points = torch.erfinv(2 * self.grid_points - 1) * np.sqrt(2)  # inverse CDF of normal

        # Initialize learnable scaling parameter (the base for positional encoding)
        self.log_scale = nn.Parameter(torch.log(torch.tensor(10000.0)))
        self.factor = nn.Parameter(torch.tensor(1.0))  # Learnable scale for the positional encoding contribution

        # Layer normalization for final output
        self.layer_norm = nn.LayerNorm([d_embedding, grid_dim, grid_dim])  # Normalizes across (channels, height, width)


    def forward(self, x):
        # x is the output from the 1x1 convolution layer with shape (batch_size, d_embedding, grid_dim, grid_dim)
        scale = torch.exp(self.log_scale)
        pos_enc = torch.zeros_like(x)

        # Repeat grid_points to match the batch size and reshape for broadcasting
        batch_grid_points = self.grid_points.repeat(x.shape[0], 1, 1).view(x.shape[0], self.grid_dim*self.grid_dim, 2)
        
        for i in range(self.d_embedding // 4):
            # Calculate positional encodings for both dimensions
            div_factor = scale ** (4 * i / self.d_embedding)
            pos_enc[:, 4 * i, :, :] = torch.sin(batch_grid_points[:, :, 0].view(x.shape[0], self.grid_dim, self.grid_dim) / div_factor)
            pos_enc[:, 4 * i + 1, :, :] = torch.cos(batch_grid_points[:, :, 0].view(x.shape[0], self.grid_dim, self.grid_dim) / div_factor)
            pos_enc[:, 4 * i + 2, :, :] = torch.sin(batch_grid_points[:, :, 1].view(x.shape[0], self.grid_dim, self.grid_dim) / div_factor)
            pos_enc[:, 4 * i + 3, :, :] = torch.cos(batch_grid_points[:, :, 1].view(x.shape[0], self.grid_dim, self.grid_dim) / div_factor)

        # Apply the learned scale to positional encoding and add to the input
        x = x + self.factor * pos_enc
        # Normalize the final output
        x = self.layer_norm(x) 

        return x

# Create the SurfacePositionalEncoding module
positional_encoder = SurfacePositionalEncoding(grid_dim, d_embedding)

# Apply positional encoding
positional_encoded_embedding_batch = positional_encoder(projection_embedding_batch)
positional_encoded_embedding_batch

tensor([[[[-1.2986e+00, -1.2970e+00, -1.2983e+00],
          [-5.9434e-01, -5.9259e-01, -5.9525e-01],
          [ 1.0423e-01,  1.0582e-01,  1.0287e-01]],

         [[-4.8653e-01, -6.0386e-01, -5.1336e-01],
          [-6.9072e-01, -8.1626e-01, -6.2564e-01],
          [-9.7487e-01, -1.0888e+00, -8.7780e-01]],

         [[-1.5788e-01,  7.2016e-01,  1.2793e+00],
          [ 5.3092e-01,  1.4215e+00,  1.8271e+00],
          [ 5.9136e-01,  1.4641e+00,  1.8385e+00]],

         [[ 1.2920e+00,  1.6978e+00,  1.3288e+00],
          [ 1.9079e+00,  2.3249e+00,  1.8186e+00],
          [ 1.9620e+00,  2.3630e+00,  1.8288e+00]],

         [[-8.6886e-02, -2.6436e-03, -6.7621e-02],
          [ 2.4298e-01,  3.3312e-01,  1.9626e-01],
          [ 2.7881e-01,  3.6059e-01,  2.0912e-01]],

         [[ 1.2566e-01,  6.7003e-02,  1.1224e-01],
          [-9.8737e-02, -1.6149e-01, -6.6205e-02],
          [-1.1845e-01, -1.7539e-01, -6.9934e-02]],

         [[-5.9383e-01, -5.8195e-01, -5.7776e-01],
          [-5.7724e

#### Block

In [37]:
class SurfaceEmbedding(nn.Module):
    def __init__(self, grid_dim, d_embedding, eps=1e-5, momentum=0.1):
        super(SurfaceEmbedding, self).__init__()
        # Initialize all sub-modules
        self.surface_batchnorm = SurfaceBatchNorm(1, eps, momentum)
        self.surface_continuous_kernel_embedding = SurfaceContinuousKernelEmbedding(grid_dim)
        self.surface_projection_embedding = SurfaceProjectionEmbedding(1, d_embedding, grid_dim)
        self.positional_encoding = SurfacePositionalEncoding(grid_dim, d_embedding)

    def forward(self, batch):
        # Process the batch with SurfaceBatchNorm
        processed_batch = self.surface_batchnorm(batch)
        
        # Generate continuous kernel embeddings from the processed 'Input Surface'
        continuous_kernel_embedding_batch = self.surface_continuous_kernel_embedding(processed_batch['Input Surface'])
        
        # Project the embeddings using 1x1 convolution
        projection_embedding_batch = self.surface_projection_embedding(continuous_kernel_embedding_batch)
        
        # Apply positional encoding to the projected embeddings
        positional_encoded_embedding_batch = self.positional_encoding(projection_embedding_batch)

        # Return both the positionally encoded embeddings and the processed batch
        return positional_encoded_embedding_batch, processed_batch

torch.manual_seed(RANDOM_STATE)
grid_dim = HYPERPARAMETERS['Input Embedding']['Surface Embedding']['Grid Dimension']
d_embedding = HYPERPARAMETERS['Input Embedding']['Surface Embedding']['Channels Dimension']  # Desired number of output channels
surface_embedding = SurfaceEmbedding(grid_dim, d_embedding)
positional_encoded_embedding_batch, processed_batch = surface_embedding(batch)
positional_encoded_embedding_batch, processed_batch

(tensor([[[[-1.2986e+00, -1.2970e+00, -1.2983e+00],
           [-5.9434e-01, -5.9259e-01, -5.9525e-01],
           [ 1.0423e-01,  1.0582e-01,  1.0287e-01]],
 
          [[-4.8653e-01, -6.0386e-01, -5.1336e-01],
           [-6.9072e-01, -8.1626e-01, -6.2564e-01],
           [-9.7487e-01, -1.0888e+00, -8.7780e-01]],
 
          [[-1.5788e-01,  7.2016e-01,  1.2793e+00],
           [ 5.3092e-01,  1.4215e+00,  1.8271e+00],
           [ 5.9136e-01,  1.4641e+00,  1.8385e+00]],
 
          [[ 1.2920e+00,  1.6978e+00,  1.3288e+00],
           [ 1.9079e+00,  2.3249e+00,  1.8186e+00],
           [ 1.9620e+00,  2.3630e+00,  1.8288e+00]],
 
          [[-8.6886e-02, -2.6436e-03, -6.7621e-02],
           [ 2.4298e-01,  3.3312e-01,  1.9626e-01],
           [ 2.7881e-01,  3.6059e-01,  2.0912e-01]],
 
          [[ 1.2566e-01,  6.7003e-02,  1.1224e-01],
           [-9.8737e-02, -1.6149e-01, -6.6205e-02],
           [-1.1845e-01, -1.7539e-01, -6.9934e-02]],
 
          [[-5.9383e-01, -5.8195e-01, -5.7776e

### Pre-Encoder

#### Block

In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PreEncoder(nn.Module):
    def __init__(self, d_embedding, branch_channels, grid_dim):
        super(PreEncoder, self).__init__()
        # Initial channel configuration is common to all branches
        self.branch1 = nn.Sequential(
            nn.Conv2d(d_embedding, branch_channels, kernel_size=1),
            nn.BatchNorm2d(branch_channels),
            nn.GELU()
        )
        
        self.branch2 = nn.Sequential(
            nn.Conv2d(d_embedding, branch_channels, kernel_size=1),
            nn.BatchNorm2d(branch_channels),
            nn.GELU(),
            nn.Conv2d(branch_channels, branch_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(branch_channels),
            nn.GELU()
        )
        
        self.branch3 = nn.Sequential(
            nn.Conv2d(d_embedding, branch_channels, kernel_size=1),
            nn.BatchNorm2d(branch_channels),
            nn.GELU(),
            nn.Conv2d(branch_channels, branch_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(branch_channels),
            nn.GELU(),
            nn.Conv2d(branch_channels, branch_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(branch_channels),
            nn.GELU()
        )
        
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(d_embedding, branch_channels, kernel_size=1),
            nn.BatchNorm2d(branch_channels),
            nn.GELU()
        )
        
        # Reduce the concatenated channels back to the original number of channels
        self.conv_reduce = nn.Conv2d(branch_channels * 4, d_embedding, kernel_size=1)
        self.bn_reduce = nn.BatchNorm2d(d_embedding)
        self.scale = nn.Parameter(torch.tensor(1.0))  # Learnable scale for residual connection
        self.layer_norm = nn.LayerNorm([d_embedding, grid_dim, grid_dim])  # Normalize across (C, H, W)

    def forward(self, x):
        # Apply each branch to the input
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)
        out4 = self.branch4(x)
        
        # Concatenate the outputs from each branch
        concatenated = torch.cat([out1, out2, out3, out4], dim=1)
        
        # Reduce back to the initial number of channels
        reduced = self.conv_reduce(concatenated)
        reduced = self.bn_reduce(reduced)
        
        # Add the residual connection with scale
        residual = x + self.scale * reduced
        residual = F.gelu(residual)  # Apply GELU after adding the residual
        
        # Normalize the output
        output = self.layer_norm(residual)
        
        return output

torch.manual_seed(RANDOM_STATE)
branch_channels = HYPERPARAMETERS['Input Embedding']['Pre-Encoder']['Branch Channels Dimension']
pre_encoder = PreEncoder(d_embedding, branch_channels, grid_dim)
pre_encoded_batch = pre_encoder(positional_encoded_embedding_batch)
pre_encoded_batch

tensor([[[[-0.8756, -0.8372, -0.8025],
          [-0.3041,  0.2683, -0.7046],
          [-0.8682, -0.1383, -0.7486]],

         [[-0.7419, -0.7990, -0.8037],
          [-0.1814, -0.8315, -0.7411],
          [-0.6794, -0.7307, -0.8436]],

         [[-0.1706, -0.3510,  0.4306],
          [ 1.6556,  1.8931,  0.7270],
          [ 0.3351,  3.1226,  1.4218]],

         [[ 0.2243,  2.3709,  0.9548],
          [ 0.0581,  2.1856, -0.5534],
          [-0.0790, -0.8032, -0.7506]],

         [[ 1.2536,  0.4729,  0.3198],
          [ 1.3939,  0.7090,  0.7465],
          [ 0.8147,  2.2066,  0.8328]],

         [[-0.8763, -0.8221, -0.7994],
          [ 1.1149,  1.0301,  0.2769],
          [ 0.0215,  2.4174,  0.2235]],

         [[-0.7372, -0.8416, -0.7338],
          [-0.8225,  0.0567, -0.5726],
          [-0.8654, -0.7333, -0.5401]],

         [[-0.7915, -0.8727, -0.7533],
          [-0.7430, -0.8666, -0.7146],
          [-0.7090, -0.7032, -0.7013]]],


        [[[-0.6158, -0.7225, -0.7801],
       

### Final Block

In [39]:
import torch
import torch.nn as nn

class InputEmbedding(nn.Module):
    def __init__(self, grid_dim, d_embedding, branch_channels, num_pre_encoder_blocks, eps=1e-5, momentum=0.1):
        super(InputEmbedding, self).__init__()
        # Initialize the Surface Embedding module
        self.surface_embedding = SurfaceEmbedding(grid_dim, d_embedding, eps, momentum)
        
        # Initialize multiple PreEncoder blocks
        self.pre_encoders = nn.ModuleList([
            PreEncoder(d_embedding, branch_channels, grid_dim) for _ in range(num_pre_encoder_blocks)
        ])

    def forward(self, batch):
        # Process batch through SurfaceEmbedding to get initial embeddings and the processed batch
        positional_encoded_embedding_batch, processed_batch = self.surface_embedding(batch)
        
        # Sequentially pass the output through each PreEncoder block
        pre_encoded_batch = positional_encoded_embedding_batch
        for pre_encoder in self.pre_encoders:
            pre_encoded_batch = pre_encoder(pre_encoded_batch)
        
        # The final output of the last PreEncoder block is the output of the module
        return pre_encoded_batch, processed_batch

# Example usage
torch.manual_seed(RANDOM_STATE)
grid_dim = HYPERPARAMETERS['Input Embedding']['Surface Embedding']['Grid Dimension']
d_embedding = HYPERPARAMETERS['Input Embedding']['Surface Embedding']['Channels Dimension']
branch_channels = HYPERPARAMETERS['Input Embedding']['Pre-Encoder']['Branch Channels Dimension']
num_pre_encoder_blocks = HYPERPARAMETERS['Input Embedding']['Pre-Encoder']['Number of Blocks']

input_embedding = InputEmbedding(grid_dim, d_embedding, branch_channels, num_pre_encoder_blocks)
pre_encoded_batch, processed_batch = input_embedding(batch)
pre_encoded_batch, processed_batch

(tensor([[[[-0.4830,  0.3449, -0.7433],
           [-0.1301,  1.5596,  0.0593],
           [-0.0939,  0.1099, -0.6984]],
 
          [[-0.7244, -0.5963, -0.6813],
           [-0.6949, -0.6773, -0.7043],
           [-0.6710, -0.6599, -0.6220]],
 
          [[-0.7427,  1.7081,  1.9834],
           [ 0.1006,  1.0234,  0.2937],
           [-0.4557,  1.7954,  2.1521]],
 
          [[ 1.8443,  0.8829,  0.9189],
           [ 1.3698,  2.5122,  0.9914],
           [ 1.6683,  3.1743,  1.9738]],
 
          [[-0.6140, -0.7369, -0.7326],
           [-0.6261, -0.7396, -0.6632],
           [-0.6169, -0.7421, -0.7374]],
 
          [[-0.6327, -0.6124, -0.5836],
           [-0.7140,  1.2474, -0.4187],
           [-0.6561,  0.9005,  0.6096]],
 
          [[ 0.0178, -0.7432, -0.6877],
           [-0.4061, -0.6993, -0.6284],
           [ 0.4814, -0.7257, -0.7278]],
 
          [[-0.6896, -0.7023, -0.6554],
           [-0.6652, -0.6467, -0.6423],
           [-0.6616, -0.6215, -0.6155]]],
 
 
         [[[-