In [1]:
import h5py
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import joblib
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import pandas as pd
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_dataset(system_id=0, points=1000):

    with h5py.File("lotka_volterra_data.h5", "r") as f:
        trajectories = f["trajectories"][:]
        time_points = f["time"][:]
        prey = trajectories[system_id, :points, 0]
        predator = trajectories[system_id, :points, 1]
        times = time_points[:points]

    return prey, predator, times

In [3]:
def load_qwen():
    model_name = "Qwen/Qwen2.5-0.5B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

    # Freeze all parameters except LM head bias
    for param in model.parameters():
        param.requires_grad = False

    # Add trainable bias to logits
    assert model.lm_head.bias is None
    model.lm_head.bias = torch.nn.Parameter(
        torch.zeros(model.config.vocab_size, device=model.device)
    )
    model.lm_head.bias.requires_grad = True

    return model, tokenizer

In [4]:
def alpha_scaler(data, alpha, decimals=3):
    data = np.array(data)
    alpha_percentile = np.percentile(data, alpha)
    rescale = data/alpha_percentile
    return np.round(rescale, decimals = decimals)   

In [5]:
def encoding(prey, predator):
    series = np.column_stack((prey, predator))
    encoded = ';'.join([','.join(map(str, row)) for row in series])
    return encoded

def decoding(data):
    time_steps = data.split(';')
    decoded = np.array([list(map(float, step.split(','))) for step in time_steps if step.strip()])
    prey = decoded[:, 0]
    predator = decoded[:, 1]
    return prey, predator

In [6]:
model, tokenizer = load_qwen()

def process_data(system_id=0, points=1000, alpha=40, decimals=3):
    prey, predator, times = get_dataset(system_id=system_id, points=points)
    new_prey = alpha_scaler(prey, alpha=alpha, decimals=decimals)
    new_predator = alpha_scaler(predator, alpha=alpha, decimals=decimals)
    encoded = encoding(new_prey, new_predator)
    tokenized_data = tokenizer(encoded, return_tensors="pt")
    return tokenized_data, encoded, np.column_stack((prey, predator, new_prey, new_predator)), times

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [7]:
example_tokenized_data, example_preprocessed_data = process_data(points=3)[:2]
print("Preprocessed data:", example_preprocessed_data)
print("Tokenized results:", example_tokenized_data["input_ids"].tolist()[0])

Preprocessed data: 1.303,1.413;1.016,1.058;0.936,0.766
Tokenized results: [16, 13, 18, 15, 18, 11, 16, 13, 19, 16, 18, 26, 16, 13, 15, 16, 21, 11, 16, 13, 15, 20, 23, 26, 15, 13, 24, 18, 21, 11, 15, 13, 22, 21, 21]


In [8]:
tokenized_data, preprocessed_data, combined_data, times = process_data()
joblib.dump(tokenized_data, "tokenized_data.pkl")

['tokenized_data.pkl']

In [9]:
loaded_data = joblib.load("tokenized_data.pkl")
with torch.no_grad():
    output = model.generate(loaded_data["input_ids"], attention_mask = loaded_data["attention_mask"], max_length = loaded_data["input_ids"].shape[1] + 1)
prediction = tokenizer.decode(output[0])

pred_prey, pred_predator = decoding(prediction)
actual_prey, actual_predator = combined_data[:,2], combined_data[:,3]

In [10]:
print("MEAN PREY SQUARED ERROR:", mean_squared_error(pred_prey, actual_prey))
print("MEAN PREY ABSOLUTE ERROR:", mean_absolute_error(pred_prey, actual_prey))
print("R2 PREY SCORE:", r2_score(pred_prey, actual_prey))
print("MEAN PREDATOR SQUARED ERROR:", mean_squared_error(pred_predator, actual_predator))
print("MEAN PREDATOR ABSOLUTE ERROR:", mean_absolute_error(pred_predator, actual_predator))
print("R2 PREDATOR SCORE:", r2_score(pred_predator, actual_predator))

MEAN PREY SQUARED ERROR: 1.1251086102890977e-15
MEAN PREY ABSOLUTE ERROR: 2.6926994319831364e-08
R2 PREY SCORE: 0.9999999999999922
MEAN PREDATOR SQUARED ERROR: 9.336275755370366e-16
MEAN PREDATOR ABSOLUTE ERROR: 2.517938614277071e-08
R2 PREDATOR SCORE: 0.999999999999995


In [None]:
def calculate_metrics(alpha=40, decimals=3):
    prey_mse_list = []
    prey_mae_list = []
    prey_r2_list = []
    predator_mse_list = []
    predator_mae_list = []
    predator_r2_list = []

    for i in tqdm.tqdm(range(1000), desc="calculating", unit="it"):
        
        tokenized_data, preprocessed_data, combined_data, times = process_data(system_id = i, alpha=alpha, decimals=decimals)
        
        with torch.no_grad():
            output = model.generate(tokenized_data["input_ids"], attention_mask = tokenized_data["attention_mask"], max_length = tokenized_data["input_ids"].shape[1] + 1)
        prediction = tokenizer.decode(output[0])

        pred_prey, pred_predator = decoding(prediction)
        actual_prey, actual_predator = combined_data[:,2], combined_data[:,3]

        prey_mse = mean_squared_error(pred_prey, actual_prey)
        prey_mae = mean_absolute_error(pred_prey, actual_prey)
        prey_r2 = r2_score(pred_prey, actual_prey)
        predator_mse = mean_squared_error(pred_predator, actual_predator)
        predator_mae = mean_absolute_error(pred_predator, actual_predator)
        predator_r2 = r2_score(pred_predator, actual_predator)

        prey_mse_list.append(prey_mse)
        prey_mae_list.append(prey_mae)
        prey_r2_list.append(prey_r2)
        predator_mse_list.append(predator_mse)
        predator_mae_list.append(predator_mae)
        predator_r2_list.append(predator_r2)

    metrics_df = pd.DataFrame({
        'prey_mse': prey_mse_list,
        'prey_mae': prey_mae_list,
        'prey_r2': prey_r2_list,
        'predator_mse': predator_mse_list,
        'predator_mae': predator_mae_list,
        'predator_r2': predator_r2_list
    })

    metrics_df.to_csv("metrics_results.csv", index=False)

    return metrics_df

In [14]:
metrics_df = calculate_metrics()

calculating: 100%|██████████| 1000/1000 [1:01:32<00:00,  3.69s/iteration]


In [15]:
metrics_df

Unnamed: 0,prey_mse,prey_mae,prey_r2,predator_mse,predator_mae,predator_r2
0,1.125109e-15,2.692699e-08,1.0,9.336276e-16,2.517939e-08,1.000000
1,8.248190e-16,2.422094e-08,1.0,6.399947e-07,8.002496e-05,0.999985
2,8.022380e-16,2.317905e-08,1.0,9.489099e-16,2.531528e-08,1.000000
3,4.945630e-15,4.566669e-08,1.0,2.583567e-15,3.705502e-08,1.000000
4,7.596991e-16,2.123952e-08,1.0,1.325676e-14,4.559040e-08,1.000000
...,...,...,...,...,...,...
995,9.258569e-16,2.269864e-08,1.0,1.346868e-15,2.665699e-08,1.000000
996,1.078725e-15,2.629519e-08,1.0,1.001742e-15,2.654076e-08,1.000000
997,2.132516e-15,3.430128e-08,1.0,1.401789e-15,2.796412e-08,1.000000
998,2.161327e-15,3.474712e-08,1.0,3.677341e-15,4.023314e-08,1.000000


In [16]:
# def calculate_flops(batch_size, sequence_length, embedding_dim, num_heads, num_layers):
#     token_embedding = batch_size * sequence_length * (embedding_dim * 2 - 1)
    
#     attention = (batch_size * sequence_length * (embedding_dim * 2 - 1)*3 + 
#                  batch_size * num_heads * sequence_length**2 * 2 )