In [None]:
def get_fantasy_predictions(model, train_x, train_y, new_x, new_y, device='cuda'):
    """
    Use get_fantasy_model to incorporate new data and make predictions
    
    Args:
        model: Trained GPR model
        train_x: Original training features
        train_y: Original training targets
        new_x: New data features to incorporate
        new_y: New data targets to incorporate
        device: Computing device
    """
    model.eval()
    likelihood = model.likelihood
    likelihood.eval()
    
    with torch.no_grad():
        # Get predictions for new points using original model
        orig_pred = likelihood(model(new_x))
        
        # Create fantasy model with new data
        fantasy_model = model.get_fantasy_model(new_x, new_y)
        
        # Make predictions using fantasy model
        fantasy_pred = likelihood(fantasy_model(new_x))
        
        # Get confidence intervals
        orig_lower, orig_upper = orig_pred.confidence_region()
        fantasy_lower, fantasy_upper = fantasy_pred.confidence_region()
        
        # Plot comparison
        plt.figure(figsize=(15, 7))
        
        # Convert to numpy
        x_np = new_x.cpu().numpy()
        y_np = new_y.cpu().numpy()
        orig_mean = orig_pred.mean.cpu().numpy()
        fantasy_mean = fantasy_pred.mean.cpu().numpy()
        orig_lower = orig_lower.cpu().numpy()
        orig_upper = orig_upper.cpu().numpy()
        fantasy_lower = fantasy_lower.cpu().numpy()
        fantasy_upper = fantasy_upper.cpu().numpy()
        
        # Plot original predictions
        plt.plot(x_np, orig_mean, 'b-', label='Original Prediction')
        plt.fill_between(x_np.flatten(), 
                        orig_lower, orig_upper,
                        alpha=0.2, color='blue')
        
        # Plot fantasy predictions
        plt.plot(x_np, fantasy_mean, 'r-', label='Fantasy Prediction')
        plt.fill_between(x_np.flatten(),
                        fantasy_lower, fantasy_upper,
                        alpha=0.2, color='red')
        
        # Plot actual new data points
        plt.scatter(x_np, y_np, c='k', marker='x', label='New Data')
        
        plt.title('Original vs Fantasy Model Predictions')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()
        
        return {
            'original': {
                'mean': orig_mean,
                'lower': orig_lower,
                'upper': orig_upper
            },
            'fantasy': {
                'mean': fantasy_mean,
                'lower': fantasy_lower,
                'upper': fantasy_upper
            }
        }

# Usage example:
"""
# Assuming you have new data points
predictions = get_fantasy_predictions(
    model=model,
    train_x=train_x,
    train_y=train_y,
    new_x=new_x,  # Your new data features
    new_y=new_y,  # Your new data targets
    device=device
)

# Access predictions
original_preds = predictions['original']['mean']
fantasy_preds = predictions['fantasy']['mean']
"""