In [None]:
##Multitask model


#Import libraries
import numpy as np
import torch
import gpytorch
import pandas as pd 
import matplotlib.pyplot as plt
import os
from gpytorch.constraints import Positive
from scipy.ndimage import generic_filter
from gpytorch.constraints import GreaterThan
torch.set_default_dtype(torch.float64)

#Load in data

df = pd.read_csv('C:/Users/Tom/OneDrive - Georgia Institute of Technology/Desktop/CEREAL Repo/demo_outboard_CEREAL_edited.csv') 
x = torch.tensor(df[['r','z','azim']].values, dtype=torch.float32)
y = torch.tensor(df[['Vmag']].values, dtype=torch.float32).flatten()
Vr = torch.tensor(df[['Vr']].values, dtype=torch.float32).flatten()
Vz = torch.tensor(df[['Vz']].values, dtype=torch.float32).flatten()
y_stand = (y - y.mean())/y.std()  #Standardize output
mask = ~torch.isnan(y)
x = x[mask]
y = y[mask]
Vr = Vr[mask]
Vz = Vz[mask]
#y_stand = y_stand[mask]
print(y)
#Define x and y coordinates, should y being velo magnitude?

#Define fourier kernel
train_x = torch.tensor(x)
train_y = torch.stack([Vr,Vz],dim=-1)
class FourierKernel(gpytorch.kernels.Kernel):
    is_stationary = True # Required for kronecker
    def __init__(self, harmonics, active_dims = None):
        # initialize kernel 
        super().__init__(has_lengthscale=True, active_dims = active_dims)
        n_lambdas = 2*len(harmonics)+1

        # Define hyperparameter, lambda 
            ###Change 1- Initialize lambda at 0.5, rather than 0
        self.register_parameter(name="raw_lambdas", parameter=torch.nn.Parameter(torch.ones(n_lambdas)*0.5)) #Initialize to 0.5 
        #self.register_constraint("raw_lambdas", Positive()) 
        self.register_constraint("raw_lambdas", GreaterThan(1e-6))
        
        self.HARMONICS = harmonics
        
        # Define prior
            ##Change 2- Set prior to multivariate normal prior centered at 0.5 with small variance
        self.register_prior(
                "lambdas_prior",
                gpytorch.priors.MultivariateNormalPrior(torch.ones(n_lambdas)*0.5, torch.eye(n_lambdas)*0.1), # gpytorch.priors.NormalPrior(1.5, 0.5)
                lambda m: m.lambdas,
                lambda m, v : m._set_lambdas(v),)
    
    # Set fx related to hyperparams
    @property
    def lambdas(self):
        # when accessing the parameter, apply the constraint transform
        return self.raw_lambdas_constraint.transform(self.raw_lambdas)
    @lambdas.setter
    def lambdas(self, values):
        return self._set_lambdas(values)
    def _set_lambdas(self, values):
        if not torch.is_tensor(values):
            values = torch.as_tensor(values).to(self.lambdas)
        # when setting the paramater, transform the actual value to a raw one by applying the inverse transform
        self.initialize(raw_lambdas=self.lambdas_constraint.inverse_transform(values))

    def forward(self, x1, x2, diag=False, **params):
        Fa = self.fourier_matrix_fast(x1[:,0], self.HARMONICS) # in Hz
        Fb = self.fourier_matrix_fast(x2[:,0], self.HARMONICS) # in Hz
        if use_gpu:
            k = Fa @ torch.diag(self.lambdas) @ Fb.T 
        else: 
            k = Fa @ torch.diag(self.lambdas) @ Fb.T 
        if diag==False:
            return k
        else: 
            return torch.diag(k)

    def fourier_matrix_fast(self, circumferential_locations, harmonics):
        num_harmonics = len(harmonics)
        interweaved_idx = torch.tensor(list(zip(range(num_harmonics), range(num_harmonics, 2*num_harmonics)))) 
        
        harmonics = torch.tensor(harmonics, dtype=torch.float32
                                 )
        X1 = torch.sin(torch.outer(harmonics, circumferential_locations))
        X2 = torch.cos(torch.outer(harmonics, circumferential_locations))
        XX = torch.cat((X1.T, X2.T), dim=1)
 
        XX = XX[:, interweaved_idx.flatten()]
        # print('6')
        ones_column = torch.ones((circumferential_locations.shape[0], 1), )##device=circumferential_locations.device
        # print('7')
        if use_gpu:
            XX = torch.cat((ones_column.cuda(), XX), dim=1)
            
        else: 
            XX = torch.cat((ones_column, XX), dim=1)
        return XX
    

# THIS IS PSEUDO CODE
# this assumes training data will be N x 3 (r, z, psi)

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self,train_x,train_y,likelihood):
        super(MultitaskGPModel,self).__init__(train_x,train_y,likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(),num_tasks=2
        )
        self.covar_module = gpytorch.kernels.MultitaskKernel(gpytorch.kernels.MaternKernel(active_dims=[0,1],
                                                                                          ard_num_dims=2,
                                                                                          nu=5/2,
                                                                                          lengthscale_prior=gpytorch.priors.HalfNormalPrior(0.5))*FourierKernel(harmonics=[1,2,3,4,5,6,7,8,9,10],active_dims=[2]),
                                                                                          num_tasks=2,rank=1
                                                            )
        
    def forward(self,x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x,covar_x)



likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=2)
likelihood.noise = 0.001
likelihood.register_prior("noise_prior",gpytorch.priors.GammaPrior(1.5,10),"noise")
model = MultitaskGPModel(train_x, train_y, likelihood) #Call on GP model
model.train()
likelihood.train() # For optimizer setup 
use_gpu=False
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Includes GaussianLikelihood parameters
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model) #Loss function

training_iter = 250
for i in range(training_iter):
    optimizer.zero_grad()
    output = model(x) 
    loss = -mll(output, train_y)
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
        print(f"Iter {i+1}/{training_iter} - Loss: {loss.item():.3f}")

model.eval()
likelihood.eval()


r_grid = np.linspace(df['r'].min(), df['r'].max(), 50)
z_grid = np.linspace(df['z'].min(), df['z'].max(), 50)
R, Z = np.meshgrid(r_grid, z_grid)
#Az = np.zeros_like(R)  # Fix azimuthal angle at 0 for 2D slice
Az = np.zeros_like(R)  # Fix azimuthal angle at 0 for 2D slice
print(Az)
x_test = torch.tensor(np.vstack([R.ravel(), Z.ravel(), Az.ravel()]), dtype=torch.float32).T
print(np.shape(x_test))
#with torch.no_grad(), gpytorch.settings.fast_pred_var():
#    pred = likelihood(model(x_test))
#    Vr_pred = pred.mean[:,0].numpy().reshape(R.shape)
#    Vz_pred = pred.mean[:,1].numpy().reshape(R.shape)
#    Vr_var = pred.variance[:,0].numpy().reshape(R.shape)
#    Vz_var = pred.variance[:,1].numpy().reshape(R.shape)

#Azimuth sweep to plot predictions
#azim_vec = np.linspace(df['azim'].min(), df['azim'].max(), 20)
#fig,axes = plt.subplots(2,4,figsize=(15,13),constrained_layout = True)
#for ax, az in zip(axes.ravel(), azim_vec):
#    Az = np.full_like(R, az)
#    x_test = torch.tensor(np.vstack([R.ravel(), Z.ravel(), Az.ravel()]), dtype=torch.float32).T
#    with torch.no_grad(), gpytorch.settings.fast_pred_var():
#        pred = likelihood(model(x_test))
#        Vr_pred = pred.mean[:,0].numpy().reshape(R.shape)
#        Vz_pred = pred.mean[:,1].numpy().reshape(R.shape)
#        Vr_var = pred.variance[:,0].numpy().reshape(R.shape)
#        Vz_var = pred.variance[:,1].numpy().reshape(R.shape)
#    q = ax.quiver(R, Z, Vr_pred, Vz_pred,cmap='viridis')
#    ax.set_title(f"azim = {az:.2f} rad")
#    ax.set_xlabel('R')
#    ax.set_ylabel('Z')
#fig.colorbar(cp, ax=axes.ravel().tolist(), label='Predicted vmag (normalized mean)')
#plt.suptitle('GP Predictions of Vr,Vz at Various Azimuthal Angles')
#plt.show()