In [1]:
import gc
import io
import multiprocessing as mp
from copy import deepcopy
from urllib.parse import urlparse

import google.cloud.storage as gcs
import pandas as pd
import seaborn as sns
import Levenshtein
from retrying import retry
from tqdm.auto import tqdm

In [2]:
GCP_PROJECT = "dena-ai-training-28-gcp"


@retry(stop_max_attempt_number=3)
def download_from_gcs(path: str):
    url = urlparse(path)
    bucket_name = url.netloc
    storage_client = gcs.Client(project=GCP_PROJECT)
    bucket = storage_client.get_bucket(bucket_name)
    blob = gcs.Blob(url.path[1:], bucket)
    return blob.download_as_string()


def load_prediction(path: str):
    if path.split("/")[-2] == "kf-bms-candidates-v2":
        model = path.split("/")[-3]
    else:
        model = path.split("/")[-2]
    df = pd.read_csv(io.BytesIO(download_from_gcs(path)))\
        .assign(model=model, filename=path.split("/")[-1])\
        .query("is_valid | image_id.isin(@NO_VALID_IMAGE_IDs)", engine="python")
    return df

In [3]:
n_valid_InChIs = pd.read_csv(io.BytesIO(download_from_gcs("gs://kfujikawa-kaggle-bms-molecular-generation/kfujikawa/kf-bms-candidates-v2/test_n_valid_InChIs.csv")))
NO_VALID_IMAGE_IDs = n_valid_InChIs.query("n_valid_InChIs == 0").image_id
len(NO_VALID_IMAGE_IDs)

5152

# Load predictions

In [4]:
TEST_FILENAMES = [
    "test_kf_0523.csv",
    "test_kf_0525.csv",
    "test_kf_0527.csv",
    "test_yokoo_0527.csv",
    "test_camaro_0525.csv",
    "test_yokoo_0531.csv",
    "test_kf_0531_renormed.csv",
    "test_camaro_old_submissions.csv",
    "test_kf_0531.csv",
    "test_camaro_0531.csv",
    "test_yokoo_0601.csv",
]
KF_MODELS = [
    "1109_vtnt_bert_512-1024-denoise-5",
    "1113_swin_large_bert_384",
]
KF_TEST_CSVs = [
    f"gs://kfujikawa-kaggle-bms-molecular-generation/kfujikawa/{model}/kf-bms-candidates-v2/{filename}"
    for model in KF_MODELS
    for filename in TEST_FILENAMES
]

In [5]:
LYAKAAP_TEST_CSVs = """
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_yokoo_0527.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_kf_0523.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_kf_0525.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_kf_0527.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_camaro_0525.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_camaro_old_submissions.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_kf_0531_renormed.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_yokoo_0531.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_kf_0531.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_camaro_0531.csv
gs://kfujikawa-kaggle-bms-molecular-generation/yokoo/v52/test_yokoo_0601.csv
""".strip().split()

In [6]:
with mp.Pool() as pool:
    total = len(KF_TEST_CSVs) + len(LYAKAAP_TEST_CSVs)
    iterator = pool.imap_unordered(load_prediction, [*KF_TEST_CSVs, *LYAKAAP_TEST_CSVs])
    kyakaap_df = pd.concat(list(tqdm(iterator, total=total)), ignore_index=True)
kyakaap_df = kyakaap_df.drop_duplicates(subset=["model", "image_id", "InChI"])
display(kyakaap_df.head(1))
display(kyakaap_df.groupby("model").image_id.count())
with pd.option_context("display.float_format", '{:.4f}'.format, "display.max_columns", None):
    display(kyakaap_df.groupby(["filename", "model"]).describe().T)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=33.0), HTML(value='')))




Unnamed: 0,image_id,InChI,levenshtein,is_valid,normed_score,model,filename
0,00073e401fa1,InChI=1S/C20H27NO9S/c1-12(22)26-11-14-16-18(28...,,True,0.04187,1109_vtnt_bert_512-1024-denoise-5,test_camaro_0531.csv


model
1109_vtnt_bert_512-1024-denoise-5    5811688
1113_swin_large_bert_384             5811688
v52                                  5811688
Name: image_id, dtype: int64

Unnamed: 0_level_0,filename,test_camaro_0525.csv,test_camaro_0525.csv,test_camaro_0525.csv,test_camaro_0531.csv,test_camaro_0531.csv,test_camaro_0531.csv,test_camaro_old_submissions.csv,test_camaro_old_submissions.csv,test_camaro_old_submissions.csv,test_kf_0523.csv,test_kf_0523.csv,test_kf_0523.csv,test_kf_0525.csv,test_kf_0525.csv,test_kf_0525.csv,test_kf_0527.csv,test_kf_0527.csv,test_kf_0527.csv,test_kf_0531.csv,test_kf_0531.csv,test_kf_0531.csv,test_kf_0531_renormed.csv,test_kf_0531_renormed.csv,test_kf_0531_renormed.csv,test_yokoo_0527.csv,test_yokoo_0527.csv,test_yokoo_0527.csv,test_yokoo_0531.csv,test_yokoo_0531.csv,test_yokoo_0531.csv,test_yokoo_0601.csv,test_yokoo_0601.csv,test_yokoo_0601.csv
Unnamed: 0_level_1,model,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52,1109_vtnt_bert_512-1024-denoise-5,1113_swin_large_bert_384,v52
levenshtein,count,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
levenshtein,mean,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
levenshtein,std,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
levenshtein,min,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
levenshtein,25%,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
levenshtein,50%,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
levenshtein,75%,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
levenshtein,max,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
normed_score,count,30694.0,30918.0,30918.0,11367.0,11367.0,11287.0,41850.0,41850.0,43041.0,3753511.0,3753511.0,3753511.0,304700.0,304700.0,304700.0,906933.0,906933.0,906933.0,24209.0,24209.0,24209.0,63822.0,63598.0,63678.0,8605.0,8605.0,8605.0,413958.0,413958.0,412767.0,252039.0,252039.0,252039.0
normed_score,mean,0.1987,0.2217,0.2159,0.1729,0.1751,0.1704,0.2397,0.2484,0.2327,0.0737,0.0802,0.086,0.1751,0.2002,0.2265,0.1679,0.1897,0.2112,0.1292,0.1156,0.1658,0.1472,0.1509,0.1733,0.1589,0.1362,0.0355,0.2342,0.2096,0.1251,0.0976,0.1014,0.0881


In [7]:
kyakaap_df = kyakaap_df.groupby(["image_id", "InChI"])[["is_valid", "normed_score"]].mean().reset_index()
kyakaap_df["score_rank"] = kyakaap_df.normed_score.rank() / len(kyakaap_df)

# Camaro

In [8]:
CAMARO_TEST_CSVs = """
gs://kfujikawa-kaggle-bms-molecular-generation/camaro/exp084/test_kf_0523.csv
gs://kfujikawa-kaggle-bms-molecular-generation/camaro/exp084/test_kf_0525.csv
gs://kfujikawa-kaggle-bms-molecular-generation/camaro/exp084/test_kf_0527.csv
gs://kfujikawa-kaggle-bms-molecular-generation/camaro/exp084/test_yokoo_0527.csv
gs://kfujikawa-kaggle-bms-molecular-generation/camaro/exp084/test_yokoo_0531.csv
gs://kfujikawa-kaggle-bms-molecular-generation/camaro/exp084/test_yokoo_0601.csv
""".strip().split()

In [10]:
with mp.Pool() as pool:
    total = len(CAMARO_TEST_CSVs)
    iterator = pool.imap_unordered(load_prediction, CAMARO_TEST_CSVs)
    camaro_df = pd.concat(list(tqdm(iterator, total=total)), ignore_index=True)
camaro_df = camaro_df.drop_duplicates(subset=["model", "image_id", "InChI"])
display(camaro_df.head(1))
display(camaro_df.groupby("model").image_id.count())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6.0), HTML(value='')))




Unnamed: 0,image_id,InChI,levenshtein,is_valid,focal_score,ce_score,model,filename
0,007706c43e8d,InChI=1S/C16H32O6/c1-5-17-11-13-21-9-3-7-19-15...,,False,0.18025,0.521818,exp084,test_yokoo_0527.csv


model
exp084    5639758
Name: image_id, dtype: int64

In [11]:
camaro_df = camaro_df.groupby(["image_id", "InChI"])[["focal_score", "is_valid"]].mean().reset_index()
camaro_df["score_rank"] = camaro_df.focal_score.rank() / len(camaro_df)

# Blend

In [12]:
sort_keys = dict(
    image_id=True,
    is_valid=False,
    score_rank=True,
)
merged_df = pd.concat([
    camaro_df,
    kyakaap_df,
], ignore_index=True)
merged_ensembled_df = merged_df.groupby(["image_id", "InChI"])[["score_rank", "is_valid"]].mean().reset_index()
merged_ensembled_df = merged_ensembled_df.sort_values(
    by=list(sort_keys.keys()),
    ascending=list(sort_keys.values()),
).groupby("image_id").first().reset_index()

In [14]:
submission_df = merged_ensembled_df[["image_id", "InChI"]]
assert len(submission_df) == 1616107
submission_df.to_csv("submission_0705_1109+1113+084_0601.csv", index=False)
!head submission_0702_kyakaap+084_0601.csv
!wc submission_0702_kyakaap+084_0601.csv

image_id,InChI
00000d2a601c,"InChI=1S/C10H14BrN5S/c1-6-10(11)9(16(3)14-6)4-7(12-2)8-5-13-17-15-8/h5,7,12H,4H2,1-3H3"
00001f7fc849,"InChI=1S/C14H18ClN3/c1-2-7-16-9-13-10-17-14(18-13)8-11-3-5-12(15)6-4-11/h3-6,10,16H,2,7-9H2,1H3,(H,17,18)"
000037687605,"InChI=1S/C16H13BrN2O/c1-11(20)12-6-7-13(9-18)16(8-12)19-10-14-4-2-3-5-15(14)17/h2-8,19H,10H2,1H3"
00004b6d55b6,"InChI=1S/C14H19FN4O/c1-14(2,3)12-13(16)17-18-19(12)8-9-5-6-10(20-4)7-11(9)15/h5-7H,8,16H2,1-4H3"
00004df0fe53,"InChI=1S/C9H12O2/c1-4-5-2-6-7(3-5)11-9(10)8(4)6/h4-8H,2-3H2,1H3/t4-,5+,6+,7+,8+/m1/s1"
000085dab281,"InChI=1S/C20H38O/c1-20(2)18-16-14-12-10-8-6-4-3-5-7-9-11-13-15-17-19-21/h17,20H,3-16,18H2,1-2H3"
00008decfc8d,"InChI=1S/C15H26N2/c1-5-10-16-15(11-12(3)6-2)14-9-7-8-13(4)17-14/h7-9,12,15-16H,5-6,10-11H2,1-4H3"
00008e8fe68c,"InChI=1S/C22H25Cl2N3O6/c1-6-32-20-13(23)8-9-14(21(20)33-7-2)26-27-18(12(3)28)22(29)25-19-16(31-5)11-10-15(30-4)17(19)24/h8-11,18H,6-7H2,1-5H3,(H,25,29)"
000095714f0f,"InChI=1S/C25H30ClN3O2/c1-17-4-9-23

In [15]:
baseline_df = pd.read_csv("submission_LB059.csv")
merged_df = submission_df.merge(baseline_df, on="image_id")
merged_df["levenshtein"] = [
    Levenshtein.distance(x, y)
    for x, y in tqdm(merged_df[["InChI_x", "InChI_y"]].values)
]
print(merged_df.query("InChI_x != InChI_y").shape)
print(merged_df.levenshtein.mean())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1616107.0), HTML(value='')))


(13226, 4)
0.15749823495597753
