In [1]:
from pathlib import Path

import click
import numpy as np
import pandas as pd
import Levenshtein
from tqdm.auto import tqdm
from loguru import logger

In [2]:
%load_ext autoreload
%autoreload 2
from nncomp_molecule.preprocessors import normalize_inchi_batch, disable_rdlogger
disable_rdlogger()

In [16]:
OUTDIR = Path("/work/input/kfujikawa/kf-bms-candidates-v2")
OUT_COLUMNS = [
    "image_id",
    "InChI",
    "levenshtein",
    "is_valid",
]

## Load valid: kf_0523, kf_0525, kf_0527, yokoo_0527

In [4]:
VALID_CSVs = """
/work/output/1113_swin_large_bert_384/valid_beam=1.csv
/work/output/1113_swin_large_bert_384/valid_beam=4.csv
/work/output/9005_1102+1105+1106/valid_beam=1.csv
/work/output/9006_1103+1106+1109/valid_beam=1.csv
/work/output/9006_1103+1106+1109/valid_beam=4.csv
/work/output/9007_1109+1113/valid_beam=1.csv
/work/output/9007_1109+1113/valid_beam=4.csv
/work/output/9007_1109+1113/valid_beam=8.csv
""".strip().split()

TEST_CSVs = """
/work/output/1113_swin_large_bert_384/test_beam=1.csv
/work/output/1113_swin_large_bert_384/test_beam=4.csv
/work/output/9005_1102+1105+1106/test_beam=1.csv
/work/output/9006_1103+1106+1109/test_beam=1.csv
/work/output/9006_1103+1106+1109/test_beam=4.csv
/work/output/9007_1109+1113/test_beam=1.csv
/work/output/9007_1109+1113/test_beam=4.csv
/work/output/9007_1109+1113/test_beam=8.csv
""".strip().split()

def check_normalization_error(inchi):
    if "?" in inchi:
        return True
    if "/q" in inchi:
        return True
    if "/p" in inchi:
        return True
    return False
    
merged_df = pd.concat([
    *[pd.read_csv(path).assign(has_label=True).query("is_valid") for path in tqdm(VALID_CSVs)],
    *[pd.read_csv(path).assign(has_label=False).query("is_valid") for path in tqdm(TEST_CSVs)],
], ignore_index=True)
merged_df = merged_df.drop_duplicates(subset=["image_id", "normed_InChI"])
merged_df["has_error"] = merged_df.normed_InChI.apply(check_normalization_error)
merged_df = merged_df.query("has_error")
merged_df

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=8.0), HTML(value='')))




Unnamed: 0,image_id,InChI,InChI_GT,score,is_valid,normed_InChI,normed_score,levenshtein,has_label,has_error
337,4011015f315c,"InChI=1S/C26H36ClN4O10PS/c1-25(2,16-28-24(36)3...","InChI=1S/C26H36ClN4O10PS/c1-25(2,16-28-24(36)3...",0.379618,True,"InChI=1S/C26H36ClN4O10PS/c1-25(2,16-28-24(36)3...",0.095704,1.0,True,True
1513,a0347566e676,InChI=1S/C22H28FIN3O9P/c1-12(2)34-19(30)13(3)2...,InChI=1S/C22H28FIN3O9P/c1-12(2)34-19(30)13(3)2...,0.400350,True,InChI=1S/C22H28FIN3O9P/c1-12(2)34-19(30)13(3)2...,0.143347,5.0,True,True
2068,43f348271e34,InChI=1S/C25H29N9O4/c1-3-27-24(38)34-20-9-19(3...,InChI=1S/C25H29N9O4/c1-3-27-24(38)34-20-9-19(3...,0.354699,True,InChI=1S/C25H28N9O4/c1-3-27-24(38)34-20-9-19(3...,1.020441,38.0,True,True
2288,b9ae8356058b,InChI=1S/C39H41N3O3S/c1-5-7-8-9-13-26(3)24-29(...,InChI=1S/C39H41N3O3S/c1-5-7-8-9-13-26(3)24-29(...,0.396320,True,InChI=1S/C39H41N3O3S/c1-5-7-8-9-13-26(3)24-29(...,0.685778,25.0,True,True
2409,e40253e626ad,InChI=1S/C28H26N4O3/c1-28-26(34-3)17(29-2)12-2...,InChI=1S/C29H28N4O4/c1-29-26(35-3)16(30-2)13-1...,0.368293,True,InChI=1S/C28H26N4O3/c1-28-26(34-3)17(29-2)12-2...,0.199130,89.0,True,True
...,...,...,...,...,...,...,...,...,...,...
23592812,192bebe0ca86,"InChI=1S/C4H2F6N/c5-3(6,7)2(1-11(2)10)4(8,9)10...",,0.168190,True,"InChI=1S/C4H2F6N/c5-3(6,7)2-1-11(2)10-4(2,8)9/...",1.464674,,False,True
23593234,83ae8d3f12b0,InChI=1S/C5H6BBrO4/c6-10-4(8)2-1-3-5(9)11-7/h1...,,0.016278,True,InChI=1S/C5H7BBrO4/c6-10-4(8)2-1-3-5(9)11-7/h8...,1.923140,,False,True
23593894,61c21f9bcf36,"InChI=1S/C3Cl2F4O2/c4-1(6)2(5,7)11-3(8,9)10-1/...",,0.543798,True,"InChI=1S/C3Cl2F4O2/c4-1(6)2(5,7)11-3(8,9)10-1/...",0.491674,,False,True
23593895,61c21f9bcf36,"InChI=1S/C3Cl2F4O2/c4-1(6)2(5,7)11-3(8,9)10-1/...",,0.465926,True,"InChI=1S/C3Cl2F4O2/c4-1(6)2(5,7)11-3(8,9)10-1/...",0.496943,,False,True


In [19]:
valid_df = merged_df.query("has_label").copy()
valid_df["renormed_levenshtein"] = [
    Levenshtein.distance(x, y)
    for x, y in valid_df[["InChI", "InChI_GT"]].values
]
valid_df[["levenshtein", "renormed_levenshtein"]].agg(["count", "mean"])

Unnamed: 0,levenshtein,renormed_levenshtein
count,9598.0,9598.0
mean,31.724213,26.74974


In [18]:
merged_df.query("has_label")[OUT_COLUMNS].to_csv(OUTDIR / "valid_kf_0531_renormed.csv", index=False)
merged_df.query("~has_label")[OUT_COLUMNS].to_csv(OUTDIR / "test_kf_0531_renormed.csv", index=False)