In [1]:
# Try legacy pip resolver for AzureML install
!pip install azureml-core==1.61.0 azureml-widgets==1.61.0 --use-deprecated=legacy-resolver --quiet
# If you need notebook features, also run:
# !pip install azureml-contrib-notebook==1.61.0 azureml-dataset-runtime==1.61.0 --use-deprecated=legacy-resolver --quiet


[33mDEPRECATION: pytorch-lightning 1.7.7 has a non-standard dependency specifier torch>=1.9.*. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pytorch-lightning or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
# Install AzureML core packages in smaller groups if needed
#!pip install azureml-core==1.61.0 azureml-widgets==1.61.0 --quiet
# If you need notebook features, also run:
# !pip install azureml-contrib-notebook==1.61.0 azureml-dataset-runtime==1.61.0 --quiet


In [3]:
# Add project root to sys.path for imports
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))


In [4]:
# Downgrade pip for compatibility, then install correct package versions
#!pip install pip==23.3.1 --quiet
#!pip install pytorch-forecasting==0.10.3 pytorch-lightning==1.7.7 torch==1.13.1 --quiet
# Install compatible torchmetrics version for pytorch-forecasting 0.10.3
#!pip install torchmetrics==0.10.0 --quiet
# Downgrade numpy for compatibility with pytorch-forecasting 0.10.3
#!pip install numpy==1.23.5 --quiet

## 1. Data Preparation
Generate or load synthetic oil well data.

In [5]:
import pandas as pd
from data.generate_well_data import generate_synthetic_well_data
df = generate_synthetic_well_data()
df.head()

Unnamed: 0,well_id,day,oil_rate,gas_rate,water_cut,choke_size,reservoir_pressure,weather
0,1,0,1014.901425,491.975591,0.103917,16,3995.964171,rain
1,1,1,990.372633,501.74213,0.082081,16,4055.348927,clear
2,1,2,1008.501804,494.787422,0.111462,16,3916.928967,clear
3,1,3,1029.342489,493.779319,0.070893,24,4053.68574,storm
4,1,4,971.237134,473.670672,0.127177,20,3978.264364,rain


## 2. Azure ML Experiment Tracking Setup
Configure Azure ML workspace and experiment.

In [6]:
'''
from azureml.core import Workspace, Experiment
ws = Workspace.from_config()
experiment = Experiment(ws, 'oil-production-forecasting')
'''


"\nfrom azureml.core import Workspace, Experiment\nws = Workspace.from_config()\nexperiment = Experiment(ws, 'oil-production-forecasting')\n"

## 3. Model Training: Temporal Fusion Transformer
Train TFT on the synthetic dataset.

In [7]:
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
import pytorch_lightning as pl
import torch
import numpy as np
# Ensure well_id and choke_size are string for categorical encoding
df["well_id"] = df["well_id"].astype(str)
df["choke_size"] = df["choke_size"].astype(str)
# Check and drop missing and infinite values
print("Missing values per column before drop:")
print(df.isnull().sum())
df = df.replace([np.inf, -np.inf], np.nan)
df = df.dropna().reset_index(drop=True)
print("Missing values per column after drop:")
print(df.isnull().sum())
print("Any infinite values:", (~np.isfinite(df.select_dtypes(include=[float, int]))).any())
# --- Filter out groups (wells) with insufficient history for sequence length ---
min_encoder_length = 60
min_prediction_length = 7
min_total_length = min_encoder_length + min_prediction_length
group_sizes = df.groupby('well_id').size()
print("Group sizes (rows per well_id):")
print(group_sizes.describe())
print(group_sizes.sort_values())
sufficient_history_wells = group_sizes[group_sizes >= min_total_length].index
print(f"Number of wells with sufficient history: {len(sufficient_history_wells)} / {len(group_sizes)}")
df = df[df['well_id'].isin(sufficient_history_wells)].reset_index(drop=True)
print("Shape after filtering short groups:", df.shape)
# Prepare dataset
dataset = TimeSeriesDataSet(
    df,
    time_idx='day',
    target='oil_rate',
    group_ids=['well_id'],
    min_encoder_length=min_encoder_length,
    max_encoder_length=min_encoder_length,
    min_prediction_length=min_prediction_length,
    max_prediction_length=min_prediction_length,
    static_categoricals=['well_id'],
    time_varying_known_categoricals=['weather', 'choke_size'],
    time_varying_known_reals=['day', 'reservoir_pressure'],
    time_varying_unknown_reals=['oil_rate', 'gas_rate', 'water_cut'],
    target_normalizer=GroupNormalizer(groups=['well_id']),
)
# Use to_dataloader() to avoid DataLoader edge cases
train_dataloader = dataset.to_dataloader(train=True, batch_size=32, shuffle=True, drop_last=True)
# Test DataLoader batches for None
for i, batch in enumerate(train_dataloader):
    if batch is None:
        print(f"Batch {i} is None!")
    else:
        print(f"Batch {i} type: {type(batch)}")
    if i >= 2:
        break
model = TemporalFusionTransformer.from_dataset(dataset)
# Use 'accelerator' and 'devices' instead of deprecated 'gpus' argument
trainer = pl.Trainer(max_epochs=10, accelerator='gpu' if torch.cuda.is_available() else 'cpu', devices=1 if torch.cuda.is_available() else 1)
trainer.fit(model, train_dataloader)


  import pkg_resources
Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


Missing values per column before drop:
well_id               0
day                   0
oil_rate              0
gas_rate              0
water_cut             0
choke_size            0
reservoir_pressure    0
weather               0
dtype: int64
Missing values per column after drop:
well_id               0
day                   0
oil_rate              0
gas_rate              0
water_cut             0
choke_size            0
reservoir_pressure    0
weather               0
dtype: int64
Any infinite values: day                   False
oil_rate              False
gas_rate              False
water_cut             False
reservoir_pressure    False
dtype: bool
Group sizes (rows per well_id):
count     10.0
mean     365.0
std        0.0
min      365.0
25%      365.0
50%      365.0
75%      365.0
max      365.0
dtype: float64
well_id
1     365
10    365
2     365
3     365
4     365
5     365
6     365
7     365
8     365
9     365
dtype: int64
Number of wells with sufficient history: 10 / 10
Sha

  rank_zero_warn(
  rank_zero_warn(
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 81    
3  | prescalers                         | ModuleDict                      | 80    
4  | static_variable_selection          | VariableSelectionNetwork        | 48    
5  | encoder_variable_selection         | VariableSelectionNetwork        | 3.4 K 
6  | decoder_variable_selection         | VariableSelectionNetwork        | 1.4 K 
7  | static_context_va

Training: 0it [00:00, ?it/s]

  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=10` reached.


In [8]:
# --- Check for missing days in each well_id group ---
def check_missing_days(df):
    missing = {}
    for well, group in df.groupby('well_id'):
        expected = set(range(group['day'].min(), group['day'].max() + 1))
        actual = set(group['day'])
        missing_days = expected - actual
        if missing_days:
            missing[well] = sorted(missing_days)
    return missing

missing_days = check_missing_days(df)
print(f"Number of wells with missing days: {len(missing_days)}")
if missing_days:
    for well, days in list(missing_days.items())[:3]:
        print(f"well_id {well} missing days: {days[:10]}{'...' if len(days) > 10 else ''}")


Number of wells with missing days: 0


In [9]:
# --- Check dtypes and sample dataset items ---
print("Column dtypes:")
print(df.dtypes)

print("\nSample values for categorical columns:")
print("well_id:", df['well_id'].unique()[:3])
print("choke_size:", df['choke_size'].unique()[:3])
print("weather:", df['weather'].unique()[:3])

print("\nFirst 5 items from TimeSeriesDataSet:")
for i in range(5):
    item = dataset[i]
    print(f"Item {i}: type={type(item)}, is None={item is None}")
    if item is None:
        print("Found None item at index", i)


Column dtypes:
well_id                object
day                     int64
oil_rate              float64
gas_rate              float64
water_cut             float64
choke_size             object
reservoir_pressure    float64
weather                object
dtype: object

Sample values for categorical columns:
well_id: ['1' '2' '3']
choke_size: ['16' '24' '20']
weather: ['rain' 'clear' 'storm']

First 5 items from TimeSeriesDataSet:
Item 0: type=<class 'tuple'>, is None=False
Item 1: type=<class 'tuple'>, is None=False
Item 2: type=<class 'tuple'>, is None=False
Item 3: type=<class 'tuple'>, is None=False
Item 4: type=<class 'tuple'>, is None=False


In [10]:
# --- Diagnostics: Check for duplicate (well_id, day) and monotonicity ---
duplicates = df.duplicated(subset=["well_id", "day"]).sum()
print(f"Number of duplicate (well_id, day) pairs: {duplicates}")
if duplicates > 0:
    print(df[df.duplicated(subset=["well_id", "day"], keep=False)].sort_values(["well_id", "day"]))

# Check monotonicity of 'day' within each well
def is_monotonic(group):
    return group["day"].is_monotonic_increasing
monotonic = df.groupby("well_id").apply(is_monotonic)
print(f"Number of wells with non-monotonic 'day': {(~monotonic).sum()}")
if (~monotonic).sum() > 0:
    print(monotonic[~monotonic])

# Print minimum group size after filtering
group_sizes = df.groupby('well_id').size()
print(f"Minimum group size after filtering: {group_sizes.min()}")


Number of duplicate (well_id, day) pairs: 0
Number of wells with non-monotonic 'day': 0
Minimum group size after filtering: 365


## 4. Log Experiments to Azure ML
Log parameters, metrics, and model artifacts.

In [None]:
# Local experiment tracking (no AzureML required)
import json
import os

val_metrics = trainer.callback_metrics if 'trainer' in locals() else {}

# Save metrics to a local JSON file
metrics_path = os.path.join(os.getcwd(), 'local_experiment_metrics.json')
with open(metrics_path, 'w') as f:
    json.dump({
        'mae': float(val_metrics.get('val_mae', 0)),
        'rmse': float(val_metrics.get('val_rmse', 0))
    }, f, indent=2)
print(f"Metrics saved to {metrics_path}")

# Save model weights using PyTorch
import torch
# Save only the model weights (state_dict)
torch.save(model.state_dict(), 'tft_model.pth')
print("Model weights saved as tft_model.pth")

Metrics saved to /Users/justin/energy-ai-azure/notebooks/local_experiment_metrics.json
Model checkpoint saved as tft_model.ckpt


## 5. Model Card
See [MODEL_CARD.md](../models/production_forecasting/MODEL_CARD.md) for details on intended use, data, evaluation, and monitoring.

## 6. Deploy Model to Azure ML Endpoint
Deploy the trained TFT model as a real-time endpoint.

In [12]:
# See deployment script: models/production_forecasting/deploy_azureml.py
# !python ../../models/production_forecasting/deploy_azureml.py

## 7. Drift Detection Setup
Monitor feature distributions for drift using KS test.

In [13]:
from deployment.drift_detection import detect_drift
reference = df[df['day'] < 335]  # First 335 days
current = df[df['day'] >= 335]  # Last 30 days
for feature in ['oil_rate', 'gas_rate', 'water_cut']:
    drift = detect_drift(reference, current, feature)
    print(f'Drift in {feature}: {drift}')

Drift in oil_rate: True
Drift in gas_rate: True
Drift in water_cut: True
