In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os, sys
import re
import torch
from tqdm import tqdm
# from sklearn.model_selection import train_test_split
import warnings 
import seaborn as sns
warnings.filterwarnings('ignore')

In [None]:
N_atoms = 1024
T = 2000

In [None]:
macro_state_1024 = torch.load('../data/atoms_1024/macro_state.pt')
T_state_1024 = torch.load('../data/atoms_1024/T_state.pt')
time_state_1024 = torch.load('../data/atoms_1024/time_state.pt')
time_state_scaled_1024 = torch.load('../data/atoms_1024/time_state_scaled.pt')
idx_1024 = torch.where(T_state_1024[:, 0] == T)[0]

In [None]:
macro_state_8192 = torch.load('../data/atoms_8192/macro_state.pt')
macro_state_partial_8192 = torch.load('../data/atoms_8192/macro_state_partial.pt')
T_state_8192 = torch.load('../data/atoms_8192/T_state.pt')
time_state_8192 = torch.load('../data/atoms_8192/time_state.pt')
time_state_scaled_8192 = torch.load('../data/atoms_8192/time_state_scaled.pt')
idx_8192 = torch.where(T_state_8192[:, 0] == T)[0]

In [None]:
# folder = '../data/partial_sampling_atoms_8192'
folder = '../data/partial_sampling_atoms_65536'
kmc_time = np.load(os.path.join(folder, f'kmc_times_T_{T}_seed_0.npy'))
z1_train_partial = np.load(os.path.join(folder, f'z1_train_partial_T_{T}_seed_0.npy'))
z0_train = np.load(os.path.join(folder, f'z0_train_T_{T}_seed_0.npy'))
step = np.load(os.path.join(folder, f'step_T_{T}_seed_0.npy'))

In [None]:
fig = plt.figure(figsize=(40, 6))
for i in range(6):
    axes = fig.add_subplot(1, 6, i+1)
    for j in idx_1024:
        axes.plot(time_state_1024[j], macro_state_1024[j, :, i], 'r')
        # axes.plot(macro_state_1024[j, :, i], 'r')
    for j in idx_8192:
        axes.plot(time_state_8192[j], macro_state_8192[j, :, i], 'b')
        # axes.plot(macro_state_8192[j, :, i], 'b')
    if idx_8192.shape[0] > 0:
        axes.set_xlim(0, torch.max(time_state_8192[idx_8192]))

    axes.set_xlabel('Step', fontsize=20)
    axes.set_ylabel(f'Delta_{i}', fontsize=20)
    axes.set_ylim(-3.5, 3.5)
    axes.tick_params(axis='x', labelsize=20)
    axes.tick_params(axis='y', labelsize=20)
plt.title(f'T = {T} K')
plt.tight_layout()

In [None]:
torch.mean(macro_state_1024[idx_1024, -1], 0)

In [None]:
time_state_1024.shape

In [None]:
fig = plt.figure(figsize=(40, 6))
for i in range(6):
    axes = fig.add_subplot(1, 6, i+1)
    
    # Plot distribution using histplot with normalization
    sns.histplot(macro_state_1024[idx_1024, :-1, i].flatten(),
                 ax=axes, kde=True, bins=50, alpha=0.5, color='skyblue', label='Full Sampling 1024 atoms', stat='density')
    
    sns.histplot(macro_state_8192[idx_8192, :-1, i].flatten(),
                 ax=axes, kde=True, bins=50, alpha=0.5, color='tab:green', label='Full Sampling 8192 atoms', stat='density')

    sns.histplot(z0_train[:, i].flatten(),
                 ax=axes, kde=True, bins=50, alpha=0.5, color='tab:orange', label='Partial Sampling', stat='density')

    axes.set_title(f'Component {i}')
    axes.set_xlabel('Normalized Delta')
    axes.set_ylabel('Density')
    axes.tick_params(axis='x', labelsize=20)
    axes.tick_params(axis='y', labelsize=20)
    # axes.legend()

plt.tight_layout()
plt.show()

In [None]:
z0_train.shape

In [None]:
macro_state_partial_8192[idx_8192, :, i].shape

In [None]:
# idx = 6
fig = plt.figure(figsize=(40, 6))
for i in range(6):
    axes = fig.add_subplot(1, 6, i+1)
    
    # Plot distribution using histplot with normalization
    # sns.histplot((macro_state_partial_8192[idx_8192, :, i] - macro_state_8192[idx_8192, :-1, i]).flatten() / \
    #              np.sqrt((time_state_8192[idx_8192, 1:] - time_state_8192[idx_8192, :-1]).flatten()),
    #              ax=axes, kde=True, bins=50, alpha=0.5, color='skyblue', label='Full Sampling', stat='density')

    sns.histplot((z1_train_partial[:, i] - z0_train[:, i]) / np.sqrt(kmc_time[:]),
                 ax=axes, kde=True, bins=50, alpha=0.5, color='tab:orange', label='Partial Sampling', stat='density')

    axes.set_title(f'Component {i}')
    axes.set_xlabel('Normalized Delta')
    axes.set_ylabel('Density')
    # axes.legend()

plt.tight_layout()
plt.show()

In [None]:
fig = plt.figure(figsize=(8, 6))
axes = fig.add_subplot(1, 1, 1)

sns.kdeplot(kmc_time, ax=axes, alpha=0.5, color='skyblue', label='Partial Sampling')
sns.kdeplot((time_state_8192[idx_8192, 1:] - time_state_8192[idx_8192, :-1]).flatten(), ax=axes, alpha=0.5, color='tab:orange', label='Partial Sampling')
sns.kdeplot((time_state_1024[idx_1024, 1:] - time_state_1024[idx_1024, :-1]).flatten(), ax=axes, alpha=0.5, color='tab:red', label='Partial Sampling')

axes.set_title(f'Component {i}')
axes.set_xlabel('Normalized Delta')
axes.set_ylabel('Density')
# axes.legend()

plt.tight_layout()
plt.show()

In [None]:
# folder = '../data/partial_sampling_atoms_8192'
# folder = '../data/partial_sampling_atoms_65536'
folder = '../data/partial_sampling_atoms_524288'

time_step_list = []
z1_train_partial_list = []
z0_train_list = []
T_list = []
step_list = []

# for T in [300, 400, 500, 600, 700, 800, 900, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000]:
# for T in [1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000]:
for T in [2000]:
    # for seed in range(10):
    for seed in range(100):

        kmc_time = np.load(os.path.join(folder, f'kmc_times_T_{T}_seed_{seed}.npy'))
        z1_train_partial = np.load(os.path.join(folder, f'z1_train_partial_T_{T}_seed_{seed}.npy'))
        z0_train = np.load(os.path.join(folder, f'z0_train_T_{T}_seed_{seed}.npy'))
        step = np.load(os.path.join(folder, f'step_T_{T}_seed_{seed}.npy'))

        time_step_list.append(kmc_time)
        z1_train_partial_list.append(z1_train_partial)
        z0_train_list.append(z0_train)
        T_list.append(T * np.ones_like(kmc_time))

        # ALERT: since we are using 8192 atoms, we need to multiply the step by 8
        # step_list.append(step * 8)
        # step_list.append(step * 64)
        step_list.append(step * 64 * 8)

T_list = np.stack(T_list, axis=0)
time_step_list = np.stack(time_step_list, axis=0)
z1_train_partial_list = np.stack(z1_train_partial_list, axis=0)
z0_train_list = np.stack(z0_train_list, axis=0)
step_list = np.stack(step_list, axis=0)


T_list.shape, z1_train_partial_list.shape, z0_train_list.shape, step_list.shape

In [None]:
plt.scatter(T_list.flatten(), step_list.flatten(), alpha=0.5, s=0.01)
plt.yscale('log')
plt.show()

In [None]:
final_time = {}
T_unique = np.unique(T_list[:, 0])
for T in T_unique:
    idx = np.where(T_list[:, 0] == T)[0]
    steps = step_list[idx]
    time_step = time_step_list[idx]
    mean_time = np.mean(time_step / steps)
    # print(f'T = {T}, mean time per step: {mean_time}')

    # if T <= 1000:
    #     total_time = mean_time * 2e7 
    # else:
    #     total_time = mean_time * 2e6
    # # FIXME: 
    # total_time = mean_time * 2e6 
    # total_time = mean_time * 2e6 * 8
    total_time = mean_time * 2e6 * 64
    final_time[T] = total_time

In [None]:
def scale_function(T):
    # Convert T to numpy array if it isn't already
    T = np.asarray(T)
    
    # Handle scalar case
    if T.ndim == 0:
        return 1 / final_time[int(T.item())]
    
    # Handle array case
    result = np.zeros_like(T, dtype=float)
    for i, temp in enumerate(T.flat):
        result.flat[i] = 1 / final_time[int(temp)]
    
    return result

In [None]:
final_time[2000]

In [None]:
# T = np.array(list(final_time.keys()), dtype=float)      # e.g. [200,300,…
# t = np.array(list(final_time.values()), dtype=float)    # corresponding final times

# # 2) prepare for linear fit:  y = ln t, x = 1/T
# idx_high = np.where(T >= 900)[0]
# x_high = 1.0 / T[idx_high]
# y_high = np.log(t[idx_high])

# idx_low = np.where(T < 900)[0]
# x_low = 1.0 / T[idx_low]
# y_low = np.log(t[idx_low])

# # 3) do a 1st‐order polyfit: y ≈ m*x + b
# m_high, b_high = np.polyfit(x_high, y_high, 1)
# A_high = np.exp(b_high)                # prefactor
# Ea_over_kB_high = m_high               # slope = Eₐ/k_B

# print(f"Arrhenius fit for high temperature: t = {A_high:.3e} · exp({Ea_over_kB_high:.1f}/T)")

# m_low, b_low = np.polyfit(x_low, y_low, 1)
# A_low = np.exp(b_low)                # prefactor
# Ea_over_kB_low = m_low               # slope = Eₐ/k_B
# print(f"Arrhenius fit for low temperature: t = {A_low:.3e} · exp({Ea_over_kB_low:.1f}/T)")


# # 4) compute fitted curve
# t_fit_high = A_high * np.exp(Ea_over_kB_high / T[idx_high])
# t_fit_low = A_low * np.exp(Ea_over_kB_low / T[idx_low])

# # 5) plot data & fit
# plt.figure(figsize=(8,5))
# plt.scatter(T, t, label="data", color="C0")
# plt.plot(T[idx_high], t_fit_high, label="Arrhenius fit (high T)", color="C1")
# plt.plot(T[idx_low], t_fit_low, label="Arrhenius fit (low T)", color="C3")
# plt.yscale("log")
# plt.xlabel("Temperature (K)")
# plt.ylabel("Final time (ps)")
# plt.title("Arrhenius fit of final time")
# plt.legend()
# plt.grid(True, which="both", ls="--", alpha=0.5)
# plt.show()

In [None]:
# T = np.array(list(final_time.keys()), dtype=float)      # e.g. [200,300,…
# t = np.array(list(final_time.values()), dtype=float)    # corresponding final times

# # 2) prepare for linear fit:  y = ln t, x = 1/T
# x_high = 1.0 / T[:]
# y_high = np.log(t[:])

# # 3) do a 1st‐order polyfit: y ≈ m*x + b
# m_high, b_high = np.polyfit(x_high, y_high, 1)
# A_high = np.exp(b_high)                # prefactor
# Ea_over_kB_high = m_high               # slope = Eₐ/k_B

# print(f"Arrhenius fit for high temperature: t = {A_high:.3e} · exp({Ea_over_kB_high:.1f}/T)")


# # 4) compute fitted curve
# t_fit_high = A_high * np.exp(Ea_over_kB_high / T[:])

# # 5) plot data & fit
# plt.figure(figsize=(8,5))
# plt.scatter(T, t, label="data", color="C0")
# plt.plot(T[:], t_fit_high, label="Arrhenius fit (high T)", color="C1")
# plt.yscale("log")
# plt.xlabel("Temperature (K)")
# plt.ylabel("Final time (ps)")
# plt.title("Arrhenius fit of final time")
# plt.legend()
# plt.grid(True, which="both", ls="--", alpha=0.5)
# plt.show()

In [None]:
# def scale_function(T):
#     # Parameters for different temperature ranges
#     # High temperature (>= 1400K)
#     A_high = 2.198e-09
#     Ea_over_kB_high = 19786.1

#     # Low temperature (<= 600 k)
#     A_low = 4.992e-06
#     Ea_over_kB_low = 14436.1

#     # Create boolean masks for temperature ranges
#     high_temp_mask = T >= 900
#     low_temp_mask = T < 900 
    
#     # Initialize result array
#     result = np.zeros_like(T, dtype=float)
    
#     # Calculate scaling for high temperatures
#     if np.any(high_temp_mask):
#         result[high_temp_mask] = np.exp(-np.log(A_high) - Ea_over_kB_high/T[high_temp_mask])
        
#     # Calculate scaling for low temperatures
#     if np.any(low_temp_mask):
#         result[low_temp_mask] = np.exp(-np.log(A_low) - Ea_over_kB_low/T[low_temp_mask])
    
#     return result

In [None]:
# def scale_function(T):
#     # Parameters for different temperature ranges
#     # A = 5.259e-09
#     # Ea_over_kB = 17562.5
#     # A = 4.753e-09
#     A = 3.054e-09 

#     # Ea_over_kB = 18143.9
#     Ea_over_kB = 19184.8
#     result = np.exp(-np.log(A) - Ea_over_kB / T)
#     return result

In [None]:
final_time_scaled = {}
for key, val in final_time.items():
    # final_time_scaled[key] = scale_function(key) * val
    final_time_scaled[key] = scale_function(key) * val

In [None]:
plt.figure(figsize=(8,5))
plt.scatter(final_time_scaled.keys(), final_time_scaled.values(), label="data", color="C0")
plt.xlabel("Temperature (K)")
plt.ylabel("Final time (ps)")
plt.legend()
plt.grid(True, which="both", ls="--", alpha=0.5)
plt.show()

In [None]:
time_step_list = time_step_list * scale_function(T_list)

In [None]:
plt.plot(T_state_8192[:, :-1], time_state_scaled_8192[:, 1:] - time_state_scaled_8192[:, :-1], 'go', alpha=1, markersize=0.05)
plt.plot(T_list, time_step_list, 'bo', alpha=0.05, markersize=0.05)
# plt.plot(final_time_scaled.keys(), mean_dt, 'ro', markersize=3)
plt.yscale('log')
plt.show()

In [None]:
delta_z = z1_train_partial_list - z0_train_list
delta_z.shape

In [None]:
delta_z = z1_train_partial_list - z0_train_list
index = np.where((T_list[:, 0] == 500))[0]
index_2000 = np.where((T_list[:, 0] == 2000))[0]
fig = plt.figure(figsize=(40, 6))
for i in range(6):
    axes = fig.add_subplot(1, 6, i+1)

    sns.histplot((delta_z[index, :, i]).flatten() / np.sqrt(time_step_list[index]).flatten(),
                 ax=axes, kde=True, bins=50, alpha=0.5, color='tab:orange', label='Partial Sampling', stat='density')

    sns.histplot((delta_z[index_2000, :, i]).flatten() / np.sqrt(time_step_list[index_2000]).flatten(),
                 ax=axes, kde=True, bins=50, color='skyblue', alpha=0.7, stat='density')

    axes.set_title(f'Component {i}')
    axes.set_xlabel('Normalized Delta')
    axes.set_ylabel('Density')

plt.tight_layout()
plt.show()

In [None]:
z0_train_list = torch.tensor(z0_train_list, dtype=torch.float32)
z1_train_partial_list = torch.tensor(z1_train_partial_list, dtype=torch.float32)
time_step_list = torch.tensor(time_step_list, dtype=torch.float32)
T_list = torch.tensor(T_list, dtype=torch.int32)

torch.save(z0_train_list, f'{folder}/z0_train.pt')
torch.save(z1_train_partial_list, f'{folder}/z1_train_partial.pt')
torch.save(time_step_list, f'{folder}/time_step.pt')
torch.save(T_list, f'{folder}/T.pt')

In [None]:
folder 