In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os, sys
sys.path.append('../')
import re
import torch
from utils.utils import *
from tqdm import tqdm
set_seed(42)

In [None]:
def extract_number(filename):
    match = re.search(r'output_(\d+)\.npy', filename)
    if match:
        return int(match.group(1))
    else:
        return float('inf')

#  Load the dataset

In [None]:
N_atoms = 1024
output_dir_short = '../data/output_atoms_1024_steps_2000000'
# output_dir_long = '../data/output_atoms_1024_steps_20000000'


In [None]:
def preprocess(output_dir):
    for folder in tqdm(os.listdir(output_dir)):
        T = int(folder.split('_')[1])
        folder_path = os.path.join(output_dir, folder)
        log_path = os.path.join(folder_path, 'chemical_order.csv')
        df = pd.read_csv(log_path)
        
        # ========= macroscopic data =========
        macro_val = np.array([
            df['delta_NbNb'].values,
            df['delta_NbMo'].values,
            df['delta_NbTa'].values,
            df['delta_MoMo'].values,
            df['delta_MoTa'].values,
            df['delta_TaTa'].values
        ]).T

        # ========= T,steps,time data =========
        time_path = os.path.join(folder_path, 'log.csv')
        df_time = pd.read_csv(time_path)
        time = df_time['time'].values
        time = np.insert(time, 0, 0.0)  # Insert initial time step
        T_state = np.ones(macro_val.shape[0], dtype=np.int32) * int(T)
        step = df['step'].values

        config_path = os.path.join(folder_path, 'config_data')
        file_path = [f for f in os.listdir(config_path) if f.endswith('.npy')]
        file_path = sorted(file_path, key=extract_number)
        
        # ========= microscopic data =========
        # micro_val = []
        # for file in file_path:
        #     file_full_path = os.path.join(config_path, file)
        #     grid = np.load(file_full_path)
        #     micro_val.append(grid)
        # micro_val = np.array(micro_val)

        # ========= partial macroscopic data =========
        # idx_partial = np.random.randint(0, 8, size=(macro_val.shape[0]-1,))
        # initial_grid = micro_val[0]
        # z0_partial = []
        # z1_partial = []
        # for j in range(idx_partial.shape[0]):
        #     idx = idx_partial[j]
        #     partial_grid_0 = micro_val[j].reshape(2, 8, 2, 8, 2, 8).transpose(0, 2, 4, 1, 3, 5).reshape(-1, 16, 16, 16)[idx]
        #     partial_grid_1 = micro_val[j+1].reshape(2, 8, 2, 8, 2, 8).transpose(0, 2, 4, 1, 3, 5).reshape(-1, 16, 16, 16)[idx]

        #     order_0 = cal_local_chemical_order(initial_grid, partial_grid_0)
        #     order_1 = cal_local_chemical_order(initial_grid, partial_grid_1)
        #     z0_partial.append(order_0)
        #     z1_partial.append(order_1)

        # z0_partial = np.array(z0_partial)
        # z1_partial = np.array(z1_partial)
        # macro_val_partial = macro_val[:-1] + (z1_partial - z0_partial)


        # ========= save data =========
        # np.save(os.path.join(folder_path, 'micro_val.npy'), micro_val)
        np.save(os.path.join(folder_path, 'macro_val.npy'), macro_val)
        np.save(os.path.join(folder_path, 'time.npy'), time)
        np.save(os.path.join(folder_path, 'T_state.npy'), T_state)
        np.save(os.path.join(folder_path, 'step.npy'), step)
        # np.save(os.path.join(folder_path, 'macro_val_partial.npy'), macro_val_partial)

In [None]:
preprocess(output_dir_short)
# preprocess(output_dir_long)

In [None]:

micro_state = [] 
macro_state = []
# macro_state_partial = []
T_state = []
time_state = []
step_state = []

# for output_dir in [output_dir_short, output_dir_long]:
for output_dir in [output_dir_short]:
# for output_dir in [output_dir_long]:
    for folder in tqdm(os.listdir(output_dir)):
        T = int(folder.split('_')[1])
        folder_path = os.path.join(output_dir, folder)

        micro_state.append(np.load(os.path.join(folder_path, 'micro_val.npy')))
        macro_state.append(np.load(os.path.join(folder_path, 'macro_val.npy')))
        # macro_state_partial.append(np.load(os.path.join(folder_path, 'macro_val_partial.npy')))
        time_state.append(np.load(os.path.join(folder_path, 'time.npy')))
        T_state.append(np.load(os.path.join(folder_path, 'T_state.npy')))
        step_state.append(np.load(os.path.join(folder_path, 'step.npy')))


micro_state = np.stack(micro_state, axis=0)
macro_state = np.stack(macro_state, axis=0)
# macro_state_partial = np.stack(macro_state_partial, axis=0)
time_state = np.stack(time_state, axis=0)
T_state = np.stack(T_state, axis=0)
step_state = np.stack(step_state, axis=0)

In [None]:
# micro_state.shape, 
macro_state.shape, T_state.shape, time_state.shape, step_state.shape

In [None]:
idx = np.where(T_state[:, 0]==2000)[0]
time_state[idx, -1]

In [None]:
T = 800
indices = np.where(T_state[:, 0] == T)[0]
fig = plt.figure(figsize=(40, 6))
for i in range(6):
    axes = fig.add_subplot(1, 6, i+1)
    for j in indices:
        axes.plot(time_state[j], macro_state[j, :, i])
        # axes.plot(macro_state[j, :, i])
    axes.set_xlabel('Step')
    axes.set_ylabel(f'Delta_{i}')
    axes.hlines(0, 0, np.max(time_state[indices]), colors='black', linestyles='dashed', linewidth=1)
    axes.set_ylim(-3.5, 3.5)
    # plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.grid()
plt.title(f'T = {T} K')
plt.tight_layout()

In [None]:
final_time = {}
T_unique = np.unique(T_state[:, 0])
for T in T_unique:
    idx = np.where(T_state[:, 0] == T)[0]
    steps = np.unique(step_state[idx, -1])
    final_time[T.item()] = time_state[idx, -1].max().item() 

In [None]:
final_time[2000]

In [None]:
T = np.array(list(final_time.keys()), dtype=float)      # e.g. [200,300,…]
t = np.array(list(final_time.values()), dtype=float)    # corresponding final times

# 2) prepare for linear fit:  y = ln t, x = 1/T
x_high = 1.0 / T[9:]
y_high = np.log(t[9:])

x_med = 1.0 / T[4:9]
y_med = np.log(t[4:9])

x_low = 1.0 / T[:4]
y_low = np.log(t[:4])

# 3) do a 1st‐order polyfit: y ≈ m*x + b
m_high, b_high = np.polyfit(x_high, y_high, 1)
A_high = np.exp(b_high)                # prefactor
Ea_over_kB_high = m_high               # slope = Eₐ/k_B

print(f"Arrhenius fit for high temperature: t = {A_high:.3e} · exp({Ea_over_kB_high:.1f}/T)")

m_med, b_med = np.polyfit(x_med, y_med, 1)
A_med = np.exp(b_med)                # prefactor
Ea_over_kB_med = m_med               # slope = Eₐ/k_B
print(f"Arrhenius fit for medium temperature: t = {A_med:.3e} · exp({Ea_over_kB_med:.1f}/T)")

m_low, b_low = np.polyfit(x_low, y_low, 1)
A_low = np.exp(b_low)                # prefactor
Ea_over_kB_low = m_low               # slope = Eₐ/k_B
print(f"Arrhenius fit for low temperature: t = {A_low:.3e} · exp({Ea_over_kB_low:.1f}/T)")


# 4) compute fitted curve
t_fit_high = A_high * np.exp(Ea_over_kB_high / T[9:])
t_fit_med = A_med * np.exp(Ea_over_kB_med / T[4:9])
t_fit_low = A_low * np.exp(Ea_over_kB_low / T[:4])

# 5) plot data & fit
plt.figure(figsize=(8,5))
plt.scatter(T, t, label="data", color="C0")
plt.plot(T[9:], t_fit_high, label="Arrhenius fit (high T)", color="C1")
plt.plot(T[4:9], t_fit_med, label="Arrhenius fit (medium T)", color="C2")
plt.plot(T[:4], t_fit_low, label="Arrhenius fit (low T)", color="C3")
plt.yscale("log")
plt.xlabel("Temperature (K)")
plt.ylabel("Final time (ps)")
plt.title("Arrhenius fit of final time")
plt.legend()
plt.grid(True, which="both", ls="--", alpha=0.5)
plt.show()

In [None]:
# train_idx, val_idx = train_test_split(np.arange(macro_state.shape[0]), test_size=0.1)
# train_idx.shape, val_idx.shape

In [None]:
def scale_function(T):
    # Parameters for different temperature ranges
    # High temperature (>= 1400K)
    A_high = 3.946e-08
    Ea_over_kB_high = 18036.1

    # A_med = 1.877e-07
    A_med = 5.316e-07
    # Ea_over_kB_med = 19477.0  
    Ea_over_kB_med = 15650.0

    # Low temperature (<= 600 k)
    # A_low = 7.512e-06
    A_low = 1.847e-08
    # Ea_over_kB_low = 15030.2
    Ea_over_kB_low = 16478.3

    # Create boolean masks for temperature ranges
    high_temp_mask = T >= 1400
    med_temp_mask = (T <= 1200) & (T > 600)
    low_temp_mask = T <= 600 
    
    # Initialize result array
    result = np.zeros_like(T, dtype=float)
    
    # Calculate scaling for high temperatures
    if np.any(high_temp_mask):
        result[high_temp_mask] = np.exp(-np.log(A_high) - Ea_over_kB_high/T[high_temp_mask])
    
    if np.any(med_temp_mask):
        result[med_temp_mask] = np.exp(-np.log(A_med) - Ea_over_kB_med/T[med_temp_mask])
    
    # Calculate scaling for low temperatures
    if np.any(low_temp_mask):
        result[low_temp_mask] = np.exp(-np.log(A_low) - Ea_over_kB_low/T[low_temp_mask])
    
    return result

In [None]:
scale_value = scale_function(T_state)
time_state_scaled = time_state * scale_value

In [None]:
final_time_scaled = {}
for T in T_unique:
    idx = np.where(T_state[:, 0] == T)[0]
    final_time_scaled[T.item()] = time_state_scaled[idx, -1].max().item()

In [None]:
plt.figure(figsize=(8,5))
plt.scatter(final_time_scaled.keys(), final_time_scaled.values(), label="data", color="C0")
plt.xlabel("Temperature (K)")
plt.ylabel("Final time (ps)")
plt.legend()
plt.grid(True, which="both", ls="--", alpha=0.5)
plt.show()

In [None]:
macro_state = torch.tensor(macro_state, dtype=torch.float32)
# macro_state_partial = torch.tensor(macro_state_partial, dtype=torch.float32)
T_state = torch.tensor(T_state, dtype=torch.int32)
time_state = torch.tensor(time_state, dtype=torch.float32)
time_state_scaled = torch.tensor(time_state_scaled, dtype=torch.float32)
step_state = torch.tensor(step_state, dtype=torch.int32)


In [None]:
train_idx, val_idx = train_test_split(np.arange(macro_state.shape[0]), test_size=0.2)
train_idx.shape, val_idx.shape

In [None]:
plt.plot(T_state[:, :-1], time_state_scaled[:, 1:] - time_state_scaled[:, :-1], 'go', alpha=1, markersize=0.05)
# plt.plot(final_time_scaled.keys(), mean_dt, 'ro', markersize=3)
plt.yscale('log')
plt.show()

In [None]:
# save the data
save_dir = f'../data/atoms_{N_atoms}'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# ========= save trainingdata =========
# torch.save(micro_state[train_idx], f'{save_dir}/micro_state.pt')
torch.save(macro_state[train_idx], f'{save_dir}/macro_state.pt')
torch.save(T_state[train_idx], f'{save_dir}/T_state.pt')
torch.save(time_state[train_idx], f'{save_dir}/time_state.pt')
torch.save(time_state_scaled[train_idx], f'{save_dir}/time_state_scaled.pt')

 # ========= save partial data =========
# save the partial macro state and idx_partial
# torch.save(idx_partial[train_idx], f'../data/atoms_{N_atoms}/idx_partial.pt')
# torch.save(macro_state_partial[train_idx], f'../data/atoms_{N_atoms}/macro_state_partial.pt')

# ========= save validation data =========
# torch.save(micro_state[val_idx], f'{save_dir}/micro_state_val.pt')
torch.save(macro_state[val_idx], f'{save_dir}/macro_state_val.pt')
torch.save(T_state[val_idx], f'{save_dir}/T_state_val.pt')
torch.save(time_state[val_idx], f'{save_dir}/time_state_val.pt')
torch.save(time_state_scaled[val_idx], f'{save_dir}/time_state_scaled_val.pt')