## Import packages

In [1]:
import pandas as pd 
import numpy as np  
import matplotlib.pyplot as plt
import gpytorch
import os
import tqdm.notebook as tn 
import torch 
import re 

## Initializations
### GPU Usage
Determine if you want to use GPU

In [2]:
def print_with_tag(message, tag=None):
    return print(f"[{tag}] {message}" if tag else message)

def initialize_devices(use_gpu, gpu_id=None):
    print_tag = 'DEVICE'
    if use_gpu == True: # USE GPU
        if gpu_id is not None: # user has specfied a GPU number
            # Make sure the gpu_id is an integer 
            assert isinstance(gpu_id, int), f"Invalid GPU ID: {gpu_id}. GPU ID must be an integer."
        else:
            gpu_id = 0  # default GPU id if not specified
             
        output_device = torch.device('cuda:'+str(gpu_id)) # output device 
        print_with_tag(f"Planning to run on GPU {gpu_id}", print_tag)
         
    elif use_gpu == False: # DO NOT USE GPU
        os.environ["CUDA_VISIBLE_DEVICES"] = "" # make sure no GPUs are available in the environment
        output_device = torch.device('cpu')
        print_with_tag('Using CPU.', print_tag)
        
    else: # INVALID INPUT, USE CPU
        os.environ["CUDA_VISIBLE_DEVICES"] = "" # make sure no GPUs are available in the environment
        output_device = torch.device('cpu')
        print_with_tag('Invalid use_gpu parameter. Defaulting to CPU.', print_tag)
    return output_device 

output_device = initialize_devices(use_gpu=True, gpu_id=0)

[DEVICE] Planning to run on GPU 0


## Data Preparation

We want the following from the CSV = ['xc', 'yb', 'surf', 'airfoil', 'alpha', 'M', 'Re', 'chord',' taper ratio', 'span', 'le sweep', 'cp', 'source']

In [3]:
def create_train_test_set(cp_file_path, coordinates_file_path, test_case_nums):
    # load in CSV file 
    raw_data = pd.read_csv(cp_file_path)
    raw_afcoord_data = pd.read_csv(coordinates_file_path).values[:, 1:]
    
    # Pre-proess training data 
    ## subsample to below stall angles 
    data = raw_data[raw_data['alpha']<=12]
    af_coord = raw_afcoord_data[raw_data['alpha']<=12] 
    data = data[['xc','surf','yb', 'alpha', 'M', 'chord', 'taper ratio', 'span', 'le sweep', 'cp', 'case', 'source','airfoil']]

    ## Convert xc to xhat and yhat
    xhat = data['xc'].values*2-1
    yhat = np.where(data['surf'] == 'U', np.sin(np.arccos(xhat)), -np.sin(np.arccos(xhat)))

    ## Convert degrees to radians
    alpha = torch.deg2rad(torch.tensor(data['alpha'].values))
    sweep_alpha = torch.deg2rad(torch.tensor(data['le sweep'].values))

    ## Non-dimensionalize span 
    spanwise = data['yb'].values#*data['span'].values/data['chord'].values
    ar = (data['span'].values/data['chord'].values)

    ## Select test cases 
    test_cases_all = np.array(test_case_nums)
    test_indices = np.where(np.isin(data['case'].values.flatten(), test_cases_all))[0]

    ## Select train cases 
    train_target_cases = np.arange(1, np.max(data['case'].values)+1)
    train_target_cases = np.delete(train_target_cases, test_cases_all-1)
    train_indices = np.nonzero(np.isin(data['case'], train_target_cases))[0]
    
    ## All cases
    all_x = torch.tensor(np.hstack((af_coord, xhat.reshape((-1,1)), yhat.reshape((-1,1)), spanwise.reshape((-1,1)), alpha.reshape((-1,1)), data[['M', 'taper ratio']].values, sweep_alpha.reshape((-1,1)), ar.reshape((-1,1)), data[['case']].values) ) )
    all_cases = data[['case']].values.flatten()
    all_y = torch.tensor(data['cp'].values) 

    ## Training cases 
    train_x = torch.tensor(np.hstack((af_coord, xhat.reshape((-1,1)), yhat.reshape((-1,1)), spanwise.reshape((-1,1)), alpha.reshape((-1,1)), data[['M', 'taper ratio']].values, sweep_alpha.reshape((-1,1)), ar.reshape((-1,1)), data[['case']].values) ) )[train_indices]
    train_cases = data[['case']].values[train_indices].flatten()
    train_y = torch.tensor(data['cp'].values)[train_indices]
    scaler_mean = torch.mean(train_y)
    scaler_scale = 10
    train_y = (train_y - scaler_mean) * scaler_scale
    train_afs = af_coord[train_indices]

    ## Test cases 
    test_x = torch.tensor(np.hstack((af_coord, xhat.reshape((-1,1)), yhat.reshape((-1,1)), spanwise.reshape((-1,1)), alpha.reshape((-1,1)), data[['M', 'taper ratio']].values, sweep_alpha.reshape((-1,1)), ar.reshape((-1,1)), data[['case']].values) ) )[test_indices]
    test_y = torch.tensor(data['cp'].values)[test_indices]
    test_y = (test_y - scaler_mean) * scaler_scale
    test_cases = data[['case']].values[test_indices].flatten()
    test_afs = af_coord[test_indices]
    test_std = raw_data[raw_data['alpha']<=12.0]['std'].values[test_indices]
    test_surf = data[['surf']].values[test_indices].flatten()

    return [all_x, all_y, all_cases], [train_x, train_y, train_cases, train_afs, train_indices], [test_x, test_y, test_cases, test_std, test_afs, test_surf], [data, af_coord]

    
alls, trains, tests, [raw_data, raw_af_coord] = create_train_test_set(
    cp_file_path = './data/wing_dataset_20241114_moredata.csv', 
    coordinates_file_path = './data/airfoil_coordinates_20241114_moredata.csv',
    test_case_nums=[44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60,
       61, 62, 63, 92, 93, 94, 95, 96, 97, 98, 100, 91, 14, 64, 65, 66, 67, 68]
)

all_x = alls[0]
train_x, train_y, train_case, train_af, train_inds = trains[0], trains[1], trains[2], trains[3], trains[4]
test_x, test_y, test_case, test_std, test_afs, test_surf = tests[0], tests[1], tests[2], tests[3], tests[4], tests[5]
all_cases_unique = np.unique(alls[2])

# push to gpu
train_x = train_x.to(output_device)
train_y = train_y.to(output_device)
scaler_mean = -0.2241
scaler_scale = 10 

Notes: 
- training data goes from 1.0 (TE, upper) to 0.0 to -1.0 (TE, lower)

In [4]:
raw_data = pd.read_csv('./data/wing_dataset_20241114_moredata.csv')

## subsample to below stall angles 
data = raw_data[raw_data['alpha']<=12] 
data = data[['xc','surf','yb', 'alpha', 'M', 'chord', 'taper ratio', 'span', 'le sweep', 'cp', 'case', 'source','airfoil']]

### Select specific test case 

In [5]:
# Best used for validation data 
def select_test_case(case_num, tests):
    # select a single case from test data
    test_x_, test_y_, test_case_, test_std_, test_afs_, test_surf_ = tests[0], tests[1], tests[2], tests[3], tests[4], tests[5]
    sel_indices = np.nonzero(np.isin(test_case_, case_num))[0]
    true_test_x = test_x_[sel_indices]
    true_test_y = test_y_[sel_indices]/scaler_scale + scaler_mean
    true_test_std = test_std_[sel_indices]
    true_test_surf = test_surf_[sel_indices] 
        
    # Override test case 3 with tunnel corrections 
    # Note that other cases in this technical expt were not corrected
    if case_num == [92]: 
        testcase3 = pd.read_csv('./data/testcase3_tunnel_corrections.csv')
        af_coords_ = np.tile(test_afs_[sel_indices][0], (testcase3['xc'].shape[0], 1))
        xhat_ = (testcase3['xc'].values*2-1)[:,None]
        yhat_ = np.where(testcase3['surf'] == 'U', np.sin(np.arccos(xhat_.flatten())), -np.sin(np.arccos(xhat_.flatten())))[:,None]
        alpha_ = np.deg2rad(4.1)*np.ones_like(xhat_) # 0.0698
        sweep_alpha_ = 0.0*np.ones_like(xhat_)
        spanwise_ = testcase3['yb'].values[:,None]
        ar_ = 3.2882*np.ones_like(xhat_)  
        m_ = 0.17*np.ones_like(xhat_)
        tr_ = 1.0*np.ones_like(xhat_) 
        true_test_x = torch.tensor(np.hstack((af_coords_, xhat_, yhat_, spanwise_, alpha_, m_, tr_, sweep_alpha_, ar_, np.ones_like(xhat_)*92)))
        true_test_y = torch.tensor(testcase3['cp'].values) 
        true_test_std = torch.zeros_like(torch.tensor(xhat_))
        true_test_surf = testcase3['surf'].values
        
    print_with_tag('Test Case Information: ', 'LOG')
    print('    Source: '  + raw_data['source'][raw_data['case']==case_num[0]].values[0])
    print('    Geometry: ')
    print('        Airfoil section: ' + raw_data['airfoil'][raw_data['case']==case_num[0]].values[0])
    print('        LE sweep angle : ' + str(np.round(np.rad2deg(true_test_x[0, -3].item()), 2)) + ' deg')
    print('        Taper ratio    : ' + str(true_test_x[0, -4].item()))
    print('        Semi-span / c  : ' + str(np.round(true_test_x[0, -2].item(), 2)))
    print('    Operating condition: ')
    print('        Angle of attack : ' + str(np.round(np.rad2deg(true_test_x[0, -6].item()), 2)) + ' deg')
    print('        Freestream Mach : ' + str(true_test_x[0, -5].item()))
    return [true_test_x, true_test_y, true_test_std, true_test_surf] 

# take base case and turn it into high resolution test data - best used for contour data 
def generate_test_data(airfoil_input, le_sweep_ang, taper_ratio, semispan, alpha, mach, num_pt_side=120, num_secs=42, ):
    if isinstance(num_pt_side, int):
        tiling_num = 2*num_pt_side * num_secs
    else: 
        tiling_num = (num_pt_side[0].shape[0]+num_pt_side[1].shape[0]) * num_secs
    # Get airfoil information 
    af_coord_ = airfoil_input[0]
    af_coord = np.tile(af_coord_.reshape((1, -1)), (tiling_num, 1)) # airfoil coordinates (56) 
    af_case_ = airfoil_input[1]
    af_case = np.ones((tiling_num, 1)) * af_case_ # airfoil case (is a guide)
    
    # Get wing parametrization 
    sweep = np.ones((tiling_num, 1)) * np.deg2rad(le_sweep_ang) # leading edge sweep angle [deg as input]
    taper = np.ones((tiling_num, 1)) * taper_ratio
    span = np.ones((tiling_num, 1)) * semispan
    
    # Get operating condition 
    alph = np.ones((tiling_num, 1)) * np.deg2rad(alpha)
    minf = np.ones((tiling_num, 1)) * mach
    
    # Get coordinates 
    if isinstance(num_pt_side, int):
        xhat = np.hstack((np.flip(1.0 - 2*np.cos(np.linspace(0.0, np.pi/2, num_pt_side))), np.flip(1.0 - 2*np.cos(np.linspace(np.pi/2, 0.0 , num_pt_side))))) 
        yhat = np.hstack((np.sin(np.arccos(xhat[:num_pt_side])), -np.sin(np.arccos(xhat[num_pt_side:]))))
        span_locs = np.flip(np.cos(np.linspace(0.0, np.pi/2, num_secs))) # cosine spacing, spanwise direction
        span_locs_all = np.kron(span_locs, np.ones(num_pt_side*2))[:,None]
    else: 
        xhat = np.hstack((num_pt_side[0]*2-1, num_pt_side[1]*2-1))
        yhat = np.hstack((np.sin(np.arccos(num_pt_side[0]*2-1)), -np.sin(np.arccos(num_pt_side[1]*2-1))))
        span_locs = np.flip(np.cos(np.linspace(0.0, np.pi/2, num_secs))) # cosine spacing, spanwise direction
        span_locs_all = np.kron(span_locs, np.ones(num_pt_side[0].shape[0]+num_pt_side[1].shape[0]))[:,None]

    xyhat = np.tile(np.hstack((xhat[:,None], yhat[:,None])), (num_secs,1)) # xhat yhat repeating per span loc 
    out_x = torch.tensor(np.hstack((af_coord, xyhat, span_locs_all, alph, minf, taper, sweep, span, af_case))) 
    return out_x

# For test cases shown in paper 
# test case 1: 91
# test case 2: 44
# test case 3: 92
target_case = 92 
true_tests = select_test_case([target_case], [test_x, test_y, test_case, test_std, test_afs, test_surf])
if output_device.type != 'cpu':
    true_test_x, true_test_y, true_test_std, true_test_surf = true_tests[0].cuda(), true_tests[1].cuda(), true_tests[2], true_tests[3]
else: 
    true_test_x, true_test_y, true_test_std, true_test_surf = true_tests[0], true_tests[1], true_tests[2], true_tests[3]

# True Test_x
true_test_x_u, true_test_x_l = (true_test_x[true_test_surf=='U'][:,56] + 1) / 2, (true_test_x[true_test_surf=='L'][:,56] + 1) / 2
true_test_y_u, true_test_y_l = true_test_x[true_test_surf=='U'][:,58], true_test_x[true_test_surf=='L'][:,58]
true_test_cp_u, true_test_cp_l = true_test_y[true_test_surf=='U'], true_test_y[true_test_surf=='L']
true_test_ci_u, true_test_ci_l = 2*true_test_std[true_test_surf=='U'], 2*true_test_std[true_test_surf=='L']

# Test-x for high resolution prediction 
# num_pts_per_side, num_span_sec = 120, 42 
# test_AoA = 8.85
# test_Minf = 0.13
# test_wing_sweep  = 0.0
# test_wing_taper  = 1.0
# test_wing_semiAR = 2.95
# test_wing_airfoil = [true_test_x[0, :28*2].cpu().detach().numpy(), target_case]

[LOG] Test Case Information: 
    Source: NASA-TP-3151
    Geometry: 
        Airfoil section: NACA0015
        LE sweep angle : 0.0 deg
        Taper ratio    : 1.0
        Semi-span / c  : 3.29
    Operating condition: 
        Angle of attack : 4.1 deg
        Freestream Mach : 0.17


## Large Airfoil Model-prior

In [6]:
# run only if ... 
from large_airfoil_model_lightweight import lam_adapt # This is the updated lam_adapt with SVDKL model 

lam_model, lam_likelihood = lam_adapt.unpack_model(model_version='v2', output_device=output_device, verbose=True)

num_per_case = 3
num_pts_per_surface = 600
mean_list = []
cov_list = []
case_iterable = tn.tqdm(all_x[:, -1].unique())
for case in case_iterable:
    case_train_entry_ = all_x[all_x[:, -1]==case][0]
    # Obtain operating conditions from case 
    case_aoa_ = np.rad2deg(case_train_entry_[-6].item())
    case_aoa_range_ = np.linspace(0.0, case_aoa_, num_per_case)
    case_mach_ = case_train_entry_[-5].item()
    
    # Generate array describing airfoil geometry 
    # this entry does not scale the airfoil and goes from TE - LE - TE 
    case_xc_ = np.array([0, 0.0025, 0.0075, 0.01, 0.015, 0.02, 0.025, 0.05, 0.075, 0.1, 
                                0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6,
                                0.65, 0.70, 0.75, 0.8, 0.85, 0.90, 0.95, 1.0])
    case_zcu_ = case_train_entry_[:28].cpu().detach().numpy()
    case_zcl_ = case_train_entry_[28:28*2].cpu().detach().numpy()
    case_af_ = np.hstack((np.hstack((np.flip(case_xc_), case_xc_))[:,None], 
                          np.hstack((case_zcu_, case_zcl_))[:,None]))
    
    pred_mu_ = []
    pred_cov_ = []
    # Sweep thru AoA to get the "middle ground" prior
    for aoa__ in case_aoa_range_:
        case_input__ = lam_adapt.input_data(case_af_, aoa__, case_mach_, num_auto_points=num_pts_per_surface, output_device=output_device, model_version='v2')
        case_tensor__ = case_input__.assemble_tensor()
        # rearrange 
        case_tensor__ = torch.vstack((torch.flip(case_tensor__[:num_pts_per_surface, :], dims=[0]), case_tensor__[num_pts_per_surface:, :]))
        pred__ = lam_model.predict(case_tensor__,)
        pred_mu_.append((pred__['cp_distribution'].mean.cpu().detach().numpy()-scaler_mean) * scaler_scale)
        pred_cov_.append(pred__['cp_distribution'].covariance_matrix.cpu().detach().numpy() * scaler_scale**2)

    mean_list.append(np.mean(np.stack(pred_mu_), axis=0))
    cov_list.append(np.mean(np.stack(pred_cov_), axis=0)) 
    xhat_locs = pred__['xc'].cpu().detach().numpy()*2-1

[DEBUG] Loading in model...
[DEBUG] Loading complete!


  0%|          | 0/87 [00:00<?, ?it/s]

In [7]:
# import positivity constraint
from gpytorch.constraints import GreaterThan, Interval 
    
# Make a smooth function based on a superposition of generalized logistic sigmoid functions
def makeSmoothFx(X, vals, deltaXs, rates):
    if len(vals) == 3:
        smoothedFx = vals[0] + \
                + generalizedLogiSig(X, asym_L=0.0, asym_R=vals[1]-vals[0], rate=rates[0], shift = -deltaXs[0], nu=1) \
                + generalizedLogiSig(X, asym_L=0.0, asym_R=vals[2]-vals[1], rate=rates[1], shift = deltaXs[1], nu=1)  #  
        
    elif len(vals) == 2:
        smoothedFx = vals[0] + \
                + generalizedLogiSig(X, asym_L=0.0, asym_R=vals[1]-vals[0], rate=rates[0], shift = deltaXs[0], nu=1) 
    return smoothedFx

# Generalized logistic sigmoid function
def generalizedLogiSig(X, asym_L = 0.0, asym_R = 1.0, rate = 1, shift = 0.0, nu = 1):
    y = asym_L + (asym_R - asym_L)/(1 + torch.exp(-rate*(X-shift)))**(1/nu)
    return y

# Spacially varying Matern Kernel 
class SVH_Matern_2var_1d(gpytorch.kernels.Kernel):
    is_stationary = True
    def __init__(self, rates=[40], **kwargs):
        # Initialize the base ScaleKernel
        super().__init__(**kwargs)
        self.rates = rates 
        # Define your custom scaling function based on x1 and x2
        self.register_parameter(name="raw_ell_1", parameter=torch.nn.Parameter(torch.tensor([1.0, 0.1]))) # two lengthscales, inboard & outboard
        self.register_parameter(name="raw_var_1", parameter=torch.nn.Parameter(torch.tensor([1.0, 1.0]))) # two vars, accordingly
        self.register_parameter(name="raw_locs", parameter=torch.nn.Parameter(torch.tensor([0.75]))) # one location of transition
        
        # Include Constraints
        self.register_constraint("raw_ell_1", GreaterThan(1e-2))
        self.register_constraint("raw_var_1", GreaterThan(1e-2))
        self.register_constraint("raw_locs", Interval(1e-2, 1.0))
        # prior 
        self.register_prior(
                "ell_1_prior",
                gpytorch.priors.MultivariateNormalPrior(torch.as_tensor([1.0, 0.1]), torch.diag(torch.tensor([0.05, 0.05])**2)),
                lambda m: m.ell_1,
                lambda m, v : m._set_ell_1(v),)
        self.register_prior(
                "var_1_prior",
                gpytorch.priors.MultivariateNormalPrior(torch.as_tensor([1.0, 1.0]), torch.eye(2)*1**2),  # torch.eye(2)*1
                lambda m: m.var_1,
                lambda m, v : m._set_var_1(v),)
        self.register_prior( # This is not used as of now 
                "locs_prior",
                gpytorch.priors.NormalPrior(0.75, 0.1), 
                lambda m: m.locs,
                lambda m, v : m._set_locs(v),)

    @property
    def ell_1(self):
        return self.raw_ell_1_constraint.transform(self.raw_ell_1)
    @ell_1.setter
    def ell_1(self, values):
        return self._set_ell_1(values)
    def _set_ell_1(self, values):
        if not torch.is_tensor(values):
            values = torch.as_tensor(values).to(self.ell_1)
        self.initialize(raw_ell_1=self.raw_ell_1_constraint.inverse_transform(values))
        
    @property
    def var_1(self):
        return self.raw_var_1_constraint.transform(self.raw_var_1)
    @var_1.setter
    def var_1(self, value):
        return self._set_var_1(value)
    def _set_var_1(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.var_1)
        self.initialize(raw_var_1=self.raw_var_1_constraint.inverse_transform(value))    
    
    @property
    def locs(self):
        return self.raw_locs_constraint.transform(self.raw_locs)
    @locs.setter
    def locs(self, values):
        return self._set_locs(values)
    def _set_locs(self, values):
        if not torch.is_tensor(values):
            values = torch.as_tensor(values).to(self.locs)
        self.initialize(raw_locs=self.raw_locs_constraint.inverse_transform(values))
        
    def forward(self, x1, x2, diag=False, **params):
        spatially_varying_length_sq = self.spatially_varying_function(x1, x2, self.ell_1)
        spatially_varying_scale = self.spatially_varying_function(x1, x2, self.var_1)
        xhat_dist = self.covar_dist(x1, x2)
        
        # Compute the Matern 5/2 kernel function
        # k(x, x') = (1 + sqrt(5) * d / l + 5 / 3 * (d / l)^2) * exp(-sqrt(5) * d / l)
        xdist_l = xhat_dist/spatially_varying_length_sq.pow(0.5)
        term1_1 = 1 + torch.sqrt(torch.tensor(3.0))*xdist_l  
        term1_2 = torch.exp(-torch.sqrt(torch.tensor(3.0))*xdist_l)
        out = term1_1*term1_2* spatially_varying_scale 
        return out 

    def spatially_varying_function(self, x1, x2, hyperparams):
        x1 = x1.flatten()
        x2 = x2.flatten()
        # Define a custom scale function that depends on x1 and x2
        val_x1_1d = torch.zeros((x1.shape[0],))
        val_x2_1d = torch.zeros((x2.shape[0],))
        
        # Apply the piecewise conditions for the scale function
        val_x1_1d = makeSmoothFx(x1, hyperparams, [0.7], self.rates) 
        val_x2_1d = makeSmoothFx(x2, hyperparams, [0.7], self.rates)  
        val_ex1 = val_x1_1d.reshape((-1,1)).tile((1, x2.shape[0]))
        val_ex2 = val_x2_1d.reshape((1,-1)).tile((x1.shape[0], 1))
        
        spatially_varying = val_ex1 * val_ex2
        return spatially_varying

In [8]:
from copy import deepcopy
import tqdm.notebook as tn  
import pickle
from scipy.interpolate import RegularGridInterpolator, CubicSpline
""" 
Covariance class 
"""
class lam_prior_cov(gpytorch.kernels.Kernel):
    is_stationary = True
    def __init__(self, covar_interpolator, train_x, guide, output_device='cpu', 
                 verbose=False, active_dims=torch.tensor([56, 57, 58, 64], dtype=int), **kwargs):
        # Initialize the base ScaleKernel
        super().__init__(**kwargs)
        self.output_device = output_device # output device, cpu or cuda
        self.interpolator = covar_interpolator # covariance interpolator 
        self.active_dims = active_dims.to(self.output_device) 
        self.register_buffer("active_dims", self.active_dims)
        self.train_x = train_x # training data to use as reference 
        self.case_guide = torch.from_numpy(guide).cpu() # case guide
        self.verbose = verbose # show logs or nah
        self.train_cache = None  # cache for training data covariance 
        # Cache the covariance matrix for training data  
        print_with_tag('Computing cache for training data covariance matrix.', 'LOG') if self.verbose else None 
        self.train_cache = self(self.train_x).evaluate()
    
    # __call__ calls this with lazy evaluation
    """ Evaluate the LAM-prior covariance matrix to obtain K_{x1. x2}
    x1 and x2 should follow:
    [xhat, yhat, spanwise location, case_number] 
    """
    # this runs at least x2 faster on CPU 
    def forward_(self, x1, x2, diag=False, **params):   
        # If cache exists and we are computing training data covar, Kxx
        if self.train_cache is not None and x1.shape[0] == self.train_x.shape[0] and x1.shape[0] == x2.shape[0]:
            # If the covariance is already cached, use it
            # print_with_tag('Reading in cache for training data covariance matrix.', 'LOG') if self.verbose else None 
            K = self.train_cache
        else: 
            # K = [[A  , ..., B  ]
            #      [..., ..., ...]
            #      [C  , ..., D  ]]
            # obtain unique spanwise locations 
            x1_ = x1.cpu()
            x2_ = x2.cpu()
            # Compute the covariance matrix using the interpolator 
            x1_unique_case = torch.unique(x1_[:,3])
            x2_unique_case = torch.unique(x2_[:,3]) 
            k_row = [] # covariance matrix row-wise piece (i.e. [A, ..., B] or [C, ..., D]) 
            with torch.no_grad():
                # Iterate through case by case 
                
                for i in x1_unique_case:   
                    
                    k_block = [] # represents individual blocks made from the "case" (i.e. A, B, C, D)
                    for j in x2_unique_case: 
                        # Obtain relevant location within the row 
                        target_span_x1_idx = torch.argmin(torch.abs(torch.as_tensor(self.case_guide) - i))
                        target_span_x2_idx = torch.argmin(torch.abs(torch.as_tensor(self.case_guide) - j))
                        
                        # subsample from x1 and x2 the relevant cases 
                        x1_subsample = x1_[x1_[:,3]==i].clone()
                        x2_subsample = x2_[x2_[:,3]==j].clone()
                        
                        # Get the spanwise locations from each case
                        x1_unique_span = torch.unique(x1_subsample[:,2])
                        x2_unique_span = torch.unique(x2_subsample[:,2]) 
                        k_segment = [] # represents individual spanwise locations within each case 
                        for u in x1_unique_span: 
                            k_subsegment = []
                            for v in x2_unique_span:
                                x1_subsubsample = x1_subsample[x1_subsample[:,2]==u].clone()
                                x2_subsubsample = x2_subsample[x2_subsample[:,2]==v].clone()
                                
                                # Take spanwise locations, convert to non-dimensional chordwise locations (x/c) from xhat yhat 
                                x1_subsubsample[x1_subsubsample[:,1] >= 0.0, 0] = -(x1_subsubsample[x1_subsubsample[:,1] >= 0.0, 0] + 1)/2.0
                                x1_subsubsample[x1_subsubsample[:,1]  < 0.0, 0] =  (x1_subsubsample[x1_subsubsample[:,1]  < 0.0, 0] + 1)/2.0
                                
                                x2_subsubsample[x2_subsubsample[:,1] >= 0.0, 0] = -(x2_subsubsample[x2_subsubsample[:,1] >= 0.0, 0] + 1)/2.0
                                x2_subsubsample[x2_subsubsample[:,1]  < 0.0, 0] =  (x2_subsubsample[x2_subsubsample[:,1]  < 0.0, 0] + 1)/2.0
                                    
                                x1x1, x2x2 = np.meshgrid(x1_subsubsample[:,0].cpu(), x2_subsubsample[:,0].cpu())  

                                # Evaluate the covariance matrix using the interpolator 
                                if u == v: # Diagonal elements 
                                    k_subsegment.append(torch.tensor(self.interpolator[target_span_x2_idx]((x2x2, x1x1))).T)
                                else:  # assume independence b/n cases 
                                    k_subsegment.append(torch.zeros((x1_subsubsample.shape[0], x2_subsubsample.shape[0])))
                            # add 'em to be concated a single block
                            k_segment.append(torch.hstack(k_subsegment)) 
                        # concatenate the each spanwise segment into a single block
                        if i == j: 
                            k_block.append(torch.cat(k_segment))
                        else:  
                            k_block.append(torch.zeros((x1_subsample.shape[0], x2_subsample.shape[0])))
                    # append each element of the row into a single row 
                    k_row.append(torch.hstack(k_block)) 
                # concatenate the rows into a single covariance matrix 
                K = torch.cat(k_row)
                K = K.to(self.output_device)
        return K 

    # ...existing code...
    def forward(self, x1, x2, diag=False, **params):
        # Use cache if possible
        if (
            self.train_cache is not None
            and x1.shape[0] == self.train_x.shape[0]
            and x1.shape[0] == x2.shape[0]
        ):
            return self.train_cache

        x1_ = x1.cpu()
        x2_ = x2.cpu()
        case_guide = self.case_guide

        # Get unique cases and spanwise locations, and their indices
        x1_cases, x1_case_idx = torch.unique(x1_[:, 3], return_inverse=True)
        x2_cases, x2_case_idx = torch.unique(x2_[:, 3], return_inverse=True)
        x1_spans, x1_span_idx = torch.unique(x1_[:, 2], return_inverse=True)
        x2_spans, x2_span_idx = torch.unique(x2_[:, 2], return_inverse=True)

        # Precompute the mapping from (case, span) to row indices
        x1_case_span = torch.stack([x1_case_idx, x1_span_idx], dim=1)
        x2_case_span = torch.stack([x2_case_idx, x2_span_idx], dim=1)

        # Build a lookup for each (case, span) group
        def build_group_lookup(unique_cases, unique_spans, case_idx, span_idx):
            group_lookup = {}
            for i, case in enumerate(unique_cases):
                for j, span in enumerate(unique_spans):
                    idx = (case_idx == i) & (span_idx == j)
                    if idx.any():
                        group_lookup[(case.item(), span.item())] = idx.nonzero(as_tuple=True)[0]
            return group_lookup

        x1_lookup = build_group_lookup(x1_cases, x1_spans, x1_case_idx, x1_span_idx)
        x2_lookup = build_group_lookup(x2_cases, x2_spans, x2_case_idx, x2_span_idx)

        # Preallocate output
        K = torch.zeros(x1_.shape[0], x2_.shape[0], device='cpu')

        # For each block (case, span) pair, fill in the covariance
        with torch.no_grad():
            for (case1, span1), idx1 in x1_lookup.items():
                for (case2, span2), idx2 in x2_lookup.items():
                    # Only fill diagonal blocks (same case and span)
                    if (case1 == case2) and (span1 == span2):
                        # Get indices for the case in the guide
                        target_span_idx = torch.argmin(torch.abs(case_guide - case1)).item()
                        # Extract sub-blocks
                        x1_block = x1_[idx1]
                        x2_block = x2_[idx2]
                        # Convert xhat as in original code
                        x1_xhat = x1_block[:, 0].clone()
                        x2_xhat = x2_block[:, 0].clone()
                        x1_yhat = x1_block[:, 1]
                        x2_yhat = x2_block[:, 1]
                        # Vectorized transformation
                        x1_xhat[x1_yhat >= 0.0] = -(x1_xhat[x1_yhat >= 0.0] + 1) / 2.0
                        x1_xhat[x1_yhat < 0.0] = (x1_xhat[x1_yhat < 0.0] + 1) / 2.0
                        x2_xhat[x2_yhat >= 0.0] = -(x2_xhat[x2_yhat >= 0.0] + 1) / 2.0
                        x2_xhat[x2_yhat < 0.0] = (x2_xhat[x2_yhat < 0.0] + 1) / 2.0
                        # Meshgrid for all pairs
                        X1, X2 = torch.meshgrid(x1_xhat, x2_xhat, indexing='ij')
                        # Evaluate covariance
                        K_block = torch.tensor(
                            self.interpolator[target_span_idx]((X2.cpu().numpy(), X1.cpu().numpy()))
                        ).T
                        # Assign to output
                        K[idx1[:, None], idx2] = K_block
                    else:
                        # Off-diagonal blocks remain zero (already initialized)
                        continue

        return K.to(self.output_device)
# ...existing code...
from copy import deepcopy

"""
mean class 
"""
class lam_prior_mean(gpytorch.means.Mean):
    def __init__(self, mean_interpolator, train_x, guide, output_device='cpu', 
                 verbose=False, **kwargs):
        # Initialize the base ScaleKernel
        super().__init__(**kwargs)
        self.output_device = output_device # output device, cpu or cuda
        self.interpolator = mean_interpolator # mean interpolator 
        self.train_x = train_x # training data to use as reference 
        self.case_guide = torch.from_numpy(guide).cpu() # case guide
        self.verbose = verbose # show logs or nah
        self.train_cache = None  # cache for training data covariance 
        # Cache the covariance matrix for training data  
        print_with_tag('Computing cache for training data mean.', 'LOG') if self.verbose else None 
        self.train_cache = self(self.train_x)
        
    def forward(self, x):  
        x_ = x.cpu()
        with torch.no_grad():
            # Operating on training data 
            if x_.shape[0] == self.train_x.shape[0]:
                if self.train_cache is None:
                    # calculate for training data 
                    x_unique_case = torch.unique(x_[:,-1])  
                    mu_ = []
                    for i in x_unique_case:
                        target_span_idx = torch.argmin(torch.abs(torch.as_tensor(self.case_guide) - i))
                        x_subsample = x_[x_[:,-1]==i][:, 56:56+2+1].clone()
                        # print(x_subsample)
                        x_subsample[x_subsample[:,1]>=0.0, 0] = -(x_subsample[x_subsample[:,1]>=0.0, 0]+1)/2
                        x_subsample[x_subsample[:,1] <0.0, 0] =  (x_subsample[x_subsample[:,1] <0.0, 0]+1)/2
                        mu_.append(torch.tensor(self.interpolator[target_span_idx](x_subsample[:,0])))
                    mu = torch.cat(mu_)
                else: 
                    # retrieve cached value 
                    mu = self.train_cache
            else: 
                # calcualte to get mu
                # xx = x_[:self.train_x.shape[0]]
                # ss = x_[self.train_x.shape[0]:]
                
                # xx_unique_case = torch.unique(xx[:,-1]) # unique case 
                # mu_x = []
                # for i in xx_unique_case:
                #     target_span_idx = torch.argmin(torch.abs(torch.as_tensor(self.case_guide) - i))
                #     xx_subsample = x_[x_[:,-1]==i][:, 56:56+2+1].clone() #deepcopy(xx[xx[:,-1]==i])
                    
                #     xx_subsample[xx_subsample[:,1]>=0.0, 0] = -(xx_subsample[xx_subsample[:,1]>=0.0, 0]+1)/2
                #     xx_subsample[xx_subsample[:,1]< 0.0, 0] =  (xx_subsample[xx_subsample[:,1]< 0.0, 0]+1)/2 
                #     mu_x.append(torch.tensor(self.interpolator[target_span_idx](xx_subsample[:,0]))) 
                    
                ss_unique_case = torch.unique(x_[:,-1]) # unique case 
                mu_s = []
                for j in ss_unique_case:
                    target_span_idx = torch.argmin(torch.abs(torch.as_tensor(self.case_guide) - j))
                    ss_subsample = x_[x_[:,-1]==j][:, 56:56+2+1].clone()
                    
                    ss_subsample[ss_subsample[:,1]>=0.0, 0] = -(ss_subsample[ss_subsample[:,1]>=0.0, 0]+1)/2
                    ss_subsample[ss_subsample[:,1]< 0.0, 0] =  (ss_subsample[ss_subsample[:,1]< 0.0, 0]+1)/2
                    mu_s.append(torch.tensor(self.interpolator[target_span_idx](ss_subsample[:,0])))
                mu = torch.cat(mu_s).flatten()
                #torch.cat((torch.cat(mu_x).reshape((-1,1)), torch.cat(mu_s).reshape((-1,1)))).flatten()
        
            mu = mu.flatten().to(self.output_device)
        return mu    
    
"""
Class to manage both of them at once 
"""
class lam_prior():
    def __init__(self, priors, case_guides, output_device='cpu', verbose=False):
        self.verbose = verbose
        if isinstance(priors, str): # if providing directory
            with open(priors, 'rb') as file:
                pickle_loaded = pickle.load(file)
                prior_mean_list = pickle_loaded['prior_mean']
                prior_cov_list = pickle_loaded['prior_cov']
                updated_inputs = pickle_loaded['ex_input'][0][:, -2]
        elif isinstance(priors, list): # if providing list 
            prior_mean_list = priors[0]
            prior_cov_list = priors[1]
            updated_inputs = priors[2]
        else: 
            raise ValueError('invalid input')
        # This needs to be streamlined later 
        mu_interpolator, covar_interpolator, cases = self.create_interpolators(prior_mean_list, prior_cov_list, [updated_inputs[:num_pts_per_surface], updated_inputs[num_pts_per_surface:]], case_guides)

        self.mean_module = lam_prior_mean(mu_interpolator, train_x, cases, output_device=output_device, verbose=self.verbose)
        self.covar_module = lam_prior_cov(covar_interpolator, train_x, cases, output_device=output_device, verbose=self.verbose)
        
    """
    Creates interpolator for the mean and covariance for LAM-prior
    Doing this rigourously would require calling int the LAM (will be implemented later)
    For simplicity, the priors for each case is precomputed using the LAM and and interpolating appropriately on the test data
    mean_list: list of mean values for LAM predictions
    covar_list: list of covariance matrices for LAM predictions
    xhat_locs: xhat locations for the upper and lower surfaces 
    cases: case number to use as guides 
    """
    def create_interpolators(self, mean_list, covar_list, xhat_locs, cases):
        # Convert xhat locations to non-dimensional chordwise directions
        xc_u = (xhat_locs[0]+1)/2 # Upper surface
        xc_l = (xhat_locs[1]+1)/2 # Lower surface
        mu_interp, cov_interp = [], [] 
        for i in range(0, len(cases)):
            # unravel such that it goes from 
            xc_unravel = np.hstack((-xc_u[:-1], xc_l))
            
            # get rid of the repeating zero within this data 
            popped_mean = np.delete(mean_list[i], xc_u.shape[0], 0)
            popped_covar = np.delete(covar_list[i], xc_u.shape[0], 0)
            popped_covar = np.delete(popped_covar, xc_u.shape[0], 1)

            # Create interpolator
            mu_interp.append(CubicSpline(xc_unravel, popped_mean))
            cov_interp.append(RegularGridInterpolator((xc_unravel, xc_unravel), popped_covar))
            case_guide = cases
        return mu_interp, cov_interp, case_guide

## Define model

[56, xh, yh, yb, a, m, l, Lambda, ar, case]

In [9]:
# from kernel_functions import SVH_Matern_2var_1d
# Define model 

# Wrapper for convenience 
class large_wing_model():
    def __init__(self, model, likelihood, optimizer, scheduler, verbose=False, output_device='cpu'):
        # I want to have DKL_GP in here with optimizer and stuff
        self.output_device = output_device
        self.model = model.to(output_device)
        self.likelihood = likelihood.to(output_device)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.history = {'loss': [],
                        'validation_err': [],
                        'checkpoints': []}
        self.save_directory = './'
        self.verbose = verbose
        
    def train_model(self, training_iterations, save_name, save_interval=200):
        # Set up loss function - Marginal Log Likelihood 
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)  
        # from soft_dtw_cuda import SoftDTW
        # sdtw = SoftDTW(use_cuda=True, gamma=0.1)
        checkpt_ct = 1 
        sub_iter = 0 
        main_iter = 0 
        iterator = tn.tqdm(range(training_iterations))
        for i in iterator: 
            self.model.train()
            self.likelihood.train()
            
            # Zero backprop gradients
            self.optimizer.zero_grad()
        
            # Get output from model
            output = self.model(train_x)
            
            # Calculate loss 
            loss = -mll(output, train_y)  
            loss.backward()
            iterator.set_postfix(loss=loss.item())
            self.history['loss'].append(loss.item())
            
            if sub_iter >= save_interval-1:
                with torch.no_grad():
                    val_pred = self.predict(test_wing=true_test_x.to(self.output_device), save_to=None) 
                    # Use the sdtw from https://github.com/Maghoumi/pytorch-softdtw-cuda if ya wants
                    val_loss = 0.0 #sdtw(true_test_y.reshape((1, -1, 1)).cuda(), val_pred[0].mean.reshape((1, -1, 1)).cuda())
                    self.history['validation_err'].append(val_loss)  # .item()
                    val_loss_2  = ((true_test_y.cuda() - val_pred[0].mean.cuda())**2).sum()
                    a = np.trapezoid(np.abs(true_test_y.cpu().detach().numpy()))
                    b = np.trapezoid(np.abs(val_pred[0].mean.cpu().detach().numpy()))
                    print(f'Checkpoint {checkpt_ct} validation loss: {np.round(val_loss, 2)}, {np.round(np.abs(a-b), 2)}, {np.round(val_loss_2.item(), 2)}') # .item()
                self.history['checkpoints'].append(save_name + str(checkpt_ct)) 
                self.save_checkpoint(self.save_directory + self.history['checkpoints'][-1])
                checkpt_ct += 1
                sub_iter = 0  
            
            
            self.optimizer.step()
            self.scheduler.step()
            sub_iter += 1 
         
    """ Save checkpoint to user-provided file_path """
    def save_checkpoint(self, file_path): 
        torch.save({
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'likelihood': self.likelihood.state_dict()
        }, file_path)
         
    
    """ Load saved checkpoint from user-provided file_path"""
    def load_checkpoint(self, file_path):
        loaded_ = torch.load(file_path, weights_only=False, map_location='cpu')
        self.model.load_state_dict(loaded_['model_state_dict'])
        self.optimizer.load_state_dict(loaded_['optimizer_state_dict'])
        self.likelihood.load_state_dict(loaded_['likelihood']) 
        print_with_tag(f'Loading weight: {file_path}', 'LWM') if self.verbose else None

    def _gen_save_string(prefix, lr, decay, nn_dims, losses, iters, override_date=None):
        from datetime import date
        if override_date is not None:
            save_str_date = override_date
        else:
            today = date.today()
            save_str_date = today.strftime("%Y%m%d") + '_'
        save_str_prefix = prefix + '_' 
        save_str_optimizer = 'lr' + str(lr) + '_' + 'decay' + str(decay) + '_'
        save_str_dims = 'dims' + str(nn_dims[0]) + '-' + str(nn_dims[1]) + '-' + str(nn_dims[2]) + '-' + str(nn_dims[3]) + '_'
        save_str_checkpt = 'checkpt' + str(iters) + '_'
        save_str_loss = 'loss' + str(losses[0]) + '_'
        save_str_trainmetric = 'trainMetric' + str(losses[1]) + '_'
        save_str_testmetric = 'testLoss' + str(losses[2])
        return save_str_date + save_str_checkpt + save_str_prefix + save_str_optimizer + save_str_dims + save_str_loss + save_str_trainmetric + save_str_testmetric 
    
    """
    Generate predictions from the model;
    Manual implementation with LAM priors
    """
    def predict(self, test_wing, target_weights:list=[None], condition={'cl': None}, save_to:str=None, get_runtime=False):
        
        
        # Assemble input tensor from wing      
        if isinstance(test_wing, input_wing):
            test_data_list = test_wing.assemble_tensor()
        elif isinstance(test_wing, torch.Tensor): 
            test_data_list = [test_wing]
        else:
            raise ValueError("Invalid input wing")
        
        output_distr = [] 
        
        
        self.model.eval() 
        self.likelihood.eval()
        import time
        start_time = time.time()
        
        for test_idx, test_data in enumerate(test_data_list):
            print_with_tag('Generating prediction...', 'LWM') if self.verbose else None 
            # Get relevant prior 
            if self.model.lam_priors is not None: # if using LAM prior 
                Kxx_prior = self.model.prior_cov.train_cache 
                Kxs_prior = self.model.prior_cov(train_x, test_data).to(self.output_device).evaluate() #.evaluate()#.to(self.output_device)
                Kss_prior = self.model.prior_cov(test_data).to(self.output_device).evaluate() #.evaluate()#.to(self.output_device) 
                muxx_prior = self.model.mean_module.train_cache #.to(self.output_device)#.cpu().detach().numpy() 
                muss_prior = self.model.mean_module(test_data) #.to(self.output_device)#.cpu().detach().numpy() 
            else:  
                print("Error.... no prior")
                
            with torch.no_grad():   
                torch.cuda.empty_cache()
                # Get average from weights
                pred_mu = torch.zeros(test_data.shape[0]).to(self.output_device)
                pred_cov = torch.zeros(test_data.shape[0], test_data.shape[0]).to(self.output_device)
                for weight in target_weights:
                    if weight is not None: 
                        self.load_checkpoint(weight)
                        self.model.eval()
                        self.likelihood.eval()
                        
                    # Need to optimize?
                    def _project_x(x):
                        with torch.no_grad():   
                            proj_x = self.model.feature_extractor(torch.hstack((x[:,:58], x[:,59:-1]))) # do without eta 
                            proj_x = self.model.scale_to_bounds(proj_x)  # Make the NN values "nice" 
                            proj_x = torch.hstack((proj_x, x[:, 58:59])) # re-attach eta
                            return proj_x 
                    
                    # Feed train and test data thru m
                    proj_xx = _project_x(train_x)
                    proj_ss = _project_x(test_data)
                        
                    Kxx_base = self.model.covar_module(proj_xx).to(self.output_device).evaluate()
                    Kxs_base = self.model.covar_module(proj_xx, proj_ss).to(self.output_device).evaluate()
                    Kss_base = self.model.covar_module(proj_ss).to(self.output_device).evaluate() 

                    Kxx = Kxx_base + Kxx_prior  #.cpu().detach().numpy()
                    Kxs = Kxs_base + Kxs_prior #.cpu().detach().numpy()
                    Kss = Kss_base + Kss_prior #.cpu().detach().numpy()
                    jitter = torch.eye(Kxs.shape[0]).to(self.output_device)*1e-4
                    
                    if len(self.likelihood.noise) == 1:
                        sigma = torch.eye(Kxs.shape[0]).to(self.output_device)*self.likelihood.noise.item()
                    else: 
                        sigma = torch.diag(self.likelihood.noise).to(self.output_device)
                    
                    Kxx = (Kxx + Kxx.T)/2# symmetrize the matrix
                    L = torch.linalg.cholesky(Kxx + sigma + jitter, upper=False)
                    s1 = torch.linalg.solve_triangular(L, (train_y - muxx_prior).reshape((-1,1)), upper=False).T
  
                    s2 = torch.linalg.solve_triangular(L, Kxs, upper=False) 
                    
                    pred_mu += (s1 @ s2).flatten() + muss_prior.flatten()
                    pred_cov += Kss - s2.T @ s2
                    # Stabilize pred_cov
                    # pred_cov = (pred_cov + pred_cov.T)/2 
                pred_mu /= len(target_weights)
                pred_cov /= len(target_weights)
                output_distr.append(gpytorch.distributions.MultivariateNormal(pred_mu.cpu()/scaler_scale + scaler_mean, pred_cov.cpu()/scaler_scale**2 + torch.eye(pred_cov.shape[0]).cpu()*1e-4))
                end_time = time.time()
                print('runtime:', end_time - start_time)
                if isinstance(test_wing, input_wing):
                    if test_wing.input_type == 'mesh':
                        test_wing.mesh['zones_data'][test_idx][:, -2] = pred_mu.flatten().cpu().detach().numpy()/scaler_scale + scaler_mean # Assign mean value 
                        test_wing.mesh['zones_data'][test_idx][:, -1] = 2*np.sqrt(np.diag(pred_cov.cpu().detach().numpy())).flatten()/scaler_scale # Assign confidence interval 
                
                # Make sure that this is compatible with mesh style later 
                # right now this is only compatible with grid style
                if condition['cl'] is not None: # Expand this later   
                    # any condition should be defined as a [mean, st_dev]
                    condition_mean = condition['cl'][0] # target mean
                    condition_stdev = condition['cl'][1] # target std
                    print_with_tag(f'Constraining posterior to C_L = N({condition_mean}, {condition_stdev}^2)', 'LWM') if self.verbose else None 
                    
                    # Establish distributions
                    # Target distribution - constraint 
                    target_dist = torch.distributions.normal.Normal(loc=condition_mean, scale=condition_stdev)   
                    # generate samples to generate a single proposal distribution 
                    xhat_arr = test_wing.xhat.reshape((-1, test_wing.chord_resolution*2)) # xhat location, reshaped into grid 
                    eta_arr = test_wing.eta.reshape((-1, test_wing.chord_resolution*2)) # eta location, reshaped into grid 
                    accumulated_CL_samples = []
                    accumulated_cp_samples = []
                    accumulated_log_weights = []
                    cp_samples = output_distr[-1].rsample(torch.Size([10000]))
                    cp_samples_ = cp_samples.reshape((10000, -1, test_wing.chord_resolution*2)).detach().numpy()
                    cl_samples = -np.trapezoid(cp_samples_[:, :, :test_wing.chord_resolution], x=(xhat_arr[0][:test_wing.chord_resolution]+1)/2, axis=2) + \
                                    np.trapezoid(cp_samples_[:, :, test_wing.chord_resolution:], x=(xhat_arr[0][test_wing.chord_resolution:]+1)/2, axis=2)
                    CL_samples = np.trapezoid(cl_samples, x=eta_arr[:, 0], axis=1)
                    # Proposal distribution - original posterior 
                    proposal_dist = torch.distributions.Normal(loc=np.mean(CL_samples), scale=np.std(CL_samples))
                    print(np.mean(CL_samples), np.std(CL_samples))
                    # Set up sampling iterator 
                    # --- REJECTION SAMPLING INSTEAD OF IMPORTANCE SAMPLING ---
                    filtered_cp_samples = []
                    filtered_cl_samples = []
                    filtered_CL_samples = []

                    desired_samples = condition['num_samples']
                    max_attempts = 1000  # To avoid infinite loops
                    attempts = 0
                    tqdm_iterable = tn.tqdm(range(max_attempts))
                    for i in tqdm_iterable:  
                        batch_size = 2000
                        cp_samples = output_distr[-1].rsample(torch.Size([batch_size]))
                        cp_samples_ = cp_samples.reshape((batch_size, -1, test_wing.chord_resolution*2)).detach().numpy()
                        cl_samples = -np.trapezoid(cp_samples_[:, :, :test_wing.chord_resolution], x=(xhat_arr[0][:test_wing.chord_resolution]+1)/2, axis=2) + \
                                        np.trapezoid(cp_samples_[:, :, test_wing.chord_resolution:], x=(xhat_arr[0][test_wing.chord_resolution:]+1)/2, axis=2)
                        CL_samples = np.trapezoid(cl_samples, x=eta_arr[:, 0], axis=1) 
                        # Compute acceptance probabilities (normalized to max=1)
    
                        target_probs = torch.exp(target_dist.log_prob(torch.from_numpy(CL_samples))).detach().numpy()
                        proposal_prob = torch.exp(proposal_dist.log_prob(torch.from_numpy(CL_samples))).detach().numpy()
                        # max_prob = np.max(target_probs)
                        # acceptance_probs = target_probs / proposal_prob #max_prob
                        ratios = target_probs / proposal_prob
                        M = np.max(ratios)
                        acceptance_probs = ratios / M 

                        # Draw uniform random numbers for acceptance
                        random_vals = np.random.uniform(0, 1, size=batch_size)
                        accept_mask = random_vals < acceptance_probs

                        # Store accepted samples
                        filtered_cp_samples.extend(cp_samples[accept_mask])
                        filtered_cl_samples.extend(cl_samples[accept_mask])
                        filtered_CL_samples.extend(CL_samples[accept_mask])

                        if  len(filtered_CL_samples) > desired_samples:
                            break
                        
                        attempts += 1
                        tqdm_iterable.set_postfix({'Samples':len(filtered_CL_samples)})

                    # Truncate to desired number of samples
                    filtered_cp_samples = filtered_cp_samples[:desired_samples]
                    filtered_cl_samples = filtered_cl_samples[:desired_samples]
                    filtered_CL_samples = filtered_CL_samples[:desired_samples]
                    
                    # Make into np array
                    filtered_cp_samples = np.vstack(filtered_cp_samples)
                    filtered_cl_samples = np.vstack(filtered_cl_samples)
                    filtered_CL_samples = np.vstack(filtered_CL_samples).flatten()
                    
                    print(f"Accepted {len(filtered_CL_samples)} samples after {attempts} attempts.")
                    plt.hist(filtered_CL_samples, bins=20)
                    conditioned_cp_posterior = torch.distributions.multivariate_normal.MultivariateNormal(torch.mean(torch.from_numpy(filtered_cp_samples), dim=0).flatten(), 
                                                                                                          torch.cov(torch.from_numpy(filtered_cp_samples).T) + torch.eye(filtered_cp_samples.shape[1])*1e-6)
                    return conditioned_cp_posterior, filtered_CL_samples
        
        # If user wants to save 
        if save_to is not None:
            # self.__write_to_mesh(test_wing=test_wing, case_solution=output_distr, output_filename=save_to)
            self.__save_to_solution_file(test_wing=test_wing, output_filename=save_to)
        return output_distr
    
    """
    Save predictions out to a reference mesh file 
    .dat file exported from tecplot 
    """
    def __save_to_solution_file(self, test_wing, output_filename):
            write_file = []
            test_wing.mesh['header']
            delimiter = '\n'
            write_file.append(delimiter.join(test_wing.mesh['header']))
            for i in range(len(test_wing.mesh['zones_data'])):
                # append the zone header 
                write_file.append(delimiter.join(test_wing.mesh['zones_header'][i])) 
                # Convert zone data into string 
                data_lines = []
                for row in test_wing.mesh['zones_data'][i]:
                    formatted = ['{: .9E}'.format(row[0])]  # First value with leading space
                    formatted += ['{:.9E}'.format(val) for val in row[1:]]  # Rest of the row
                    line = ' '.join(formatted)
                    data_lines.append(line)
                data_str = '\n'.join(data_lines)
                write_file.append(data_str)
                write_file.append(delimiter.join(test_wing.mesh['zones_connectivity'][i])) 
            # Join all together 
            write_file = delimiter.join(write_file)
            
            with open(output_filename, "w") as f:
                f.write(write_file)
                
    def __write_to_mesh(self, test_wing, case_solution:gpytorch.distributions.MultivariateNormal, output_filename:str=None):
        # Default file name is the original file name with lwm appended
        if output_filename is None:
            output_filename = test_wing.mesh.file_directory[:-3] + 'lwm.dat'
            
        # Read in node information
        wing_nodes = test_wing.mesh['nodes'][0]
        all_nodes = test_wing.mesh['nodes'][1]
        
        # Mesh file block 1 - header w textual information
        title_str = f'TITLE:     = "{test_wing.case_name}"'
        variable_str = 'VARIABLES = "X"\n"Y"\n"Z"\n"mean C_p"\n"Confidence region"'
        zone_str = 'ZONE T="wing"'
        id_str = 'STRANDID=1, SOLUTIONTIME=2'
        mesh_str = f'Nodes={str(np.unique(wing_nodes[:,1:]).shape[0])}, Elements={str(test_wing.mesh["N_marker_elems"])}, ZONETYPE={test_wing.mesh["type_str"]}'
        pack_str = 'DATAPACKING=POINT'
        DT_str = 'DT=(SINGLE SINGLE SINGLE SINGLE, SINGLE)'
        block1 = [title_str, variable_str, zone_str, id_str, mesh_str, pack_str, DT_str]

        # block 2 - coordinates info
            # get unique node numbers
        wing_node_indices = np.sort(np.unique(wing_nodes[:,1:])) # indices of wing nodes
        wing_node_new_indices = np.arange(1, np.unique(wing_nodes[:,1:]).shape[0]+1)
        mask = np.isin(all_nodes[:, -1].astype(int), wing_node_indices)
        wing_node_coordinates = all_nodes[mask][:, :3] # coordinates of wing nodes 
            
        case_solution_mean = case_solution.mean[:,None].cpu().detach().numpy()
        case_solution_2sig = 2*np.sqrt(np.diag(case_solution.covariance_matrix.cpu().detach().numpy()))[:, None]
        block2 = np.hstack((wing_node_coordinates, case_solution_mean, case_solution_2sig)).astype(float) 

        # block 3 - nodes info
        block3 = np.zeros_like(wing_nodes[:, 1:])
        for j in range(1, wing_nodes.shape[1]):
            for i in range(wing_nodes.shape[0]):
                new_val = wing_node_new_indices[np.argwhere(wing_node_indices==wing_nodes[i, j])[0][0]]
                block3[i, j-1] = new_val
                
        # save file 
        with open(output_filename, "w") as f:
            # Write each string on a new line
            for line in block1:
                f.write(line + "\n")
            
            # Write arrays back to back
            np.savetxt(f, block2, fmt='%.6f', delimiter=' ')
            np.savetxt(f, block3, fmt='%d', delimiter=' ')
            print_with_tag(f'Prediction saved as {output_filename}', 'LWM')
    
    def __plot_directly(self, test_wing, case_solution:gpytorch.distributions.MultivariateNormal, output_filename:str=None):
        1
        
## Main DKL Model
class DKL_model(gpytorch.models.ExactGP):
        def __init__(self, train_x, train_y, likelihood, # GP parameters
                     feature_extractor, # NN parameters 
                     lam_priors:lam_prior=None, # LAM-prior parameters
                     verbose:bool=False):
            super(DKL_model, self).__init__(train_x, train_y, likelihood)
            self.feature_extractor = feature_extractor
            self.lam_priors = lam_priors
            self.spanwise_dim_idx = 58 # Index in the training data for the spanwise location, eta
            self.verbose = verbose
            self.train_x = train_x 
            print_with_tag('Intializing the model', 'LOG') if verbose else None
            # Intialize the model with or without the LAM-prior depending on user input
            if self.lam_priors is not None:
                print_with_tag('The model is using a LAM prior', 'LOG') if verbose else None
                self.mean_module = self.lam_priors.mean_module # LAM-prior prior mean 
                self.prior_cov = self.lam_priors.covar_module # LAM-prior prior covariance, this is used in the forward method 
            else: 
                print_with_tag('The model is using a constant mean prior', 'LOG') if verbose else None
                self.mean_module = gpytorch.means.ConstantMean()
                
            # Set up covariance module 
            self.covar_module = gpytorch.kernels.MaternKernel(nu=5/2, ard_num_dims=self.feature_extractor.nn_dims[-1], 
                                                                    lengthscale_prior=gpytorch.priors.NormalPrior(0.5, 1.0), 
                                                                    active_dims = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, ]) * SVH_Matern_2var_1d(rates=[15], active_dims = [14])# Covariance module  

            # Set up other DKL things  
            self.scale_to_bounds = gpytorch.utils.grid.ScaleToBounds(-1.0, 1.0) # scale the feature extractor outputs 

        def forward(self, x):
            # We're first putting our data through a deep net (feature extractor) 
            projected_x = self.feature_extractor(torch.hstack((x[:,:self.spanwise_dim_idx], x[:,self.spanwise_dim_idx+1:-1])))  # remove the spanwise loc -> feed into feature extractor
            projected_x = self.scale_to_bounds(projected_x)  # scale to nice values 
            projected_x = torch.hstack((projected_x, x[:, self.spanwise_dim_idx:self.spanwise_dim_idx+1])) # add back in the spanwise dimension
            
            # Run latent variables thru the GP model 
            mean_x = self.mean_module(x)
            if self.lam_priors is not None:
                covar_x = self.covar_module(projected_x) + self.prior_cov(x)
            elif self.lam_priors is None: 
                covar_x = self.covar_module(projected_x)
            else: 
                raise ValueError('Invalid prior type')
            return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
        
        
## Feature Extractor 
class nn_feature_extractor(torch.nn.Sequential):
    def __init__(self, train_x, nn_dims):
        self.nn_dims = nn_dims
        super(nn_feature_extractor, self).__init__()
        self.add_module('linear1', torch.nn.Linear(train_x.shape[-1]-1-1, self.nn_dims[0])) # __
        self.add_module('relu1', torch.nn.ReLU())
        self.add_module('dropout1', torch.nn.Dropout(0.2))
        self.add_module('linear2', torch.nn.Linear(self.nn_dims[0], self.nn_dims[1]))
        self.add_module('relu2', torch.nn.ReLU())
        self.add_module('dropout2', torch.nn.Dropout(0.2))
        self.add_module('linear3', torch.nn.Linear(self.nn_dims[1], self.nn_dims[2]))
        self.add_module('relu3', torch.nn.ReLU())
        self.add_module('dropout3', torch.nn.Dropout(0.2))
        self.add_module('linear4', torch.nn.Linear(self.nn_dims[2], self.nn_dims[3]))
        if len(self.nn_dims) > 4:
            self.add_module('relu4', torch.nn.ReLU())
            self.add_module('dropout4', torch.nn.Dropout(0.2))
            self.add_module('linear5', torch.nn.Linear(self.nn_dims[3], self.nn_dims[4]))
            self.add_module('relu5', torch.nn.ReLU())
            self.add_module('dropout6', torch.nn.Dropout(0.2))
            self.add_module('linear6', torch.nn.Linear(self.nn_dims[4], self.nn_dims[5]))
            
class input_wing():
    #We want the following from the CSV = ['xc', 'yb', 'surf', 'airfoil', 'alpha', 'M', 'Re', 'chord',' taper ratio', 'span', 'le sweep', 'cp', 'source']
    def __init__(self, airfoil:lam_adapt.input_data, chord_length:float, 
                 span_length:float, sweep_angle:float, taper_ratio:float, reynolds_number:float, 
                 mach_number:float, angle_of_attack:float, case_num:int, output_device:torch.device='cpu', verbose:bool=False,
                 source_file:str=None, coordinate_marker:str=None, case_name:str='WING',
                 span_resolution=20):
        # CASE NUM NEEDS TO BE REPLACED EVENTUALLY
        self.case_num = case_num
        self.case_name = case_name 
        self.verbose = verbose 
        self.output_device = output_device 
        
        self.airfoil = airfoil
        self.chord_dimensional = chord_length # dimensional chord length, root chord assumed for now 
        self.span_dimensional = span_length # dimensional span length
        self.le_sweep = sweep_angle # leading edge sweep angle in degrees, \Lambda
        self.taper_ratio = taper_ratio # taper ratio, \lambda
        self.reynolds = reynolds_number # Re (not used)
        self.mach = mach_number # Freestream mach number 
        self.alph = angle_of_attack  # angle of attack (degrees)
        self.alph_rad = np.deg2rad(angle_of_attack) # angle of attack (radians)
        
        # Derived quantities 
        self.semispan_ratio = self.span_dimensional / self.chord_dimensional # semi-span to c_root ratio (s/c_root)
        self.xhat = None # xhat and zhat, should be filled out when running assemble_tensor() function
        self.zhat = None # 
        self.eta = None # 
        self.coordinates = None # this ?  
        self.case_title = None
        self.input_tensor = None
        self.mesh = {
                'file_directory': None,
                'coordinate_marker': None,
                'header': None,
                'zones_header': [],
                'zones_data': [],  
                'zones_connectivity': [],
                'type_str': None
            } 
        self.coordinate_marker = coordinate_marker
        
        # tensor quantity 
        self.span_resolution = span_resolution # Number of unique spanwise locations, either an int or an array of pre-determine locations
        self.chord_resolution = 200 # per surface 
        # Use mesh source file if given, make default tensor if not 
        if source_file is not None: 
            assert coordinate_marker is not None, "A list of marker string that defines the coordinates is required"
            self.mesh['file_directory'] = source_file
            self.mesh['coordinate_marker'] = coordinate_marker
            self.input_type = 'mesh' 
        else: 
            self.input_type = 'default' 
            
    def assemble_tensor(self):
        if self.input_type == 'default':
            out_tensor = []
            # Initialize chordwise coordinates 
            xhat_init = np.hstack((np.linspace(-1.0, 1.0, self.chord_resolution), np.linspace(-1.0, 1.0, self.chord_resolution)))
            zhat_init = np.ones_like(xhat_init)
            zhat_init[:self.chord_resolution] = np.sin(np.arccos(xhat_init[:self.chord_resolution]))
            zhat_init[0] = 0.0
            zhat_init[self.chord_resolution:] = -np.sin(np.arccos(xhat_init[self.chord_resolution:]))
            zhat_init[self.chord_resolution] = 0.0
            # Set up eta 
            if isinstance(self.span_resolution, int): # if int, do uniform spacing
                self.eta = np.repeat(np.linspace(0.0, 1.0, self.span_resolution), self.chord_resolution*2) 
            elif isinstance(self.span_resolution, np.ndarray): 
                self.eta = np.repeat(self.span_resolution, self.chord_resolution*2)  
            else: 
                raise ValueError("Invalid spanwise resolution")
            # Tile out full coordinates
            xhat = np.tile(xhat_init, self.span_resolution) if isinstance(self.span_resolution, int) else np.tile(xhat_init, self.span_resolution.shape[0])
            zhat = np.tile(zhat_init, self.span_resolution) if isinstance(self.span_resolution, int) else np.tile(zhat_init, self.span_resolution.shape[0]) 
            # Other stuff
            alpha_array = np.ones_like(xhat) * self.alph_rad
            mach_array = np.ones_like(xhat) * self.mach
            taper_array = np.ones_like(xhat) * self.taper_ratio
            le_sweep_array = np.ones_like(xhat) * self.le_sweep
            semispan_array = np.ones_like(xhat) * self.semispan_ratio
            case_array = np.ones_like(xhat) * self.case_num
            self.xhat = xhat 
            self.zhat = zhat 
            # Get airfoil coordinates
            out_tensor_pt1 = torch.from_numpy(np.tile(self.airfoil.retrieve_airfoil_input(true_values=True), (xhat.shape[0], 1))) 
            print('Reminder: temp fix remove')
            out_tensor_pt1 = torch.hstack((torch.flip(out_tensor_pt1[:, :28], dims=[1]), out_tensor_pt1[:, 28:]))# FIX 
            # Assemble tensor  
            out_tensor_pt2 = torch.from_numpy(np.hstack((self.xhat[:, None], self.zhat[:, None], self.eta[:, None], 
                                alpha_array[:, None], mach_array[:, None], taper_array[:, None], 
                                le_sweep_array[:, None], semispan_array[:, None], case_array[:, None])))
            if len(out_tensor) == 0: # only append if this is the first time assemble_tensor has been run
                out_tensor.append(torch.hstack((out_tensor_pt1, out_tensor_pt2)).to(self.output_device) )
            self.input_tensor = out_tensor
              
        elif self.input_type == 'mesh':
            out_tensor = self._assemble_from_solution()
            print_with_tag('Input tensor assembled from source mesh file', 'WING') if self.verbose else None
        else: 
            print('Implement more here ')
        return out_tensor
    
    
    def _assemble_from_solution(self):
        out_tensor = []
        assert self.mesh['file_directory'] is not None, "You must provide a source solutions file"
        file_path = self.mesh['file_directory']
        coordinate_names = [f'"{self.coordinate_marker[0]}"', f'"{self.coordinate_marker[1]}"',  f'"{self.coordinate_marker[2]}"']

        file_title = "dummy file"  # Replace this 

        output_header = [] 
        with open(file_path, 'r') as file:
            file_content = file.read()
            # Figure out how many zones there are 
            zone_idx = [match.start() for match in re.finditer('ZONE T', file_content)]
            num_zones = len(zone_idx)
            zone_idx = zone_idx + [len(file_content)]
            
            # Extract Meta-information from file header 
            print('[LOG] Reading file header...')
            current_line = file_content
            line_start_idx = 0 # index where the line starts 
            entry_str = ''
            while entry_str != 'ZONE T': # repeat until zones 
                # Isolate line 
                current_line = file_content[line_start_idx:]
                line_end_idx = [m.start() for m in re.finditer(r'\n ?[A-Z]', current_line)][0] #current_line.find('\n') # index where the line ends 
                current_line = current_line[:line_end_idx] # current line 

                # Identify what its about  
                if '=' in current_line:
                    entry_str = current_line.split('=')[0].strip() 
                    
                # Perform relevant action depending on line 
                match entry_str: 
                    case "TITLE": # Assign new title 
                        new_entry  = "TITLE = " + file_title
                        output_header.append(new_entry) 
                    case "VARIABLES": # Identify variable names 
                        list_variables = current_line.split('=')[1].split('\n')
                        list_variables = [i.strip() for i in list_variables]
                        coordinates_idx = [list_variables.index(j) for j in coordinate_names]

                        new_entry = 'VARIABLES = "x/c"\n"y/b"\n"z/c"\n"Pressure coefficient"\n"Confidence interval"'
                        output_header.append(new_entry)  
                
                # Move onto next line 
                line_start_idx += line_end_idx + 1 # +2 due to \n
            
            # Iterate through zones and calculate the model predicted values 
            for i in range(num_zones): # [:1]
                print(f'[LOG] Reading Zone {i+1}...')
                zone_content = file_content[zone_idx[i]:zone_idx[i+1]] # The main "zone data"
                zone_header = []
                current_line = zone_content
                line_start_idx = 0 # index where the line starts 
                line_end_idx = 0
                current_line = zone_content[line_start_idx:]
                entry_str = ''
                while len([m.start() for m in re.finditer(r'\n ?[A-Z]', current_line)]) != 0: # repeat until the data matrix 
                    # Isolate line  
                    line_end_idx = [m.start() for m in re.finditer(r'\n ?[A-Z]', current_line)][0] # index where the line ends 
                    current_line = current_line[:line_end_idx] # current line 

                    # Identify what its about  
                    if '=' in current_line:
                        entry_str = current_line.split('=')[0].strip() 
                    # Perform relevant action depending on line 
                    match entry_str: 
                        case "ZONE T": 
                            # Add the new_entry
                            zone_header.append(current_line) 
                        case "STRANDID": # STRANDID=1, SOLUTIONTIME=0
                            # Nothing needs to change here, append the same thing 
                            zone_header.append(current_line) 
                        case "Nodes": # Nodes=1579, Elements=919, ZONETYPE=FETriangle
                            split_line = current_line.split(',')
                            nodes, elements, zonetype = split_line[0], split_line[1], split_line[2] 
                            num_nodes = int(nodes.split('=')[-1]) # number of nodes in mesh ## 
                            num_elems = int(elements.split('=')[-1]) # number of elements in mesh ## 
                            zonetype = zonetype.split('=')[-1]  # zone type  
                            
                            # Nothing needs to change here, append the same thing 
                            zone_header.append(current_line) 
                        case "DATAPACKING": # DATAPACKING=POINT
                            # Nothing needs to change here, append the same thing 
                            zone_header.append(current_line) 
                        case "AUXDATA Time": 
                            # Nothing needs to change here, append the same thing 
                            zone_header.append(current_line) 
                    
                    # Move onto next line 
                    line_start_idx += line_end_idx + 1 # 
                    current_line = zone_content[line_start_idx:]
                    
                # Add DT entry       
                new_entry = " DT=(SINGLE SINGLE SINGLE SINGLE SINGLE )" # requires 5, x, y, z, Cp, 2sig range
                zone_header.append(new_entry) 
                
                mesh_information = zone_content[zone_content.find(')\n')+2:] 
                output_locations = mesh_information.split('\n')[:num_nodes] # Locations at which predictions will be generated
                output_connectivity = mesh_information.split('\n')[num_nodes:-1] # Information regarding connectivity of nodes 

                # Convert output locations to LWM input 
                output_locations_array = np.array([[float(num) for num in line.split()] for line in output_locations])
                wing_node_xc = output_locations_array[:, coordinates_idx[0]]
                wing_node_yb = output_locations_array[:, coordinates_idx[1]]
                wing_node_zc = output_locations_array[:, coordinates_idx[2]]
                
                xc = wing_node_xc / self.chord_dimensional # This needs to be converted to hats 
                xc[xc>1.0] = 1.0 # cut off
                zc = wing_node_zc / self.chord_dimensional # This needs to be converted to hats
                xhat = xc*2 - 1 # transform to xhat 
                zhat = np.zeros_like(xhat)
                if len(zhat[zc>0.0]) > 0:
                    zhat[zc>0.0] = np.sin(np.arccos(xhat[zc>0.0])) # transform to zhat for upper
                if len(zhat[zc<=0.0]) > 0:
                    zhat[zc<=0.0] = -np.sin(np.arccos(xhat[zc<=0.0])) # transform to zhat for lower
                # self.zhat = zhat 
                # self.xhat = xhat 
                self.eta = wing_node_yb / 3.35#self.span_dimensional # get eta, non-dimensional spanwise
                # generate other parts of the input tensor
                alpha_array = np.ones_like(xhat) * self.alph_rad
                mach_array = np.ones_like(xhat) * self.mach
                taper_array = np.ones_like(xhat) * self.taper_ratio
                le_sweep_array = np.ones_like(xhat) * self.le_sweep
                semispan_array = np.ones_like(xhat) * self.semispan_ratio
                case_array = np.ones_like(xhat) * self.case_num
                self.xhat = xhat 
                self.zhat = zhat
                # Get airfoil coordinates
                out_tensor_pt1 = torch.from_numpy(np.tile(self.airfoil.retrieve_airfoil_input(true_values=True), (xhat.shape[0], 1)))  
                print('temp fix')
                out_tensor_pt1 = torch.hstack((torch.flip(out_tensor_pt1[:, :28], dims=[1]), out_tensor_pt1[:, 28:]))# FIX f
                # Assemble tensor  
                out_tensor_pt2 = torch.from_numpy(np.hstack((self.xhat[:, None], self.zhat[:, None], self.eta[:, None], 
                                    alpha_array[:, None], mach_array[:, None], taper_array[:, None], 
                                    le_sweep_array[:, None], semispan_array[:, None], case_array[:, None])))
                out_tensor.append(torch.hstack((out_tensor_pt1, out_tensor_pt2)).to(self.output_device))
                # Store useful information to dict 
                self.mesh['zones_connectivity'].append(output_connectivity)
                self.mesh['zones_header'].append(zone_header)
                self.mesh['zones_data'].append(np.hstack((wing_node_xc[:, None], wing_node_yb[:, None], wing_node_zc[:, None], np.zeros_like(wing_node_zc[:, None]), np.zeros_like(wing_node_zc[:, None])))) 
                # Cp and 2sig are dummy data that must be replaced during predictions
            # Store more useful information to dict  
            self.mesh['header'] = output_header
            
            
            self.input_tensor = out_tensor 
            self.input_type = 'mesh'
            return out_tensor 

## Train

In [10]:
from torch.optim.lr_scheduler import StepLR
# initialize LAM-priors
lam_priors = lam_prior(priors=[mean_list, cov_list, xhat_locs], output_device=output_device, case_guides=all_cases_unique, verbose=True) 
likelihood = gpytorch.likelihoods.FixedNoiseGaussianLikelihood(torch.ones_like(train_y)*0.01 * scaler_scale**2, learn_additional_noise=False, noise_prior=gpytorch.priors.NormalPrior(0.1*scaler_scale, 0.01)) 
feature_extractor = nn_feature_extractor(train_x, nn_dims = [1000, 1000, 1000, 500, 50, 14]) 
model = DKL_model(train_x, train_y, likelihood, 
               feature_extractor=feature_extractor, 
               lam_priors=lam_priors,  # 
               verbose=True) 

# Define Optimizer  
lr = 1e-3 
optimizer = torch.optim.Adam([{'params': model.parameters()},], lr=lr)  
scheduler = StepLR(optimizer, step_size=2000, gamma=0.1)

# Create LWM instance 
lwm = large_wing_model(model, likelihood, optimizer, scheduler, output_device=output_device, verbose=True)

[LOG] Computing cache for training data mean.
[LOG] Computing cache for training data covariance matrix.
[LOG] Intializing the model
[LOG] The model is using a LAM prior


In [11]:
lwm.train_model(10, save_name='testrun_', save_interval=100)

  0%|          | 0/10 [00:00<?, ?it/s]

## Example prediction

In [20]:
# First create an instance of the input_wing class
chordwise_var_name = 'CoordinateX'
spanwise_var_name = 'CoordinateY'
chordnormal_var_name = 'CoordinateZ'

test_airfoil = lam_adapt.input_data('NACA 0015', 4.0, 0.2,  output_device='cpu', model_version='v2', ) # the AoA and M does not matter here
pred_wing = input_wing(test_airfoil, chord_length=1, span_length=3.3, sweep_angle=0, taper_ratio=1, 
                       reynolds_number=_, mach_number=0.2, angle_of_attack=4.0, case_num=target_case, output_device=output_device,
                       source_file='example_mesh.dat', coordinate_marker=[chordwise_var_name, spanwise_var_name, chordnormal_var_name],
                       verbose=True) 

In [22]:
# Generate prediction
prediction = lwm.predict(pred_wing, save_to=None, 
                         target_weights=['./weights/weights_1',
                                         './weights/weights_2'], )    
# condition = {'cl': [0.353, 0.001], 'num_samples':1000} add this in for posterior conditioning

[LOG] Reading file header...
[LOG] Reading Zone 1...
temp fix
[LOG] Reading Zone 2...
temp fix
[WING] Input tensor assembled from source mesh file
[LWM] Generating prediction...
[LWM] Loading weight: ./weights/weights_1
[LWM] Loading weight: ./weights/weights_2
runtime: 1.9324729442596436
[LWM] Generating prediction...
[LWM] Loading weight: ./weights/weights_1
[LWM] Loading weight: ./weights/weights_2
runtime: 22.930493593215942
