In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['hatch.linewidth'] = 0.2
import numpy as np
import pandas as pd
import pickle
from tqdm.notebook import tqdm
import polars as pl
import xgboost as xgb
print("xgboost version:", xgb.__version__)

import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.signal_categories import topological_category_labels, topological_category_colors, topological_category_labels_latex, topological_category_hatches, topological_categories_dic
from src.signal_categories import filetype_category_labels, filetype_category_colors, filetype_category_hatches
from src.signal_categories import del1g_detailed_category_labels, del1g_detailed_category_colors, del1g_detailed_category_labels_latex, del1g_detailed_category_hatches, del1g_detailed_categories_dic
from src.signal_categories import del1g_simple_category_labels, del1g_simple_category_colors, del1g_simple_category_labels_latex, del1g_simple_category_hatches, del1g_simple_categories_dic
from src.signal_categories import train_category_labels, train_category_labels_latex

from src.ntuple_variables.pandora_variables import pandora_scalar_second_half_training_vars

from src.file_locations import intermediate_files_location

from src.plot_helpers import make_histogram_plot

from src.ntuple_variables.variables import combined_training_vars


In [None]:
print("loading all_df.parquet...")
print("file size:", os.path.getsize(f"{intermediate_files_location}/all_df.parquet") / (1024**3), "GB")
all_df = pl.read_parquet(f"{intermediate_files_location}/all_df.parquet")
print(f"{all_df.shape=}")


In [None]:
# Analyze memory size and dtypes
print("Memory size analysis:")
print(f"Estimated size: {all_df.estimated_size() / (1024**3):.4f} GB")
print()

print("Column dtype counts:")
dtype_counts = pd.Series([str(dtype) for dtype in all_df.schema.values()]).value_counts().sort_index()
print(dtype_counts)
print()
print(f"Total columns: {len(all_df.columns)}")


In [None]:
# Print Float64 and Int64 columns
print("Float64 columns:")
float64_cols = [col for col, dtype in all_df.schema.items() if dtype == pl.Float64]
for col in float64_cols:
    print(f"  {col}")
print(f"Total Float64 columns: {len(float64_cols)}")
print()

print("Int64 columns:")
int64_cols = [col for col, dtype in all_df.schema.items() if dtype == pl.Int64]
for col in int64_cols:
    print(f"  {col}")
print(f"Total Int64 columns: {len(int64_cols)}")


In [None]:
# Convert Float64 to Float32 and Int64 to Int32 to reduce memory usage
# Clip Int64 values to Int32 range before conversion
print("Converting dtypes to reduce memory usage...")
memory_before = all_df.estimated_size() / (1024**3)
print(f"Memory before conversion: {memory_before:.4f} GB")
print()

# Get columns to convert
float64_cols = [col for col, dtype in all_df.schema.items() if dtype == pl.Float64]
int64_cols = [col for col, dtype in all_df.schema.items() if dtype == pl.Int64]

print(f"Converting {len(float64_cols)} Float64 columns to Float32")
print(f"Converting {len(int64_cols)} Int64 columns to Int32 (clipping values to Int32 range)")
print()

# Convert Float64 to Float32
if float64_cols:
    all_df = all_df.with_columns([pl.col(col).cast(pl.Float32) for col in float64_cols])

# Convert Int64 to Int32, clipping values to Int32 range
# Int32 range: -2,147,483,648 to 2,147,483,647
int32_min = -2147483648
int32_max = 2147483647

if int64_cols:
    # Clip values to Int32 range, then cast to Int32
    all_df = all_df.with_columns([
        pl.col(col).clip(int32_min, int32_max).cast(pl.Int32) 
        for col in int64_cols
    ])
    print(f"Converted {len(int64_cols)} Int64 columns to Int32")

memory_after = all_df.estimated_size() / (1024**3)
memory_saved_gb = memory_before - memory_after
print(f"\nMemory after conversion: {memory_after:.4f} GB")
print(f"Memory saved: {memory_saved_gb:.4f} GB ({memory_saved_gb/memory_before*100:.1f}%)")
print()

print("Updated dtype counts:")
dtype_counts = pd.Series([str(dtype) for dtype in all_df.schema.values()]).value_counts().sort_index()
print(dtype_counts)
