# Similarity Matching Between Test and Train+Val Datasets

This notebook loads the train+val and test datasets (stored as pickle files) and computes similarity for each test battery based on features that are computed solely from `trimmed_q_d_n`.

The features used are:

- **slope_last_k_cycles** for k in [10, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
- **mean_grad_last_k_cycles** for k in [10, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
- **trimmed_q_d_n_avg**
- **total_cycles**

For each test battery and each such feature, the notebook computes the absolute difference to each battery in the train+val dataset. It then sorts the results (smallest difference first) and takes the top 25 matches. The final CSV file contains one row per feature per test battery with the following columns:

```
TEST_BATTERY_ID, TEST_BATTERY_QUERY_FEATURE, TEST_BATTERY_QUERY_VALUE, 
    1_Most_Similar_Battery_ID, 1_Most_Similar_Battery_ID_SCORE, 
    2_Most_Similar_Battery_ID, 2_Most_Similar_Battery_ID_SCORE, 
    ... up to 25_Most_Similar_Battery_ID, 25_Most_Similar_Battery_ID_SCORE
```

In [1]:
import pickle
import csv

# Adjust file paths as needed
train_val_file = "/home/jaf/battery-lifespan-kg/resources/processed/processed_for_kg_v2.pkl"  # Update with the actual path
test_file = "/home/jaf/battery-lifespan-kg/resources/processed/processed_test.pkl"            # Update with the actual path

# Load the train+val dataset
with open(train_val_file, "rb") as f:
    train_val_data = pickle.load(f)

# Load the test dataset
with open(test_file, "rb") as f:
    test_data = pickle.load(f)

print(f"Loaded {len(train_val_data)} batteries from train+val dataset.")
print(f"Loaded {len(test_data)} batteries from test dataset.")

Loaded 83 batteries from train+val dataset.
Loaded 40 batteries from test dataset.


In [2]:
# Define the features computed solely from trimmed_q_d_n
k_values = [10, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]

feature_names = []
for k in k_values:
    feature_names.append(f"slope_last_{k}_cycles")
    feature_names.append(f"mean_grad_last_{k}_cycles")
feature_names.extend(["cycle"])

print("Features to be matched:", feature_names)

Features to be matched: ['slope_last_10_cycles', 'mean_grad_last_10_cycles', 'slope_last_50_cycles', 'mean_grad_last_50_cycles', 'slope_last_100_cycles', 'mean_grad_last_100_cycles', 'slope_last_200_cycles', 'mean_grad_last_200_cycles', 'slope_last_300_cycles', 'mean_grad_last_300_cycles', 'slope_last_400_cycles', 'mean_grad_last_400_cycles', 'slope_last_500_cycles', 'mean_grad_last_500_cycles', 'slope_last_600_cycles', 'mean_grad_last_600_cycles', 'slope_last_700_cycles', 'mean_grad_last_700_cycles', 'slope_last_800_cycles', 'mean_grad_last_800_cycles', 'slope_last_900_cycles', 'mean_grad_last_900_cycles', 'slope_last_1000_cycles', 'mean_grad_last_1000_cycles', 'cycle']


In [3]:
# For each test battery and for each feature, find the top 25 most similar batteries from the train+val dataset
top_n = 10
rows = []

for test_bat_id, test_features in test_data.items():
    for feature in feature_names:
        # Proceed only if the feature exists in the test battery
        if feature in test_features:
            test_value = test_features[feature]
            similarities = []
            
            # Compute similarity for each train+val battery that has this feature
            for train_bat_id, train_features in train_val_data.items():
                if feature in train_features:
                    train_value = train_features[feature]
                    diff = abs(test_value - train_value)
                    similarities.append((train_bat_id, diff))
            
            # Sort by absolute difference (smallest difference first)
            similarities.sort(key=lambda x: x[1])
            
            # Take top_n matches (if there are fewer than top_n matches, all will be used)
            top_matches = similarities[:top_n]
            
            # Construct the CSV row
            # Format: [TEST_BATTERY_ID, TEST_BATTERY_QUERY_FEATURE, TEST_BATTERY_QUERY_VALUE, 
            #          1_Most_Similar_Battery_ID, 1_Most_Similar_Battery_ID_SCORE, ..., 
            #          25_Most_Similar_Battery_ID, 25_Most_Similar_Battery_ID_SCORE]
            row = [test_bat_id, feature, test_value]
            for match in top_matches:
                row.extend([match[0], match[1]])
            rows.append(row)

# Construct the CSV header
header = ["TEST_BATTERY_ID", "TEST_BATTERY_QUERY_FEATURE", "TEST_BATTERY_QUERY_VALUE"]
for i in range(1, top_n+1):
    header.extend([f"{i}_Most_Similar_Battery_ID", f"{i}_Most_Similar_Battery_ID_SCORE"])

# Write the results to a CSV file
output_csv_file = "/home/jaf/battery-lifespan-kg/resources/output_similarity.csv"  # Update with the desired output path
with open(output_csv_file, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(header)
    writer.writerows(rows)

print(f"CSV file saved at {output_csv_file}")

CSV file saved at /home/jaf/battery-lifespan-kg/resources/output_similarity.csv


# Use only the sample with existing label in the dataset

In [4]:
# Sample dataset (using your provided data)
battery_id_to_charging_policy = {'b1c0': '3.6C(80%)-3.6C',
 'b1c1': '3.6C(80%)-3.6C',
 'b1c2': '3.6C(80%)-3.6C',
 'b1c3': '4C(80%)-4C',
 'b1c4': '4C(80%)-4C',
 'b1c5': '4.4C(80%)-4.4C',
 'b1c6': '4.8C(80%)-4.8C',
 'b1c7': '4.8C(80%)-4.8C',
 'b1c9': '5.4C(40%)-3.6C',
 'b1c11': '5.4C(50%)-3C',
 'b1c14': '5.4C(60%)-3C',
 'b1c15': '5.4C(60%)-3C',
 'b1c16': '5.4C(60%)-3.6C',
 'b1c17': '5.4C(60%)-3.6C',
 'b1c18': '5.4C(70%)-3C',
 'b1c19': '5.4C(70%)-3C',
 'b1c20': '5.4C(80%)-5.4C',
 'b1c21': '5.4C(80%)-5.4C',
 'b1c23': '6C(30%)-3.6C',
 'b1c24': '6C(40%)-3C',
 'b1c25': '6C(40%)-3C',
 'b1c26': '6C(40%)-3.6C',
 'b1c27': '6C(40%)-3.6C',
 'b1c28': '6C(50%)-3C',
 'b1c29': '6C(50%)-3C',
 'b1c30': '6C(50%)-3.6C',
 'b1c31': '6C(50%)-3.6C',
 'b1c32': '6C(60%)-3C',
 'b1c33': '6C(60%)-3C',
 'b1c34': '7C(30%)-3.6C',
 'b1c35': '7C(30%)-3.6C',
 'b1c36': '7C(40%)-3C',
 'b1c37': '7C(40%)-3C',
 'b1c38': '7C(40%)-3.6C',
 'b1c39': '7C(40%)-3.6C',
 'b1c40': '8C(15%)-3.6C',
 'b1c41': '8C(15%)-3.6C',
 'b1c42': '8C(25%)-3.6C',
 'b1c43': '8C(25%)-3.6C',
 'b1c44': '8C(35%)-3.6C',
 'b1c45': '8C(35%)-3.6C',
 'b2c0': '1C(4%)-6C',
 'b2c1': '2C(10%)-6C',
 'b2c2': '2C(2%)-5C',
 'b2c3': '2C(7%)-5.5C',
 'b2c4': '3.6C(22%)-5.5C',
 'b2c5': '3.6C(2%)-4.85C',
 'b2c6': '3.6C(30%)-6C',
 'b2c10': '3.6C(9%)-5C',
 'b2c11': '4C(13%)-5C',
 'b2c12': '4C(31%)-5',
 'b2c13': '4C(40%)-6C',
 'b2c14': '4C(4%)-4.85C',
 'b2c17': '4.4C(24%)-5C',
 'b2c18': '4.4C(47%)-5.5C',
 'b2c19': '4.4C(55%)-6C',
 'b2c20': '4.4C(8%)-4.85C',
 'b2c21': '4.65C(19%)-4.85C',
 'b2c22': '4.65C(44%)-5C',
 'b2c23': '4.65C(69%)-6C',
 'b2c24': '4.8C(80%)-4.8C',
 'b2c25': '4.8C(80%)-4.8C',
 'b2c26': '4.8C(80%)-4.8C',
 'b2c27': '4.9C(27%)-4.75C',
 'b2c28': '4.9C(61%)-4.5C',
 'b2c29': '4.9C(69%)-4.25C',
 'b2c30': '5.2C(10%)-4.75C',
 'b2c31': '5.2C(37%)-4.5C',
 'b2c32': '5.2C(50%)-4.25C',
 'b2c33': '5.2C(58%)-4C',
 'b2c34': '5.2C(66%)-3.5C',
 'b2c35': '5.2C(71%)-3C',
 'b2c36': '5.6C(25%)-4.5C',
 'b2c37': '5.6C(38%)-4.25C',
 'b2c38': '5.6C(47%)-4C',
 'b2c39': '5.6C(58%)-3.5C',
 'b2c40': '5.6C(5%)-4.75C',
 'b2c41': '5.6C(65%)-3C',
 'b2c42': '6C(20%)-4.5C',
 'b2c43': '6C(31%)-4.25C',
 'b2c44': '6C(40%)-4C',
 'b2c45': '6C(4%)-4.75C',
 'b2c46': '6C(52%)-3.5C',
 'b2c47': '6C(60%)-3C',
 'b3c0': '5C(67%)-4C-newstructure',
 'b3c1': '5.3C(54%)-4C-newstructure',
 'b3c3': '5.6C(36%)-4.3C-newstructure',
 'b3c4': '5.6C(19%)-4.6C-newstructure',
 'b3c5': '5.6C(36%)-4.3C-newstructure',
 'b3c6': '3.7C(31%)-5.9C-newstructure',
 'b3c7': '4.8C(80%)-4.8C-newstructure',
 'b3c8': '5C(67%)-4C-newstructure',
 'b3c9': '5.3C(54%)-4C-newstructure',
 'b3c10': '4.8C(80%)-4.8C-newstructure',
 'b3c11': '5.6C(19%)-4.6C-newstructure',
 'b3c12': '5.6C(36%)-4.3C-newstructure',
 'b3c13': '5.6C(19%)-4.6C-newstructure',
 'b3c14': '5.6C(36%)-4.3C-newstructure',
 'b3c15': '5.9C(15%)-4.6C-newstructure',
 'b3c16': '4.8C(80%)-4.8C-newstructure',
 'b3c17': '5.3C(54%)-4C-newstructure',
 'b3c18': '5.6C(19%)-4.6C-newstructure',
 'b3c19': '5.6C(36%)-4.3C-newstructure',
 'b3c20': '5C(67%)-4C-newstructure',
 'b3c21': '3.7C(31%)-5.9C-newstructure',
 'b3c22': '5.9C(60%)-3.1C-newstructure',
 'b3c24': '5C(67%)-4C-newstructure',
 'b3c25': '5.3C(54%)-4C-newstructure',
 'b3c26': '5.6C(19%)-4.6C-newstructure',
 'b3c27': '5.6C(36%)-4.3C-newstructure',
 'b3c28': '3.7C(31%)-5.9C-newstructure',
 'b3c29': '5.9C(15%)-4.6C-newstructure',
 'b3c30': '5.3C(54%)-4C-newstructure',
 'b3c31': '5.9C(60%)-3.1C-newstructure',
 'b3c33': '5C(67%)-4C-newstructure',
 'b3c34': '5.3C(54%)-4C-newstructure',
 'b3c35': '5.6C(19%)-4.6C-newstructure',
 'b3c36': '5.6C(36%)-4.3C-newstructure',
 'b3c38': '5C(67%)-4C-newstructure',
 'b3c39': '5.3C(54%)-4C-newstructure',
 'b3c40': '5.6C(19%)-4.6C-newstructure',
 'b3c41': '5.6C(36%)-4.3C-newstructure',
 'b3c44': '5.3C(54%)-4C-newstructure',
 'b3c45': '4.8C(80%)-4.8C-newstructure'
}

# Build a set of charging policies from b1xxx and b2xxx.
# (They don't include "-newstructure", but we use replace() in case they ever do.)
valid_policies = {v.replace("-newstructure", "") for k, v in battery_id_to_charging_policy.items() if k.startswith('b1') or k.startswith('b2')}

# Now, check b3xxx keys and see if their value (after stripping "-newstructure")
# matches one of the valid policies.
matching_b3_keys = {k for k, v in battery_id_to_charging_policy.items() 
                      if k.startswith('b3') and v.replace("-newstructure", "") in valid_policies}

print(matching_b3_keys)

{'b3c16', 'b3c7', 'b3c45', 'b3c10'}


In [5]:
import pandas as pd

# Replace this with your actual matching_b3_keys set from your previous extraction.
# Read the CSV file into a DataFrame.
df = pd.read_csv('/home/jaf/battery-lifespan-kg/resources/output_similarity.csv')

# Filter rows where the TEST_BATTERY_ID is in the matching_b3_keys.
filtered_df = df[df['TEST_BATTERY_ID'].isin(matching_b3_keys)]

# Save the filtered rows to a new CSV file.
filtered_df.to_csv('/home/jaf/battery-lifespan-kg/resources/matched_label_output_similarity.csv', index=False)

In [6]:
# for each matching battery, save the q_d_n from test_data values to a new txt file (one q_d_n index per line )
for matching_b3_key in matching_b3_keys:
    q_d_n = test_data[matching_b3_key]['q_d_n']
    with open(f'/home/jaf/battery-lifespan-kg/resources/testset/{matching_b3_key}.txt', 'w') as f:
        for item in q_d_n:
            f.write("%s,\n" % item)
        print(f"Saved q_d_n for {matching_b3_key} to {matching_b3_key}_q_d_n.txt")
        

Saved q_d_n for b3c16 to b3c16_q_d_n.txt
Saved q_d_n for b3c7 to b3c7_q_d_n.txt
Saved q_d_n for b3c45 to b3c45_q_d_n.txt
Saved q_d_n for b3c10 to b3c10_q_d_n.txt


# Generating questions

In [7]:
import os
import csv
from dotenv import load_dotenv

# --- Retrieve API keys and database credentials ---
load_dotenv()  # Ensure .env is loaded
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# --- Initialize OpenAI LLM ---
from langchain.chat_models import ChatOpenAI
from langchain import PromptTemplate, LLMChain

llm = ChatOpenAI(
    model_name="gpt-4",
    openai_api_key=OPENAI_API_KEY,
    temperature=0.0  # adjust to your preference
)

# --- Define Prompt Template ---
prompt_template = PromptTemplate(
    input_variables=["battery_id", "query_feature", "n"],
    template=(
        "You are helping to generate test questions for a Knowledge Graph RAG pipeline.\n"
        "For an unknown battery with query feature '{query_feature}', "
        "please generate {n} variations of a sample question. Each question should be on a separate line and similar to:\n"
        "'Can you find batteries similar to one with {query_feature}?'\n"
        "Make sure each variation is phrased slightly differently."
    )
)

# --- Create LLMChain ---
llm_chain = LLMChain(llm=llm, prompt=prompt_template)

def generate_sample_question(battery_id: str, query_feature: str, n: int) -> str:
    """
    Generates sample questions using the LLM chain.
    Returns the generated questions as a single string, with each variation on a new line.
    """
    response = llm_chain.run(battery_id=battery_id, query_feature=query_feature, n=n)
    return response.strip()

def read_csv(input_file: str) -> list:
    """Reads the input CSV file and returns a list of dictionaries for each row."""
    with open(input_file, newline='', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        return list(reader)

def write_csv(rows: list, fieldnames: list, output_file: str):
    """Writes the list of row dictionaries to the output CSV file."""
    with open(output_file, "w", newline='', encoding='utf-8') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

def process_csv(input_file: str, output_file: str, n: int):
    """
    Processes the input CSV:
      - Reads each row.
      - Generates sample questions using the LLM.
      - Splits the output into {n} variations.
      - Adds new columns 'SAMPLE_QUESTION_1', 'SAMPLE_QUESTION_2', ..., 'SAMPLE_QUESTION_N' to each row.
      - Writes the updated rows to the output CSV.
    """
    rows = read_csv(input_file)
    
    # Build fieldnames: original keys + new SAMPLE_QUESTION_i columns
    if rows:
        fieldnames = list(rows[0].keys())
    else:
        fieldnames = []
    for i in range(1, n + 1):
        fieldnames.append(f"SAMPLE_QUESTION_{i}")
    
    for row in rows:
        battery_id = row.get("TEST_BATTERY_ID", "")
        query_feature = row.get("TEST_BATTERY_QUERY_FEATURE", "")
        sample_questions_text = generate_sample_question(battery_id, query_feature, n)
        
        # Split into lines and clean empty lines
        sample_questions = [s.strip() for s in sample_questions_text.splitlines() if s.strip()]
        
        # Assign each variation to its own column; pad with empty string if needed
        for i in range(n):
            key = f"SAMPLE_QUESTION_{i + 1}"
            row[key] = sample_questions[i] if i < len(sample_questions) else ""
    
    write_csv(rows, fieldnames, output_file)

  llm = ChatOpenAI(
  llm_chain = LLMChain(llm=llm, prompt=prompt_template)


In [8]:
BURN_YOUR_API_KEY_NOW = False  # Set to True to burn your money

if BURN_YOUR_API_KEY_NOW:
    # ===== Notebook usage section =====
    # Set input and output CSV paths and specify the number of sample question variations (n)
    input_csv = "/home/jaf/battery-lifespan-kg/resources/matched_label_output_similarity.csv"   # change to your input CSV file path
    output_csv = "/home/jaf/battery-lifespan-kg/resources/matched_label_with_question.csv" # change to your desired output CSV file path
    n = 3  # Number of sample question variations per row

    # Run the processing function
    process_csv(input_csv, output_csv, n)

    print(f"Processing complete. Output saved to {output_csv}")