In [98]:
import pandas as pd
from ast import literal_eval
from collections import Counter


In [99]:
df = pd.read_pickle("../results/schembl_summs_v3_alg_cleaned_labels.pkl")

In [100]:
vocab_df = pd.read_csv("../results/schembl_summs_v4_vocab_gpt_cleaned_eps_0.340_diff_20030_clusters_errors_fixed.csv")

vocab_df.loc[0, "gpt_cleaned_labels"] = ""

# remove row with API REQUEST ERROR
# vocab_df = vocab_df[vocab_df["gpt_cleaned_labels"] != "API REQUEST ERROR"]

In [101]:
print(len(vocab_df))

words = set()
for i in vocab_df['gpt_cleaned_labels']:
    words.update([i])
print(len(words))

# print clusters with more than 2 words (when gpt messes up / cannot determine)
print(vocab_df[vocab_df['gpt_cleaned_labels'].str.split().str.len() > 2]["gpt_cleaned_labels"])

# drop clusters with more than 2 words (when gpt messes up / cannot determine)
vocab_df = vocab_df[~(vocab_df['gpt_cleaned_labels'].str.split().str.len() > 2)].reset_index(drop=True)

words = set()
for i in vocab_df['gpt_cleaned_labels']:
    words.update([i])
print(len(words))

20030
19610
391     No centroid descriptor can be determined from ...
702                                               b, c, d
1146    No average descriptor can be determined from t...
1172    No centroid descriptor can be determined as th...
1837    No centroid descriptor can be determined as th...
2181    No centroid descriptor can be determined as th...
Name: gpt_cleaned_labels, dtype: object
19604


In [102]:
# make each element in original_clustered_labels its own row, repeating the gpt_cleaned_labels
vocab_df = vocab_df.assign(original_clustered_labels=vocab_df['original_clustered_labels'].str.split(',')).explode('original_clustered_labels')
vocab_df = vocab_df.reset_index()
# convert to dictionary
vocab_dict = dict(zip(vocab_df["original_clustered_labels"].tolist(), vocab_df["gpt_cleaned_labels"].tolist()))

In [103]:
# remove leading spaces from keys
vocab_dict = {k.strip(): v for k, v in vocab_dict.items()}

In [104]:
df["summarizations"] = df["summarizations"].map(set)

In [105]:
# map vocab_dict to each value in summarizations, ignoring when the value is not in vocab_dict and defaulting to the original value
df['summarizations'] = df['summarizations'].map(lambda x: [vocab_dict.get(i, i) for i in x])

In [106]:
words = set()
for i in df["summarizations"]:
    
    words.update(set(i))
print(len(words))
# remove empty strings from df["summarizations"]
df["summarizations"] = df["summarizations"].apply(lambda x: [i for i in x if i != ""])

words = set()
for i in df["summarizations"]:
    words.update(set(i))
print(len(words))

19617
19616


In [107]:
# count all words in df["summarizations"]
counter = Counter()
for summarization in df["summarizations"]:
    counter.update(summarization)
counter.most_common()

[('inhibitor', 36218),
 ('treatment', 32390),
 ('disease', 16016),
 ('compound', 14648),
 ('derivative', 13040),
 ('cancer', 11369),
 ('receptor', 10031),
 ('disorder', 9318),
 ('modulator', 8468),
 ('antagonist', 7619),
 ('agent', 7618),
 ('therapeutic', 7399),
 ('kinase', 7100),
 ('pharmaceutical', 6284),
 ('composition', 6259),
 ('organic', 5089),
 ('agonist', 4433),
 ('inflammatory', 4232),
 ('protein', 4178),
 ('anti-microbial', 4124),
 ('inhibitory', 3737),
 ('activity', 3689),
 ('device', 3572),
 ('antiviral', 3499),
 ('acid', 3294),
 ('polymer', 3062),
 ('anti-inflammatory', 2989),
 ('cell', 2932),
 ('pain', 2833),
 ('diabetes', 2799),
 ('inhibition', 2778),
 ('therapy', 2754),
 ('high', 2558),
 ('modulate', 2533),
 ('treat', 2421),
 ('cardiovascular', 2391),
 ('prevention', 2370),
 ('electroluminescence', 2350),
 ('anti-tumor', 2284),
 ('medicament', 2088),
 ('material', 2062),
 ('drug', 2058),
 ('autoimmune', 1995),
 ('control', 1987),
 ('synthesis', 1986),
 ('neurodegenerati

In [108]:
df[["smiles", "cid", "patent_ids", "summarizations"]].to_csv("../results/schembl_summs_v4_gpt_cleaned_eps_0.340_diff_20030_final.csv")

In [109]:
# remove all entries if they contain less than 10 counts

counter_trimmed = counter.copy()
for key in counter.keys():
    if counter[key] < 50:
        del counter_trimmed[key]

print(len(counter_trimmed))

# print(counter_trimmed.most_common()[-100:])

1544


In [110]:
# remove entries from df["summariations"] if they are not in counter_trimmed
df["summarizations"] = df["summarizations"].apply(lambda x: [y for y in x if y in counter_trimmed])


In [111]:
print(len(df))

df = df.dropna(subset=["summarizations"])

print(len(df))

99454
99454


In [114]:
df[["smiles", "cid", "patent_ids", "summarizations"]].to_csv("../results/schembl_summs_v5_final.csv", index=False)

In [115]:
from rdkit import Chem
df["fingerprint"] = df["smiles"].apply(lambda x: Chem.RDKFingerprint(Chem.MolFromSmiles(x)))



In [117]:
import numpy as np
df["fingerprint"] = df["fingerprint"].apply(np.array)

In [118]:
df[["smiles", "cid", "patent_ids", "summarizations", "fingerprint"]].to_pickle("../results/schembl_summs_v5_final_fp.pkl")