In [1]:
# | default_exp output


In [61]:
# |export
import math
import re
from collections import Counter, defaultdict
from typing import Iterable, List, Optional, Tuple
import os
import joblib
import numpy as np
import pandas as pd
from nbdev.showdoc import *
from rdkit import Chem, DataStructs
from rdkit.Chem.Fingerprints import FingerprintMols
from sklearn.metrics import (max_error, mean_absolute_error,
                             mean_squared_error, r2_score)
from strsimpy.levenshtein import Levenshtein
from strsimpy.longest_common_subsequence import LongestCommonSubsequence
from strsimpy.normalized_levenshtein import NormalizedLevenshtein

from gpt3forchem.api_wrappers import extract_inverse_prediction, query_gpt3
from gpt3forchem.baselines import compute_fragprints
from gpt3forchem.data import POLYMER_FEATURES
from gpt3forchem.input import encode_categorical_value

from tqdm import tqdm
from loguru import logger

# Analyzing results

> Analyze the outputs of the models


To measure how different our outputs are from the input data, we'll use string distances.


In [3]:
# | export 

_DEFAULT_AGGREGATIONS =  [
        ("min", lambda x: np.min(x)),
        ("max", lambda x: np.max(x)),
        ("mean", lambda x: np.mean(x)),
        ("std", lambda x: np.std(x)),
    ]

def aggregate_array(array, aggregations: Optional[List[Tuple[str, callable]]]= None): 
    if aggregations is None:
        aggregations = _DEFAULT_AGGREGATIONS

    aggregated_array = {}
    for k,v in aggregations:
        aggregated_array[k] = v(array)
    return aggregated_array

In [4]:
aggregate_array(np.array([1,2,3,4,5]), aggregations=[("mean", lambda x: np.mean(x))])

{'mean': 3.0}

If no aggregation functions are specified, the default aggregation functions are used.

In [5]:
aggregate_array(np.array([1,2,3,4,5]))

{'min': 1, 'max': 5, 'mean': 3.0, 'std': 1.4142135623730951}

In [6]:
# |export
def string_distances(
    training_set: Iterable[str], # string representations of the compounds in the training set
    query_string: str # string representation of the compound to be queried
):

    distances = defaultdict(list)

    metrics = [
        ("Levenshtein", Levenshtein()),
        ("NormalizedLevenshtein", NormalizedLevenshtein()),
        ("LongestCommonSubsequence", LongestCommonSubsequence()),
    ]

    aggregations = [
        ("min", lambda x: np.min(x)),
        ("max", lambda x: np.max(x)),
        ("mean", lambda x: np.mean(x)),
        ("std", lambda x: np.std(x)),
    ]

    for training_string in training_set:
        for metric_name, metric in metrics:
            distances[metric_name].append(
                metric.distance(training_string, query_string)
            )

    aggregated_distances = {}

    for k, v in distances.items():
        for agg_name, agg_func in aggregations:
            aggregated_distances[f"{k}_{agg_name}"] = agg_func(v)

    return aggregated_distances


In [7]:
# |hide
training_set = ["AAA", "BBB", "CCC"]
query_string = "BBB"
result = string_distances(training_set, query_string)

assert result["NormalizedLevenshtein_min"] == 0.0
assert result["NormalizedLevenshtein_max"] == 1.0


In [None]:
cm.overall_stat

{'Overall ACC': 0.352,
 'Overall RACCU': 0.260694,
 'Overall RACC': 0.204088,
 'Kappa': 0.18583964056327834,
 'Gwet AC1': 0.205077201356521,
 'Bennett S': 0.18999999999999995,
 'Kappa Standard Error': 0.026835443670701672,
 'Kappa Unbiased': 0.12350231162739109,
 'Scott PI': 0.12350231162739109,
 'Kappa No Prevalence': -0.29600000000000004,
 'Kappa 95% CI': (0.13324217096870306, 0.23843711015785363),
 'Standard Error': 0.02135865164283551,
 '95% CI': (0.3101370427800424, 0.39386295721995757),
 'Chi-Squared': 'None',
 'Phi-Squared': 'None',
 'Cramer V': 'None',
 'Response Entropy': 1.3814056651434996,
 'Reference Entropy': 2.316058449955823,
 'Cross Entropy': 1.1144477741367746,
 'Joint Entropy': 3.4137961407287363,
 'Conditional Entropy': 1.0977376907729133,
 'Mutual Information': 0.28366797437058633,
 'KL Divergence': 'None',
 'Lambda B': 0.1941747572815534,
 'Lambda A': 0.1906005221932115,
 'Chi-Squared DF': 16,
 'Overall J': (0.7983837510803802, 0.15967675021607602),
 'Hamming Loss'

In [8]:
# |export

def is_valid_smiles(smiles: str) -> bool:
    """We say a SMILES is valid if RDKit can parse it."""
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False
        return True
    except:
        return False

In [9]:
is_valid_smiles('aba')

[20:14:37] SMILES Parse Error: syntax error while parsing: aba
[20:14:37] SMILES Parse Error: Failed parsing SMILES 'aba' for input: 'aba'


False

In [10]:
is_valid_smiles("CCC")

True

In [11]:
# |export
def is_string_in_training_data(string: str, training_data: Iterable[str]) -> bool:
    """Check if a string is in the training data.
    
    Note that this is not an exact check of a molecule is in the training data 
    as the model might in principle generate an equivalent, non-canonical SMILES.
    However, one might expect that if a model remembers the training data
    it will simple remember the canonical SMILES.
    """
    return string in training_data

In [12]:
is_string_in_training_data('a a hahah', ['a', 'b', 'c'])

False

In [13]:
is_string_in_training_data('a a hahah', ['a', 'b', 'c', 'a a hahah'])

True

In [14]:
# | export

def get_similarity_to_train_mols(smiles: str, train_smiles: List[str]) -> List[float]: 
    train_mols = [Chem.MolFromSmiles(x) for x in train_smiles]
    mol = Chem.MolFromSmiles(smiles)

    train_fps = [Chem.RDKFingerprint(x) for x in train_mols]
    fp = Chem.RDKFingerprint(mol)

    s = DataStructs.BulkTanimotoSimilarity(fp, train_fps)
    return s


In [15]:
get_similarity_to_train_mols('CCC', ['CCC', 'CCCNC', "N"])

[1.0, 0.25, 0.0]

In [16]:
aggregate_array(get_similarity_to_train_mols('CCC', ['CCC', 'CCCNC']))

{'min': 0.25, 'max': 1.0, 'mean': 0.625, 'std': 0.375}

In [17]:
#| export 

def extract_numeric_prediction(predictions: List[str], is_int: bool = True):
    converter = int if is_int else float
    converted = []
    for p in predictions:
        try:
            converted.append(converter(p))
        except:
            converted.append(np.nan)
    return converted

In [18]:
extract_numeric_prediction(['1', '2', '3', '4', '5', 'nknik'])

[1, 2, 3, 4, 5, nan]

In [42]:
# | export 

def get_continuos_binned_distance(prediction, bin, bins):
    in_bin = (prediction >= bins[bin][0]) & (prediction < bins[bin][1])
    if in_bin:
        loss = 0
    else:
        # compute the minimum distance to bin
        left_edge_distance = abs(prediction - bins[bin][0])
        right_edge_distance = abs(prediction - bins[bin][1])
        loss = min(left_edge_distance, right_edge_distance)
    return loss


## Polymers

> Code specific for the polymer test case


In [19]:
# |export


def convert2smiles(string):
    new_encoding = {"A": "[Ta]", "B": "[Tr]", "W": "[W]", "R": "[R]"}

    for k, v in new_encoding.items():
        string = string.replace(k, v)

    string = string.replace("-", "")

    return string


To train the model, we simply use single letters, without any special characters such as brackets.


In [20]:
convert2smiles("AWWRRA")


'[Ta][W][W][R][R][Ta]'

To get the composition from the prompt, we will check how often we find a given monomer in the string.


In [67]:
# |export
def get_num_monomer(string, monomer):
    num = re.findall(f"([\d]+) {monomer}", string)
    try:
        num = int(num[0])
    except Exception:
        num = 0
    return num


In [68]:
get_num_monomer("Polymer with 3 A, 5 B and 0 C", "A")


3

In [71]:
assert get_num_monomer('what is a polymer with large adsorption energy and 4 A, 4 B, 12 W, and 12 R?###', 'R') == 12

In [69]:
# |export
def get_prompt_compostion(prompt):
    composition = {}

    for monomer in ["R", "W", "A", "B"]:
        composition[monomer] = get_num_monomer(prompt, monomer)

    return composition


In [70]:
get_prompt_compostion('what is a polymer with large adsorption energy and 4 A, 4 B, 12 W, and 12 R?###')

{'R': 12, 'W': 12, 'A': 4, 'B': 4}

In [45]:
# |export


def get_target(string, target_name="adsorption", numerically_encoded=True):
    if numerically_encoded:
        num = re.findall(f"([\d+]) {target_name}", string)
        return int(num[0])
    else:
        val = re.findall(f"(very large|large|medium|small|very small) {target_name}", string)
        return val[0]

In [47]:
get_target('what is a polymer with large adsorption energy', numerically_encoded=False)

'large'

In [48]:
get_target('what is a polymer with very large adsorption energy', numerically_encoded=False)

'very large'

In [53]:
assert get_target('what is a polymer with very small adsorption energy', numerically_encoded=False) == 'very small'
assert get_target('what is a polymer with large adsorption energy and 8 A, 12 B, 12 W, and 8 R?###', numerically_encoded=False) == 'large'

In [55]:
# |export


def get_polymer_prompt_data(prompt, numerically_encoded=False):
    composition = get_prompt_compostion(prompt)

    return composition, get_target(prompt, numerically_encoded=numerically_encoded)


In [58]:
get_polymer_prompt_data('what is a polymer with 3 adsorption energy and 8 A, 8 B, 10 W, and 8 R?###', numerically_encoded=True)

({'R': 8, 'W': 0, 'A': 8, 'B': 8}, 3)

In [57]:
get_polymer_prompt_data('what is a polymer with large adsorption energy and 8 A, 12 B, 12 W, and 8 R?###', numerically_encoded=False) 

({'R': 8, 'W': 2, 'A': 8, 'B': 2}, 'large')

In [26]:
# |export


def get_polymer_completion_composition(string):
    parts = string.split("-")
    counts = Counter(parts)
    return dict(counts)


In [27]:
get_polymer_completion_composition('W-W-R-W-R-W-A-W-R-B-W-A-R-B-W-A-R-B-W-A-R-B-W-R-A-B-W-B-R-A-B')

{'W': 10, 'R': 8, 'A': 6, 'B': 7}

In [64]:
get_polymer_completion_composition('W-R-W-R-R-W-R-W-R-W-R-W-R-W-A-R-W-R-W-A-R-B-W-B-A-R-B-W-B-W-R-A-B')

{'W': 12, 'R': 12, 'A': 4, 'B': 5}

In [28]:
assert get_polymer_completion_composition('W-W') == {'W': 2}
assert get_polymer_completion_composition('W-R-A-B') == {'W': 1, 'R': 1, 'A': 1, 'B': 1}

Next, let's reuse the featurizer [we used in the original work on those polymers](https://www.nature.com/articles/s41467-021-22437-0).

In [29]:
# |export

# Copyright 2020 PyPAL authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Turn a Polymer SMILES into features"""


def featurize_many_polymers(smiless: list) -> pd.DataFrame:
    """Utility function that runs featurizaton on a
    list of linear polymer smiles and returns a dataframe"""
    features = []
    for smiles in smiless:
        pmsf = LinearPolymerSmilesFeaturizer(smiles)
        features.append(pmsf.featurize())
    return pd.DataFrame(features)


class LinearPolymerSmilesFeaturizer:
    """Compute features for linear polymers"""

    def __init__(self, smiles: str, normalized_cluster_stats: bool = True):
        self.smiles = smiles
        assert "(" not in smiles, "This featurizer does not work for branched polymers"
        self.characters = ["[W]", "[Tr]", "[Ta]", "[R]"]
        self.replacement_dict = dict(
            list(zip(self.characters, ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"]))
        )
        self.normalized_cluster_stats = normalized_cluster_stats
        self.surface_interactions = {"[W]": 30, "[Ta]": 20, "[Tr]": 30, "[R]": 20}
        self.solvent_interactions = {"[W]": 30, "[Ta]": 25, "[Tr]": 35, "[R]": 30}
        self._character_count = None
        self._balance = None
        self._relative_shannon = None
        self._cluster_stats = None
        self._head_tail_feat = None
        self.features = None

    @staticmethod
    def get_head_tail_features(string: str, characters: list) -> dict:
        """0/1/2 encoded feature indicating if the building block is at start/end of the polymer chain"""
        is_head_tail = [0] * len(characters)

        for i, char in enumerate(characters):
            if string.startswith(char):
                is_head_tail[i] += 1
            if string.endswith(char):
                is_head_tail[i] += 1

        new_keys = ["head_tail_" + char for char in characters]
        return dict(list(zip(new_keys, is_head_tail)))

    @staticmethod
    def get_cluster_stats(
        s: str, replacement_dict: dict, normalized: bool = True
    ) -> dict:  # pylint:disable=invalid-name
        """Statistics describing clusters such as [Tr][Tr][Tr]"""
        clusters = LinearPolymerSmilesFeaturizer.find_clusters(s, replacement_dict)
        cluster_stats = {}
        cluster_stats["total_clusters"] = 0
        for key, value in clusters.items():
            if value:
                cluster_stats["num" + "_" + key] = len(value)
                cluster_stats["total_clusters"] += len(value)
                cluster_stats["max" + "_" + key] = max(value)
                cluster_stats["min" + "_" + key] = min(value)
                cluster_stats["mean" + "_" + key] = np.mean(value)
            else:
                cluster_stats["num" + "_" + key] = 0
                cluster_stats["max" + "_" + key] = 0
                cluster_stats["min" + "_" + key] = 0
                cluster_stats["mean" + "_" + key] = 0

        if normalized:
            for key, value in cluster_stats.items():
                if "num" in key:
                    try:
                        cluster_stats[key] = value / cluster_stats["total_clusters"]
                    except ZeroDivisionError:
                        cluster_stats[key] = 0

        return cluster_stats

    @staticmethod
    def find_clusters(s: str, replacement_dict: dict) -> dict:  # pylint:disable=invalid-name
        """Use regex to find clusters"""
        clusters = re.findall(
            r"((\w)\2{1,})", LinearPolymerSmilesFeaturizer._multiple_replace(s, replacement_dict)
        )
        cluster_dict = dict(
            list(zip(replacement_dict.keys(), [[] for i in replacement_dict.keys()]))
        )
        inv_replacement_dict = {v: k for k, v in replacement_dict.items()}
        for cluster, character in clusters:
            cluster_dict[inv_replacement_dict[character]].append(len(cluster))

        return cluster_dict

    @staticmethod
    def _multiple_replace(s: str, replacement_dict: dict) -> str:  # pylint:disable=invalid-name
        for word in replacement_dict:
            s = s.replace(word, replacement_dict[word])
        return s

    @staticmethod
    def get_counts(smiles: str, characters: list) -> dict:
        """Count characters in SMILES string"""
        counts = [smiles.count(char) for char in characters]
        return dict(list(zip(characters, counts)))

    @staticmethod
    def get_relative_shannon(character_count: dict) -> float:
        """Shannon entropy of string relative to maximum entropy of a string of the same length"""
        counts = [c for c in character_count.values() if c > 0]
        length = sum(counts)
        probs = [count / length for count in counts]
        ideal_entropy = LinearPolymerSmilesFeaturizer._entropy_max(length)
        entropy = -sum([p * math.log(p) / math.log(2.0) for p in probs])

        return entropy / ideal_entropy

    @staticmethod
    def _entropy_max(length: int) -> float:
        "Calculates the max Shannon entropy of a string with given length"

        prob = 1.0 / length

        return -1.0 * length * prob * math.log(prob) / math.log(2.0)

    @staticmethod
    def get_balance(character_count: dict) -> dict:
        """Frequencies of characters"""
        counts = list(character_count.values())
        length = sum(counts)
        frequencies = [c / length for c in counts]
        return dict(list(zip(character_count.keys(), frequencies)))

    def _featurize(self):
        """Run all available featurization methods"""
        self._character_count = LinearPolymerSmilesFeaturizer.get_counts(
            self.smiles, self.characters
        )
        self._balance = LinearPolymerSmilesFeaturizer.get_balance(self._character_count)
        self._relative_shannon = LinearPolymerSmilesFeaturizer.get_relative_shannon(
            self._character_count
        )
        self._cluster_stats = LinearPolymerSmilesFeaturizer.get_cluster_stats(
            self.smiles, self.replacement_dict, self.normalized_cluster_stats
        )
        self._head_tail_feat = LinearPolymerSmilesFeaturizer.get_head_tail_features(
            self.smiles, self.characters
        )

        self.features = self._head_tail_feat
        self.features.update(self._cluster_stats)
        self.features.update(self._balance)
        self.features["rel_shannon"] = self._relative_shannon
        self.features["length"] = sum(self._character_count.values())
        solvent_interactions = sum(
            [
                [self.solvent_interactions[char]] * count
                for char, count in self._character_count.items()
            ],
            [],
        )
        self.features["total_solvent"] = sum(solvent_interactions)
        self.features["std_solvent"] = np.std(solvent_interactions)
        surface_interactions = sum(
            [
                [self.surface_interactions[char]] * count
                for char, count in self._character_count.items()
            ],
            [],
        )
        self.features["total_surface"] = sum(surface_interactions)
        self.features["std_surface"] = np.std(surface_interactions)

    def featurize(self) -> dict:
        """Run featurization"""
        self._featurize()
        return self.features


To evaluate the new polymers our model came up with, we have to build a model that predicts the target (the adsorption energy). We have done so in a separate module (`27_polymer_delta_g_model.ipynb`) and saved the output in the `models` directory.

In [30]:
# |export

def polymer_string2performance(string, model_dir = '../models'):
    # we need to perform a bunch of tasks here:
    # 1) Featurize
    # 2) Query the model

    DELTA_G_MODEL = joblib.load(os.path.join(model_dir, 'delta_g_model.joblib'))

    predicted_monomer_sequence = string.split("@")[0].strip()
    monomer_sq = re.findall("[(R|W|A|B)\-(R|W|A|B)]+", predicted_monomer_sequence)[0]
    composition = get_polymer_completion_composition(monomer_sq)
    smiles = convert2smiles(predicted_monomer_sequence)

    features = pd.DataFrame(featurize_many_polymers([smiles]))
    prediction = DELTA_G_MODEL.predict(features[POLYMER_FEATURES])
    return {
        "monomer_squence": monomer_sq,
        "composition": composition,
        "smiles": smiles,
        "prediction": prediction,
    }


In [35]:
polymer_string2performance('W-W-R-W-R-W-A-W-R-B-W-A-R-B-W-A-R-B-W-A-R-B-W-R-A-B-W-B-R-A-B')

{'monomer_squence': 'W-W-R-W-R-W-A-W-R-B-W-A-R-B-W-A-R-B-W-A-R-B-W-R-A-B-W-B-R-A-B',
 'composition': {'W': 10, 'R': 8, 'A': 6, 'B': 7},
 'smiles': '[W][W][R][W][R][W][Ta][W][R][Tr][W][Ta][R][Tr][W][Ta][R][Tr][W][Ta][R][Tr][W][R][Ta][Tr][W][Tr][R][Ta][Tr]',
 'prediction': array([-7.335449], dtype=float32)}

In [32]:
# |export


def composition_mismatch(composition: dict, found: dict):
    distances = []

    # We also might have the case the there are keys that the input did not contain
    all_keys = set(composition.keys()) & set(found.keys())

    expected_len = []
    found_len = []

    for key in all_keys:
        try:
            expected = composition[key]
        except KeyError:
            expected = 0
        expected_len.append(expected)
        try:
            f = found[key]
        except KeyError:
            f = 0
        found_len.append(f)

        distances.append(np.abs(expected - f))

    expected_len = sum(expected_len)
    found_len = sum(found_len)
    return {
        "distances": distances,
        "min": np.min(distances),
        "max": np.max(distances),
        "mean": np.mean(distances),
        "expected_len": expected_len,
        "found_len": found_len,
    }


In [65]:
composition_mismatch(
    {"A": 4, "B": 4, "R": 12, "W": 12},
    {'W': 12, 'R': 12, 'A': 4, 'B': 5}
)

{'distances': [0, 0, 0, 1],
 'min': 0,
 'max': 1,
 'mean': 0.25,
 'expected_len': 32,
 'found_len': 33}

In [50]:
# | export

def get_inverse_polymer_metrics(completion_texts, df_test, df_train, bins, max_num_train_sequences = 2000, numerically_encoded = False):
    losses = []
    composition_mismatches = []

    train_sequences = [polymer_string2performance(seq)["monomer_squence"] for seq in df_train["completion"]]
    print(f"Using {len(train_sequences)} training sequences")
    for i, row in tqdm(df_test.iterrows(), total=len(completion_texts)):
        if i < len(completion_texts):
            try:
                composition, bin = get_polymer_prompt_data(row["prompt"], numerically_encoded = numerically_encoded)
                completion_data = polymer_string2performance(completion_texts[i])
                bin = bin if numerically_encoded else encode_categorical_value(bin)
                loss = get_continuos_binned_distance(completion_data["prediction"][0], bin, bins)
                losses.append(loss)

                mm = composition_mismatch(composition, completion_data["composition"])
 
                distances = string_distances(
                    train_sequences[:max_num_train_sequences], completion_data["monomer_squence"]
                )
                mm.update(completion_data)
                mm.update(distances)
                mm.update({"loss": loss})
                composition_mismatches.append(mm)
            except Exception as e:
                logger.exception(f'Error in get_inverse_polymer_metrics {e}')
    return losses, pd.DataFrame(composition_mismatches)


In [33]:
# |export


def get_regression_metrics(
    y_true,  # actual values (ArrayLike)
    y_pred,  # predicted values (ArrayLike)
) -> dict:

    try:
        return {
            "r2": r2_score(y_true, y_pred),
            "max_error": max_error(y_true, y_pred),
            "mean_absolute_error": mean_absolute_error(y_true, y_pred),
            "mean_squared_error": mean_squared_error(y_true, y_pred),
        }
    except Exception:
        return {
            "r2": np.nan,
            "max_error": np.nan,
            "mean_absolute_error": np.nan,
            "mean_squared_error": np.nan,
        }


In [34]:
get_regression_metrics([1, 2, 3, 4, 5], [1, 2, 3, 4, 5])


{'r2': 1.0,
 'max_error': 0,
 'mean_absolute_error': 0.0,
 'mean_squared_error': 0.0}

## Photoswitches

Code specific for the photoswitch case study.


First, we'll have some wrapper around GPR models that predict for us the $\pi-\pi^*$ and $n-\pi^*$ transition energies. 
For simplicity, we'll just go via joblib files.

In [None]:
# |export
def _predict_photoswitch(smiles_string: str,pi_pi_star_model_file='../models/pi_pi_star_model.joblib', n_pi_star_model_file='../models/n_pi_star_model.joblib'):
    """Predicting for a single SMILES string. Not really efficient due to the I/O overhead in loading the model."""
    pi_pi_star_model = joblib.load(pi_pi_star_model_file)
    n_pi_star_model = joblib.load(n_pi_star_model_file)
    fragprints = compute_fragprints([smiles_string])
    return pi_pi_star_model.predict(fragprints)[0], n_pi_star_model.predict(fragprints)[0]

In [None]:
# |export
def predict_photoswitch(smiles: Iterable[str], pi_pi_star_model_file='../models/pi_pi_star_model.joblib', n_pi_star_model_file='../models/n_pi_star_model.joblib'): 
    """Predicting for a single SMILES string. Not really efficient due to the I/O overhead in loading the model."""
    if not isinstance(smiles, Iterable):
        smiles = [smiles]
    pi_pi_star_model = joblib.load(pi_pi_star_model_file)
    n_pi_star_model = joblib.load(n_pi_star_model_file)
    fragprints = compute_fragprints(smiles)
    return pi_pi_star_model.predict(fragprints), n_pi_star_model.predict(fragprints)

In [None]:
predict_photoswitch(['C1=CC=C(/N=N/C2=CC=C(NCCC#N)C=C2)C=C1'])

(array([[390.91004025]]), array([[446.54990223]]))

In [None]:
# |  export

_PI_PI_STAR_REGEX = r'pi-pi\* transition wavelength of ([.\d]+) nm'
_N_PI_STAR_REGEX = r'n-pi\* transition wavelength of ([.\d]+) nm'

def get_expected_wavelengths(prompt): 
    pi_pi_star_match = re.search(_PI_PI_STAR_REGEX, prompt)
    n_pi_star_match = re.search(_N_PI_STAR_REGEX, prompt)
    pi_pi_star = float(pi_pi_star_match.group(1)) if pi_pi_star_match else None
    n_pi_star = float(n_pi_star_match.group(1)) if n_pi_star_match else None
    return pi_pi_star, n_pi_star

In [None]:
get_expected_wavelengths('What is a molecule pi-pi* transition wavelength of 404.0 nm###')

(404.0, None)

In [None]:
get_expected_wavelengths('What is a molecule pi-pi* transition wavelength of 321.0 nm and n-pi* transition wavelength of 424.0 nm###')

(321.0, 424.0)

The full evaluation is wrapped upped in the function below. Note that sampling at temperature > 0 is associated with some randomness. In other works, people samples $k$ times and took the best prediction for analysis. In the function below, we do not do this; we only sample once.

Query multiple times to estimate the variance.

In [None]:
# | export
def test_inverse_photoswitch(
    prompt_frame, model, train_smiles, temperature, max_tokens: int = 80
):
    completions = query_gpt3(
        model, prompt_frame, max_tokens=max_tokens, temperature=temperature
    )
    predictions = np.array(
        [
            extract_inverse_prediction(completions, i)
            for i in range(len(completions["choices"]))
        ]
    )

    valid_smiles = [is_valid_smiles(smiles) for smiles in predictions]

    smiles_in_train = [
        is_string_in_training_data(smiles, train_smiles)
        for smiles in predictions[valid_smiles]
    ]

    expected_pi_pi_star, expected_n_pi_star = [], []

    for prompt in prompt_frame["prompt"].values:
        pi_pi_star, n_pi_star = get_expected_wavelengths(prompt)
        expected_pi_pi_star.append(pi_pi_star)
        expected_n_pi_star.append(n_pi_star)

    expected_pi_pi_star = np.array(expected_pi_pi_star)
    expected_n_pi_star = np.array(expected_n_pi_star)

    has_expected_n_pi_star = np.array(
        [n_pi_star is not None for n_pi_star in expected_n_pi_star]
    )

    try:
        predicted_pi_pi_star, predicted_n_pi_star = predict_photoswitch(
            predictions[valid_smiles]
        )

        predicted_pi_pi_star = predicted_pi_pi_star.flatten()
        predicted_n_pi_star = predicted_n_pi_star.flatten()

        pi_pi_star_metrics = get_regression_metrics(
            expected_pi_pi_star[valid_smiles],
            predicted_pi_pi_star,
        )

        mask_n_valid_smiles = [
            has_expected_n_pi_star[i] for i in range(len(valid_smiles)) if valid_smiles[i]
        ]
        n_pi_star_metrics = get_regression_metrics(
            expected_n_pi_star[valid_smiles & has_expected_n_pi_star],
            np.array(predicted_n_pi_star)[mask_n_valid_smiles],
        )

        error_pi_pi_star = np.abs(expected_pi_pi_star[valid_smiles] - predicted_pi_pi_star)
        error_n_pi_star = np.abs(
            expected_n_pi_star[valid_smiles & has_expected_n_pi_star]
            - np.array(predicted_n_pi_star)[mask_n_valid_smiles]
        )

        min_error_pi_pi_star = predictions[valid_smiles][np.argmin(error_pi_pi_star)]
        min_error_n_pi_star = predictions[valid_smiles & has_expected_n_pi_star][np.argmin(error_n_pi_star)]

        max_error_pi_pi_star = predictions[valid_smiles][np.argmax(error_pi_pi_star)]
        max_error_n_pi_star = predictions[valid_smiles & has_expected_n_pi_star][np.argmax(error_n_pi_star)]

        error_pi_pi_star_w_n = np.abs(
            expected_pi_pi_star[valid_smiles & has_expected_n_pi_star]
            - np.array(predicted_pi_pi_star)[mask_n_valid_smiles]
        )

        total_error_pi_pi_star = error_n_pi_star + error_pi_pi_star_w_n
        min_total_error_pi_pi_star = predictions[valid_smiles & has_expected_n_pi_star][
            np.argmin(total_error_pi_pi_star)
        ]
        max_total_error_pi_pi_star = predictions[valid_smiles & has_expected_n_pi_star][
            np.argmax(total_error_pi_pi_star)
        ]

        mol_similarity_metrics = pd.DataFrame(
            [
                aggregate_array(get_similarity_to_train_mols(smile, train_smiles))
                for smile in predictions[valid_smiles]
            ]
        )
    except Exception:
        smiles_in_train = []
        predicted_pi_pi_star = None
        predicted_n_pi_star = None
        expected_pi_pi_star = None
        expected_n_pi_star = None
        valid_smiles = []
        pi_pi_star_metrics= None
        n_pi_star_metrics = None
        min_error_pi_pi_star = None
        max_error_pi_pi_star = None
        min_error_n_pi_star = None
        max_error_n_pi_star = None
        min_total_error_pi_pi_star = None
        max_total_error_pi_pi_star = None
        mol_similarity_metrics = pd.DataFrame([])

    results = {
        "meta": {
            "temperature": temperature,
            "max_tokens": max_tokens,
            "model": model,
        },
        "predictions": predictions,
        "valid_smiles": valid_smiles,
        "smiles_in_train": smiles_in_train,
        "predicted_pi_pi_star": predicted_pi_pi_star,
        "predicted_n_pi_star": predicted_n_pi_star,
        "expected_pi_pi_star": expected_pi_pi_star,
        "expected_n_pi_star": expected_n_pi_star,
        "fractions_valid_smiles": np.mean(valid_smiles),
        "fractions_smiles_in_train": np.mean(smiles_in_train),
        "pi_pi_star_metrics": pi_pi_star_metrics,
        "n_pi_star_metrics": n_pi_star_metrics,
        "examples": {
            "min_error_pi_pi_star": min_error_pi_pi_star,
            "max_error_pi_pi_star": max_error_pi_pi_star,
            "min_error_n_pi_star": min_error_n_pi_star,
            "max_error_n_pi_star": max_error_n_pi_star,
            "min_total_error_pi_pi_star": min_total_error_pi_pi_star,
            "max_total_error_pi_pi_star": max_total_error_pi_pi_star,
        },
        "mol_similarity_metrics": mol_similarity_metrics,
        "mol_similarity_metrics_mean": mol_similarity_metrics.mean(),
        "mol_similarity_metrics_std": mol_similarity_metrics.std(),
    }

    return results
