In [1]:
import difflib
import re
import pandas as pd
import os
import csv
from datetime import datetime

## Comparison of strings 
This tool can be used for comparing the extracted odes for the same model, but with different prompts or architectures.

> s1 is the original string, s2 is the extracted

In [2]:
version = "002" # manually type the version here

data_path = '.'  
df_correct = pd.read_csv(os.path.join(data_path, 'correct_eqs_list.tsv'), sep='\t')
df_extracted = pd.read_csv(os.path.join(data_path, f'extracted_eqs_VERSION{version}.tsv'), sep='\t')

In [4]:
def tokenize(s):
    return re.split('\s+', s)
def untokenize(ts):
    return ' '.join(ts)
        
def equalize(s1, s2):
    l1 = tokenize(s1)
    l2 = tokenize(s2)
    res1 = []
    res2 = []
    prev = difflib.Match(0,0,0)
    for match in difflib.SequenceMatcher(a=l1, b=l2).get_matching_blocks():
        if (prev.a + prev.size != match.a):
            for i in range(prev.a + prev.size, match.a):
                res2 += ['_' * len(l1[i])]
            res1 += l1[prev.a + prev.size:match.a]
        if (prev.b + prev.size != match.b):
            for i in range(prev.b + prev.size, match.b):
                res1 += ['_' * len(l2[i])]
            res2 += l2[prev.b + prev.size:match.b]
        res1 += l1[match.a:match.a+match.size]
        res2 += l2[match.b:match.b+match.size]
        prev = match
    return untokenize(res1), untokenize(res2)

def insert_newlines(string, every=64, window=10):
    result = []
    from_string = string
    while len(from_string) > 0:
        cut_off = every
        if len(from_string) > every:
            while (from_string[cut_off-1] != ' ') and (cut_off > (every-window)):
                cut_off -= 1
        else:
            cut_off = len(from_string)
        part = from_string[:cut_off]
        result += [part]
        from_string = from_string[cut_off:]
    return result

def show_comparison(s1, s2, width=40, margin=10, sidebyside=True, compact=False):
    s1, s2 = equalize(s1,s2)

    if sidebyside:
        s1 = insert_newlines(s1, width, margin)
        s2 = insert_newlines(s2, width, margin)
        if compact:
            for i in range(0, len(s1)):
                lft = re.sub(' +', ' ', s1[i].replace('_', '')).ljust(width)
                rgt = re.sub(' +', ' ', s2[i].replace('_', '')).ljust(width) 
                print(lft + ' | ' + rgt + ' | ')        
        else:
            for i in range(0, len(s1)):
                lft = s1[i].ljust(width)
                rgt = s2[i].ljust(width)
                print(lft + ' | ' + rgt + ' | ')
    else:
        print(s1)
        print(s2)

def find_errors(s1, s2):
    errors = []
    
    tokens1 = tokenize(s1)
    tokens2 = tokenize(s2)
    
    missing = set(tokens1) - set(tokens2)           #missing tokens
    for token in missing:
        errors.append(f"Missing token: '{token}'")
    
    extra = set(tokens2) - set(tokens1)         #extra token insertations
    for token in extra:
        errors.append(f"Extra token: '{token}'")
    
    matcher = difflib.SequenceMatcher(None, tokens1, tokens2)
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == 'replace':
            for k in range(i2 - i1):
                if i1 + k < len(tokens1) and j1 + k < len(tokens2):
                    errors.append(f"Changed: '{tokens1[i1 + k]}' → '{tokens2[j1 + k]}'")
    
    return errors

def count_mismatches(eq1, eq2):         # counting underscores
    mismatches = 0
    mismatches += eq1.count('_')
    mismatches += eq2.count('_')
    return mismatches

def normalize_equation(s):      # linebreaks and spaces are not errors
    s = re.sub(r'\s+', ' ', s)
    s = re.sub(r'\s*,\s*', ', ', s)
    s = re.sub(r'\s*\[\s*', '[', s)
    s = re.sub(r'\s*\]\s*', ']', s)
    s = re.sub(r'\s*([+\-*/=()<>,])\s*', r'\1', s)
    s = re.sub(r'([+\-*/=])', r' \1 ', s)
    return s.strip()

  return re.split('\s+', s)


In [5]:
output_file= f'str_comparison_results_version{version}.txt'

total_models = 0
total_mismatches = 0
models_with_errors = []

with open(output_file, 'w') as f:
        for idx, row in df_correct.iterrows():
            model_name = row['model']
            s1 = str(row['correct_eqs'])
            s1 = normalize_equation(s1)

            matching = df_extracted[df_extracted['model'] == model_name]
            if not matching.empty:
                total_models += 1
                s2 = str(matching.iloc[0]['extracted_eqs']) #there shouldn't be multiple rows for a model, but iloc[0] selects the first one if it happens
                s2 = normalize_equation(s2)

                f.write(f"\n{'='*80}\n")
                f.write(f"Model: {model_name}\n")
                f.write(f"{'='*80}\n")
                f.write("Correct equation:\n")
                f.write(s1 + "\n\n")
                f.write("Extracted equation:\n")
                f.write(s2 + "\n\n")
                f.write('Above-below comparison\n')
                f.write('-'*80 + '\n')

                eq1, eq2 = equalize(s1, s2)
                f.write(eq1 + '\n')
                f.write(eq2 + '\n')
                f.write('\n')

                mismatch_count = count_mismatches(eq1, eq2)
                f.write(f"\nNumber of mismatches: {mismatch_count}\n")
                total_mismatches += mismatch_count

                errors = find_errors(s1, s2)
                if errors:
                    models_with_errors.append((model_name, len(errors)))
                    f.write(f"\nErrors found ({len(errors)}):\n")
                    for error in errors:
                        f.write(f"  - {error}\n")
                else:
                    f.write("\nNo token-level errors found (equations may differ only in spacing/formatting)\n")

                print(f"\nComparing model: {model_name}")
                print('Above-below comparison')
                print('-'*80)
                show_comparison(s1, s2, sidebyside=False)
            else:
                f.write(f"\nModel {model_name} not found in extracted equations!\n")    

with open(output_file, 'a') as f:
    f.write(f"\n\n{'='*80}\n")
    f.write("SUMMARY\n")
    f.write(f"{'='*80}\n")
    f.write(f"Total models compared: {total_models}\n")
    f.write(f"Total mismatches found: {total_mismatches}\n")
    f.write(f"Models with errors: {len(models_with_errors)}\n\n")
    
    if models_with_errors:
        f.write("Models with errors (model_name, error_count):\n")
        for model, count in models_with_errors:
            f.write(f"  - {model}: {count} errors\n")

print(f"\nResults saved to: {output_file}")


Comparing model: BIOMD0000000955
Above-below comparison
--------------------------------------------------------------------------------
odes = [sympy.Eq(S(t).diff(t), _______ _____ _____ ___ - S(t) * (alpha * I(t) + beta * D(t) + gamma * A(t) + delta * R(t))),sympy.Eq(I(t).diff(t),S(t) * (alpha * I(t) + beta * D(t) + gamma * A(t) + delta * R(t)) - (epsilon + zeta + lambda_) * I(t)),sympy.Eq(D(t).diff(t),epsilon * I(t) - (eta + rho) * D(t)),sympy.Eq(A(t).diff(t),zeta * I(t) - (theta + mu + kappa) * A(t)),sympy.Eq(R(t).diff(t),eta * D(t) + theta * A(t) - (nu + xi) * R(t)),sympy.Eq(T(t).diff(t),mu * A(t) + nu * R(t) - (sigma + tau) * T(t)),sympy.Eq(H(t).diff(t),lambda_ * I(t) + rho * D(t) + kappa * A(t) + xi * R(t) + sigma * T(t)),sympy.Eq(E(t).diff(t),tau * T(t))] _________ ___________ ____ ________ ____ _______ ____________ _____ ____ ____ ______
____ _ _______________________ FAILED: Error code: 429 - ____ _ ______ _ ____ _ ____ _ ____ _ _____ _ ____ _ _____ _ _______________________

In [None]:
#if you want to make only on comaprison, do it manually like this:

# original_str = """odes = [
#     sympy.Eq(S(t).diff(t), - beta * S(t) * I(t) / N),
#     sympy.Eq(E(t).diff(t), beta * S(t) * I(t) / N - alpha * E(t)),
#     sympy.Eq(I(t).diff(t), alpha * E(t) - gamma * I(t)),
#     sympy.Eq(R(t).diff(t), gamma * I(t))
# ]"""

# extracted_str = """odes = [
#     sympy.Eq(S(t).diff(t), - beta * S(t) / N),
#     sympy.Eq(I(t).diff(t), beta * S(t) * I(t) / N - alpha * E(t)),
#     sympy.Eq(I(t).diff(t), alpha * E(t) - gamma_2 * I(t)),
#     sympy.Eq(R(t).diff(t), gamma * I(t))
# ]"""

# print(original_str)
# print(extracted_str)

# print()
# print('Above-below comparison')
# print('-------------------------------------------------------------------------------------')
# show_comparison(original_str, extracted_str, sidebyside=False)