# Impact of Offset
In this notebook we want to assess the impact of varying each of the offsets, to work out which ones contribute most highly to increased error in the case of miscalibration. The miscalibration will be added to the **raw PV value**. 

In [None]:
from utils import load_lcls
import torch
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_percentage_error
from lume_model.torch import LUMEModule
import json
from copy import deepcopy
from torch.nn import MSELoss
from botorch.models.transforms.input import InputTransform

In [None]:
with open('configs/pv_info.json', 'r') as f:
    pv_info = json.load(f)
    f.close()

pv_info

In [None]:
nn_model = load_lcls('configs/lcls_variables.yml', 'configs/normalization.json', 'torch_model.pt')
output_transformer = deepcopy(nn_model.output_transformers[0])
input_transformer = deepcopy(nn_model._input_transformers[0])
# we remove the output transformation so we can make comparisons between the outcomes using MSE
nn_model._output_transformers = []

In [None]:
x_test = torch.from_numpy(np.load('data/x_raw_small.npy', allow_pickle=True).astype('float64'))
y_test = torch.from_numpy(np.load('data/y_raw_small.npy', allow_pickle=True).astype('float64'))
y_test = output_transformer(y_test)
print(x_test.shape)
print(y_test.shape)

In [None]:
conversions = torch.tensor([pv_info['pv_to_sim_factor'][pv_info['sim_name_to_pv_name'][feature_name]] for feature_name in nn_model.features])

class PVtoSimFactor(InputTransform, torch.nn.Module):
    def __init__(self, conversion: torch.Tensor) -> None:
        super().__init__()
        self._conversion = conversion
        self.transform_on_train = True
        self.transform_on_eval = True
        self.transform_on_fantasize = False

    def transform(self, x):
        return x * self._conversion

    def untransform(self, x):
        return x / self._conversion
    
pv_to_sim = PVtoSimFactor(conversions)
x_test_pv = pv_to_sim.untransform(x_test)
x_test_transformed = pv_to_sim.transform(x_test_pv)

# verify that the transformations work as expected
print(x_test)
print(x_test_pv)
print(x_test_transformed)


In [None]:
nn_model._input_transformers.insert(0, pv_to_sim)
print(nn_model.input_transformers)
base_model = LUMEModule(nn_model, nn_model.features, nn_model.outputs)

FOr each input, we will add a certain degree of miscalibration offset (only offset, no scale to begin with) and study the result on the model's prediction. 

In [None]:
offset_degrees = torch.linspace(-0.1, 0.1, 9)  # vary from -10% of mean to +10% of mean

In [None]:
constants = [feature_name for feature_name, var in nn_model.input_variables.items() if var.value_range[0] == var.value_range[1]]
print(constants)

In [None]:
mse_loss = MSELoss()

fig, ax = plt.subplots(2,1, sharex='all', figsize=(12,8))

fig2, ax2 = plt.subplots(4,4, figsize=(12,10))
ax2 = ax2.ravel()

for i, feature_name in enumerate(nn_model.features):
    errors = []
    offset_tensor = torch.zeros_like(x_test)
    for offset_degree in offset_degrees:
        # add the offset to the individual input data
        offset_value = x_test_pv[:,i].mean() * offset_degree
        offset_tensor[:,i] = offset_value
        x_test_offset_input = x_test_pv + offset_tensor
        
        if offset_degree in [-0.1, 0, 0.1]:
            ax2[i].hist(x_test_offset_input[:,i], bins=20, label=f'{offset_degree:.2f}', alpha=0.75)

        # pass the input through the model and compare with what the result should be
        true_result = base_model(x_test_pv)
        offset_result = base_model(x_test_offset_input)

        mse = mse_loss(true_result, offset_result)
        errors.append(mse.item())
    if feature_name in constants:
        linestyle= 'dashed'
        print(feature_name, errors)
        ax[0].plot(offset_degrees, errors, linestyle=linestyle, label=feature_name)
    else:
        linestyle= 'solid'
        ax[1].plot(offset_degrees, errors, linestyle=linestyle, label=feature_name)

ax2[-1].legend()
ax[0].set_ylim(0.0, 100)
ax[1].set_ylim(0.0, 0.002)
ax[0].legend()
ax[1].legend(loc='upper right')
fig2.tight_layout()
fig.tight_layout()
plt.show()