In [None]:
import sys
import os
sys.path.append(os.path.abspath('../..'))
import torch
from data import PVWeatherGenerator, SystemLoader


In [None]:
# data parameters
DAY_INIT = 0
DAY_MIN = 8
DAY_MAX = 16
N_DAYS_FOLD = 7
MINUTE_INTERVAL = 5
DAILY_DATA_POINTS = (DAY_MAX - DAY_MIN) * 60 // MINUTE_INTERVAL
N_SYSTEMS = 6
CIRCLE_COORDS = (53.28, -3.05)
RADIUS = 0.25

In [None]:
generator = PVWeatherGenerator(
    coords=CIRCLE_COORDS,
    radius=RADIUS,
    day_init=DAY_INIT,
    n_systems=N_SYSTEMS,
    n_days=365,
    minute_interval=MINUTE_INTERVAL,
)
df = generator.df

In [None]:
individual_interval = int(DAILY_DATA_POINTS * N_DAYS_FOLD)
loader = SystemLoader(df, train_interval=individual_interval)

In [None]:
# from pv_plot import plot_grid
# CIRCLE_COORDS = (53.28, -3.05)
# RADIUS = 0.25
# plot_grid(df, CIRCLE_COORDS, RADIUS)

In [None]:
x, _,_,_,_,_ = next(iter(loader))
d = x.shape[1]
print(d)

In [None]:
from kernels import get_mean_covar_weather

num_tasks = N_SYSTEMS
num_latents = N_SYSTEMS // 2 + 1
mean, covar = get_mean_covar_weather(num_latents=num_latents,
                                      d=d,
                                      combine='product')

interval = 6

In [None]:
from matplotlib import pyplot as plt
import numpy as np
from models import HadamardGPModel
from likelihoods import HadamardBetaLikelihood

for X_tr, Y_tr, X_te, Y_te, T_tr, T_te in loader:
  
    mean, covar = get_mean_covar_weather(num_latents, d-1, combine='product', weather_kernel='matern')
    model = HadamardGPModel(
        X=X_tr[::interval],
        y=Y_tr[::interval],
        mean_module=mean,
        covar_module=covar,
        likelihood=HadamardBetaLikelihood(num_tasks=num_tasks, scale=20),
        num_tasks=num_tasks,
        num_latents=num_latents,
        learn_inducing_locations=True,
        inducing_proportion=1.0,
        jitter=1e-6,
    )
    model.set_cpu()
    model.fit(n_iter=100, 
          lr=0.2, 
          task_indices=T_tr[::interval],
          verbose=True)
    model.predict(X_tr, T_tr)
    model.predict(X_te, T_te)

    # fig, ax = plt.subplots(num_tasks // 2, 2, figsize=(30, 5 * (num_tasks)), sharex=True, sharey=True)
    #  ax = ax.flatten()
    # plt.rcParams['font.serif'] = ['Times New Roman']
    

    # for i in range(num_tasks):
    #     _, y_tr, _, y_te = loader.train_test_split_individual(i)
    #     n_tr, n_te = y_tr.shape[0], y_te.shape[0]
    #     t = torch.linspace(0, int(N_DAYS_FOLD * DAILY_DATA_POINTS), n_tr + n_te)
    #     t_tr, t_te = t[:n_tr], t[n_tr:]

    #     y_pred_tr, lower_tr, upper_tr = model.get_i_prediction(i, T_tr)
    #     ax[i].scatter(t_tr, y_tr, color='black', marker='x', label='Observed Data', alpha=0.5)
    #     ax[i].scatter(t_te, y_te, color='black', marker='x', alpha=0.5)
    #     ax[i].plot(t_tr, y_pred_tr, color='blue')
    #     ax[i].fill_between(t_tr, lower_tr, upper_tr, color='blue', alpha=0.1)
    
    # model.predict(X_te, T_te)
    # pred_dist = model.predict_dist()
  
    # for i in range(num_tasks):
    #     x_tr, y_tr, x_te, y_te = loader.train_test_split_individual(i)
    #     n_tr, n_te = y_tr.shape[0], y_te.shape[0]
    #     t = torch.linspace(0, int(N_DAYS_FOLD * DAILY_DATA_POINTS), n_tr + n_te)
    #     t_tr, t_te = t[:n_tr], t[n_tr:]

    #     y_pred_te, lower_te, upper_te = model.get_i_prediction(i, T_te)
#         if i == 0:
#             ax[i].plot(t_te, y_pred_te, color='red')
#             ax[i].fill_between(t_te, lower_te, upper_te, color='red', alpha=0.1)
#         else:
#             ax[i].plot(t_te, y_pred_te, color='red')
#             ax[i].fill_between(t_te, lower_te, upper_te, color='red', alpha=0.1)
#         ax[i].axvline(t_tr.max(), color='black', linestyle='--', label='Train/Test Split')
#         ax[i].set_ylim(-0.01, 1.01)
#         ax[i].set_title(f'Task {i+1}', fontsize=30)
       
#         # set y label for left column
#         if i % 2 == 0:
#             ax[i].set_ylabel('PV Output (0-1 Scale)', fontsize=30)
#             # add y ticks
#             y_ticks = [0, 0.0, 0.25, 0.5, 0.75, 1.0]
#             ax[i].set_yticklabels(y_ticks, fontsize=25)
#         # set x label for bottom row
#         if i >= num_tasks - 2:
#             ax[i].set_xlabel('Time Steps (5 Minute Intervals)', fontsize=30)
#             # add x ticks
#             ax[i].set_xticklabels([0, 0, 100, 200, 300, 400, 500, 600], fontsize=25)
        
#         if i == 0:
#             ax[i].legend(fontsize=30)
      
#     for i in range(num_tasks, len(ax)):
#         ax[i].axis('off')
#     break

# plt.tight_layout()
# plt.show()


In [None]:
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

# Sample data (replace with your actual data)
_latent_factor = np.random.rand(4, 10, 10)
_latent_covar = np.random.rand(4, 10, 10)
_covar = np.random.rand(10, 10)

plt.rcParams['font.family'] = 'Arial'
            
# Create a 2x5 grid of subplots using gridspec
fig = plt.figure(figsize=(18, 8))
gs = gridspec.GridSpec(2, 5)

# Create subplots for the latent factors and covariances
for i in range(4):
    ax = fig.add_subplot(gs[0, i])
    ax.imshow(_latent_factor[i, :, :], cmap='viridis')
    ax.set_title(r"$\mathbf{a}_{%d}$" % (i+1) + r"$\mathbf{a}_{%d}^T$" % (i+1), fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])

    ax = fig.add_subplot(gs[1, i])
    ax.imshow(_latent_covar[i, :, :], cmap='viridis')
    ax.set_title(r"$k_{}(\mathbf{{X}}, \mathbf{{X}})$".format(i+1), fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])

# Create subplot for the _covar plot
ax2 = fig.add_subplot(gs[:, 4])
ax2.imshow(_covar, cmap='viridis', aspect='equal')  # Set aspect to 'equal'
ax2.set_title(r"$\mathbf{A}\mathbf{A}^T \odot K(\mathbf{X}, \mathbf{X})$", fontsize=20)
ax2.set_xticks([])
ax2.set_yticks([])

# Adjust the layout to make room for titles and labels
plt.tight_layout()

# Manually calculate the top and bottom positions to align with the rows
top = 1.0  # Top of the figure
bottom = 0.0  # Bottom of row 2
left, width = 0.82, 0.35  # Adjust as needed
ax2.set_position([left, bottom, width, top - bottom])

# Show the plot
plt.show()

