In [1]:
import gc
import io
from copy import deepcopy
from urllib.parse import urlparse

import cv2
import numpy as np
import pandas as pd
import seaborn as sns
import Levenshtein
import ipywidgets
import plotly.express as px
import matplotlib.pyplot as plt
from retrying import retry
from tqdm.auto import tqdm
from rdkit import Chem

In [2]:
TEST_CSVs = """
/work/input/camaro/exp084/test_camaro_0525.csv
/work/input/camaro/exp084_rescore_test_kf_0523.csv
/work/input/camaro/exp084_rescore_test_kf_0525.csv
/work/input/camaro/exp084_rescore_test_kf_0527.csv
""".strip().split()

In [3]:
def load_prediction(path: str):
    df = pd.read_csv(path)\
        .assign(filename=path.split("/")[-1])
    return df

In [4]:
test_df = pd.concat([load_prediction(path) for path in tqdm(TEST_CSVs)], ignore_index=True)
test_df = test_df.rename(columns={'focal_score':'normed_score'})
display(test_df.head(1))
with pd.option_context("display.float_format", '{:.4f}'.format):
    display(test_df.groupby(["filename"]).describe().T)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))




Unnamed: 0,image_id,InChI,levenshtein,is_valid,normed_score,ce_score,filename
0,00000d2a601c,InChI=1S/C10H14BrN5S/c1-6-10(11)9(16(3)14-6)4-...,,True,9.234664e-09,3.7e-05,test_camaro_0525.csv


Unnamed: 0,filename,exp084_rescore_test_kf_0523.csv,exp084_rescore_test_kf_0525.csv,exp084_rescore_test_kf_0527.csv,test_camaro_0525.csv
levenshtein,count,0.0,0.0,0.0,0.0
levenshtein,mean,,,,
levenshtein,std,,,,
levenshtein,min,,,,
levenshtein,25%,,,,
levenshtein,50%,,,,
levenshtein,75%,,,,
levenshtein,max,,,,
normed_score,count,9655806.0,1578538.0,17255.0,1616103.0
normed_score,mean,0.0805,0.1093,0.0666,0.0034


In [9]:
%%time
sort_keys = dict(
    image_id=True,
    is_valid=False,
    normed_score=True,
)
lb085_df = test_df.groupby(["image_id", "InChI"])[["is_valid", "normed_score"]].mean().reset_index()
lb085_df = lb085_df.sort_values(
    by=list(sort_keys.keys()),
    ascending=list(sort_keys.values()),
).groupby("image_id").first().reset_index()

CPU times: user 1min 9s, sys: 2.3 s, total: 1min 12s
Wall time: 1min 12s


In [10]:
%%time
sort_keys = dict(
    image_id=True,
    is_valid=False,
    normed_score=True,
)
lb063_df = test_df.query("filename != 'test_camaro_0525.csv'")\
    .groupby(["image_id", "InChI"])[["is_valid", "normed_score"]]\
    .mean().reset_index()
lb063_df = lb063_df.sort_values(
    by=list(sort_keys.keys()),
    ascending=list(sort_keys.values()),
).groupby("image_id").first().reset_index()

CPU times: user 1min 14s, sys: 2.65 s, total: 1min 17s
Wall time: 1min 17s


In [12]:
merged_df = lb063_df.merge(lb085_df, on=["image_id"], suffixes=["_lb063", "_lb085"])
diff_df = merged_df.query("InChI_lb063 != InChI_lb085").copy()
diff_df["diff_normed_score"] = (diff_df.normed_score_lb063 - diff_df.normed_score_lb085).abs()
diff_df["levenshtein"] = [
    Levenshtein.distance(x, y)
    for x, y in diff_df[["InChI_lb063", "InChI_lb085"]].values
]
diff_df = diff_df.sort_values("diff_normed_score", ascending=False)

In [13]:
with pd.option_context("display.max_colwidth", None):
    display(diff_df)

Unnamed: 0,image_id,InChI_lb063,is_valid_lb063,normed_score_lb063,InChI_lb085,is_valid_lb085,normed_score_lb085,diff_normed_score,levenshtein
823676,8278642346fd,"InChI=1S/C6H2F3N3/c7-2-1-3(8)5-6(4(2)9)11-12-10-5/h1H,(H,10,11,12)",True,1.196999,"InChI=1S/C6H4ClN3/c7-4-2-1-3-5-6(4)9-10-8-5/h1-3H,(H,8,9,10)/i1D,3D",True,2.755862e-08,1.196999,29
348138,37210cb64fc3,"InChI=1S/C2H2ClF3O/c3-2(4,5)1-7-6/h1H2/i1D2",True,1.095005,"InChI=1S/C2H3F3O/c3-2(4)1-6-5/h2H,1H2/i1D2",True,1.312955e-08,1.095005,10
1134653,b3cb267055ca,"InChI=1S/C14H13ClN2O/c15-11-5-10(6-12(18)7-11)14-16-8-9-3-1-2-4-13(9)17-14/h5-8,18H,1-4H2/i8D",True,1.042551,"InChI=1S/C14H12ClFN2O/c15-13-11-3-1-2-4-12(11)17-14(18-13)8-5-9(16)7-10(19)6-8/h5-7,19H,1-4H2",True,9.826707e-09,1.042551,36
344241,367f00c9bc94,"InChI=1S/C21H33N3O2/c1-6-15(2)24(20(25)26-21(3,4)5)19-18-16(10-12-22-19)11-14-23-13-8-7-9-17(18)23/h10,12,15,17H,6-9,11,13-14H2,1-5H3",True,0.753532,"InChI=1S/C21H35N3O2/c1-7-16(3)24(20(25)26-21(4,5)6)19-17(12-11-14-22-19)18-13-9-10-15-23(18)8-2/h11-12,14,16,18H,7-10,13,15H2,1-6H3",True,1.505410e-08,0.753532,42
1316646,d09b9bb51f23,"InChI=1S/C23H16O/c1-3-10-18-16(7-1)9-5-11-19(18)20-12-6-14-23-21(20)15-17-8-2-4-13-22(17)24-23/h1-14H,15H2",True,0.691327,"InChI=1S/C23H14O/c1-2-12-19-17(8-1)22-18-11-4-7-14-6-3-9-15(21(14)18)16-10-5-13-20(24-19)23(16)22/h1-13,22H",True,9.510845e-09,0.691327,48
...,...,...,...,...,...,...,...,...,...
958885,97ef2742ddf7,"InChI=1S/C14H29NO2/c1-6-14(5,12(4)16)9-15-13-7-10(2)17-11(3)8-13/h10-13,15-16H,6-9H2,1-5H3/t10-,11+,12-,13?,14+/m0/s1",True,,"InChI=1S/C14H29NO2/c1-6-14(5,12(4)16)9-15-13-7-10(2)17-11(3)8-13/h10-13,15-16H,6-9H2,1-5H3/t10-,11-,12+,14+/m0/s1",True,3.537280e-07,,6
959589,980c36908e80,"InChI=1S/C19H23F3N4O2/c20-9-12-5-13(12)11-28-18(17(10-23)26-3-1-24-2-4-26)19(27)25-16-7-14(21)6-15(22)8-16/h6-8,10,12-13,23-24H,1-5,9,11H2,(H,25,27)/b18-17+,23-10?/t12-,13+/m0/s1",True,,"InChI=1S/C19H23F3N4O2/c20-9-12-5-13(12)11-28-18(17(10-23)26-3-1-24-2-4-26)19(27)25-16-7-14(21)6-15(22)8-16/h6-8,10,12-13,23-24H,1-5,9,11H2,(H,25,27)/b18-17+/t12-,13-/m1/s1",True,1.219893e-08,,9
1084275,abde89662b72,"InChI=1S/C13H17BrFN5O4/c1-4(2)23-11-6-9(14)20(19-10(6)17-13(16)18-11)12-7(15)8(22)5(3-21)24-12/h4-5,7-8,12,21-22H,3H2,1-2H3,(H2,16,17,19)/p+1/t5-,7+,8+,12-/m1/s1",True,,"InChI=1S/C13H17BrFN5O4/c1-4(2)23-11-6-9(14)20(19-10(6)17-13(16)18-11)12-7(15)8(22)5(3-21)24-12/h4-5,7-8,12,21-22H,3H2,1-2H3,(H2,16,17,18,19)/t5-,7-,8+,12+/m0/s1",True,1.011821e-05,,9
1344096,d4f18e45fbb9,"InChI=1S/C24H32N2O3/c1-15(27)25-17-10-8-16(9-11-17)22(29)26-13-12-24(4)19-6-5-7-20(28)18(19)14-21(26)23(24,2)3/h5-7,10,16,21,28H,8-9,11-14H2,1-4H3,(H,25,27)/t16?,21-,24+/m0/s1",True,,"InChI=1S/C24H30N2O3/c1-15(27)25-17-10-8-16(9-11-17)22(29)26-13-12-24(4)19-6-5-7-20(28)18(19)14-21(26)23(24,2)3/h5-8,10,21,28H,9,11-14H2,1-4H3,(H,25,27)/t21-,24-/m1/s1",True,1.224649e-08,,13


In [47]:
slider = ipywidgets.IntSlider(min=0, max=len(diff_df))
@ipywidgets.interact(i=slider)
def visualize(i):
    sample = diff_df.sort_values("levenshtein", ascending=False).iloc[i]
    image_path = f"/work/input/bms-molecular-translation/test/{'/'.join(sample.image_id[:3])}/{sample.image_id}.png"
    image = cv2.imread(image_path)
    h, w, d = image.shape
    print(f"Levenshtein: {sample.levenshtein}")
    print(f"上: LB063 (score: {sample.normed_score_lb063})")
    print(sample.InChI_lb063)
    print(f"下: LB085 (score: {sample.normed_score_lb085})")
    print(sample.InChI_lb085)
    if h > w:
        image = np.flipud(image.transpose(1, 0, 2))
    plt.figure(figsize = (20, 20))
    plt.imshow(image)
    display(Chem.MolFromInchi(sample.InChI_lb063))
    display(Chem.MolFromInchi(sample.InChI_lb085))

interactive(children=(IntSlider(value=0, description='i', max=48699), Output()), _dom_classes=('widget-interac…

In [46]:
len("InChI=1S/C59H109NO5/c1-3-5-7-9-11-13-15-17-18-19-22-25-28-31-35-39-43-47-51-57(62)56(55-61)60-58(63)52-48-44-40-36-32-29-26-23-20-21-24-27-30-34-38-42-46-50-54-65-59(64)53-49-45-41-37-33-16-14-12-10-8-6-4-2/h10,12,16,24,27,30,34,47,51,56-57,61-62H,3-9,11,13-15,17-23,25-26,28-29,31-33,35-46,48-50,")

297

In [26]:
# LB差: 0.22
# Public LB: 1616107 // 4 = 40万件
# → Levenshtein差: 40万 * 0.22 = 8.8万
# 3%しか違いがない → 1.2万件

404026

In [27]:
48699 / 1616107

0.030133524574796098

In [35]:
IMAGE.shape

(1805, 135, 3)

In [51]:
diff_df.levenshtein.sum()

730081