In [106]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn import decomposition
from sklearn.manifold import TSNE
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import KFold
from scipy.stats import zscore
import pingouin as pg
import os 
from sklearn.model_selection import train_test_split

In [107]:
folder_path = '/Users/vappaji/Documents/Capstone/dasked/negative_complete'
csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')]

neg_comp = pd.concat(
    [pd.read_csv(os.path.join(folder_path, file)) for file in csv_files],
    ignore_index=True
)

In [108]:
def rename_columns(df):
    def transform_column_name(col_name):
        if "_level_" in col_name:
            parts = col_name.split("_level_")
            if len(parts) == 2 and parts[1].isdigit():
                var, n = parts
                return f"{float(n)}-{var}"
        return col_name

    df.columns = [transform_column_name(col) for col in df.columns]
    return df

In [109]:
rename_columns(neg_comp)

Unnamed: 0,_uid_,id,longitude,latitude,initialdate,days,year,month,1000.0-air,975.0-air,...,350.0-shum,300.0-shum,275.0-shum,250.0-shum,225.0-shum,200.0-shum,175.0-shum,150.0-shum,125.0-shum,100.0-shum
0,911487.0,14546434.0,-1.829501e+06,2.793841e+06,2013-01-02,0,2013,1,294.643036,292.469025,...,,,,,,,,,,
1,911774.0,14546481.0,-2.170116e+06,2.658773e+06,2013-01-07,0,2013,1,294.640686,292.790039,...,,,,,,,,,,
2,946008.0,14546352.0,-1.331762e+06,1.275915e+06,2013-01-08,4,2013,1,297.110016,294.878967,...,,,,,,,,,,
3,945832.0,14546500.0,-1.009945e+06,1.745870e+06,2013-01-06,2,2013,1,296.712036,294.741302,...,,,,,,,,,,
4,945776.0,14546487.0,-2.097758e+06,2.615743e+06,2013-01-17,10,2013,1,295.502563,293.332947,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
167056,1015403.0,2028453.0,-9.699338e+05,2.668803e+06,2002-10-19,10,2002,10,297.223389,295.136230,...,0.000999,0.000452,0.000340,0.000207,0.000111,0.000062,0.000031,0.000013,0.000005,0.000005
167057,1015387.0,2028460.0,-8.675489e+05,2.509888e+06,2002-10-29,9,2002,10,297.188965,295.116150,...,0.000502,0.000263,0.000191,0.000140,0.000100,0.000059,0.000030,0.000012,0.000007,0.000006
167058,1015387.0,2028460.0,-8.992530e+05,2.508546e+06,2002-10-24,9,2002,10,294.148193,292.283539,...,0.000574,0.000300,0.000175,0.000112,0.000090,0.000062,0.000035,0.000011,0.000005,0.000006
167059,1015387.0,2028460.0,-8.911107e+05,2.483461e+06,2002-10-28,9,2002,10,297.547546,295.347107,...,0.000583,0.000258,0.000251,0.000219,0.000131,0.000064,0.000026,0.000014,0.000006,0.000005


In [110]:
# fire vertical profiles 
folder_path = '/Users/vappaji/Documents/Capstone/dasked/fires'
csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')]

fires = pd.concat(
    [pd.read_csv(os.path.join(folder_path, file)) for file in csv_files],
    ignore_index=True
)
rename_columns(fires)

Unnamed: 0,_uid_,id,longitude,latitude,initialdate,days,year,month,1000.0-air,975.0-air,...,350.0-vwnd,300.0-vwnd,275.0-vwnd,250.0-vwnd,225.0-vwnd,200.0-vwnd,175.0-vwnd,150.0-vwnd,125.0-vwnd,100.0-vwnd
0,962781.0,2028104.0,-2.170446e+06,2.677716e+06,2002-12-01,2,2002,12,296.131104,293.962830,...,1.586227,-0.632767,0.159988,2.491455,1.893295,-0.658127,-3.132828,-5.439087,-5.666870,-4.197739
1,923530.0,1948424.0,-2.179160e+06,2.714006e+06,2002-12-01,4,2002,12,295.974854,293.837830,...,1.242477,-0.164017,0.769363,2.725830,2.362045,-0.048752,-2.617203,-5.220337,-5.323120,-3.822739
2,923381.0,1948349.0,-2.078244e+06,3.151850e+06,2002-12-01,3,2002,12,296.646729,294.525330,...,-4.070023,-0.898392,1.472488,2.960205,3.002670,1.451248,-0.335953,-3.532837,-3.979370,-1.353989
3,962781.0,2028104.0,-2.170446e+06,2.677716e+06,2002-12-02,2,2002,12,295.735413,293.567657,...,-1.312195,0.808426,0.911270,2.838562,3.909058,4.740891,4.695358,2.106171,-0.095612,2.609711
4,923530.0,1948424.0,-2.179160e+06,2.714006e+06,2002-12-02,4,2002,12,295.688538,293.520782,...,-1.374695,0.902176,0.817520,2.432312,3.268433,4.115891,4.257858,1.668671,-0.408112,2.359711
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
167056,205533.0,14926058.0,-1.121581e+06,2.836049e+06,2013-03-27,4,2013,3,295.519104,293.545898,...,9.950073,13.267899,13.905899,13.356018,12.767487,13.590591,14.012405,9.711319,1.913452,-0.567387
167057,205534.0,14926059.0,-1.164551e+06,2.836001e+06,2013-03-28,5,2013,3,293.917847,292.141418,...,7.532578,9.355545,11.197037,12.581055,12.312088,12.084091,13.940674,13.135620,9.329224,5.347836
167058,204287.0,14923501.0,-1.265955e+06,2.290768e+06,2013-03-28,4,2013,3,295.324097,293.110168,...,5.188828,6.464920,8.431412,10.987305,14.749588,15.037216,10.362549,4.494995,0.376099,0.582211
167059,205533.0,14926058.0,-1.121581e+06,2.836049e+06,2013-03-28,4,2013,3,293.917847,292.141418,...,7.532578,9.355545,11.197037,12.581055,12.312088,12.084091,13.940674,13.135620,9.329224,5.347836


In [111]:
# vertical profiles at lag-2 of fire events 

folder_path = '/Users/vappaji/Documents/Capstone/dasked/T-4'
csv_files = [f for f in os.listdir(folder_path) if f.endswith('.csv')]

T4 = pd.concat(
    [pd.read_csv(os.path.join(folder_path, file)) for file in csv_files],
    ignore_index=True
)


In [112]:
rename_columns(T4)

Unnamed: 0,_uid_,id,longitude,latitude,initialdate,days,year,month,1000.0-air,975.0-air,...,350.0-vwnd,300.0-vwnd,275.0-vwnd,250.0-vwnd,225.0-vwnd,200.0-vwnd,175.0-vwnd,150.0-vwnd,125.0-vwnd,100.0-vwnd
0,891945.0,16877687.0,-2.052250e+06,2.415330e+06,2014-10-02,4,2014,10,298.292480,296.348145,...,6.110306,7.626801,9.322876,10.173004,7.952545,5.158890,5.785675,0.088181,-2.049866,0.239731
1,891890.0,16877577.0,-1.463973e+06,1.405232e+06,2014-10-06,2,2014,10,299.104553,297.064453,...,6.189728,2.523926,3.669907,3.645828,1.647827,0.066162,-0.031174,1.522186,-2.480713,1.755081
2,836244.0,16754643.0,-1.461894e+06,1.404166e+06,2014-10-02,8,2014,10,299.542480,297.363770,...,8.063431,5.314301,5.729126,5.219879,3.796295,1.862015,0.285675,-7.255569,0.059509,5.161606
3,891941.0,16877683.0,-1.670926e+06,2.831041e+06,2014-10-02,0,2014,10,297.276855,295.301270,...,5.782181,10.658051,13.776001,14.813629,15.390045,12.674515,8.582550,1.931931,-0.690491,-0.900894
4,926007.0,16877682.0,-1.857254e+06,2.920378e+06,2014-10-06,5,2014,10,298.338928,296.548828,...,3.986603,7.336426,11.013657,13.739578,13.069702,12.456787,9.156326,7.600311,5.925537,1.426956
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
167056,1002116.0,16953505.0,-1.706563e+06,1.253436e+06,2014-09-21,11,2014,9,298.787109,296.741943,...,0.504166,-0.201813,0.613770,-0.730286,-2.202560,-2.434326,0.402359,3.550522,0.480240,3.930145
167057,1002117.0,16953506.0,-1.708001e+06,1.251899e+06,2014-09-26,0,2014,9,299.467041,297.241211,...,3.443390,-1.643951,-8.192886,-12.166809,-8.798096,-6.198624,-10.033310,-14.064423,-6.526245,2.225998
167058,1002152.0,16953779.0,-7.099453e+05,1.270970e+06,2014-09-23,2,2014,9,299.325439,297.101807,...,-3.188080,-0.578354,-1.378342,-2.873901,-4.023285,-7.202957,-5.992508,-5.131149,-3.476746,-2.029449
167059,1002118.0,16953509.0,-1.755894e+06,1.191284e+06,2014-09-23,8,2014,9,299.466064,297.258057,...,-3.250580,-1.015854,-1.956467,-2.592651,-4.523285,-7.796707,-4.851883,-5.006149,-3.711121,-2.091949


In [113]:
#combine non-fire condition profiles (for balanced class dist)

all_neg = pd.concat([neg_comp, T4], axis=0, ignore_index=True)
all_neg

Unnamed: 0,_uid_,id,longitude,latitude,initialdate,days,year,month,1000.0-air,975.0-air,...,350.0-shum,300.0-shum,275.0-shum,250.0-shum,225.0-shum,200.0-shum,175.0-shum,150.0-shum,125.0-shum,100.0-shum
0,911487.0,14546434.0,-1.829501e+06,2.793841e+06,2013-01-02,0,2013,1,294.643036,292.469025,...,,,,,,,,,,
1,911774.0,14546481.0,-2.170116e+06,2.658773e+06,2013-01-07,0,2013,1,294.640686,292.790039,...,,,,,,,,,,
2,946008.0,14546352.0,-1.331762e+06,1.275915e+06,2013-01-08,4,2013,1,297.110016,294.878967,...,,,,,,,,,,
3,945832.0,14546500.0,-1.009945e+06,1.745870e+06,2013-01-06,2,2013,1,296.712036,294.741302,...,,,,,,,,,,
4,945776.0,14546487.0,-2.097758e+06,2.615743e+06,2013-01-17,10,2013,1,295.502563,293.332947,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
334117,1002116.0,16953505.0,-1.706563e+06,1.253436e+06,2014-09-21,11,2014,9,298.787109,296.741943,...,,,,,,,,,,
334118,1002117.0,16953506.0,-1.708001e+06,1.251899e+06,2014-09-26,0,2014,9,299.467041,297.241211,...,,,,,,,,,,
334119,1002152.0,16953779.0,-7.099453e+05,1.270970e+06,2014-09-23,2,2014,9,299.325439,297.101807,...,,,,,,,,,,
334120,1002118.0,16953509.0,-1.755894e+06,1.191284e+06,2014-09-23,8,2014,9,299.466064,297.258057,...,,,,,,,,,,


In [114]:
# add class labels 
T4['fire_next_period'] = 1
neg_comp = neg_comp.assign(fire_next_period=0)

combined = pd.concat([T4, neg_comp], ignore_index=True)
combined

Unnamed: 0,_uid_,id,longitude,latitude,initialdate,days,year,month,1000.0-air,975.0-air,...,300.0-vwnd,275.0-vwnd,250.0-vwnd,225.0-vwnd,200.0-vwnd,175.0-vwnd,150.0-vwnd,125.0-vwnd,100.0-vwnd,fire_next_period
0,891945.0,16877687.0,-2.052250e+06,2.415330e+06,2014-10-02,4,2014,10,298.292480,296.348145,...,7.626801,9.322876,10.173004,7.952545,5.158890,5.785675,0.088181,-2.049866,0.239731,1
1,891890.0,16877577.0,-1.463973e+06,1.405232e+06,2014-10-06,2,2014,10,299.104553,297.064453,...,2.523926,3.669907,3.645828,1.647827,0.066162,-0.031174,1.522186,-2.480713,1.755081,1
2,836244.0,16754643.0,-1.461894e+06,1.404166e+06,2014-10-02,8,2014,10,299.542480,297.363770,...,5.314301,5.729126,5.219879,3.796295,1.862015,0.285675,-7.255569,0.059509,5.161606,1
3,891941.0,16877683.0,-1.670926e+06,2.831041e+06,2014-10-02,0,2014,10,297.276855,295.301270,...,10.658051,13.776001,14.813629,15.390045,12.674515,8.582550,1.931931,-0.690491,-0.900894,1
4,926007.0,16877682.0,-1.857254e+06,2.920378e+06,2014-10-06,5,2014,10,298.338928,296.548828,...,7.336426,11.013657,13.739578,13.069702,12.456787,9.156326,7.600311,5.925537,1.426956,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
334117,1015403.0,2028453.0,-9.699338e+05,2.668803e+06,2002-10-19,10,2002,10,297.223389,295.136230,...,0.809494,-0.369461,-1.643112,-2.768219,-2.061630,-1.590317,-4.362610,-4.742157,-4.796143,0
334118,1015387.0,2028460.0,-8.675489e+05,2.509888e+06,2002-10-29,9,2002,10,297.188965,295.116150,...,3.725388,3.678986,3.594803,2.755310,5.925201,5.278214,3.404724,5.478104,6.952393,0
334119,1015387.0,2028460.0,-8.992530e+05,2.508546e+06,2002-10-24,9,2002,10,294.148193,292.283539,...,6.585938,9.110397,13.467072,17.133102,20.856232,22.029144,21.463440,16.519577,7.349997,0
334120,1015387.0,2028460.0,-8.911107e+05,2.483461e+06,2002-10-28,9,2002,10,297.547546,295.347107,...,8.268936,6.632019,3.199814,8.435440,9.696411,9.954269,8.207336,2.081482,4.140839,0


In [115]:
combined = combined.drop(columns=['_uid_','id','initialdate','days'])

### Baseline logistic regression with one lag

In [117]:

X = combined.drop(['fire_next_period'], axis=1)
y = combined['fire_next_period']

# dev-test split
combined_X_dev, combined_X_test, combined_y_dev, combined_y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# training-val split 
combined_X_train, combined_X_val, combined_y_train, combined_y_val = train_test_split(
    combined_X_dev, combined_y_dev, test_size=0.25, random_state=42 
)


In [118]:
#consider alternate imputation options, for now, mean (bc others take too long)  
from sklearn.impute import SimpleImputer

imputer = SimpleImputer(strategy='mean')  
combined_X_train = pd.DataFrame(imputer.fit_transform(combined_X_train), columns=combined_X_train.columns)
combined_X_val = pd.DataFrame(imputer.transform(combined_X_val), columns=combined_X_val.columns)
combined_X_test = pd.DataFrame(imputer.transform(combined_X_test), columns=combined_X_test.columns)

In [119]:
# normalize vertical profiles 
vertical_profile_cols = [col for col in combined.columns if any(x in col for x in ['-air', '-hgt', '-omega', '-shum', '-tke', '-uwnd', '-vwnd'])]
scaler = StandardScaler()
combined_X_train[vertical_profile_cols] = scaler.fit_transform(combined_X_train[vertical_profile_cols])
combined_X_val[vertical_profile_cols] = scaler.transform(combined_X_val[vertical_profile_cols])
combined_X_test[vertical_profile_cols] = scaler.transform(combined_X_test[vertical_profile_cols])


In [120]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, roc_auc_score

model = LogisticRegression(random_state=42, max_iter=1000, class_weight='balanced')  # Increase max_iter if convergence issues
model.fit(combined_X_train, combined_y_train)

# Predict on the validation set
val_predictions = model.predict(combined_X_val)
val_probs = model.predict_proba(combined_X_val)[:, 1]

# Evaluate on validation set
print("Validation Set Performance:")
print(classification_report(combined_y_val, val_predictions))
print(f"Validation ROC-AUC: {roc_auc_score(combined_y_val, val_probs)}")

# Predict on the test set
test_predictions = model.predict(combined_X_test)
test_probs = model.predict_proba(combined_X_test)[:, 1]

# Evaluate on test set
print("\nTest Set Performance:")
print(classification_report(combined_y_test, test_predictions))
print(f"Test ROC-AUC: {roc_auc_score(combined_y_test, test_probs)}")

Validation Set Performance:
              precision    recall  f1-score   support

           0       0.68      0.62      0.65     33326
           1       0.65      0.71      0.68     33499

    accuracy                           0.66     66825
   macro avg       0.67      0.66      0.66     66825
weighted avg       0.67      0.66      0.66     66825

Validation ROC-AUC: 0.7058338042883121

Test Set Performance:
              precision    recall  f1-score   support

           0       0.68      0.62      0.65     33339
           1       0.65      0.71      0.68     33486

    accuracy                           0.66     66825
   macro avg       0.67      0.66      0.66     66825
weighted avg       0.67      0.66      0.66     66825

Test ROC-AUC: 0.7045801998644999


In [104]:
rename_columns(T2)
print(T2.describe())


              _uid_            id     longitude      latitude           days  \
count  1.670610e+05  1.670610e+05  1.670610e+05  1.670610e+05  167061.000000   
mean   5.444790e+05  1.422087e+07 -1.568003e+06  2.276302e+06       8.091170   
std    2.583259e+05  7.774883e+06  4.458747e+05  5.244969e+05       9.901401   
min    1.000000e+00  9.668550e+05 -2.347082e+06  1.015877e+06       0.000000   
25%    3.497530e+05  7.044573e+06 -1.946325e+06  1.913814e+06       2.000000   
50%    5.560590e+05  1.419805e+07 -1.569169e+06  2.388195e+06       5.000000   
75%    7.467020e+05  2.115142e+07 -1.299289e+06  2.730666e+06      10.000000   
max    1.224020e+06  2.708806e+07 -4.462469e+05  3.209308e+06      85.000000   

                year         month     1000.0-air      975.0-air  \
count  167061.000000  167061.00000  167061.000000  167061.000000   
mean     2012.230461       6.07277     296.542004     294.583553   
std         6.478014       2.18580       1.229597       1.190302   
min    

In [102]:
print(T4.describe())


              _uid_            id     longitude      latitude           days  \
count  1.670610e+05  1.670610e+05  1.670610e+05  1.670610e+05  167061.000000   
mean   5.444790e+05  1.422087e+07 -1.568003e+06  2.276302e+06       8.091170   
std    2.583259e+05  7.774883e+06  4.458747e+05  5.244969e+05       9.901401   
min    1.000000e+00  9.668550e+05 -2.347082e+06  1.015877e+06       0.000000   
25%    3.497530e+05  7.044573e+06 -1.946325e+06  1.913814e+06       2.000000   
50%    5.560590e+05  1.419805e+07 -1.569169e+06  2.388195e+06       5.000000   
75%    7.467020e+05  2.115142e+07 -1.299289e+06  2.730666e+06      10.000000   
max    1.224020e+06  2.708806e+07 -4.462469e+05  3.209308e+06      85.000000   

                year         month     1000.0-air      975.0-air  \
count  167061.000000  167061.00000  167061.000000  167061.000000   
mean     2012.230461       6.07277     296.542004     294.583553   
std         6.478014       2.18580       1.229597       1.190302   
min    

In [76]:
T1['lag'] = 1
T1['fire_next_period'] = 1
T2['fire_next_period'] = 1
T2['lag'] = 2
T4['fire_next_period'] = 1
T4['lag'] = 4

In [78]:
temp = pd.concat([T1, T2, T4], ignore_index=True)

In [80]:
temp

Unnamed: 0,_uid_,id,longitude,latitude,initialdate,days,year,month,1000.0-air,975.0-air,...,275.0-vwnd,250.0-vwnd,225.0-vwnd,200.0-vwnd,175.0-vwnd,150.0-vwnd,125.0-vwnd,100.0-vwnd,lag,fire_next_period
0,891945.0,16877687.0,-2.052250e+06,2.415330e+06,2014-10-02,4,2014,10,298.292480,296.348145,...,9.322876,10.173004,7.952545,5.158890,5.785675,0.088181,-2.049866,0.239731,1,1
1,891890.0,16877577.0,-1.463973e+06,1.405232e+06,2014-10-06,2,2014,10,299.104553,297.064453,...,3.669907,3.645828,1.647827,0.066162,-0.031174,1.522186,-2.480713,1.755081,1,1
2,836244.0,16754643.0,-1.461894e+06,1.404166e+06,2014-10-02,8,2014,10,299.542480,297.363770,...,5.729126,5.219879,3.796295,1.862015,0.285675,-7.255569,0.059509,5.161606,1,1
3,891941.0,16877683.0,-1.670926e+06,2.831041e+06,2014-10-02,0,2014,10,297.276855,295.301270,...,13.776001,14.813629,15.390045,12.674515,8.582550,1.931931,-0.690491,-0.900894,1,1
4,926007.0,16877682.0,-1.857254e+06,2.920378e+06,2014-10-06,5,2014,10,298.338928,296.548828,...,11.013657,13.739578,13.069702,12.456787,9.156326,7.600311,5.925537,1.426956,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
501178,1002116.0,16953505.0,-1.706563e+06,1.253436e+06,2014-09-21,11,2014,9,298.787109,296.741943,...,0.613770,-0.730286,-2.202560,-2.434326,0.402359,3.550522,0.480240,3.930145,4,1
501179,1002117.0,16953506.0,-1.708001e+06,1.251899e+06,2014-09-26,0,2014,9,299.467041,297.241211,...,-8.192886,-12.166809,-8.798096,-6.198624,-10.033310,-14.064423,-6.526245,2.225998,4,1
501180,1002152.0,16953779.0,-7.099453e+05,1.270970e+06,2014-09-23,2,2014,9,299.325439,297.101807,...,-1.378342,-2.873901,-4.023285,-7.202957,-5.992508,-5.131149,-3.476746,-2.029449,4,1
501181,1002118.0,16953509.0,-1.755894e+06,1.191284e+06,2014-09-23,8,2014,9,299.466064,297.258057,...,-1.956467,-2.592651,-4.523285,-7.796707,-4.851883,-5.006149,-3.711121,-2.091949,4,1


In [55]:
#duplicate rows for control lags 

def duplicate_non_fire_lags(non_fire_df, lags):
    lagged_data = []
    for lag in lags:
        temp = non_fire_df.copy()
        temp['lag'] = lag
        lagged_data.append(temp)
    return pd.concat(lagged_data, ignore_index=True)

non_fire_lagged = duplicate_non_fire_lags(neg_comp, lags=[1, 2, 4])
non_fire_lagged['fire_next_period'] = 0

In [57]:
non_fire_lagged

Unnamed: 0,_uid_,id,longitude,latitude,initialdate,days,year,month,1000.0-air,975.0-air,...,275.0-shum,250.0-shum,225.0-shum,200.0-shum,175.0-shum,150.0-shum,125.0-shum,100.0-shum,fire_next_period,lag
0,911487.0,14546434.0,-1.829501e+06,2.793841e+06,2013-01-02,0,2013,1,294.643036,292.469025,...,,,,,,,,,0,1
1,911774.0,14546481.0,-2.170116e+06,2.658773e+06,2013-01-07,0,2013,1,294.640686,292.790039,...,,,,,,,,,0,1
2,946008.0,14546352.0,-1.331762e+06,1.275915e+06,2013-01-08,4,2013,1,297.110016,294.878967,...,,,,,,,,,0,1
3,945832.0,14546500.0,-1.009945e+06,1.745870e+06,2013-01-06,2,2013,1,296.712036,294.741302,...,,,,,,,,,0,1
4,945776.0,14546487.0,-2.097758e+06,2.615743e+06,2013-01-17,10,2013,1,295.502563,293.332947,...,,,,,,,,,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
501178,1015403.0,2028453.0,-9.699338e+05,2.668803e+06,2002-10-19,10,2002,10,297.223389,295.136230,...,0.000340,0.000207,0.000111,0.000062,0.000031,0.000013,0.000005,0.000005,0,4
501179,1015387.0,2028460.0,-8.675489e+05,2.509888e+06,2002-10-29,9,2002,10,297.188965,295.116150,...,0.000191,0.000140,0.000100,0.000059,0.000030,0.000012,0.000007,0.000006,0,4
501180,1015387.0,2028460.0,-8.992530e+05,2.508546e+06,2002-10-24,9,2002,10,294.148193,292.283539,...,0.000175,0.000112,0.000090,0.000062,0.000035,0.000011,0.000005,0.000006,0,4
501181,1015387.0,2028460.0,-8.911107e+05,2.483461e+06,2002-10-28,9,2002,10,297.547546,295.347107,...,0.000251,0.000219,0.000131,0.000064,0.000026,0.000014,0.000006,0.000005,0,4


In [82]:
combined_lags =  pd.concat([temp, non_fire_lagged], ignore_index=True)
combined_lags

Unnamed: 0,_uid_,id,longitude,latitude,initialdate,days,year,month,1000.0-air,975.0-air,...,275.0-vwnd,250.0-vwnd,225.0-vwnd,200.0-vwnd,175.0-vwnd,150.0-vwnd,125.0-vwnd,100.0-vwnd,lag,fire_next_period
0,891945.0,16877687.0,-2.052250e+06,2.415330e+06,2014-10-02,4,2014,10,298.292480,296.348145,...,9.322876,10.173004,7.952545,5.158890,5.785675,0.088181,-2.049866,0.239731,1,1
1,891890.0,16877577.0,-1.463973e+06,1.405232e+06,2014-10-06,2,2014,10,299.104553,297.064453,...,3.669907,3.645828,1.647827,0.066162,-0.031174,1.522186,-2.480713,1.755081,1,1
2,836244.0,16754643.0,-1.461894e+06,1.404166e+06,2014-10-02,8,2014,10,299.542480,297.363770,...,5.729126,5.219879,3.796295,1.862015,0.285675,-7.255569,0.059509,5.161606,1,1
3,891941.0,16877683.0,-1.670926e+06,2.831041e+06,2014-10-02,0,2014,10,297.276855,295.301270,...,13.776001,14.813629,15.390045,12.674515,8.582550,1.931931,-0.690491,-0.900894,1,1
4,926007.0,16877682.0,-1.857254e+06,2.920378e+06,2014-10-06,5,2014,10,298.338928,296.548828,...,11.013657,13.739578,13.069702,12.456787,9.156326,7.600311,5.925537,1.426956,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1002361,1015403.0,2028453.0,-9.699338e+05,2.668803e+06,2002-10-19,10,2002,10,297.223389,295.136230,...,-0.369461,-1.643112,-2.768219,-2.061630,-1.590317,-4.362610,-4.742157,-4.796143,4,0
1002362,1015387.0,2028460.0,-8.675489e+05,2.509888e+06,2002-10-29,9,2002,10,297.188965,295.116150,...,3.678986,3.594803,2.755310,5.925201,5.278214,3.404724,5.478104,6.952393,4,0
1002363,1015387.0,2028460.0,-8.992530e+05,2.508546e+06,2002-10-24,9,2002,10,294.148193,292.283539,...,9.110397,13.467072,17.133102,20.856232,22.029144,21.463440,16.519577,7.349997,4,0
1002364,1015387.0,2028460.0,-8.911107e+05,2.483461e+06,2002-10-28,9,2002,10,297.547546,295.347107,...,6.632019,3.199814,8.435440,9.696411,9.954269,8.207336,2.081482,4.140839,4,0


In [84]:
combined_lags = pd.get_dummies(combined_lags, columns=['lag'], drop_first=True)


In [86]:
combined_lags

Unnamed: 0,_uid_,id,longitude,latitude,initialdate,days,year,month,1000.0-air,975.0-air,...,250.0-vwnd,225.0-vwnd,200.0-vwnd,175.0-vwnd,150.0-vwnd,125.0-vwnd,100.0-vwnd,fire_next_period,lag_2,lag_4
0,891945.0,16877687.0,-2.052250e+06,2.415330e+06,2014-10-02,4,2014,10,298.292480,296.348145,...,10.173004,7.952545,5.158890,5.785675,0.088181,-2.049866,0.239731,1,False,False
1,891890.0,16877577.0,-1.463973e+06,1.405232e+06,2014-10-06,2,2014,10,299.104553,297.064453,...,3.645828,1.647827,0.066162,-0.031174,1.522186,-2.480713,1.755081,1,False,False
2,836244.0,16754643.0,-1.461894e+06,1.404166e+06,2014-10-02,8,2014,10,299.542480,297.363770,...,5.219879,3.796295,1.862015,0.285675,-7.255569,0.059509,5.161606,1,False,False
3,891941.0,16877683.0,-1.670926e+06,2.831041e+06,2014-10-02,0,2014,10,297.276855,295.301270,...,14.813629,15.390045,12.674515,8.582550,1.931931,-0.690491,-0.900894,1,False,False
4,926007.0,16877682.0,-1.857254e+06,2.920378e+06,2014-10-06,5,2014,10,298.338928,296.548828,...,13.739578,13.069702,12.456787,9.156326,7.600311,5.925537,1.426956,1,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1002361,1015403.0,2028453.0,-9.699338e+05,2.668803e+06,2002-10-19,10,2002,10,297.223389,295.136230,...,-1.643112,-2.768219,-2.061630,-1.590317,-4.362610,-4.742157,-4.796143,0,False,True
1002362,1015387.0,2028460.0,-8.675489e+05,2.509888e+06,2002-10-29,9,2002,10,297.188965,295.116150,...,3.594803,2.755310,5.925201,5.278214,3.404724,5.478104,6.952393,0,False,True
1002363,1015387.0,2028460.0,-8.992530e+05,2.508546e+06,2002-10-24,9,2002,10,294.148193,292.283539,...,13.467072,17.133102,20.856232,22.029144,21.463440,16.519577,7.349997,0,False,True
1002364,1015387.0,2028460.0,-8.911107e+05,2.483461e+06,2002-10-28,9,2002,10,297.547546,295.347107,...,3.199814,8.435440,9.696411,9.954269,8.207336,2.081482,4.140839,0,False,True


In [90]:
from sklearn.model_selection import GroupShuffleSplit

# Define features (X) and target (y)
X = combined_lags.drop(['fire_next_period'], axis=1)
y = combined_lags['fire_next_period']

# Add group identifiers (e.g., `_uid_` and `id`) for grouping
groups = combined_lags['_uid_'].astype(str) + '_' + combined_lags['id'].astype(str)


# Initialize GroupShuffleSplit for dev-test split
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
dev_idx, test_idx = next(gss.split(X, y, groups=groups))

# Split into dev and test sets
combined_X_dev = X.iloc[dev_idx]
combined_X_test = X.iloc[test_idx]
combined_y_dev = y.iloc[dev_idx]
combined_y_test = y.iloc[test_idx]

# Perform training-validation split within the dev set
gss_dev = GroupShuffleSplit(n_splits=1, test_size=0.25, random_state=42)
train_idx, val_idx = next(gss_dev.split(combined_X_dev, combined_y_dev, groups=groups.iloc[dev_idx]))

combined_X_train = combined_X_dev.iloc[train_idx]
combined_X_val = combined_X_dev.iloc[val_idx]
combined_y_train = combined_y_dev.iloc[train_idx]
combined_y_val = combined_y_dev.iloc[val_idx]


In [94]:
vertical_profile_cols = [col for col in combined.columns if any(x in col for x in ['-air', '-hgt', '-omega', '-shum', '-tke', '-uwnd', '-vwnd'])]
scaler = StandardScaler()
combined_X_train.loc[:, vertical_profile_cols] = scaler.fit_transform(combined_X_train[vertical_profile_cols])
combined_X_val.loc[:, vertical_profile_cols] = scaler.transform(combined_X_val[vertical_profile_cols])
combined_X_test.loc[:, vertical_profile_cols] = scaler.transform(combined_X_test[vertical_profile_cols])


In [96]:
from sklearn.impute import KNNImputer
knn_imputer = KNNImputer(n_neighbors=5, weights="uniform") 
# Apply the imputer to the training set
combined_X_train[vertical_profile_cols] = knn_imputer.fit_transform(combined_X_train[vertical_profile_cols])

# Apply the trained imputer to validation and test sets
combined_X_val[vertical_profile_cols] = knn_imputer.transform(combined_X_val[vertical_profile_cols])
combined_X_test[vertical_profile_cols] = knn_imputer.transform(combined_X_test[vertical_profile_cols])


KeyboardInterrupt: 

In [None]:
model = LogisticRegression(random_state=42, max_iter=1000, class_weight='balanced')  # Increase max_iter if convergence issues
model.fit(combined_X_train, combined_y_train)

# Predict on the validation set
val_predictions = model.predict(combined_X_val)
val_probs = model.predict_proba(combined_X_val)[:, 1]

# Evaluate on validation set
print("Validation Set Performance:")
print(classification_report(combined_y_val, val_predictions))
print(f"Validation ROC-AUC: {roc_auc_score(combined_y_val, val_probs)}")

# Predict on the test set
test_predictions = model.predict(combined_X_test)
test_probs = model.predict_proba(combined_X_test)[:, 1]

# Evaluate on test set
print("\nTest Set Performance:")
print(classification_report(combined_y_test, test_predictions))
print(f"Test ROC-AUC: {roc_auc_score(combined_y_test, test_probs)}")