In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from gluonts.dataset.common import ListDataset
from uni2ts.model.moirai_moe import MoiraiMoEForecast, MoiraiMoEModule
from uni2ts.eval_util.plot import plot_single

import numpy as np
from sklearn.metrics import r2_score, mean_squared_error

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import pandas as pd
import numpy as np
from gluonts.dataset.common import ListDataset
from uni2ts.model.moirai_moe import MoiraiMoEForecast, MoiraiMoEModule

# === Params ===
TARGET = "EXCESS_RET"
context_length = 252
prediction_length = 1
start_date = pd.to_datetime("2016-01-01")
end_date   = pd.to_datetime("2024-12-31")
FREQ = "D"

# === Load & prep ===
main_df = pd.read_csv("main_data.csv")
main_df["date"] = pd.to_datetime(main_df["date"])
main_df = main_df.sort_values(["PERMNO", "date"]).reset_index(drop=True)
main_df[TARGET] = pd.to_numeric(main_df[TARGET], errors="coerce")
main_df["predicted"] = np.nan   # placeholder column

# === Model ===
model = MoiraiMoEForecast(
    module=MoiraiMoEModule.from_pretrained("Salesforce/moirai-moe-1.0-R-small"),
    prediction_length=prediction_length,
    context_length=context_length,
    patch_size=64,   # important for small context length
    num_samples=100,
    target_dim=1,
    feat_dynamic_real_dim=0,
    past_feat_dynamic_real_dim=0,
)
predictor = model.create_predictor(batch_size=16)

# === Forecast for all PERMNOs ===
for permno, df_perm in main_df.groupby("PERMNO", sort=True):
    df_perm = df_perm.sort_values("date").reset_index()
    idx_map = df_perm["index"]

    contexts, write_indices = [], []
    for pred_idx in range(len(df_perm)):
        pred_date = df_perm.loc[pred_idx, "date"]
        if pred_date < start_date or pred_date > end_date:
            continue

        left = pred_idx - context_length
        if left < 0:
            continue

        context_slice = df_perm.iloc[left:pred_idx]
        ctx_vals = pd.to_numeric(context_slice[TARGET], errors="coerce").astype(float)

        if ctx_vals.isna().any() or len(ctx_vals) < context_length:
            continue

        contexts.append({
            "start": context_slice.loc[context_slice.index[0], "date"],
            "target": ctx_vals.to_list()
        })
        write_indices.append(int(idx_map[pred_idx]))

    if not contexts:
        continue

    ds = ListDataset(contexts, freq=FREQ)
    try:
        forecasts = predictor.predict(ds)
        for i, fc in enumerate(forecasts):
            yhat = float(fc.samples.mean(axis=0)[0])  # one-step ahead
            main_df.loc[write_indices[i], "predicted"] = yhat
    except Exception as e:
        print(f"⚠️ PERMNO {permno}: {e}")

display(main_df)

# === Save only evaluation slice with required columns ===
export_df = main_df.loc[
    (main_df["date"] >= start_date) & (main_df["date"] <= end_date),
    ["PERMNO", "date", TARGET, "predicted"]
].copy()

os.makedirs("Results", exist_ok=True)
export_path = "Results/uni2ts_predictions_small_252.csv"
export_df.to_csv(export_path, index=False)

print(f"✅ Exported evaluation slice with PERMNO/date/{TARGET}/predicted_value to: {export_path}")
display(export_df.head(20))

Unnamed: 0,PERMNO,date,SICCD,COMNAM,PRC,RET,SHROUT,Category,rf,EXCESS_RET,...,rolling_mean_5,rolling_mean_21,rolling_mean_252,rolling_mean_512,EXCESS_RET_lag_1,EXCESS_RET_lag_2,EXCESS_RET_lag_3,EXCESS_RET_lag_4,EXCESS_RET_lag_5,predicted
0,11174,2000-01-03,4840.0,CROWN GROUP INC,5.10938,0.034810,9711.0,Telecom & Cable,0.00021,0.034600,...,,,,,,,,,,
1,11174,2000-01-04,4840.0,CROWN GROUP INC,5.00000,-0.021407,9711.0,Telecom & Cable,0.00021,-0.021617,...,,,,,0.034600,,,,,
2,11174,2000-01-05,4840.0,CROWN GROUP INC,4.87500,-0.025000,9711.0,Telecom & Cable,0.00021,-0.025210,...,,,,,-0.021617,0.034600,,,,
3,11174,2000-01-06,4840.0,CROWN GROUP INC,5.18750,0.064103,9711.0,Telecom & Cable,0.00021,0.063893,...,,,,,-0.025210,-0.021617,0.034600,,,
4,11174,2000-01-07,4840.0,CROWN GROUP INC,4.75000,-0.084337,9711.0,Telecom & Cable,0.00021,-0.084547,...,,,,,0.063893,-0.025210,-0.021617,0.034600,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
314445,87344,2024-12-24,7370.0,SIFY TECHNOLOGIES LTD,3.02000,0.016835,11779.0,Internet & Software,0.00017,0.016665,...,-0.011826,-0.005351,-0.002356,0.000115,-0.000170,0.023968,-0.070683,-0.034226,0.021982,0.045999
314446,87344,2024-12-26,7370.0,SIFY TECHNOLOGIES LTD,3.19000,0.056291,11779.0,Internet & Software,0.00017,0.056121,...,-0.012889,-0.004832,-0.002185,0.000162,0.016665,-0.000170,0.023968,-0.070683,-0.034226,0.040439
314447,87344,2024-12-27,7370.0,SIFY TECHNOLOGIES LTD,3.16000,-0.009404,11779.0,Internet & Software,0.00017,-0.009574,...,0.005180,-0.003422,-0.001940,0.000316,0.056121,0.016665,-0.000170,0.023968,-0.070683,0.043678
314448,87344,2024-12-30,7370.0,SIFY TECHNOLOGIES LTD,2.89000,-0.085443,11779.0,Internet & Software,0.00017,-0.085613,...,0.017402,-0.005107,-0.001869,0.000357,-0.009574,0.056121,0.016665,-0.000170,0.023968,0.027808


✅ Exported evaluation slice with PERMNO/date/EXCESS_RET/predicted_value to: Results/uni2ts_predictions_small_252.csv


Unnamed: 0,PERMNO,date,EXCESS_RET,predicted
4025,11174,2016-01-04,0.025478,0.002895
4026,11174,2016-01-05,0.002192,0.005526
4027,11174,2016-01-06,-0.060518,0.003745
4028,11174,2016-01-07,-0.008925,0.003748
4029,11174,2016-01-08,-0.02036,0.006513
4030,11174,2016-01-11,-0.024381,0.003473
4031,11174,2016-01-12,-0.035231,0.005999
4032,11174,2016-01-13,-0.066242,0.007854
4033,11174,2016-01-14,-0.007276,0.00778
4034,11174,2016-01-15,-0.015117,0.007878


In [6]:
display(main_df)

Unnamed: 0,PERMNO,date,SICCD,COMNAM,PRC,RET,SHROUT,Category,rf,EXCESS_RET,...,rolling_mean_5,rolling_mean_21,rolling_mean_252,rolling_mean_512,EXCESS_RET_lag_1,EXCESS_RET_lag_2,EXCESS_RET_lag_3,EXCESS_RET_lag_4,EXCESS_RET_lag_5,predicted_value
0,11174,2000-01-03,4840.0,CROWN GROUP INC,5.10938,0.034810,9711.0,Telecom & Cable,0.00021,0.034600,...,,,,,,,,,,
1,11174,2000-01-04,4840.0,CROWN GROUP INC,5.00000,-0.021407,9711.0,Telecom & Cable,0.00021,-0.021617,...,,,,,0.034600,,,,,
2,11174,2000-01-05,4840.0,CROWN GROUP INC,4.87500,-0.025000,9711.0,Telecom & Cable,0.00021,-0.025210,...,,,,,-0.021617,0.034600,,,,
3,11174,2000-01-06,4840.0,CROWN GROUP INC,5.18750,0.064103,9711.0,Telecom & Cable,0.00021,0.063893,...,,,,,-0.025210,-0.021617,0.034600,,,
4,11174,2000-01-07,4840.0,CROWN GROUP INC,4.75000,-0.084337,9711.0,Telecom & Cable,0.00021,-0.084547,...,,,,,0.063893,-0.025210,-0.021617,0.034600,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
314445,87344,2024-12-24,7370.0,SIFY TECHNOLOGIES LTD,3.02000,0.016835,11779.0,Internet & Software,0.00017,0.016665,...,-0.011826,-0.005351,-0.002356,0.000115,-0.000170,0.023968,-0.070683,-0.034226,0.021982,0.011252
314446,87344,2024-12-26,7370.0,SIFY TECHNOLOGIES LTD,3.19000,0.056291,11779.0,Internet & Software,0.00017,0.056121,...,-0.012889,-0.004832,-0.002185,0.000162,0.016665,-0.000170,0.023968,-0.070683,-0.034226,0.005625
314447,87344,2024-12-27,7370.0,SIFY TECHNOLOGIES LTD,3.16000,-0.009404,11779.0,Internet & Software,0.00017,-0.009574,...,0.005180,-0.003422,-0.001940,0.000316,0.056121,0.016665,-0.000170,0.023968,-0.070683,0.036292
314448,87344,2024-12-30,7370.0,SIFY TECHNOLOGIES LTD,2.89000,-0.085443,11779.0,Internet & Software,0.00017,-0.085613,...,0.017402,-0.005107,-0.001869,0.000357,-0.009574,0.056121,0.016665,-0.000170,0.023968,0.032050


In [24]:
df = pd.read_csv("Results/uni2ts_predictions_5_new.csv")
display(df)
df['date'] = pd.to_datetime(df['date'], format='%d/%m/%Y')
df['date'] = df['date'].dt.strftime('%Y-%m-%d')

export_path = "Results/uni2ts_predictions_5_new.csv"
df.to_csv(export_path, index=False)

Unnamed: 0,PERMNO,date,EXCESS_RET,predicted
0,11174,04/01/2016,0.025478,0.001030
1,11174,05/01/2016,0.002192,0.019545
2,11174,06/01/2016,-0.060518,0.009723
3,11174,07/01/2016,-0.008925,-0.001282
4,11174,08/01/2016,-0.020360,0.003059
...,...,...,...,...
113195,87344,24/12/2024,0.016665,0.005839
113196,87344,26/12/2024,0.056121,0.007711
113197,87344,27/12/2024,-0.009574,0.034208
113198,87344,30/12/2024,-0.085613,0.028785


In [19]:
import pandas as pd
import os

# === File paths ===
files = [
    "Results/uni2ts_predictions_base_5.csv",
    "Results/uni2ts_predictions_base_21.csv",
    "Results/uni2ts_predictions_base_252.csv",
    "Results/uni2ts_predictions_base_512.csv",
    "Results/uni2ts_predictions_small_5.csv",
    "Results/uni2ts_predictions_small_21.csv",
    "Results/uni2ts_predictions_small_252.csv",
    "Results/uni2ts_predictions_small_512.csv"
]
output_path = "Results/uni2ts_merged_predictions.csv"

# Use the first file as the "source of truth"
master_file_path = files[0]

# --- Step 1: Create the master keys from the base file ---
print(f"Loading master keys from: {master_file_path}")
master_df = pd.read_csv(master_file_path)

# Clean and normalize the master keys to ensure they are perfect
master_df['date'] = pd.to_datetime(master_df['date'], errors='coerce').dt.normalize()
master_df['PERMNO'] = master_df['PERMNO'].astype(int)

# Our final DataFrame will start with just the master keys and the true value
# Assuming the column for the actual return is named 'true'
final_df = master_df[['PERMNO', 'date', 'EXCESS_RET']]

# --- Step 2: Loop through all files, replace keys, and merge ---
for file in files:
    print(f"Processing: {file}...")
    # Generate a unique column name from the filename
    parts = file.replace(".csv", "").split("_")
    modelsize = parts[2]
    window = parts[3]
    pred_col_name = f"uni2ts_{modelsize}_{window}"

    # Load the prediction file
    df_pred = pd.read_csv(file)

    # --- THE KEY STEP: Create a temporary DataFrame for merging ---
    # It contains the master keys and the renamed prediction column
    df_to_merge = pd.DataFrame({
        'PERMNO': final_df['PERMNO'],
        'date': final_df['date'],
        pred_col_name: df_pred['predicted'] # Assumes prediction column is named 'predicted'
    })
    
    # Merge using the now-identical keys
    final_df = pd.merge(final_df, df_to_merge, on=['PERMNO', 'date'], how='left')

# --- Save the final result ---
os.makedirs("Results", exist_ok=True)
final_df.to_csv(output_path, index=False)
print(f"\n✅ Merged Uni2TS predictions with forced keys saved to: {output_path}")
print("\nFinal DataFrame head:")
print(final_df.head())
print("\nChecking for NaNs in the final DataFrame:")
print(final_df.isna().sum())

Loading master keys from: Results/uni2ts_predictions_base_5.csv
Processing: Results/uni2ts_predictions_base_5.csv...
Processing: Results/uni2ts_predictions_base_21.csv...
Processing: Results/uni2ts_predictions_base_252.csv...
Processing: Results/uni2ts_predictions_base_512.csv...
Processing: Results/uni2ts_predictions_small_5.csv...
Processing: Results/uni2ts_predictions_small_21.csv...
Processing: Results/uni2ts_predictions_small_252.csv...
Processing: Results/uni2ts_predictions_small_512.csv...

✅ Merged Uni2TS predictions with forced keys saved to: Results/uni2ts_merged_predictions.csv

Final DataFrame head:
   PERMNO       date  EXCESS_RET  uni2ts_base_5  uni2ts_base_21  \
0   11174 2016-01-04    0.025478       0.001030        0.014883   
1   11174 2016-01-05    0.002192       0.019545        0.019260   
2   11174 2016-01-06   -0.060518       0.009723        0.022006   
3   11174 2016-01-07   -0.008925      -0.001282        0.018895   
4   11174 2016-01-08   -0.020360       0.00305

In [8]:
display(merged_df)
na_count = merged_df.isna().sum()
print(na_count)

Unnamed: 0,PERMNO,date,EXCESS_RET,unit2ts_base_5,unit2ts_base_21,unit2ts_base_252,unit2ts_base_512,unit2ts_small_5,unit2ts_small_21,unit2ts_small_252,unit2ts_small_512
0,11174,2016-01-04,0.025478,0.001030,0.014883,0.007511,0.005008,0.004547,0.015602,0.002895,0.007404
1,11174,2016-01-05,0.002192,0.019545,0.019260,0.007842,0.006862,0.011989,0.020649,0.005526,0.008984
2,11174,2016-01-06,-0.060518,0.009723,0.022006,0.009646,0.007276,0.007948,0.021080,0.003745,0.008323
3,11174,2016-01-07,-0.008925,-0.001282,0.018895,0.005178,0.004460,-0.002494,0.012171,0.003748,0.002663
4,11174,2016-01-08,-0.020360,0.003059,0.013902,0.007630,0.003904,0.002424,0.026155,0.006513,0.004421
...,...,...,...,...,...,...,...,...,...,...,...
102920,87344,2024-12-24,0.016665,0.005839,0.001923,0.026779,0.040137,0.008724,0.014163,0.045999,0.032166
102921,87344,2024-12-26,0.056121,0.007711,0.005730,0.035131,0.030198,0.008835,0.006301,0.040439,0.032199
102922,87344,2024-12-27,-0.009574,0.034208,0.015319,0.048565,0.044167,0.033586,0.018984,0.043678,0.033241
102923,87344,2024-12-30,-0.085613,0.028785,0.016099,0.037105,0.034163,0.032892,0.008009,0.027808,0.030169


PERMNO               0
date                 0
EXCESS_RET           0
unit2ts_base_5       0
unit2ts_base_21      0
unit2ts_base_252     0
unit2ts_base_512     0
unit2ts_small_5      0
unit2ts_small_21     0
unit2ts_small_252    0
unit2ts_small_512    0
dtype: int64


In [20]:
import pandas as pd
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import re

# Load predictions
df = pd.read_csv("Results/uni2ts_merged_predictions.csv")

# Directional metrics
def directional_accuracy(y_true, y_pred):
    return np.mean(np.sign(y_true) == np.sign(y_pred))

def directional_up_accuracy(y_true, y_pred):
    mask = y_true > 0
    return np.mean(np.sign(y_pred[mask]) == 1) if mask.any() else np.nan

def directional_down_accuracy(y_true, y_pred):
    mask = y_true < 0
    return np.mean(np.sign(y_pred[mask]) == -1) if mask.any() else np.nan

# Identify model prediction columns
base_cols = {"PERMNO", "date", "EXCESS_RET"}
pred_cols = [col for col in df.columns if col not in base_cols]
results_by_permno = []

for model_col in pred_cols:
    match = re.search(r'_(\d+)$', model_col)
    rolling_window = int(match.group(1)) if match else None
    model_name = model_col.rsplit("_", 1)[0]

    for permno, group in df.groupby("PERMNO"):
        y_true = group["EXCESS_RET"]
        y_pred = group[model_col]

        valid = y_true.notna() & y_pred.notna()
        y_true = y_true[valid]
        y_pred = y_pred[valid]

        if len(y_true) < 5:
            continue

        results_by_permno.append({
            'Model': model_name,
            'Rolling Window': rolling_window,
            'PERMNO': permno,
            'R2': r2_score(y_true, y_pred),
            'MSE': mean_squared_error(y_true, y_pred),
            'MAE': mean_absolute_error(y_true, y_pred),
            'Directional Accuracy': directional_accuracy(y_true, y_pred),
            'Directional Up': directional_up_accuracy(y_true, y_pred),
            'Directional Down': directional_down_accuracy(y_true, y_pred)
        })

# Create DataFrame
results_by_permno_df = pd.DataFrame(results_by_permno)
results_by_permno_df.to_csv("Results/uni2ts_metrics_by_permno.csv", index=False)
display(results_by_permno_df)

Unnamed: 0,Model,Rolling Window,PERMNO,R2,MSE,MAE,Directional Accuracy,Directional Up,Directional Down
0,uni2ts_base,5,11174,-0.471601,0.001584,0.027247,0.496908,0.852837,0.143486
1,uni2ts_base,5,20512,-0.451356,0.000433,0.014563,0.519435,0.875415,0.115094
2,uni2ts_base,5,29647,-0.383575,0.000190,0.009079,0.516784,0.864135,0.135940
3,uni2ts_base,5,39731,-0.531115,0.002723,0.036602,0.469523,0.854244,0.117698
4,uni2ts_base,5,40125,-0.443509,0.001508,0.023737,0.500442,0.846285,0.130156
...,...,...,...,...,...,...,...,...,...
395,uni2ts_small,512,86929,-0.285669,0.000979,0.022670,0.486749,1.000000,0.002584
396,uni2ts_small,512,86996,-0.258521,0.000386,0.014495,0.529594,1.000000,0.001876
397,uni2ts_small,512,87075,-0.253201,0.000536,0.016682,0.517668,0.999147,0.000917
398,uni2ts_small,512,87179,-0.269961,0.000614,0.017674,0.496908,0.998225,0.000000


In [21]:
# Aggregate by Model & Rolling Window
results_overall_df = results_by_permno_df.groupby(
    ['Model', 'Rolling Window']
).agg({
    'R2': 'mean',
    'MSE': 'mean',
    'MAE': 'mean',
    'Directional Accuracy': 'mean',
    'Directional Up': 'mean',
    'Directional Down': 'mean'
}).reset_index()

# Save to CSV
results_overall_df.to_csv("Results/uni2ts_metrics_overall.csv", index=False)
print("✅ Saved: Results/uni2ts_metrics_overall.csv")
display(results_overall_df)


✅ Saved: Results/uni2ts_metrics_overall.csv


Unnamed: 0,Model,Rolling Window,R2,MSE,MAE,Directional Accuracy,Directional Up,Directional Down
0,uni2ts_base,5,-0.47912,0.001633,0.023164,0.495362,0.844394,0.147396
1,uni2ts_base,21,-0.305401,0.001434,0.02219,0.499011,0.982819,0.01717
2,uni2ts_base,252,-0.265078,0.001368,0.022202,0.498737,0.998717,0.001075
3,uni2ts_base,512,-0.254452,0.001354,0.022095,0.499046,0.999075,0.001336
4,uni2ts_small,5,-0.484052,0.001638,0.023175,0.495583,0.845228,0.147013
5,uni2ts_small,21,-0.307157,0.001429,0.022185,0.498701,0.982658,0.016692
6,uni2ts_small,252,-0.263773,0.001363,0.022181,0.49886,0.998618,0.001419
7,uni2ts_small,512,-0.254166,0.001356,0.0221,0.498949,0.998831,0.001384
