In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import xgboost as xgb
import sklearn
import geopandas as gpd
import matplotlib.pyplot as plt
import subprocess
import sys
import seaborn as sns

import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os 
import pandas as pd


SEED = 42

def manual_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are suing GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

manual_seed(SEED)

# Load Data & Preprocessing

In [None]:
data = pd.read_csv('../input/btlaionkk/data_onkk_merged.csv')
data.head()

In [None]:
import math
from tqdm import tqdm

def preprocess_data(df):
    df = df.copy()
    full_dates = pd.date_range(start=df['time'].min(), end=df['time'].max(), freq='D')
    station_data_list = []
    for station_id in tqdm(df.ID.unique()):
        station_data = df[df.ID == station_id].copy()
        if len(station_data) == 0: continue
        station_data['time']=pd.to_datetime(station_data['time'])
        station_data_daily = station_data.set_index('time').reindex(full_dates).rename_axis('time').reset_index()

        ### Preprocess time-dependent features
        station_data_daily['pm25_lag1'] = station_data_daily.pm25 - station_data_daily.pm25.shift(1)
        
        station_data_daily['lat'] = np.nanmean(station_data_daily['lat'].values)
        station_data_daily['lon'] = np.nanmean(station_data_daily['lon'].values)
        station_data_daily['ID'] = np.nanmean(station_data_daily['ID'].values)

        ### Gather station data
        station_data_list += [station_data_daily]

    df = pd.concat(station_data_list, axis=0)



    ### Preprocess time-independent features
    df['WDIR_x'] = np.cos(np.radians(df['WDIR']))
    df['WDIR_y'] = np.sin(np.radians(df['WDIR']))
    df['time'] = pd.to_datetime(df['time'])
    df["day_of_year"] = df["time"].dt.dayofyear
    df["sin_day"] = np.sin(2 * np.pi * df["day_of_year"] / 365)
    df["cos_day"] = np.cos(2 * np.pi * df["day_of_year"] / 365)
    df['wind_u'] = df['WSPD'] * np.cos(np.radians(df['WDIR']))
    df['wind_v'] = df['WSPD'] * np.sin(np.radians(df['WDIR']))
    df['temp_range'] = df['TX'] - df['TN']
    
    df['time'] = pd.to_datetime(df['time'])
    df['day_of_week'] = df['time'].dt.dayofweek
    df['month'] = df['time'].dt.month
    
    def get_season(month):
        if month in [12, 1, 2]:
            return '4'
        elif month in [3, 4, 5]:
            return '1'
        elif month in [6, 7, 8]:
            return '2'
        elif month in [9, 10, 11]:
            return '3'
            
    df['season'] = df['month'].apply(get_season).astype(int)
    df['is_weekend'] = df['day_of_week'].apply(lambda x: 1 if x >= 5 else 0)
    df['heat_index'] = df['TMP'] * df['RH']
    
    def calculate_dew_point(temp, rh):
        a = 17.27
        b = 237.7
        gamma = np.log(rh / 100.0) + (a * temp) / (b + temp)
        dew_point = (b * gamma) / (a - gamma)
        return dew_point
        
    df['dew_point'] = df.apply(lambda row: calculate_dew_point(row['TMP'], row['RH']), axis=1)
    
    hanoi_lat, hanoi_lon = 21.0278, 105.8342
    def haversine_distance(row, lat2, lon2):
        lat1, lon1 = row['lat'], row['lon']
        lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])
        
        dlat = lat2 - lat1 
        dlon = lon2 - lon1 
        a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2
        c = 2 * math.asin(math.sqrt(a))
        r = 6371 
        return c * r
    
    df['distance_to_hanoi'] = df.apply(lambda row: haversine_distance(row, hanoi_lat, hanoi_lon), axis=1)
    
    # df['inversion_strength'] = df['TX'] - df['TN']
    df['temp_wind'] = df['TMP'] * df['WSPD']
    df['rh_pressure'] = df['RH'] * df['PRES2M']
    df['wspd_squared']= df['WSPD'] ** 2

    full_dates = pd.date_range(start=df['time'].min(), end=df['time'].max(), freq='D')
    
    df = df.copy()
    station_data_list = []
    for station_id in tqdm(df.ID.unique()):
        station_data = df[df.ID == station_id].copy()
        if len(station_data) == 0: continue
        # full_dates = pd.date_range(start=station_data['time'].min(), end=station_data['time'].max(), freq='D')
        station_data['time']=pd.to_datetime(station_data['time'])
        # station_data_daily = station_data.set_index('time').reindex(full_dates).rename_axis('time').reset_index()
        station_data_daily = station_data.set_index('time').rename_axis('time').reset_index()

        ### Preprocess time-dependent features
        for ft_name in station_data_daily.columns:
            if ft_name not in [
                'CO_column_number_density', 'Cloud', 'NO2_column_number_density',
                'O3_column_number_density', 'absorbing_aerosol_index',
            ]: continue
            if station_data_daily[ft_name].dtype not in ['float', 'int']:  continue
            if ft_name in ['pm25', 'lat', 'lon', 'time', 'ID']:  continue
            station_data_daily[f'{ft_name}_prev1'] = station_data_daily[ft_name].shift(1)
            station_data_daily[f'{ft_name}_next1'] = station_data_daily[ft_name].shift(-1)

        ### Gather station data
        station_data_list += [station_data_daily]

    df = pd.concat(station_data_list, axis=0)

    return df

def create_timeseries_data(df, window_size=16, show_tqdm=True):
    ### Assume these 2 vars have been setup
    global features, scaler

    Xs = []
    ys = []
    pbar = df.ID.unique()
    if show_tqdm: pbar = tqdm(pbar)
    for station_id in pbar:
        station_data = df[df.ID == station_id].copy()
        if len(station_data) == 0: continue
        station_data['time']=pd.to_datetime(station_data['time'])
        full_dates = pd.date_range(start=station_data['time'].min(), end=station_data['time'].max(), freq='D')
        station_data_daily = station_data.set_index('time').reindex(full_dates).rename_axis('time').reset_index()
    
        for i in range(len(station_data_daily) - window_size):
            # Do scaling stuff
            currX = scaler.transform(station_data_daily[features].iloc[i:i+window_size].values)
            curry = scaler.transform(station_data_daily[features].iloc[[i+window_size]].values)[0]

            # Let the pm25 be in the first column
            pm25_idx = features.index('pm25')
            curry = curry[[pm25_idx] + [i for i in range(len(features)) if i != pm25_idx]]
            
            if np.isnan(np.sum(currX)) or np.isnan(np.sum(curry)):
                continue
            Xs += [currX]
            ys += [curry]

    X = np.stack(Xs)
    y = np.stack(ys)
    return X, y

def create_timeseries_data_missing(df, window_size=8, show_tqdm=True):
    ### Assume these 2 vars have been setup
    global features, scaler
    full_dates = pd.date_range(start=df['time'].min(), end=df['time'].max(), freq='D')

    Xs = []
    ys = []
    pbar = data.ID.unique()
    if show_tqdm: pbar = tqdm(pbar)
    for station_id in pbar:
        station_data = df[df.ID == station_id].copy()
        if len(station_data) == 0: continue
        station_data['time']=pd.to_datetime(station_data['time'])
        # station_data_daily = station_data.set_index('time').reindex(full_dates).rename_axis('time').reset_index()
        station_data_daily = station_data.set_index('time').rename_axis('time').reset_index()

        
        currX = scaler.transform(station_data_daily[features].values)
        # currX[:, 1] = np.nanmean(currX[:, 1])
        # currX[:, 2] = np.nanmean(currX[:, 2])
        Xs += [currX[:, None]]
    Xs = np.concatenate(Xs, axis=1)
    return Xs

In [None]:
data_processed = preprocess_data(data)
data_processed.shape

In [None]:
val_date = '2021-06-01'
test_date = '2021-08-01'
train = data_processed[data_processed['time'] < val_date]
val = data_processed[(data_processed['time'] >= val_date) & (data_processed['time'] < test_date)]
test = data_processed[data_processed['time'] >= test_date]
#train = train.drop('ID',axis=1)
#train = train.drop('time',axis=1)


In [None]:
data_processed.columns

In [None]:
features = [
    ### Features selection
    # 'time', 'ID',
    'pm25',
    'lat', 'lon',
    'sin_day', 'cos_day',
    'SQRT_SEA_DEM_LAT', 'WSPD',
    'WDIR',
    'TMP',
    'TX', 'TN', 'TP', 'RH', 'PRES2M',
    # 'pm25_lag1',
    # 'WDIR_x', 'WDIR_y',
    # 'day_of_year',
    'wind_u', 'wind_v',
    # 'temp_range',
    # 'day_of_week', 'month', 'season', 'is_weekend',
    'heat_index', 'dew_point', 'distance_to_hanoi', 'temp_wind',
    'rh_pressure',
    # 'wspd_squared',
    'CO_column_number_density', 'Cloud', 'NO2_column_number_density',
    'O3_column_number_density', 'absorbing_aerosol_index',
]

### Fit a shared scaler on training data
scaler = MinMaxScaler()
scaler.fit(train[features].values)

In [None]:
selected_features = [
    "pm25",
    
    "SQRT_SEA_DEM_LAT",
    
    "TN", "dew_point", "heat_index", "TMP", "sin_day", "PRES2M",
    "distance_to_hanoi", "temp_wind", "cos_day", "TP", "TX", "wind_u",
    "rh_pressure",

    'CO_column_number_density', 'Cloud', 'NO2_column_number_density',
    'O3_column_number_density', 'absorbing_aerosol_index',
]

print([ft for ft in features if ft not in selected_features])
print([ft for ft in selected_features if ft not in features])

In [None]:
pos_embed_indices = [1,2,3,4]
np.array(features)[pos_embed_indices]

# PyTorch Dataset

In [None]:
X_train_np = create_timeseries_data_missing(train, show_tqdm=False)
station_locs = X_train_np[0, :, [1, 2]]
import numpy as np

coords = station_locs.T   # shape → (26, 2)

diff    = coords[:, np.newaxis, :] - coords[np.newaxis, :, :]   # → (26,26,2)
sqdist  = np.sum(diff**2, axis=2)                               # → (26,26)

np.fill_diagonal(sqdist, np.inf)

nearest_stations = np.argsort(sqdist, axis=1)[:, :9]  # shape → (26,4)

nearest_stations = nearest_stations.tolist()
stations_sampling = [[i] + nearest_stations[i] for i in range(26)]

In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm
import random

class TimeSeriesDataset(Dataset):
    def __init__(self, df, window_sizes=[16, 32, 64], batch_size=16):
        """
        Creates a dataset that precomputes timeseries data for multiple window sizes.
        
        Args:
            df (pd.DataFrame): The input dataframe.
            window_sizes (list): A list of integers specifying window sizes.
        """
        self.df = df
        self.window_sizes = window_sizes
        self.batch_size = batch_size
        self.data_dict = {}
        
        # Precompute timeseries data for each window size and store in a dictionary.
        for ws in tqdm(window_sizes):
            X, y = create_timeseries_data(df, window_size=ws, show_tqdm=False)
            self.data_dict[ws] = {'X': X, 'y': y}

    def __len__(self):
        # Define length as the sum of all samples computed for all window sizes.
        total = 0
        for ws in self.window_sizes:
            total += len(self.data_dict[ws]['X'])
        return total // self.batch_size

    def __getitem__(self, dummy_index):
        if dummy_index > self.__len__(): raise StopIteration
        ws = random.choice(self.window_sizes)
        data = self.data_dict[ws]
        # Randomly sample an index from the chosen data.
        sample_idx = random.sample(range(len(data['X'])), k=self.batch_size)
        sample = {
            'X': data['X'][sample_idx],
            'y': data['y'][sample_idx],
        }
        return sample


class ImputeDataset(Dataset):
    def __init__(self, df, time_window_size=7, station_window_size=5, sampling_size=10, additional_mask_probs=0.3):
        """
        Creates a dataset that precomputes timeseries data for multiple window sizes.
        
        Args:
            df (pd.DataFrame): The input dataframe.
            window_sizes (list): A list of integers specifying window sizes.
        """
        self.df = df

        self.time_window_size=time_window_size
        self.station_window_size=station_window_size
        
        self.X = create_timeseries_data_missing(df, show_tqdm=False)
        self.mask = ~np.isnan(self.X)
        self.additional_mask = (np.random.rand(*self.mask.shape) > additional_mask_probs) & self.mask
        self.X[~self.mask] = 0.0

        self.sample_indices = []

        self.sampling_size = sampling_size

        self._prepare()

    def _prepare(self, ):
        T, S, N = self.X.shape
        for i in range(T):
            if i + self.time_window_size >= T: break
            time_indices = np.arange(i, i + self.time_window_size)
            for _ in range(self.sampling_size):
                station_indices = random.sample(
                    random.choice(stations_sampling), k=self.station_window_size)
                mask_batch = self.mask[time_indices][:, station_indices]
                if mask_batch.mean() < 0.5: continue
                self.sample_indices.append([time_indices, station_indices])
                
        

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, i):
        time_indices, station_indices = self.sample_indices[i]
        batch = torch.tensor(
            self.X[time_indices][:, station_indices], dtype=torch.float32)
        mask_batch = torch.tensor(
            self.mask[time_indices][:, station_indices], dtype=torch.float32)
        additional_mask_batch = torch.tensor(
            self.additional_mask[time_indices][:, station_indices], dtype=torch.float32)
        
        return (batch, mask_batch, additional_mask_batch)



class ImputeBilevelDataset(Dataset):
    def __init__(self, df, time_window_size=7, station_window_size=5, sampling_size=10, additional_mask_probs=0.0):
        """
        Creates a dataset that precomputes timeseries data for multiple window sizes.
        
        Args:
            df (pd.DataFrame): The input dataframe.
            window_sizes (list): A list of integers specifying window sizes.
        """
        self.df = df

        self.time_window_size=time_window_size
        self.station_window_size=station_window_size
        
        self.X = create_timeseries_data_missing(df, show_tqdm=False)
        self.mask = ~np.isnan(self.X)
        self.additional_mask = (np.random.rand(*self.mask.shape) > additional_mask_probs) & self.mask
        self.X[~self.mask] = 0.0

        self.sample_indices = []

        self.sampling_size = sampling_size

        self._prepare()

    def _prepare(self, ):
        T, S, N = self.X.shape
        for i in range(T):
            if i + self.time_window_size >= T: break
            time_indices = np.arange(i, i + self.time_window_size)
            for _ in range(self.sampling_size):
                station_indices = random.sample(
                    random.choice(stations_sampling), k=self.station_window_size)
                mask_batch = self.mask[time_indices][:, station_indices]
                if mask_batch[-1, :, 0].mean() < 1: continue
                if mask_batch.mean() < 0.5: continue
                self.sample_indices.append([time_indices, station_indices])
                
        

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, i):
        time_indices, station_indices = self.sample_indices[i]
        batch = torch.tensor(
            self.X[time_indices][:, station_indices], dtype=torch.float32)
        mask_batch = torch.tensor(
            self.mask[time_indices][:, station_indices], dtype=torch.float32)
        additional_mask_batch = torch.tensor(
            self.additional_mask[time_indices][:, station_indices], dtype=torch.float32)
        
        return (batch, mask_batch, additional_mask_batch)



# Import MAE

In [None]:
import sys

sys.path.append('../models/')

In [None]:
from mae import MAE, LSTMStudent
from bilevel_impute import BiImpute

# Training Scripts

In [None]:
from sklearn.metrics import *
from copy import deepcopy
import time



def reset_all_parameters(module):
    def reset_fn(m):
        if hasattr(m, 'reset_parameters'):
            m.reset_parameters()
    module.apply(reset_fn)


def eval_model(model, data_loader):
    model = deepcopy(model).eval().to(args.device)
    tick = time.time()
    losses = []
    for sample in data_loader:
        batch, mask_batch, additional_mask_batch = sample
        batch = batch[0].to(args.device)
        mask_batch = mask_batch[0].to(args.device)
        additional_mask_batch = additional_mask_batch[0].to(args.device)

        rec_mask = mask_batch - additional_mask_batch
        
        data_pred = model.forward_impute(batch, additional_mask_batch)
        loss = ((data_pred - batch) ** 2 * rec_mask).sum() / (rec_mask.sum() + 1e-8)
        losses.append(loss.item())
        
    return np.mean(losses)


def train_bi(models, optimizers, train_loader, train_bilevel_loader, val_bilevel_loader, val_loader):
    torch.cuda.empty_cache()
    losses = []
    prev_eval_score = -9999.0
    best_eval_score = -9999.0
    scores_list = []
    args.global_step = 0
    
    teacher, student = models
    teacher_optimizer, student_optimizer = optimizers
    bi = BiImpute()
    
    train_iter = iter(train_loader)
    def get_next_train_sample():
        nonlocal train_iter
        try: return next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            return next(train_iter)
    
    train_bilevel_iter = iter(train_bilevel_loader)
    def get_next_train_bi_sample():
        nonlocal train_bilevel_iter
        try: return next(train_bilevel_iter)
        except StopIteration:
            train_bilevel_iter = iter(train_bilevel_loader)
            return next(train_bilevel_iter)
    
    val_bilevel_iter = iter(val_bilevel_loader)
    def get_next_val_bi_sample():
        nonlocal val_bilevel_iter
        try: return next(val_bilevel_iter)
        except StopIteration:
            val_bilevel_iter = iter(val_bilevel_loader)
            return next(val_bilevel_iter)
        
    pbar = tqdm(range(args.num_train_steps), position=0, leave=True)
    for iteration in pbar:
        teacher.train().to(args.device)
        student.train().to(args.device)

        t_batch, t_mask, _ = get_next_train_sample()
        t_batch_bi, t_mask_bi, _ = get_next_train_bi_sample()
        v_batch, v_mask, _ = get_next_val_bi_sample()
        t_batch = t_batch[0].to(args.device)
        t_mask = t_mask[0].to(args.device)
        t_batch_bi = t_batch_bi[0].to(args.device)
        t_mask_bi = t_mask_bi[0].to(args.device)
        v_batch = v_batch[0].to(args.device)
        v_mask = v_mask[0].to(args.device)

        all_samples = [(t_batch, t_mask), (t_batch_bi, t_mask_bi), (v_batch, v_mask)]
        bi.step_fn(args, models, optimizers, all_samples)

        
        if (iteration + 1) % 20 == 0:
            if best_eval_score > -9999:
                pbar.set_description_str(
                f"Best Val Score: {best_eval_score:.4f} | Val Score: {prev_eval_score:.4f}\t")
        
        if (iteration + 1) % 100 == 0:
            
            eval_score = -eval_model(teacher, val_loader)
            prev_eval_score = eval_score
            if eval_score > best_eval_score:
                best_eval_score = eval_score
                best_ckpt_state_dict = deepcopy(teacher.state_dict())
                torch.save(teacher.state_dict(), 'best_model.pt')
                # print("On val:", best_eval_score)
                print("On test:", eval_model(teacher, test_loader))
            
            if best_eval_score > -9999:
                print(f"""
    Step {args.global_step}: Best Val Score: {best_eval_score:.4f} | Val Score: {prev_eval_score:.4f}\t
    s: ({bi.step_info['nlll/student_on_t']:.4f}, {bi.step_info['nlll/student_on_v']:.4f}), t: {bi.step_info['mae']:.4f}
    """.strip(), end=', ')
            else:
                print(f"""
    Step {args.global_step}:
    s: ({bi.step_info['nlll/student_on_t']:.4f}, {bi.step_info['nlll/student_on_v']:.4f}), t: {bi.step_info['mae']:.4f}
    """.strip(), end=', ')


        if (iteration + 1) % 50000 == 0:
            reset_all_parameters(student)
            student_optimizer.state.clear()
            student_optimizer.zero_grad()
            bi.moving_dot_product = None
            
        

        args.global_step += 1
        
    return teacher, best_eval_score



def train_model(model, train_loader, val_loader):
    torch.cuda.empty_cache()
    manual_seed(SEED)
    
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    prev_eval_score = -9999.0
    best_eval_score = -9999.0
    iteration = 0
    mean_loss = -1
    best_ckpt_state_dict = deepcopy(model.state_dict())

    model.train().to(args.device)
    args.global_step = 0
    for epoch in range(args.num_epochs):
        print(f"Epoch {epoch}")
        pbar = tqdm(train_loader, position=0, leave=True)
        if args.global_step > args.num_pretrain_steps: break
        for sample in pbar:
            if args.global_step > args.num_pretrain_steps: break
            batch, mask_batch, additional_mask_batch = sample
            batch = batch[0].to(args.device)
            mask_batch = mask_batch[0].to(args.device)
            additional_mask_batch = additional_mask_batch[0].to(args.device)

            optimizer.zero_grad()
            loss = model.forward_ssl(batch, mask_batch, mask_ratio=args.mask_ratio)
            loss.backward()

            total_norm = 0
            for p in model.parameters():
                try:
                    param_norm = p.grad.detach().data.norm(2)
                    total_norm += param_norm.item() ** 2
                except: pass
            total_norm = total_norm ** 0.5
            optimizer.step()
            
            if mean_loss is None or mean_loss < 0: mean_loss = loss.item()
            else: mean_loss = 0.9 * mean_loss + 0.1 * loss.item()


                
            if (iteration + 1) % 20 == 0:
                if best_eval_score > -9999:
                    pbar.set_description_str(
                    f"Loss: {mean_loss:.4f} | Best Val Score: {best_eval_score:.4f} | Val Score: {prev_eval_score:.4f}\t")
                else:
                    pbar.set_description_str(
                    f"Loss: {mean_loss:.4f}\t")
            if (iteration + 1) % 500 == 0:
                # Evaluate on test data.
                eval_score = -eval_model(model, val_loader)
                prev_eval_score = eval_score
                if eval_score > best_eval_score:
                    best_eval_score = eval_score
                    best_ckpt_state_dict = deepcopy(model.state_dict())
                    torch.save(model.state_dict(), 'best_model_pretrain.pt')
            iteration += 1
            args.global_step += 1
    model.load_state_dict(best_ckpt_state_dict)
    return model, best_eval_score

### Training args

In [None]:
from argparse import Namespace
import numpy as np

# args = Namespace(
#     device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
#     seed=1902,
#     ft_embed_dim=8,
#     ft_enc_nhead=1,
#     ft_enc_num_layers=3,
#     mae_hidden_dim=64,
#     mlp_ratio=4.,
#     dim_feedforward=64 * 4,
#     mae_nhead=4,
#     mae_num_layers=3,
#     mae_dropout=0.1,
#     activation='gelu',
#     time_window_size=5,
#     station_window_size=5,
#     num_epochs=2,
#     lr=3e-4,
#     weight_decay=1e-6,
#     mask_ratio=0.3,
# )

args = Namespace( # Optuna
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    seed=1902,
    ft_embed_dim=24,
    ft_enc_nhead=2,
    ft_enc_num_layers=4,
    mae_hidden_dim=160,
    mlp_ratio=4.36,
    dim_feedforward=697,
    mae_nhead=4,
    mae_num_layers=3,
    mae_dropout=0.29,
    activation='gelu',
    time_window_size=5,
    station_window_size=5,
    num_epochs=1,
    lr=6.9e-5,
    weight_decay=4.12e-6,
    mask_ratio=0.31,

    student_lr=0.006619339793756876,
    student_weight_decay=3.903728867916741e-08,
    student_lambda=0.38463245884031055,
    student_dropout=0.36358270939782433,

    num_train_steps=2000,
    num_pretrain_steps=20000,
)

args.pos_embed_indices = pos_embed_indices
args.num_features = len(features)

def manual_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # if you are using GPU
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

manual_seed(args.seed)

### Setup datasets and train model

In [None]:

train_dataset = ImputeDataset(train, args.time_window_size, args.station_window_size,
                              sampling_size=200, additional_mask_probs=0.0)
train_bilevel_dataset = ImputeDataset(train, args.time_window_size, args.station_window_size,
                              sampling_size=200, additional_mask_probs=0.0)
# train_bilevel_dataset = ImputeBilevelDataset(train, args.time_window_size, args.station_window_size,
#                               sampling_size=200, additional_mask_probs=0.0)
val_bilevel_dataset = ImputeBilevelDataset(val, args.time_window_size, args.station_window_size,
                            sampling_size=200, additional_mask_probs=0.0)
val_dataset = ImputeDataset(val, args.time_window_size, args.station_window_size,
                            sampling_size=20, additional_mask_probs=0.3)
test_dataset = ImputeDataset(test, args.time_window_size, args.station_window_size,
                            sampling_size=10, additional_mask_probs=0.3)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
train_bilevel_loader = DataLoader(train_bilevel_dataset, batch_size=1, shuffle=True)
val_bilevel_loader = DataLoader(val_bilevel_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

teacher = MAE(args)
student = LSTMStudent(args)

teacher_optimizer = optim.AdamW(teacher.parameters(), lr=args.lr, weight_decay=args.weight_decay)
student_optimizer = optim.AdamW(student.parameters(), lr=args.student_lr, weight_decay=args.student_weight_decay)

models = (teacher, student)
optimizers = (teacher_optimizer, student_optimizer)

In [None]:
teacher, best_eval_score = train_model(teacher, train_loader, val_loader)

In [None]:
teacher, best_eval_score = train_bi(models, optimizers, train_loader, train_bilevel_loader, val_bilevel_loader, val_loader)

# Load model and create imputed dataset

In [None]:
best_model = deepcopy(teacher).eval().to(args.device)
try:best_model.load_state_dict(torch.load('./best_model.pt'))
except:
    try: best_model.load_state_dict(torch.load('./best_model_pretrain.pt'))
    except:
        pass

### Calculate nearest stations

In [None]:
X_train_np = create_timeseries_data_missing(train, show_tqdm=False)
station_locs = X_train_np[0, :, [1, 2]]
import numpy as np

coords = station_locs.T   # shape → (26, 2)

diff    = coords[:, np.newaxis, :] - coords[np.newaxis, :, :]   # → (26,26,2)
sqdist  = np.sum(diff**2, axis=2)                               # → (26,26)

np.fill_diagonal(sqdist, np.inf)

nearest_stations = np.argsort(sqdist, axis=1)[:, :args.station_window_size - 1]  # shape → (26,4)

nearest_stations = nearest_stations.tolist()


### 📌 Imputation Rules

To ensure the best use of incomplete data while maintaining strict evaluation integrity, we apply different imputation strategies for training, validation, and test sets. These strategies aim to **(1) reduce reconstruction error during training** and **(2) prevent information leakage during evaluation**.

#### 🔧 Training Set
For the training set, our goal is to provide the downstream LSTM model with the most accurate reconstructed values. Therefore, **future data is allowed** during imputation to maximize the information available. Each missing value is filled using predictions from a masked autoencoder (MAE) applied on a spatiotemporal window that includes the target station and its nearest neighbors. The MAE is trained and selected based on validation MSE.

We create multiple versions of the imputed training data:
- `imputed_X_train_np`: Fully imputed version for inspection or optional use.
- `imputed_X_train_np_dilate`: Only retains values that are supported by sufficiently reliable neighboring information.
- `imputed_X_train_np_restrict`: Masks out values that were originally missing in the PM2.5 feature.

The `imputed_X_train_np_dilate` version selectively retains imputed values based on local spatiotemporal support. For each station, we apply 1D `binary_dilation` over time to extend valid PM2.5 observations, then merge this mask with that of the station’s nearest neighbor. This ensures a value is only kept if the original data had support either at that station or nearby in time/space. All other imputed values are discarded by resetting them to `NaN`.

#### 🧪 Validation & Test Sets
In contrast to training, **future data is strictly prohibited** during imputation on validation and test sets to prevent any leakage. Each sample is imputed **only using past and current data**, ensuring that the evaluation remains unbiased.

For each time point `t`, only the time window `[t - W + 1, t]` is used (where `W` is the window size). The MAE model imputes missing values in this window using the most informative set of stations — consisting of the target station and its closest neighbors — as determined dynamically by mask coverage.

Only the **last time step in each window is retained** after imputation, corresponding to the target input of the LSTM predictor. Furthermore, any validation/test sample where the target PM2.5 label is missing is **excluded from downstream evaluation** to maintain fairness.

In [None]:
           
X_train_np = create_timeseries_data_missing(train, show_tqdm=False)
mask_np = ~np.isnan(X_train_np)

imputed_X_train_np = np.zeros_like(X_train_np) + np.nan
mask_blocks = mask_np.mean(axis=-1) < 0.5
mask_blocks.shape

T, S, N = X_train_np.shape
for t in tqdm(range(T)):
    if t + args.time_window_size >= T: break
    time_indices = np.arange(t, t + args.time_window_size)
    for s in range(S):
        if mask_blocks[t].mean() < 0.1: continue
        best_station_indices = None
        best_score = 0.0
        station_indices = [s] + nearest_stations[s]
        mask_batch = mask_np[time_indices][:, station_indices]
        if mask_batch.mean() > best_score:
            best_score = mask_batch.mean()
            best_station_indices = deepcopy(station_indices)
        if best_score < 0.5: 
            continue
        station_indices = best_station_indices
        batch = torch.tensor(
            X_train_np[time_indices][:, station_indices], dtype=torch.float32, device=args.device)
        batch = batch.nan_to_num(0.0)
        mask_batch = torch.tensor(
            mask_np[time_indices][:, station_indices], dtype=torch.float32, device=args.device)
        with torch.no_grad():
            data_pred = best_model.forward_impute(batch, mask_batch)
            data_pred = data_pred * (1 - mask_batch) + batch * mask_batch
        data_pred = data_pred.detach().cpu().numpy()
        imputed_X_train_np[time_indices, s] = data_pred[:, 0]
        
    

In [None]:
           
X_val_np = create_timeseries_data_missing(val, show_tqdm=False)
mask_np = ~np.isnan(X_val_np)

imputed_X_val_np = np.zeros_like(X_val_np) + np.nan
mask_blocks = mask_np.mean(axis=-1) < 0.5
mask_blocks.shape

T, S, N = X_val_np.shape
for t_reverse in tqdm(range(args.time_window_size - 1, T)):
    t = t_reverse - args.time_window_size + 1
    if t < 0: break
    time_indices = np.arange(t, t + args.time_window_size)
    for s in range(S):
        if mask_blocks[t].mean() < 0.1: continue
        best_station_indices = None
        best_score = 0.0
        station_indices = [s] + nearest_stations[s]
        mask_batch = mask_np[time_indices][:, station_indices]
        if mask_batch.mean() > best_score:
            best_score = mask_batch.mean()
            best_station_indices = deepcopy(station_indices)
        if best_score < 0.5: 
            continue
        station_indices = best_station_indices
        batch = torch.tensor(
            X_val_np[time_indices][:, station_indices], dtype=torch.float32, device=args.device)
        batch = batch.nan_to_num(0.0)
        mask_batch = torch.tensor(
            mask_np[time_indices][:, station_indices], dtype=torch.float32, device=args.device)
        with torch.no_grad():
            data_pred = best_model.forward_impute(batch, mask_batch)
            data_pred = data_pred * (1 - mask_batch) + batch * mask_batch
        data_pred = data_pred.detach().cpu().numpy()
        imputed_X_val_np[time_indices[-1], s] = data_pred[-1, 0]
        
    

In [None]:

X_test_np = create_timeseries_data_missing(test, show_tqdm=False)
mask_np = ~np.isnan(X_test_np)

imputed_X_test_np = np.zeros_like(X_test_np) + np.nan
mask_blocks = mask_np.mean(axis=-1) < 0.5
mask_blocks.shape

T, S, N = X_test_np.shape
for t_reverse in tqdm(range(args.time_window_size - 1, T)):
    t = t_reverse - args.time_window_size + 1
    if t < 0: break
    time_indices = np.arange(t, t + args.time_window_size)
    for s in range(S):
        if mask_blocks[t].mean() < 0.1: continue
        best_station_indices = None
        best_score = 0.0
        station_indices = [s] + nearest_stations[s]
        mask_batch = mask_np[time_indices][:, station_indices]
        if mask_batch.mean() > best_score:
            best_score = mask_batch.mean()
            best_station_indices = deepcopy(station_indices)
        if best_score < 0.5: 
            continue
        station_indices = best_station_indices
        batch = torch.tensor(
            X_test_np[time_indices][:, station_indices], dtype=torch.float32, device=args.device)
        batch = batch.nan_to_num(0.0)
        mask_batch = torch.tensor(
            mask_np[time_indices][:, station_indices], dtype=torch.float32, device=args.device)
        with torch.no_grad():
            data_pred = best_model.forward_impute(batch, mask_batch)
            data_pred = data_pred * (1 - mask_batch) + batch * mask_batch
        data_pred = data_pred.detach().cpu().numpy()
        imputed_X_test_np[time_indices[-1], s] = data_pred[-1, 0]
        
    
imputed_X_test_np_full = imputed_X_test_np.copy()

In [None]:

from scipy.ndimage import binary_dilation
train_keep_mask = ~np.isnan(X_train_np[:, :, 0])[:, :]
T, S = train_keep_mask.shape
for s in range(S):
    train_keep_mask[:, s] = binary_dilation(train_keep_mask[:, s])
    
    for s1 in nearest_stations[s][:1]:
        train_keep_mask[:, s] |= train_keep_mask[:, s1]

imputed_X_train_np_dilate = imputed_X_train_np.copy()
imputed_X_train_np_restrict = imputed_X_train_np.copy()

imputed_X_train_np_dilate[~train_keep_mask] = np.nan
imputed_X_train_np_restrict[np.isnan(X_train_np[:, :, 0])[:, :]] = np.nan

imputed_X_val_np[np.isnan(X_val_np[:, :, 0])[:, :]] = np.nan
imputed_X_test_np[np.isnan(X_test_np[:, :, 0])[:, :]] = np.nan

In [None]:
import os
os.makedirs('imputed_data', exist_ok=True)

np.save('imputed_data/imputed_X_train_np_dilate.npy', imputed_X_train_np_dilate)
np.save('imputed_data/imputed_X_train_np_restrict.npy', imputed_X_train_np_restrict)
np.save('imputed_data/imputed_X_train_np.npy', imputed_X_train_np)
np.save('imputed_data/imputed_X_val_np.npy', imputed_X_val_np)
np.save('imputed_data/imputed_X_test_np.npy', imputed_X_test_np)