In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from sklearn.model_selection import train_test_split

from gptchem.data import get_polymer_data
from gptchem.evaluator import PolymerKLDivBenchmark, polymer_string2performance, string_distances, get_inverse_polymer_metrics
from gptchem.formatter import InverseDesignFormatter

In [3]:
data = get_polymer_data()

In [4]:
formatter = InverseDesignFormatter(
    representation_column="string",
    property_columns=["deltaGmin"],
    property_names=["adsorption_energy"],
)

In [5]:
formatted = formatter(data)

In [6]:
polymer_string2performance(formatted.iloc[0]["label"])

{'monomer_squence': 'W-A-B-W-W-A-A-A-R-W-B-B-R-R-B-R',
 'composition': {'W': 4, 'A': 4, 'B': 4, 'R': 4},
 'smiles': '[W][Ta][Tr][W][W][Ta][Ta][Ta][R][W][Tr][Tr][R][R][Tr][R]',
 'prediction': array([-6.1970377], dtype=float32),
 'features':    head_tail_[W]  head_tail_[Tr]  head_tail_[Ta]  head_tail_[R]  \
 0              1               0               0              1   
 
    total_clusters  num_[W]  max_[W]  min_[W]  mean_[W]  num_[Tr]  ...   [W]  \
 0               4     0.25        2        2       2.0      0.25  ...  0.25   
 
    [Tr]  [Ta]   [R]  rel_shannon  length  total_solvent  std_solvent  \
 0  0.25  0.25  0.25          0.5      16            480     3.535534   
 
    total_surface  std_surface  
 0            400          5.0  
 
 [1 rows x 31 columns]}

In [39]:
kldiv_benchmark = PolymerKLDivBenchmark(data.iloc[80:], 20)

In [40]:
kldiv_benchmark.score(data.iloc[0:40])

0.5348989420975668

In [7]:
?string_distances

[0;31mSignature:[0m [0mstring_distances[0m[0;34m([0m[0mtraining_set[0m[0;34m:[0m [0mCollection[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m,[0m [0mquery_string[0m[0;34m:[0m [0mstr[0m[0;34m)[0m [0;34m->[0m [0mdict[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Calculate the distances between the query string and the training set.

Args:
    training_set (Collection[str]): The training set
    query_string (str): The query string

Returns:
    dict: A dictionary with the distances, the min, max, mean and the expected length

Example:
    >>> training_set = ["AAA", "BBB", "CCC"]
    >>> query_string = "BBB"
    >>> result = string_distances(training_set, query_string)
    assert result["NormalizedLevenshtein_min"] == 0.0
    assert result["NormalizedLevenshtein_max"] == 1.0
[0;31mFile:[0m      ~/git/kjappelbaum/gptchem/src/gptchem/evaluator.py
[0;31mType:[0m      function


In [8]:
string_distances(formatted.iloc[0:10]["label"], formatted.iloc[1]["label"])

{'Levenshtein_min': 0.0,
 'Levenshtein_max': 18.0,
 'Levenshtein_mean': 12.4,
 'Levenshtein_std': 4.543126676640219,
 'NormalizedLevenshtein_min': 0.0,
 'NormalizedLevenshtein_max': 0.5142857142857142,
 'NormalizedLevenshtein_mean': 0.3734562211981567,
 'NormalizedLevenshtein_std': 0.13521010617991372,
 'LongestCommonSubsequence_min': 0.0,
 'LongestCommonSubsequence_max': 26.0,
 'LongestCommonSubsequence_mean': 19.6,
 'LongestCommonSubsequence_std': 6.916646586316233}

In [44]:
metrics = get_inverse_polymer_metrics(
    formatted.iloc[0:50]["label"] + "rfer",
    df_test=formatted.iloc[50:100],
    df_train=formatted.iloc[100:200],
)

In [36]:
metrics.keys()

dict_keys(['composition_mismatches', 'summary_composition_mismatches', 'losses', 'kldiv_score', 'valid_smiles_fraction', 'valid_indices', 'valid_polymers', 'unique_smiles_fraction', 'novel_smiles_fraction', 'generated_sequences', 'predictions'])

In [38]:
metrics['summary_composition_mismatches']

{'min': 0.0, 'max': 0.0, 'mean': 0.0, 'expected_len': 0.0, 'found_len': 0.0}

In [39]:
metrics['losses']

[3.5533907031616163,
 3.2233195198886495,
 4.139636470855033,
 3.439842991328943,
 3.5715008269694,
 4.320588906003822,
 3.6292993269612577,
 4.652228306003829,
 3.4233172461025703,
 2.179405461844887,
 1.7596549320061579,
 1.1487882770331463,
 2.520249638629826,
 2.9433567582519693,
 3.387179476071511,
 2.294553752803542,
 1.4690893938666463,
 1.8402308838602721,
 1.6390341962849915,
 4.1102977641777905,
 2.198821489453124,
 2.1232323561197903,
 1.6224677561197964,
 0.26506864505411354,
 1.1799203102945954,
 4.815044626306145,
 4.785310266236717,
 4.265409357242152,
 4.999765903474243,
 4.638475530416509,
 4.341379093323438,
 3.429638762830262,
 3.1306268734205407,
 4.567679195343022,
 3.625339593174921,
 2.6224203779208715,
 3.6466782944383382,
 2.6124288796569832,
 2.7522656308559874,
 2.474759842627634,
 3.5771372157979293,
 1.9830966799750378,
 0.9942113092020737,
 3.7167064427212146,
 2.6137557488532153,
 3.6028628273050884,
 1.0396825048909504,
 1.621060321158855,
 1.59507601600

In [43]:
metrics['novel_smiles_fraction']

1.0

In [9]:
formatted.iloc[0:10]

Unnamed: 0,prompt,completion,label,representation
0,What is a molecule with adsorption_energy -7.5...,W-A-B-W-W-A-A-A-R-W-B-B-R-R-B-R@@@,W-A-B-W-W-A-A-A-R-W-B-B-R-R-B-R,[-7.535286244444447]
1,What is a molecule with adsorption_energy -7.3...,R-W-W-R-R-B-B-B-A-A-A-W-W-A-R-B@@@,R-W-W-R-R-B-B-B-A-A-A-W-W-A-R-B,[-7.270527222222221]
2,What is a molecule with adsorption_energy -6.4...,A-R-A-W-B-W-A-R-B-W-A-B-B-R-W-R@@@,A-R-A-W-B-W-A-R-B-W-A-B-B-R-W-R,[-6.416311311111111]
3,What is a molecule with adsorption_energy -6.7...,W-A-R-A-B-B-B-W-A-W-B-R-A-W-R-R@@@,W-A-R-A-B-B-B-W-A-W-B-R-A-W-R-R,[-6.684815644444439]
4,What is a molecule with adsorption_energy -6.6...,R-R-B-B-W-R-A-W-R-W-A-B-A-A-W-B@@@,R-R-B-B-W-R-A-W-R-W-A-B-A-A-W-B,[-6.606492355555552]
5,What is a molecule with adsorption_energy -6.2...,W-B-B-A-W-A-W-W-A-R-W-R-B-R-W-B-A-R@@@,W-B-B-A-W-A-W-W-A-R-W-R-B-R-W-B-A-R,[-6.210183644444443]
6,What is a molecule with adsorption_energy -6.6...,W-W-B-W-R-B-R-A-W-B-A-A-W-R-W-B-A-R@@@,W-W-B-W-R-B-R-A-W-B-A-A-W-R-W-B-A-R,[-6.616805355555557]
7,What is a molecule with adsorption_energy -5.6...,B-B-A-B-A-W-R-W-R-W-B-R-W-A-W-W-A-R@@@,B-B-A-B-A-W-R-W-R-W-B-R-W-A-W-W-A-R,[-5.5528067555555545]
8,What is a molecule with adsorption_energy -6.4...,R-W-R-A-R-A-B-W-W-W-B-B-W-W-A-B-A-R@@@,R-W-R-A-R-A-B-W-W-W-B-B-W-W-A-B-A-R,[-6.361648266666666]
9,What is a molecule with adsorption_energy -6.6...,B-B-W-W-B-W-R-A-A-B-A-W-A-R-R-R-W-W@@@,B-B-W-W-B-W-R-A-A-B-A-W-A-R-R-R-W-W,[-6.6253376444444445]
