# Prediction demo

In [14]:
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent))

import numpy as np
import matplotlib.pyplot as plt
from src.predictors.rain_predictor import RainPredictor
import torch

### Load and split data

In [24]:
tensor = np.load("../data/sample_ir069_9f.npz")['tensor']
input_data = tensor[:-1,:,:,:]
output_ground_truth = tensor[-1,:,:,:]

### Visualize input data

In [None]:
channel = 0

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.ravel()

for i in range(input_data.shape[0]):
    axes[i].imshow(input_data[i, channel], cmap="gray")
    axes[i].set_title(f"Frame {i}")
    axes[i].axis("off")

plt.tight_layout()
plt.show()
plt.show()

### Predict frame

In [None]:
model_path = "../outputs/hydra_tests_sample/model/model_sample.ckpt"
model = RainPredictor.load_from_checkpoint(model_path, map_location="cuda")
model = model.to("cuda").eval()

X = torch.from_numpy(input_data).to(device="cuda")
X = X.unsqueeze(0)

with torch.no_grad():
    output = model(X)

output = output.to(device="cpu")

fig, axes = plt.subplots(1, 2, figsize=(10, 8))
axes = axes.ravel()
axes[0].imshow(output_ground_truth[0], cmap="gray")
axes[0].set_title("Ground Truth")
axes[0].axis("off")
axes[1].imshow(output[0, channel], cmap="gray")
axes[1].set_title("Prediction")
axes[1].axis("off")
plt.tight_layout()
plt.show()
plt.show()