## Package Imports

In [None]:
import os
import random

# Data Analysis
import pandas as pd
import polars as pl
import numpy as np

# Visualization
import seaborn as sns
import matplotlib.pyplot as plt


# Stats & ML

from sklearn.ensemble import IsolationForest

from sklearn.impute import KNNImputer

from sklearn.linear_model import LinearRegression
from statsmodels.stats.outliers_influence import variance_inflation_factor
from sklearn.model_selection import TimeSeriesSplit
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.impute import SimpleImputer
import polars as pl
import numpy as np



pl.Config.set_tbl_rows(-1)
%matplotlib inline

## Data Loading

#### Load all water meter devices

In [None]:
try:
    df = pl.read_csv('../data/Abyei_water_meters.csv', infer_schema_length=100000)
    print(f"Data loaded successfully")
except Exception as e:
    print(f"An error occurred while loading the data: {e}")

#### Load devices to enhance

In [None]:
try:
    poor_devices = pl.read_csv('../exports/devices_to_clean.csv', infer_schema_length=100000)
    print(f"Data loaded successfully")
except Exception as e:
    print(f"An error occurred while loading the data: {e}")

In [None]:
poor_devices.head()

#### Filter out the data for these devices from the main df

In [None]:
# Get the list of device IDs from the filtered devices
filtered_device_ids = poor_devices['DEVICE_ID'].to_list()

filtered_data = df.filter(pl.col('DEVICE_ID').is_in(filtered_device_ids))

In [None]:
filtered_data.shape[0]

In [None]:
level_1_devices = filtered_data.filter(pl.col('OGI_LEVEL') == 1)['MISSION_DEVICE_TAG'].n_unique()
level_2_devices = filtered_data.filter(pl.col('OGI_LEVEL') == 2)['MISSION_DEVICE_TAG'].n_unique()
level_3_devices= filtered_data.filter(pl.col('OGI_LEVEL') == 3)['MISSION_DEVICE_TAG'].n_unique()

In [None]:
level_1_devices , level_2_devices, level_3_devices

All the poor devices are in level 2.

#### How many devices are available in total?

In [None]:
filtered_data['DEVICE_ID'].n_unique()

In [None]:
filtered_data.head()

#### Convert date string to date object

In [None]:
filtered_data = filtered_data.with_columns(
    pl.col("TAG_VALUE_DATE").cast(pl.Datetime)
)

In [None]:
filtered_data.sort('DEVICE_ID').head()

### Daily resample

In [None]:
cummulative_daily_consumption = (
    filtered_data
    .with_columns(
        pl.col("TAG_VALUE_DATE").dt.truncate("1d").cast(pl.Date).alias("DATE")
    )
    .sort(["DEVICE_ID", "DATE"]) 
    .group_by(["DEVICE_ID", "DATE"])
    .agg([
        pl.col("TAG_VALUE_RAW").max().alias("CUMMULATIVE_CONSUMPTION"),
        pl.col("OGI_LONG").first(),
        pl.col("OGI_LAT").first(),
    ])
    .sort(["DEVICE_ID", "DATE"])
)

In [None]:
cummulative_daily_consumption.head()

In [None]:
cummulative_daily_consumption.write_csv('../exports/abyei_to_clean_complete.csv')

In [None]:
cummulative_daily_consumption.shape[0]

In [None]:
# Select 5 random devices
random_devices = random.sample(cummulative_daily_consumption['DEVICE_ID'].unique().to_list(), 5)

# Create the plot
plt.figure(figsize=(12, 6))

for device_id in random_devices:
  device_data = cummulative_daily_consumption.filter(pl.col("DEVICE_ID") == device_id)
  device_data_pd = device_data.to_pandas()
  plt.plot(device_data_pd['DATE'], device_data_pd['CUMMULATIVE_CONSUMPTION'], label=f'Device {device_id}')

plt.xlabel('Date')
plt.ylabel('Cummulative Daily Consumption')
plt.title('Cummulative Daily Consumption of 5 Random Devices')
plt.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

### Validity

In [None]:
daily_diff = cummulative_daily_consumption.with_columns(
        (pl.col("CUMMULATIVE_CONSUMPTION") - pl.col("CUMMULATIVE_CONSUMPTION").shift(1))
        .over("DEVICE_ID")
        .alias("DAILY_DIFF")
    ).filter(pl.col("DAILY_DIFF").is_not_null())

In [None]:
# Step 1: Flag invalid rows based on business rules
daily_diff = daily_diff.with_columns(
    pl.when(
        (pl.col("CUMMULATIVE_CONSUMPTION") == 0) | (pl.col("DAILY_DIFF") < 0)
    )
    .then(0)  # Flag as True/1 if invalid
    .otherwise(1)  # Otherwise False / 0
    .alias("VALIDITY")
)
daily_diff.head()

In [None]:
daily_diff['VALIDITY'].value_counts()

#### Validity Enhancement

In [None]:
def correct_detected_resets(df: pl.DataFrame, device_col: str, date_col: str, consumption_col: str):
    # Sort data
    df = df.sort([device_col, date_col])


    # Combine invalid records to detect resets (negative increments or zero cumulative)
    df = df.with_columns(
        (pl.col('VALIDITY')==0).cast(pl.Boolean).alias('reset_flag')
    )

    print(f"Number of resets detected: {df['reset_flag'].sum()}")

    def correct_consumption(df):
      reset_indices = df.filter(pl.col('reset_flag')).select([date_col]).to_series()

      for reset_date in reset_indices:
          # Get the last valid cumulative value before the reset
          prev_value = (
            df.filter(pl.col(date_col) < reset_date)
            .sort(date_col, descending=True)
            .select(consumption_col)  # Select only the consumption value
            .limit(1)  # Get the most recent entry
            .to_series()
          )
          if len(prev_value) > 0:
              prev_value = prev_value[0]

              # Calculate the shift required to correct the reset
              reset_value = df.filter(pl.col(date_col) == reset_date).select(pl.col(consumption_col)).to_series()[0]
              shift_value = prev_value - reset_value

              # Apply the shift to all subsequent points
              df = df.with_columns(
                  pl.when(pl.col(date_col) >= reset_date)
                  .then(pl.col(consumption_col) + shift_value)
                  .otherwise(pl.col(consumption_col))
                  .alias(consumption_col)
              )
      return df

    # Apply correction to dataframe
    df = df if df['reset_flag'].sum() == 0 else correct_consumption(df)

    return df



In [None]:
def plot_validity_reset_correction(original_df, corrected_df, device_id):
    plt.figure(figsize=(12, 6))

    # Plot original cumulative consumption
    plt.plot(original_df['DATE'], original_df['CUMMULATIVE_CONSUMPTION'], label='Original', color='blue')

    # Plot corrected cumulative consumption
    plt.plot(corrected_df['DATE'], corrected_df['CUMMULATIVE_CONSUMPTION'], label='Corrected', color='green')

    # Highlight reset points (where VALIDITY == False)
    invalid_points = original_df.filter(pl.col('VALIDITY') == False)

    plt.scatter(
        invalid_points['DATE'],
        invalid_points['CUMMULATIVE_CONSUMPTION'],
        color='red',
        label='Invalid Data (Reset Detected)'
    )

    plt.title(f'Cumulative Consumption - Device {device_id}')
    plt.xlabel('Date')
    plt.ylabel('Cumulative Consumption')
    plt.legend()
    plt.grid(True)
    plt.show()


In [None]:
corrected_dataframes = []

for device_id in daily_diff['DEVICE_ID'].unique():
    device_df = daily_diff.filter(pl.col('DEVICE_ID') == device_id)

    device_df = device_df.with_columns(
        pl.col('CUMMULATIVE_CONSUMPTION').alias('ORIGINAL_CUMMULATIVE')
    )
    # Correct cumulative consumption (already part of your loop)
    corrected_device_df = correct_detected_resets(
      device_df,
      device_col='DEVICE_ID',
      date_col='DATE',
      consumption_col='CUMMULATIVE_CONSUMPTION'
    )

    # Append to the list for concatenation later
    corrected_dataframes.append(corrected_device_df)

# Concatenate all corrected dataframes into a single dataframe
corrected_cum_daily_consumption = pl.concat(corrected_dataframes)


In [None]:
corrected_cum_daily_consumption.head()

In [None]:
corrected_cum_daily_consumption.tail()

In [None]:
corrected_cum_daily_consumption.write_csv('../exports/corrected_cum_daily_consumption.csv')

#### Visualization and exports

In [None]:
# Function to plot original vs reconstructed cumulative consumption
def plot_comparison(corrected_data, device_id):
    plt.figure(figsize=(12, 6))

    # Extract data for the selected device
    device_data = corrected_data.filter(pl.col("DEVICE_ID") == device_id).sort("DATE")

    # Identify resets (sharp drops in cumulative consumption)
    resets = device_data.filter(pl.col("VALIDITY") == 0)  # Assuming resets are marked as invalid

    # Plot original cumulative consumption (before reconstruction)
    plt.plot(
        device_data["DATE"], 
        device_data["ORIGINAL_CUMMULATIVE"], 
        color="red", 
        linestyle="-", 
        linewidth=1.5, 
        label="Original Cumulative"
    )

    # Plot reconstructed cumulative consumption (after reset correction)
    plt.plot(
        device_data["DATE"], 
        device_data["CUMMULATIVE_CONSUMPTION"], 
        color="green", 
        linestyle="-", 
        linewidth=1.5, 
        label="Reconstructed Cumulative"
    )

    # Add reset annotations
    for row in resets.iter_rows(named=True):  # Corrected iteration
        plt.scatter(
            row["DATE"], 
            row["ORIGINAL_CUMMULATIVE"], 
            color="black", 
            marker="o", 
            label="Reset Point" if "Reset Point" not in plt.gca().get_legend_handles_labels()[1] else "",  
            zorder=3
        )
        plt.annotate(
            "Reset", 
            xy=(row["DATE"], row["ORIGINAL_CUMMULATIVE"]), 
            xytext=(row["DATE"], row["ORIGINAL_CUMMULATIVE"] + 8000),  # Adjust annotation position
            arrowprops=dict(facecolor='black', arrowstyle="->", lw=1.5),
            fontsize=10,
            color="black"
        )

    # Add labels and legend
    plt.title(f"Cumulative Consumption Reconstruction for Device {device_id}", fontsize=14)
    plt.xlabel("Date", fontsize=12)
    plt.ylabel("Cumulative Consumption", fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, linestyle="--", alpha=0.7)
    # plt.tight_layout()
    plt.savefig("../visualizations/plots/cumulative_consumption_reconstruction.png", dpi=300, bbox_inches="tight")
    # Show the plot
    plt.show()

# Example usage
plot_comparison(corrected_cum_daily_consumption, 1187)  # Replace with actual dataset variable


In [None]:
# Sample 20 random devices
sampled_devices = np.random.choice(
    corrected_cum_daily_consumption['DEVICE_ID'].unique(),
    size=10,
)

# Plotting function for daily increments
def plot_device_increments(df, device_id):
    device_data = df.filter(pl.col('DEVICE_ID') == device_id)

    plt.figure(figsize=(12, 6))
    plt.plot(
        device_data['DATE'],
        device_data['CUMMULATIVE_CONSUMPTION'],
        label=f'Device {device_id}',
        color='blue'
    )

    plt.title(f'Daily Cummulative Consumption for Device {device_id}')
    plt.xlabel('Date')
    plt.ylabel('Daily Consumption')
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot for 20 random devices
for device_id in sampled_devices:
    plot_device_increments(corrected_cum_daily_consumption, device_id)


In [None]:
import matplotlib.pyplot as plt

# Count the number of devices with at least one reset
reset_counts = (
    corrected_cum_daily_consumption
    .filter(pl.col('reset_flag') == True)  # Filter devices with resets
    .group_by('DEVICE_ID')
    .count()
)

# Total number of devices that experienced a reset
total_devices_reset = reset_counts.height
total_devices = len(daily_diff['DEVICE_ID'].unique())
total_devices_without_reset = total_devices - total_devices_reset

# Plotting
plt.figure(figsize=(8, 6))
bars = plt.bar(
    ['Devices with Reset', 'Devices without Reset'], 
    [total_devices_reset, total_devices_without_reset],
    color=['red', 'blue']
)

# Add text labels on bars
for bar in bars:
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2,  # X-position
        height,  # Y-position (above bar)
        f'{height}',  # Label text
        ha='center', va='bottom', fontsize=12, fontweight='bold'
    )

# Title and labels
plt.title('Number of Devices with and without Resets')
plt.ylabel('Number of Devices')

# Save the figure as a PNG 
output_path = "../visualizations/plots/devices_with_without_resets.png"
plt.savefig(output_path, dpi=300, bbox_inches='tight')

plt.show()


In [None]:
# Count the number of resets per device
reset_counts = (
    corrected_cum_daily_consumption
    .filter(pl.col('reset_flag') == True)
    .group_by('DEVICE_ID')
    .count()
)

plt.figure(figsize=(10, 6))
plt.hist(reset_counts['count'], bins=10, color='blue', edgecolor='black')
plt.xlabel('Number of Resets')
plt.ylabel('Number of Devices')
plt.title('Distribution of Reset Counts per Device')
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.savefig("../visualizations/plots/reset_count_distribution.png", dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# Compute total reported days per device
reported_days = corrected_cum_daily_consumption.group_by('DEVICE_ID').agg(
    pl.count('reset_flag').alias('TOTAL_REPORTED_DAYS')
)

# Merge with reset counts
reset_merged = reset_counts.join(reported_days, on='DEVICE_ID')

plt.figure(figsize=(10, 6))
plt.scatter(reset_merged['TOTAL_REPORTED_DAYS'], reset_merged['count'], alpha=0.7, color='purple')
plt.xlabel('Total Reported Days per Device')
plt.ylabel('Number of Resets')
plt.title('Scatter Plot: Resets vs. Total Reported Days')
plt.grid(True, linestyle="--", alpha=0.7)
plt.savefig("../visualizations/plots/resets_vs_reported_days.png", dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# Filter a single device with resets for visualizatio
device_id = reset_counts['DEVICE_ID'][0]  # Pick a device with resets
device_data = corrected_cum_daily_consumption.filter(pl.col('DEVICE_ID') == device_id)

plt.figure(figsize=(12, 6))
sns.lineplot(x=device_data['DATE'], y=device_data['ORIGINAL_CUMMULATIVE'], label="Corrected")
plt.scatter(
    device_data.filter(pl.col('reset_flag') == True)['DATE'],
    device_data.filter(pl.col('reset_flag') == True)['ORIGINAL_CUMMULATIVE'],
    color='red', label="Reset Points", zorder=3
)
plt.xlabel('Date')
plt.ylabel('Cumulative Consumption')
plt.title(f'Cumulative Consumption for Device {device_id} (Resets Highlighted)')
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.savefig(f"../visualizations/plots/cumulative_resets_device_{device_id}.png", dpi=300, bbox_inches='tight')
plt.show()


In [None]:
total_devices = len(corrected_cum_daily_consumption['DEVICE_ID'].unique())
plt.figure(figsize=(8, 8))
plt.pie(
    [total_devices_reset, total_devices - total_devices_reset],
    labels=["Devices with Resets", "Devices without Resets"],
    autopct='%1.1f%%',
    colors=['red', 'green'],
    startangle=40
)
plt.title("Proportion of Devices with and without Resets")
plt.savefig("../visualizations/plots/devices_with_without_resets_pie.png", dpi=300, bbox_inches='tight')
plt.show()

### Validity Recheck

In [None]:
corrected_daily_consumption=corrected_cum_daily_consumption.with_columns(
        (pl.col("CUMMULATIVE_CONSUMPTION") - pl.col("CUMMULATIVE_CONSUMPTION").shift(1))
        .over("DEVICE_ID")
        .alias("DAILY_CONSUMPTION")
    ).filter(pl.col("DAILY_CONSUMPTION").is_not_null())

In [None]:
corrected_daily_consumption.head()

In [None]:
# Step 1: Flag invalid rows based on business rules
validity_check = corrected_daily_consumption.with_columns(
    pl.when(
        (pl.col("CUMMULATIVE_CONSUMPTION") == 0) | (pl.col("DAILY_CONSUMPTION") < 0)
    )
    .then(0)  # Flag as True/1 if invalid
    .otherwise(1)  # Otherwise False / 0
    .alias("VALIDITY")
)
validity_check.head()

In [None]:
validity_check['VALIDITY'].value_counts()

All records are now valid

### Accuracy

In [None]:
# IQR method
def detect_outliers_iqr(data):
    q1 = np.percentile(data, 25)
    q3 = np.percentile(data, 75)
    iqr = q3 - q1
    lower_bound = q1 - 1.5 * iqr
    upper_bound = q3 + 1.5 * iqr
    return (data < lower_bound) | (data > upper_bound)

# MAD method
def detect_outliers_mad(data, threshold=3):
    median = np.median(data)
    mad = np.median(np.abs(data - median))
    modified_z_score = 0.6745 * (data - median) / mad
    return np.abs(modified_z_score) > threshold

# Isolation Forest method
def detect_outliers_isolation_forest(data):
    iso = IsolationForest(random_state=42,contamination='auto',)
    return iso.fit_predict(data.reshape(-1, 1)) == -1


In [None]:
# Outlier detection workflow for a single device
def detect_outliers_per_device(device_data):
    # Mark `VALIDITY=0` as outliers automatically
    device_data = device_data.with_columns(
        pl.when(pl.col("VALIDITY") == 0)
        .then(True)  # Automatically mark as outlier
        .otherwise(False)
        .alias("FINAL_OUTLIER")
    )

    # Filter out records with VALIDITY=0 before outlier detection
    valid_data = device_data.filter(pl.col("VALIDITY") == 1)

    # Proceed only if valid data exists
    if not valid_data.is_empty():
      
        data = valid_data["CUMMULATIVE_CONSUMPTION"].to_numpy()

        # Initialize empty columns for all possible outlier methods
        valid_data = valid_data.with_columns(
            # pl.lit(False).alias("OUTLIER_ZSCORE"),
            pl.lit(False).alias("OUTLIER_IQR"),
            pl.lit(False).alias("OUTLIER_MAD"),
            pl.lit(False).alias("OUTLIER_ISO"),
        )

        # Check normality
        
        iqr_outliers = detect_outliers_iqr(data)
        mad_outliers = detect_outliers_mad(data)
        iso_outliers = detect_outliers_isolation_forest(data)
        
        valid_data = valid_data.with_columns(
            pl.Series("OUTLIER_ISO", iso_outliers),
            pl.Series("OUTLIER_MAD", mad_outliers),
            pl.Series("OUTLIER_IQR", iqr_outliers),
        )
        # Combine Z-score and IQR using logical AND
        final_outlier = np.logical_and(iqr_outliers, np.logical_and(mad_outliers, iso_outliers))

        # Update FINAL_OUTLIER for valid records
        valid_data = valid_data.with_columns(
            pl.Series("FINAL_OUTLIER", final_outlier)
        )

        # Add missing columns to invalid data with default values
        invalid_data = device_data.filter(pl.col("VALIDITY") == 0).with_columns(
            # pl.lit(False).alias("OUTLIER_ZSCORE"),
            pl.lit(False).alias("OUTLIER_IQR"),
            pl.lit(False).alias("OUTLIER_MAD"),
            pl.lit(False).alias("OUTLIER_ISO"),
        )

        # Combine updated valid data with invalid data
        device_data = valid_data.vstack(invalid_data)
        
    return device_data

# Group data by DEVICE_ID and apply the workflow
result = []
for device_id, group in corrected_daily_consumption.group_by("DEVICE_ID"):
    processed_group = detect_outliers_per_device(group)
    result.append(processed_group)

# Concatenate results into a single DataFrame
result_df = pl.concat(result)

In [None]:
result_df['OUTLIER_MAD'].value_counts()

In [None]:
result_df['OUTLIER_IQR'].value_counts()

In [None]:
result_df['OUTLIER_ISO'].value_counts()

In [None]:
result_df['FINAL_OUTLIER'].value_counts()

In [None]:
result_df.head()

#### Outlier Detection on Original Cummulative

In [None]:
# Outlier detection workflow for a single device
def detect_outliers_per_device(device_data):
    # Mark `VALIDITY=0` as outliers automatically
    device_data = device_data.with_columns(
        pl.when(pl.col("VALIDITY") == 0)
        .then(True)  # Automatically mark as outlier
        .otherwise(False)
        .alias("FINAL_OUTLIER")
    )

    # Filter out records with VALIDITY=0 before outlier detection
    valid_data = device_data.filter(pl.col("VALIDITY") == 1)

    # Proceed only if valid data exists
    if not valid_data.is_empty():
      
        data = valid_data["ORIGINAL_CUMMULATIVE"].to_numpy()

        # Initialize empty columns for all possible outlier methods
        valid_data = valid_data.with_columns(
            # pl.lit(False).alias("OUTLIER_ZSCORE"),
            pl.lit(False).alias("OUTLIER_IQR"),
            pl.lit(False).alias("OUTLIER_MAD"),
            pl.lit(False).alias("OUTLIER_ISO"),
        )

        iqr_outliers = detect_outliers_iqr(data)
        mad_outliers = detect_outliers_mad(data)
        iso_outliers = detect_outliers_isolation_forest(data)
        
        valid_data = valid_data.with_columns(
            pl.Series("OUTLIER_ISO", iso_outliers),
            pl.Series("OUTLIER_MAD", mad_outliers),
            pl.Series("OUTLIER_IQR", iqr_outliers),
        )
        # Combine Z-score and IQR using logical AND
        final_outlier = np.logical_and(iqr_outliers, np.logical_and(mad_outliers, iso_outliers))

        # Update FINAL_OUTLIER for valid records
        valid_data = valid_data.with_columns(
            pl.Series("FINAL_OUTLIER", final_outlier)
        )

        # Add missing columns to invalid data with default values
        invalid_data = device_data.filter(pl.col("VALIDITY") == 0).with_columns(
            pl.lit(False).alias("OUTLIER_IQR"),
            pl.lit(False).alias("OUTLIER_MAD"),
            pl.lit(False).alias("OUTLIER_ISO"),
        )

        # Combine updated valid data with invalid data
        device_data = valid_data.vstack(invalid_data)
        
    return device_data

# Group data by DEVICE_ID and apply the workflow
result = []
for device_id, group in corrected_daily_consumption.group_by("DEVICE_ID"):
    processed_group = detect_outliers_per_device(group)
    result.append(processed_group)

# Concatenate results into a single DataFrame
original_result_df = pl.concat(result)

In [None]:
original_result_df['FINAL_OUTLIER'].value_counts()

In [None]:
# Define anomaly counts before and after reconstruction
anomaly_counts = {'Before Reconstruction': 2063, 'After Reconstruction': 1699}

# Create a bar chart
plt.figure(figsize=(8, 6))
plt.bar(anomaly_counts.keys(), anomaly_counts.values(), color=['red', 'green'])
plt.ylabel('Number of Anomalies')
plt.title('Anomaly Count Before and After Reconstruction')

# Annotate the bars with actual values
for i, (label, count) in enumerate(anomaly_counts.items()):
    plt.text(i, count + 5, str(count), ha='center', fontsize=12)

# Save the figure as a PNG file
plt.savefig("../visualizations/plots/anomaly_count_before_after_reconstruction.png", dpi=300, bbox_inches='tight')

# Show the plot
plt.show()


#### Visualization

In [None]:
# Step 1: Aggregate to count outliers per device
outlier_summary = (
    result_df
    .filter(pl.col("FINAL_OUTLIER") == 1)  # Filter rows marked as outliers
    .group_by("DEVICE_ID")
    .agg(pl.count().alias("OUTLIER_COUNT"))  # Count outliers per device
    .sort("OUTLIER_COUNT", descending=True)  # Sort devices by the number of outliers
)

# Step 2: Select top N devices with the most outliers
top_n = 5  # Adjust this value as needed
top_devices = outlier_summary.head(top_n)

# Step 3: Filter data for the top devices
top_device_ids = top_devices["DEVICE_ID"].to_list()
top_device_data = result_df.filter(pl.col("DEVICE_ID").is_in(top_device_ids))


for device_id in top_device_ids:
    device_data = top_device_data.filter(pl.col("DEVICE_ID") == device_id)
    # Step 4: Visualize cumulative consumption for the top devices
    plt.figure(figsize=(12, 8))
    # Plot valid points
    plt.plot(
        device_data.filter(pl.col("FINAL_OUTLIER") == 0)["DATE"],
        device_data.filter(pl.col("FINAL_OUTLIER") == 0)["CUMMULATIVE_CONSUMPTION"],
        marker="o",
        label=f"{device_id} - Valid",
        linestyle="-",
        alpha=0.8,
    )
    
    # Plot outliers
    plt.scatter(
        device_data.filter(pl.col("FINAL_OUTLIER") == 1)["DATE"],
        device_data.filter(pl.col("FINAL_OUTLIER") == 1)["CUMMULATIVE_CONSUMPTION"],
        color="red",
        label=f"{device_id} - Outlier",
        alpha=0.8,
    )

    # Add plot details
    plt.title("Top Devices with the Most Outliers", fontsize=14)
    plt.xlabel("Date", fontsize=12)
    plt.ylabel("Cumulative Consumption", fontsize=12)
    plt.legend(fontsize=10, title="Legend")
    plt.grid(alpha=0.5, linestyle="--")
    plt.tight_layout()
    
    # Show plot
    plt.show()

#### Nullify column

In [None]:
### Keep a copy of daily_consumption before normalizing the column in place
result_df = result_df.with_columns(pl.col('CUMMULATIVE_CONSUMPTION').alias('CUMMULATIVE_CONSUMPTION_COPY'))

In [None]:
# Step 1: Create an Accuracy Column
result_df = result_df.with_columns(
    pl.when(pl.col("FINAL_OUTLIER") == 0)
    .then(1)  # Mark as accurate if FINAL_OUTLIER is 0
    .otherwise(0)  # Mark as inaccurate if FINAL_OUTLIER is not 0
    .alias("ACCURACY")
)

# Step 2: Preview the DataFrame with the new Accuracy column
result_df.head()

#### Nullify inaccurate records and impute them to increase completeness and accuracy.

In [None]:
result_df = result_df.with_columns(
        pl.when(pl.col("FINAL_OUTLIER"))
        .then(None)  # Replace outliers with None (null)
        .otherwise(pl.col("CUMMULATIVE_CONSUMPTION"))
        .alias("CUMMULATIVE_CONSUMPTION")
    )

In [None]:
cum_nan_count_column = result_df.select(
    pl.col("CUMMULATIVE_CONSUMPTION").is_null().sum().alias("NaN_Count")
)
print(cum_nan_count_column)

In [None]:
result_df.head()

In [None]:
result_df=result_df.drop(['VALIDITY','ORIGINAL_CUMMULATIVE','reset_flag','FINAL_OUTLIER',
                          'OUTLIER_IQR','OUTLIER_MAD','OUTLIER_ISO','ACCURACY','DAILY_DIFF'])

In [None]:
result_df.head()

In [None]:
result_df.shape[0]

In [None]:
result_df['DEVICE_ID'].n_unique()

### Device Selection

Remove these device ids: 4759, 1307, 2049, 2048

These devices either have constant and long period of no data or they are small in records

In [None]:
result_df =result_df.filter(pl.col('DEVICE_ID')!=1307)
result_df = result_df.filter(pl.col('DEVICE_ID')!=2049)
result_df = result_df.filter(pl.col('DEVICE_ID')!=2048)

In [None]:
# 

#### Null Percentage

Remove devices with small null percentage which can be ignored or deleted. So focus on devices with > 4% missing rates because the less than 4% are between 1 - 6 missing records. These are because during reconstruction anomaly did not detect them again.

In [None]:
# Step 1: Add an `IS_NULL` column to identify null values
result_df = result_df.with_columns(
    pl.when(pl.col("CUMMULATIVE_CONSUMPTION").is_null())
    .then(1)
    .otherwise(0)
    .alias("IS_NULL")
)

In [None]:
# Step 2: Calculate total rows and null rows per device
null_percentage_df = result_df.group_by("DEVICE_ID").agg(
    pl.col("IS_NULL").sum().alias("NULL_COUNT"),  # Count of null values
    pl.col("CUMMULATIVE_CONSUMPTION").count().alias("TOTAL_COUNT")  # Total rows
).with_columns(
    # Step 3: Calculate null percentage
    (pl.col("NULL_COUNT") / pl.col("TOTAL_COUNT") * 100).alias("NULL_PERCENTAGE")
)

In [None]:
null_percentage_df.sort('NULL_PERCENTAGE', descending=True)

In [None]:
# Step 4: Filter devices with missing rates over 4%
high_null_devices = null_percentage_df.filter(pl.col("NULL_PERCENTAGE") > 4).select("DEVICE_ID")

# Step 5: Keep only records from devices with high missing rates
filtered_result_df = result_df.join(high_null_devices, on="DEVICE_ID", how="inner")

# Check how many devices remain after filtering
print(f"Remaining devices with high missing rates: {filtered_result_df['DEVICE_ID'].n_unique()}")

# Display the first few rows of the filtered dataset
filtered_result_df.head()

In [None]:
# Define counts for devices with outliers before and after reconstruction
device_counts = {
    'Before Reconstruction': 56,
    # 'After Reconstruction': 55,
    'Proceeding to Imputation': 19
}

# Create a bar chart
plt.figure(figsize=(8, 6))
plt.bar(device_counts.keys(), device_counts.values(), color=['red', 'green', 'blue'])
plt.ylabel('Number of Devices with Outliers')
plt.title('Devices with Outliers Before and After Reconstruction')

# Annotate bars with actual values
for i, (label, count) in enumerate(device_counts.items()):
    plt.text(i, count + 0.5, str(count), ha='center', fontsize=12)

# Save the figure as a PNG file
plt.savefig("../visualizations/plots/devices_with_outliers_before_after.png", dpi=300, bbox_inches='tight')

# Show the plot
plt.show()


## Imputation

### Statistics

Mean, Median, Linear Interpolation, Cubic Interpolation

In [None]:
result_df=filtered_result_df
result_df.head()

#### Mean

In [None]:
result_df = result_df.with_columns([
    pl.col("CUMMULATIVE_CONSUMPTION").fill_null(
        pl.col("CUMMULATIVE_CONSUMPTION").mean().over("DEVICE_ID")
    ).alias("MEAN_IMPUTED")
])

#### Median

In [None]:
result_df = result_df.with_columns([
    pl.col("CUMMULATIVE_CONSUMPTION").fill_null(
        pl.col("CUMMULATIVE_CONSUMPTION").median().over("DEVICE_ID")
    ).alias("MEDIAN_IMPUTED")
])

#### Linear Interpolation

In [None]:
# Convert Polars DataFrame to Pandas for interpolation
df_pd = result_df.to_pandas()

In [None]:
# Interpolate by Linear Method, ensuring all missing values are filled
df_pd["LINEAR_IMPUTED"] = (
    df_pd.groupby("DEVICE_ID", group_keys=False)["CUMMULATIVE_CONSUMPTION"]
    .apply(lambda x: x.interpolate(method="linear", limit_direction="both"))
    .reset_index(drop=True)  # Align index with the original DataFrame
)

#### Cubic Interpolation

In [None]:
# Interpolate by Cubic Method
df_pd["CUBIC_IMPUTED"] = (
    df_pd.groupby("DEVICE_ID", group_keys=False)["CUMMULATIVE_CONSUMPTION"]
    .apply(lambda x: x.interpolate(method="cubic",limit_direction="both"))
    .reset_index(drop=True)  # Align index with the original DataFrame
)

# Fill any remaining nulls using forward and backward fill
df_pd["CUBIC_IMPUTED"].fillna(method="ffill", inplace=True)  # Fill forward
df_pd["CUBIC_IMPUTED"].fillna(method="bfill", inplace=True)  # Fill backward


In [None]:
# Convert back to Polars
result_df = pl.from_pandas(df_pd)

In [None]:
result_df.head()

#### Confirm all records are imputed

In [None]:
null_count = df_pd["LINEAR_IMPUTED"].isna().sum()
print(f"Total null values in LINEAR_IMPUTED: {null_count}")

In [None]:
null_count = df_pd["CUBIC_IMPUTED"].isna().sum()
print(f"Total null values in CUBIC_IMPUTED: {null_count}")

### Naive

#### Forward Fill

In [None]:
# Forward Fill
result_df = result_df.with_columns(
    pl.col("CUMMULATIVE_CONSUMPTION").forward_fill().alias("FFILL_IMPUTED")
)

In [None]:
# Count null values in the "FFILL_IMPUTED" column
null_count = result_df.select(pl.col("FFILL_IMPUTED").is_null().sum()).item()
print(f"Total null values in FFILL_IMPUTED: {null_count}")


In [None]:
result_df = result_df.with_columns(
    pl.col("CUMMULATIVE_CONSUMPTION")
    .forward_fill()  
    .backward_fill()# Fill any remaining nulls after backward fill
    .alias("FFILL_IMPUTED")
)

# Check if any nulls remain
null_count = result_df.select(pl.col("BFILL_IMPUTED").is_null().sum()).item()
print(f"Total null values in BFILL_IMPUTED after fixes: {null_count}")


#### Backward Fill

In [None]:
# Backward Fill
result_df = result_df.with_columns(
    pl.col("CUMMULATIVE_CONSUMPTION").backward_fill().alias("BFILL_IMPUTED")
)

In [None]:
# Count null values in the "FFILL_IMPUTED" column
null_count = result_df.select(pl.col("BFILL_IMPUTED").is_null().sum()).item()
print(f"Total null values in FFILL_IMPUTED: {null_count}")

Use forward fill to fix the last one.

In [None]:
result_df = result_df.with_columns(
    pl.col("CUMMULATIVE_CONSUMPTION")
    .backward_fill()
    .forward_fill()  # Fill any remaining nulls after backward fill
    .alias("BFILL_IMPUTED")
)

# Check if any nulls remain
null_count = result_df.select(pl.col("BFILL_IMPUTED").is_null().sum()).item()
print(f"Total null values in BFILL_IMPUTED after fixes: {null_count}")


### ML

- KNN
- LR
- SAITS

### Load Exogenous parameters and merge
- Climatic from GEE: https://code.earthengine.google.com/?accept_repo=users/jolaiyaemmanuel/thesis

In [None]:
climate_df = pl.read_csv('../data/daily_climate_gee.csv')

In [None]:
climate_df = climate_df.rename({'date': 'DATE'})

# Convert the 'DATE' column in climate_df to datetime objects
climate_df = climate_df.with_columns(pl.col('DATE').cast(pl.Date))

In [None]:
climate_df = climate_df.drop('.geo')

In [None]:
# Perform the join operation
result_df = result_df.with_columns(pl.col('DATE').cast(pl.Date))
merged_df = result_df.join(climate_df, on='DATE', how='left')

# Print some info
print(f"Shape of the original DataFrame: {df.shape}")
print(f"Shape of the climate DataFrame: {climate_df.shape}")
print(f"Shape of the merged DataFrame: {merged_df.shape}")

merged_df.head()

In [None]:
# Assuming your DataFrame `df` has these columns: TEMPERATURE_MEAN, TEMPERATURE_MIN, TEMPERATURE_MAX
merged_df = merged_df.with_columns([
    (pl.col("temp_max") - 273.15).alias("temp_max"),
    (pl.col("temp_min") - 273.15).alias("temp_min"),
])

### Feature Engineering
Only used in LR

#### Temporal

In [None]:
merged_df= merged_df.with_columns(
    pl.col('DATE').dt.weekday().alias('day_of_week'),
    pl.col('DATE').dt.month().alias('month'),
    pl.col('DATE').dt.week().alias('week_of_year'),
    pl.col('DATE').dt.quarter().alias('quarter'),
    pl.col('DATE').dt.day().alias('day'),
    pl.col('DATE').dt.year().alias('year'),
    # Add 'is_weekend' column: 1 if Saturday (5) or Sunday (6), else 0
    pl.col('DATE').dt.weekday().is_in([5, 6]).cast(pl.Int8).alias('is_weekend')
)

merged_df = merged_df.with_columns([
    np.sin(2 * np.pi * pl.col('day_of_week') / 7).alias('sin_day_of_week'),
    np.cos(2 * np.pi * pl.col('day_of_week') / 7).alias('cos_day_of_week'),
    np.sin(2 * np.pi * pl.col('month') / 12).alias('sin_month'),
    np.cos(2 * np.pi * pl.col('month') / 12).alias('cos_month')
])

#### Lags

In [None]:
lags = [1, 2, 3,]
for lag in lags:
    merged_df = merged_df.with_columns(
        pl.col('CUMMULATIVE_CONSUMPTION').shift(lag).alias(f'lag_{lag}')
    )
    
for lag in lags:
    merged_df = merged_df.with_columns(
        pl.col(f"lag_{lag}").fill_null(strategy="backward")
    )

#### Holiday

In [None]:
# List of holiday dates
holidays_2022 = [
    '2022-01-01', '2022-04-15', '2022-04-16', '2022-04-17', '2022-04-18',
    '2022-05-01', '2022-05-03', '2022-05-16', '2022-07-09', '2022-07-10',
    '2022-07-30', '2022-12-24', '2022-12-25', '2022-12-26'
]

holidays_2023 = [
    '2023-01-01', '2023-04-07', '2023-04-08', '2023-04-09', '2023-04-10',
    '2023-04-21', '2023-05-01', '2023-05-16', '2023-06-28', '2023-07-09',
    '2023-07-30', '2023-12-24', '2023-12-25', '2023-12-26'
]

holidays_2024 = [
    '2024-01-01', '2024-01-09', '2024-03-29', '2024-03-30', '2024-03-31',
    '2024-04-01', '2024-04-10', '2024-05-01', '2024-05-16', '2024-06-16',
    '2024-07-09', '2024-07-30', '2024-12-24', '2024-12-25', '2024-12-26'
]

holiday_dates = holidays_2022 + holidays_2023 + holidays_2024

holidays_df = pl.DataFrame({
    'date': pl.Series(holiday_dates).cast(pl.Date),
    'is_holiday': pl.Series([1] * len(holiday_dates))
})

In [None]:
merged_df = merged_df.join(
    holidays_df,
    left_on='DATE',
    right_on='date',
    how='left'
)

In [None]:
# Fill NaN values in 'is_holiday' with 0 (non-holiday)
merged_df = merged_df.with_columns(
    pl.col('is_holiday').fill_null(0)
)

In [None]:
merged_df.head()

#### Normalizing
Min-Max

In [None]:
merged_df.head()

In [None]:
normalized_df = merged_df.with_columns(
    [
        ((pl.col(col) - pl.col(col).min().over("DEVICE_ID")) /
         (pl.col(col).max().over("DEVICE_ID") - pl.col(col).min().over("DEVICE_ID")))
        .alias(col)
        for col in merged_df.columns
        if col not in ["DEVICE_ID", "DATE","OGI_LONG","OGI_LAT",
                       "MEAN_IMPUTED","MEDIAN_IMPUTED","FFILL_IMPUTED",
                       "BFILL_IMPUTED","LINEAR_IMPUTED",
                       "CUMMULATIVE_CONSUMPTION_COPY",
                      "CUBIC_IMPUTED"]  # Exclude non-numeric or identifier columns
    ]
)

In [None]:
normalized_df.head()

In [None]:
normalized_df.filter(pl.col('CUMMULATIVE_CONSUMPTION').is_null()).head()

### Table - Missing rate

In [None]:
# Make a copy of the dataframe to avoid modifying the original
numeric_df_pd_modified = normalized_df.to_pandas().copy()

# Dictionary to store missing data statistics per device
missing_data_per_device = []

# Loop through each device and introduce 10% additional nulls
for device_id, group in numeric_df_pd_modified.groupby("DEVICE_ID"):
    # Total number of records before introducing new nulls
    total_records_before = len(group)
    
    # Count missing values before introducing new nulls
    missing_before = group["CUMMULATIVE_CONSUMPTION"].isna().sum()
    
    # Identify non-null indices for the current device
    non_null_indices = group[group["CUMMULATIVE_CONSUMPTION"].notna()].index
    
    # Determine the number of new nulls to introduce (10% of available non-null values)
    num_new_nulls = int(0.1 * len(non_null_indices))
    
    if num_new_nulls > 0:
        # Randomly select indices to introduce new nulls
        new_null_indices = np.random.choice(non_null_indices, num_new_nulls, replace=False)
        # Introduce the new nulls in the modified dataframe
        numeric_df_pd_modified.loc[new_null_indices, "CUMMULATIVE_CONSUMPTION"] = np.nan
    
    # Count missing values after introducing new nulls
    missing_after = numeric_df_pd_modified.loc[group.index, "CUMMULATIVE_CONSUMPTION"].isna().sum()
    
    # Total number of records should decrease due to introduced nulls
    total_records_after = total_records_before - num_new_nulls
    
    # Store the results
    missing_data_per_device.append({
        "DEVICE_ID": device_id,
        "Total Records Before": total_records_before,
        "Missing Before": missing_before,
        "Total Records After": total_records_after,
        "Missing After": missing_after
    })

# Create a summary dataframe
missing_data_summary_per_device = pd.DataFrame(missing_data_per_device)

# Return the modified dataframe
missing_data_summary_per_device.head()

In [None]:
missing_data_summary_per_device

#### Correlation Analysis

In [None]:
# Select numeric columns only
numeric_columns = [col for col in normalized_df.columns if normalized_df[col].dtype in [pl.Float64, pl.Int64]]
numeric_df = normalized_df.select(numeric_columns)

In [None]:
# List of exogenous features (replace with your feature names)
exogenous_features = [
    'precip_max',
    'precip_min',
    'temp_max',
    'temp_min',
    'day_of_week',
    'month',
    'week_of_year',
    'quarter',
    'day',
    'year',
    'is_weekend',
    'sin_day_of_week',
    'cos_day_of_week',
    'sin_month',
    'cos_month',
    'lag_1',
    'lag_2',
    'lag_3',
    'is_holiday'
]

# Manual method

def compute_device_correlations_spearman(
    df: pl.DataFrame, 
    target_column: str, 
    exogenous_features: list
) -> pl.DataFrame:
    # Get unique device IDs
    device_ids = df['DEVICE_ID'].unique().to_list()
    
    # List to store correlation results
    all_correlations = []
    
    # Compute correlations for each device
    for device_id in device_ids:
        # Filter data for the specific device
        device_data = df.filter(pl.col('DEVICE_ID') == device_id)

        # Compute Spearman correlation for each feature
        # Spearman correlation = Pearson correlation on rank-transformed data
        correlations = []
        for feature in exogenous_features:
            # Rank-transform both columns
            ranked_feature = pl.col(feature).rank("average")
            ranked_target = pl.col(target_column).rank("average")
            
            # Compute the correlation of these ranks
            spearman_corr = device_data.select(
                pl.corr(ranked_feature, ranked_target)
            )[0, 0]
            
            correlations.append({
                "device_id": device_id,
                "feature": feature,
                "correlation": spearman_corr
            })
        
        all_correlations.extend(correlations)
    
    # Convert to Polars DataFrame
    correlation_df = pl.DataFrame(all_correlations)
    
    return correlation_df

# Usage
correlation_results = compute_device_correlations_spearman(
    numeric_df,
    "CUMMULATIVE_CONSUMPTION_COPY",  # Replace with your target column
    exogenous_features
)


In [None]:
# Step 2: Analyze the Results
# Find the top correlated features across all devices
top_features = (
    correlation_results.group_by("feature")
    .agg(pl.col("correlation").abs().mean().alias("mean_correlation"))
    .sort("mean_correlation", descending=True)
)

In [None]:
top_features

### Colinearity

Since we're using linear regression, we have to prevent collinearity.

In [None]:
numeric_df.head()

In [None]:
# Dictionary to store VIF results per device
vif_per_device = []

# Loop through each device and compute VIF
for device_id, group in numeric_df.to_pandas().groupby("DEVICE_ID"):
    # Select lagged features
    lags_df = group[["lag_1", "lag_2", "lag_3"]].replace([np.inf, -np.inf], np.nan).dropna()

    # Ensure there are enough observations to compute VIF
    if lags_df.shape[0] > 3:  # VIF requires at least as many rows as columns
        vif_data = pd.DataFrame()
        vif_data["Feature"] = lags_df.columns
        vif_data["VIF"] = [variance_inflation_factor(lags_df.values, i) for i in range(lags_df.shape[1])]
        vif_data["DEVICE_ID"] = device_id

        # Append results
        vif_per_device.append(vif_data)

# Combine all device VIF results
vif_results_df = pd.concat(vif_per_device, ignore_index=True)

# Compute the mean VIF per feature across all devices
average_vif = vif_results_df.groupby("Feature")["VIF"].mean().reset_index()

In [None]:
average_vif

Based on this, only lag 1 is kept.

In [None]:
columns_to_remove = top_features.filter(pl.col('mean_correlation') < 0.5)
# remove colinear features too
cols_to_remove_list = columns_to_remove['feature'].to_list() + ['lag_2','lag_3']

In [None]:
cols_to_remove_list

In [None]:
# Filter features with mean correlation > 0.5
exogenous_features = top_features.filter(pl.col('mean_correlation') > 0.5)

# Convert to list and exclude 'lag_1', 'lag_2', and 'lag_3'
exogenous_features = [feature for feature in exogenous_features['feature'].to_list() if feature not in ["lag_2", "lag_3"]]

In [None]:
normalized_df.head()

In [None]:
modelling_df = normalized_df.drop(cols_to_remove_list)

In [None]:
modelling_df.head()

In [None]:
modelling_df['DEVICE_ID'].n_unique(), modelling_df['DEVICE_ID'].shape[0]

### KNN

In [None]:
modelling_df.head()

In [None]:
# Define range of k values to test
K_VALUES = list(range(1, 11))  # Testing k from 1 to 10

# Columns used as input for KNN (but only imputing CUMMULATIVE_CONSUMPTION)
target_column = "CUMMULATIVE_CONSUMPTION"
feature_columns = ["year", "lag_1",]
selected_features = [target_column] + feature_columns

# Dictionary to store best k per device
best_k_per_device = {}

# List to store final imputed data
imputed_data = []

# Loop through each device group
for device_id, group in modelling_df.group_by("DEVICE_ID"):
    print(f'Optimizing k for Device: {device_id}')

    # Extract features as NumPy array
    values = group.select(selected_features).to_numpy()

    # Identify missing values in CUMMULATIVE_CONSUMPTION
    missing_mask = np.isnan(values[:, 0])  # Only for target column

    # Ensure enough non-null values for RMSE evaluation
    if np.sum(~missing_mask) < 5:
        print(f"Skipping device {device_id} due to insufficient data for validation.")
        best_k_per_device[device_id] = None
        continue

    # Create a copy for evaluation (masking known values for validation)
    known_values = values[:, 0].copy()  # Keep only the target column
    known_indices = np.where(~missing_mask)[0]  # Indices of known values

    if len(known_indices) > 5:  # Ensure at least 5 known values exist for testing
        np.random.shuffle(known_indices)
        test_indices = known_indices[: max(1, len(known_indices) // 10)]  # Mask 10% of known values for validation
        test_mask = np.zeros_like(values[:, 0], dtype=bool)
        test_mask[test_indices] = True

        train_values = values.copy()
        train_values[test_mask, 0] = np.nan  # Mask only `CUMMULATIVE_CONSUMPTION` for evaluation

        # Optuna optimization
        def objective(trial):
            k = trial.suggest_int("k", 1, 10)
            knn_imputer = KNNImputer(n_neighbors=k)
            imputed = knn_imputer.fit_transform(train_values)

            # Compute RMSE only for masked `CUMMULATIVE_CONSUMPTION` values
            rmse = np.sqrt(mean_squared_error(known_values[test_mask], imputed[test_mask, 0]))
            return rmse

        study = optuna.create_study(direction="minimize")
        study.optimize(objective, n_trials=10)

        best_k = study.best_params["k"]
    else:
        best_k = 5  # Default k if too few known values exist

    best_k_per_device[device_id] = best_k  # Store best k for the device

    # Perform final imputation using best k (only imputing target column)
    knn_imputer = KNNImputer(n_neighbors=best_k)
    final_imputed = knn_imputer.fit_transform(values)[:, 0]  # Only keep target column

    # Append imputed values to group without altering other features
    group = group.with_columns(pl.Series("KNN_IMPUTED", final_imputed))

    imputed_data.append(group)

# Combine all processed groups
imputed_df = pl.concat(imputed_data)

# Print best k values per device
print("Best k per device:", best_k_per_device)

In [None]:
imputed_df.head()

### LR

#### LR + TSSplit

In [None]:
# Define target column and features
target_column = "CUMMULATIVE_CONSUMPTION"
feature_columns = ["year", "lag_1"]
features = feature_columns

# List to store imputed data for all devices
all_devices = []

# Get unique device IDs
device_ids = imputed_df["DEVICE_ID"].unique().to_list()

# Loop through each device
for device_id in device_ids:
    # Filter data for the current device
    device_data = imputed_df.filter(pl.col("DEVICE_ID") == device_id)

    # Separate rows with missing and non-missing target values
    non_missing_df = device_data.filter(~pl.col(target_column).is_null())
    missing_df = device_data.filter(pl.col(target_column).is_null())

    # Skip the device if there are no missing values
    if missing_df.is_empty():
        all_devices.append(device_data.with_columns(
            pl.col(target_column).alias("LIN_REG_IMPUTED_EXO")
        ))
        continue

    # Prepare training data (features and target) for Linear Regression
    X_train = non_missing_df.select(features).to_pandas()
    y_train = non_missing_df.select(target_column).to_pandas().values.ravel()

    # Prepare data for prediction (features only)
    X_predict = missing_df.select(features).to_pandas()

    # ✅ Fix: Impute missing feature values (replace NaNs with mean) especially for the lags
    imputer = SimpleImputer(strategy="mean")
    X_train = imputer.fit_transform(X_train)
    X_predict = imputer.transform(X_predict)

    # Time Series Cross-Validation
    tscv = TimeSeriesSplit(n_splits=5)
    rmse_scores = []

    for train_idx, test_idx in tscv.split(X_train):
        X_tr, X_val = X_train[train_idx], X_train[test_idx]
        y_tr, y_val = y_train[train_idx], y_train[test_idx]

        # Train Linear Regression model
        model = LinearRegression()
        model.fit(X_tr, y_tr)

        # Predict on validation set
        y_pred = model.predict(X_val)

        # Compute RMSE for validation set
        rmse = np.sqrt(mean_squared_error(y_val, y_pred))
        rmse_scores.append(rmse)

    # Compute average RMSE across folds
    mean_rmse = np.mean(rmse_scores)
    print(f"Device {device_id} - Mean RMSE: {mean_rmse}")

    # Train final model on entire available dataset
    final_model = LinearRegression()
    final_model.fit(X_train, y_train)

    # Predict missing values
    predicted_values = final_model.predict(X_predict)

    # Add the predicted values to the missing_df in a new column
    missing_df = missing_df.with_columns(
        pl.Series(predicted_values).alias("LIN_REG_IMPUTED_EXO")
    )

    # For non-missing rows, copy the original target column into the new column
    non_missing_df = non_missing_df.with_columns(
        pl.col(target_column).alias("LIN_REG_IMPUTED_EXO")
    )

    # Combine the non-missing and imputed data for the current device
    device_combined = pl.concat([non_missing_df, missing_df]).sort("DATE")
    all_devices.append(device_combined)

# Combine data from all devices
final_df = pl.concat(all_devices)


In [None]:
# # Define your target column and features
# target_column = "CUMMULATIVE_CONSUMPTION"
# feature_columns = ["year", "lag_1",]
# features = feature_columns

# # List to store imputed data for all devices
# all_devices = []

# # Get unique device IDs
# device_ids = imputed_df["DEVICE_ID"].unique().to_list()

# # Loop through each device
# for device_id in device_ids:
#     # Filter data for the current device
#     device_data = imputed_df.filter(pl.col("DEVICE_ID") == device_id)

#     # Separate rows with missing and non-missing target values
#     non_missing_df = device_data.filter(~pl.col(target_column).is_null())
#     missing_df = device_data.filter(pl.col(target_column).is_null())

#     # Skip the device if there are no missing values
#     if missing_df.is_empty():
#         # Add non-missing data to the final list as is
#         all_devices.append(device_data.with_columns(
#             pl.col(target_column).alias("LIN_REG_IMPUTED_EXO")
#         ))
#         continue

#     # Prepare training data (features and target) for Linear Regression
#     X_train = non_missing_df.select(features).to_pandas()
#     y_train = non_missing_df.select(target_column).to_pandas().values.ravel()

#     # Prepare data for prediction (features only)
#     X_predict = missing_df.select(features).to_pandas()

#     # Train Linear Regression model
#     lr_model = LinearRegression()
#     lr_model.fit(X_train, y_train)

#     # Predict missing values
#     predicted_values = lr_model.predict(X_predict)

#     # Add the predicted values to the missing_df in a new column
#     missing_df = missing_df.with_columns(
#         pl.Series(predicted_values).alias("LIN_REG_IMPUTED_EXO")
#     )

#     # For non-missing rows, copy the original target column into the new column
#     non_missing_df = non_missing_df.with_columns(
#         pl.col(target_column).alias("LIN_REG_IMPUTED_EXO")
#     )

#     # Combine the non-missing and imputed data for the current device
#     device_combined = pl.concat([non_missing_df, missing_df]).sort("DATE")
#     all_devices.append(device_combined)

# # Combine data from all devices
# final_df = pl.concat(all_devices)

### SAITS

In [None]:
# final_df = final_df.with_columns(pl.col('DATE').cast(pl.Date))

In [None]:
# final_df = final_df.sort(['DEVICE_ID','DATE'])

In [None]:
# final_df.head()

In [None]:
# # Add an index column to your Polars DataFrame
# final_df = final_df.with_row_count(name="index")

In [None]:
# # Define the feature columns to include
# feature_columns = ["CUMMULATIVE_CONSUMPTION"]

# # Dictionary to store imputed values for each device
# imputed_values_dict = {}

# for id, device in enumerate(final_df['DEVICE_ID'].unique()):
#     print(f"Processing device: {id + 1} of {len(final_df['DEVICE_ID'].unique())}")

#     # Filter and sort data for the device
#     device_data = final_df.filter(pl.col('DEVICE_ID') == device).sort('DATE')

#     # Select relevant columns and convert to NumPy array
#     data_for_imputation = device_data.select(feature_columns).to_numpy()

#     # Define X for SAITS
#     X = data_for_imputation.reshape(1, -1, len(feature_columns))

#     # Keep a copy of the original data for evaluation
#     X_ori = X.copy()

#     # Apply MCAR (mask 10% of the observed values for validation)
#     X = mcar(X, 0.1)  # Mask 10% of the observed data

#     dataset = {"X": X}

#     # Initialize the SAITS model
#     saits = SAITS(
#         n_steps=X.shape[1],  # Number of time steps
#         n_features=X.shape[2],  # Number of features
#         n_layers=2,
#         d_model=256,
#         d_ffn=128,
#         n_heads=4,
#         d_k=64,
#         d_v=64,
#         dropout=0.1,
#         epochs=100,
#         batch_size=32,
#         attn_dropout=0.1,
#         patience=3,
#         saving_path="./saits",  # Path for saving model checkpoints
#         model_saving_strategy="best",
#     )

#     # Train the model and impute the missing values
#     saits.fit(dataset)
#     imputation = saits.impute(dataset)

#     # Extract imputed values for the target column (CUMMULATIVE_CONSUMPTION)
#     imputed_values = imputation[0, :, 0]  # Extract the imputed values for the first feature

#     # Store the imputed values in the dictionary
#     imputed_values_dict[device] = {
#         "indices": device_data['index'].to_numpy(),  # Save indices for mapping back to original data
#         "imputed_values": imputed_values,
#     }

#     # Calculate MAE and RMSE
#     indicating_mask = np.isnan(X) ^ np.isnan(X_ori)  # Mask indicating artificially missing values
#     mae = calc_mae(imputation, np.nan_to_num(X_ori), indicating_mask)
#     rmse = calc_rmse(imputation, np.nan_to_num(X_ori), indicating_mask)

#     print(f"Device {device} - MAE: {mae}")
#     print(f"Device {device} - RMSE: {rmse}")

In [None]:
# # Initialize an empty list to collect rows
# rows = []

# # Iterate over each device in the dictionary
# for device_id, data in imputed_values_dict.items():
#     indices = data["indices"]
#     imputed_values = data["imputed_values"]

#     # Create a row for each index and value
#     for idx, value in zip(indices, imputed_values):
#         rows.append({"DEVICE_ID": device_id, "index": idx, "imputed_value": value})

# # Create a DataFrame from the collected rows
# imputed_df = pd.DataFrame(rows)

# # Sort the DataFrame by index if needed
# imputed_df = imputed_df.sort_values("index").reset_index(drop=True)

In [None]:
# imputed_df = imputed_df.sort_values("index").reset_index(drop=True)

In [None]:
# imputed_df.head()

In [None]:
# imputed_df = pl.from_dataframe(imputed_df)

In [None]:
# imputed_df = imputed_df.with_columns(pl.col('index').cast(pl.UInt32))

In [None]:
# merged_df = final_df.join(imputed_df, on=["DEVICE_ID","index"], how="left")
# merged_df.head()

In [None]:
# Create an array of 10 random device IDs from the unique device IDs in your dataset
random_devices = random.sample(imputed_df['DEVICE_ID'].unique().to_list(), 10)

In [None]:

for device in random_devices:  # Use the same random devices as before
    device_data = imputed_df.filter(pl.col('DEVICE_ID') == device).to_pandas()
    
    plt.figure(figsize=(10, 6))

    # Plot original cumulative consumption (with outliers removed)
    plt.plot(device_data['DATE'], device_data['CUMMULATIVE_CONSUMPTION'], 
             linestyle='-', color='blue', label='Original (No Outliers)')

    
    # Plot KNN values
    plt.plot(device_data['DATE'], device_data['KNN_IMPUTED'], 
             linestyle='--', color='red', label='KNN Imputed Values')
    plt.xlabel('Date')
    plt.ylabel('Cumulative Consumption')
    plt.title(f'Cumulative Consumption with Imputed Values for Device ID: {device}')
    plt.xticks(rotation=45)
    plt.legend()
    plt.tight_layout()
    plt.show()

### Data Reconstruction

In [None]:
final_df.head()

#### Reverse Normalization

In [None]:
# List of features to reverse normalize
features_reverse_normalize = [
    'KNN_IMPUTED',
    'LIN_REG_IMPUTED_EXO',
    # 'imputed_values'
]

scaling_params = final_df.group_by("DEVICE_ID").agg([
    pl.col("CUMMULATIVE_CONSUMPTION_COPY").min().alias("DAILY_CONSUMPTION_min"),
    pl.col("CUMMULATIVE_CONSUMPTION_COPY").max().alias("DAILY_CONSUMPTION_max")
])

In [None]:
scaling_params.head()

In [None]:
# Join scaling params with merged_df
merged_df = final_df.join(scaling_params, on="DEVICE_ID", how="left")

In [None]:
merged_df.head()

In [None]:
# Denormalize KNN predictions
merged_df= merged_df.with_columns(
    (
        (pl.col("KNN_IMPUTED") * 
         (pl.col("DAILY_CONSUMPTION_max") - pl.col("DAILY_CONSUMPTION_min"))) +
        pl.col("DAILY_CONSUMPTION_min")
    ).alias("DENORMALIZED_KNN_IMPUTED")
)

In [None]:
# Denormalize LR predictions
merged_df = merged_df.with_columns(
    (
        (pl.col("LIN_REG_IMPUTED_EXO") * 
         (pl.col("DAILY_CONSUMPTION_max") - pl.col("DAILY_CONSUMPTION_min"))) +
        pl.col("DAILY_CONSUMPTION_min")
    ).alias("DENORMALIZED_LIN_REG_IMPUTED")
)

In [None]:
# # Denormalize SAITS predictions
# merged_df = merged_df.with_columns(
#     (
#         (pl.col("imputed_value") * 
#          (pl.col("DAILY_CONSUMPTION_max") - pl.col("DAILY_CONSUMPTION_min"))) +
#         pl.col("DAILY_CONSUMPTION_min")
#     ).alias("DENORMALIZED_SAITS_IMPUTED")
# )

In [None]:
merged_df.head()

In [None]:
reconstructed_df = merged_df.select([
    "DATE","DEVICE_ID","OGI_LONG","OGI_LAT",
    "MEAN_IMPUTED","MEDIAN_IMPUTED","FFILL_IMPUTED","BFILL_IMPUTED","LINEAR_IMPUTED","CUBIC_IMPUTED",
    "DENORMALIZED_KNN_IMPUTED","CUMMULATIVE_CONSUMPTION_COPY","DENORMALIZED_LIN_REG_IMPUTED",
])

# "DENORMALIZED_SAITS_IMPUTED",

In [None]:
reconstructed_df.head()

### Export imputed data

In [None]:
reconstructed_df.write_csv('../exports/imputed_water_meters_v2.csv')

In [None]:
reconstructed_df['DEVICE_ID'].n_unique()