# Classfication

In [1]:
import venn
import pandas as pd
from data.splitters import scaffold_split
from data.loader import MoleculeDataset ##

test_len = {}
data_root = "dataset/"
for feature in ['CNN']:
    for dt in ['bace', 'bbbp', 'tox21', 'toxcast', 'sider', 'hiv', 'clintox', 'freesolv', 'lipophilicity', 'esol']:
        dataset = MoleculeDataset(data_root + dt, dataset=dt, feature=feature)
        smiles_list = pd.read_csv(data_root + dt + '/processed/smiles.csv', header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
        test_len[dt] = len(test_dataset)
test_len

{'bace': 152,
 'bbbp': 198,
 'tox21': 778,
 'toxcast': 854,
 'sider': 136,
 'hiv': 4076,
 'clintox': 144,
 'freesolv': 63,
 'lipophilicity': 420,
 'esol': 113}

In [2]:
import numpy as np
import pandas as pd
from pathlib import Path

dts = ['bace', 'bbbp', 'tox21', 'toxcast', 'sider', 'hiv']
result = {dt: {42: {}, 43: {}, 44: {}, 45: {}, 46: {}} for dt in dts}
exps = [fd for fd in Path('experiments').glob("*") if fd.stem != '20250710']
for fd in exps:
    feature = fd.stem
    print(feature)
    
    files = [f for f in fd.glob('*.csv') if 'result' not in f.stem]
    for f in files:
        fs = f.stem.split('_')
        dt, seed = fs[0], int(fs[3])
        
        if dt.lower() not in dts:
            continue
        
        data = pd.read_csv(f)
        tc = [c for c in data.columns if 'task' in c]
        for yt, yp in list(zip(tc[::2], tc[1::2])):
            i = yt.split('_')[1]
            data[f'correct_{i}'] = (data[yt] == data[yp]).astype(int)
            correct = data[data[f'correct_{i}'] == 1]
            correct_smiles = set(correct['smiles'].tolist())
            
            if feature == '2D-GNN-tuto':
                feature = '2D-GNN'

            result[dt.lower()][seed][f'{feature}_task{i}'] = correct_smiles

2D-GNN-tuto
3D-GNN
ChemBERTa
CNN
FP-MACCS
FP-Morgan


In [3]:
def draw_venn(sts, dt_name, title_name, file_path):
    import matplotlib.pyplot as plt
    plt.rcParams['font.family'] = 'Arial'
    
    global test_len
    
    labels = venn.get_labels([sts[0][1], sts[1][1], sts[2][1], sts[3][1], sts[4][1]], fill=['number'])
    venn.venn3(labels, names=[sts[0][0], sts[1][0], sts[2][0], sts[3][0], sts[4][0]])

    success = sum([int(v) for v in labels.values()])
    st1 = sum(map(int, [labels[k] for k in labels.keys() if '1' in k[0]]))
    st2 = sum(map(int, [labels[k] for k in labels.keys() if '1' in k[1]]))
    st3 = sum(map(int, [labels[k] for k in labels.keys() if '1' in k[2]]))
    st4 = sum(map(int, [labels[k] for k in labels.keys() if '1' in k[3]]))
    st5 = sum(map(int, [labels[k] for k in labels.keys() if '1' in k[4]]))

    fail = test_len[dt_name] - success
    memo1_title = '[Test Set]'
    memo1 = f'Samples: {test_len[dt_name]}\nFailed: {fail} ({fail / test_len[dt_name]:.0%})'

    memo2_title = '[Oracle Accuracy]'
    memo2 = f'{sts[0][0]}: {st1} ({st1 / success:.0%})\n{sts[1][0]}: {st2} ({st2 / success:.0%})\n{sts[2][0]}: {st3} ({st3 / success:.0%})\n{sts[3][0]}: {st4} ({st4 / success:.0%})\n{sts[4][0]}: {st5} ({st5 / success:.0%})'
    plt.text(0, 1.05, memo1_title, fontsize=18)
    plt.text(0, 0.99, memo1, fontsize=13)
    plt.text(0, 0.94, memo2_title, fontsize=18)
    plt.text(0, 0.79, memo2, fontsize=13)
    plt.title(f'{title_name}', fontsize=30, fontweight='bold', pad=10)
    plt.tight_layout()
    plt.savefig(file_path, dpi=300)
    plt.close()

In [4]:
def draw_venn_v2(lbs, legends, title_name, file_path, test_samples, fail, task_nums):
    import matplotlib.pyplot as plt
    plt.rcParams['font.family'] = 'Arial'
    
    global test_len
    venn.venn3(lbs, names=legends)

    success = sum([int(v) for v in lbs.values()])
    st1 = sum(map(int, [lbs[k] for k in lbs.keys() if '1' in k[0]]))
    st2 = sum(map(int, [lbs[k] for k in lbs.keys() if '1' in k[1]]))
    st3 = sum(map(int, [lbs[k] for k in lbs.keys() if '1' in k[2]]))
    st4 = sum(map(int, [lbs[k] for k in lbs.keys() if '1' in k[3]]))
    st5 = sum(map(int, [lbs[k] for k in lbs.keys() if '1' in k[4]]))

    memo1_title = '[Test Set]'
    memo1 = f'Samples: {test_samples} (num_tasks: {task_nums})\nFailed: {fail} ({fail / test_samples:.1%})'

    memo2_title = '[Oracle Accuracy]'
    memo2 = f'All: {success}\n{legends[0]}: {st1} ({st1 / success:.1%})\n{legends[1]}: {st2} ({st2 / success:.1%})\n{legends[2]}: {st3} ({st3 / success:.1%})\n{legends[3]}: {st4} ({st4 / success:.1%})\n{legends[4]}: {st5} ({st5 / success:.1%})'
    plt.text(0, 1.05, memo1_title, fontsize=16)
    plt.text(0, 1.00, memo1, fontsize=12)
    plt.text(0, 0.95, memo2_title, fontsize=16)
    plt.text(0, 0.80, memo2, fontsize=12)
    plt.title(f'{title_name}', fontsize=30, fontweight='bold', pad=10)
    plt.tight_layout()
    plt.savefig(file_path, dpi=300)
    plt.close()

In [58]:
def draw_venn_v3(lbs, legends, title_name, file_path, test_samples, fail, task_nums):
    import matplotlib.pyplot as plt
    plt.rcParams['font.family'] = 'Arial'
    
    global test_len
    venn.venn3(lbs, names=legends)

    success = sum([float(v) for v in lbs.values()])
    st1 = sum(map(float, [lbs[k] for k in lbs.keys() if '1' in k[0]]))
    st2 = sum(map(float, [lbs[k] for k in lbs.keys() if '1' in k[1]]))
    st3 = sum(map(float, [lbs[k] for k in lbs.keys() if '1' in k[2]]))
    st4 = sum(map(float, [lbs[k] for k in lbs.keys() if '1' in k[3]]))
    st5 = sum(map(float, [lbs[k] for k in lbs.keys() if '1' in k[4]]))

    memo1_title = '[Test Set]'
    memo1 = f'Samples: {test_samples} (num_tasks: {task_nums})\nFailed: {fail} ({fail / test_samples:.1%})'

    memo2_title = '[Oracle Accuracy]'
    memo2 = f'All: {success}\n{legends[0]}: {st1} ({st1 / success:.1%})\n{legends[1]}: {st2} ({st2 / success:.1%})\n{legends[2]}: {st3} ({st3 / success:.1%})\n{legends[3]}: {st4} ({st4 / success:.1%})\n{legends[4]}: {st5} ({st5 / success:.1%})'
    plt.text(0, 1.05, memo1_title, fontsize=16)
    plt.text(0, 1.00, memo1, fontsize=12)
    plt.text(0, 0.95, memo2_title, fontsize=16)
    plt.text(0, 0.80, memo2, fontsize=12)
    plt.title(f'{title_name}', fontsize=30, fontweight='bold', pad=10)
    plt.tight_layout()
    plt.savefig(file_path, dpi=300)
    plt.close()

In [None]:
# tox21, toxcast, sider (multi-task)
for dt in ['bace', 'bbbp', 'hiv', 'tox21', 'toxcast', 'sider']:
    out = Path(f'Venn/classification/')
    out.mkdir(parents=True, exist_ok=True)
    
    print(dt)

    fail = []
    success_smi = []
    total_score_v1 = {}
    total_score_v2 = {}
    for sd in [42, 43, 44, 45, 46]:
        pred_data = result[dt][sd]
        
        task_score_v1 = {}
        task_score_v2 = {}
        task_num = [c.split('_')[1] for c in pred_data.keys() if 'CNN' in c]
        for tn in task_num:
            fp_maccs = pred_data[f'FP-MACCS_{tn}']
            fp_morgan = pred_data[f'FP-Morgan_{tn}']
            cnn = pred_data[f'CNN_{tn}']
            chemberta = pred_data[f'ChemBERTa_{tn}']
            graph_2d = pred_data[f'2D-GNN_{tn}']
            graph_3d = pred_data[f'3D-GNN_{tn}']
            
            union = set.union(*[fp_maccs, fp_morgan, cnn, chemberta, graph_2d, graph_3d])
            success_smi.append(union)
            fail.append(test_len[dt] - len(union))
            
            sts1 = [('MACCS', fp_maccs), ('1D-CNN', cnn), ('ChemBERTa', chemberta), ('2D-GNN', graph_2d), ('3D-GNN', graph_3d)]
            labels_v1 = venn.get_labels([sts1[0][1], sts1[1][1], sts1[2][1], sts1[3][1], sts1[4][1]], fill=['number'])
            
            sts2 = [('Morgan', fp_morgan), ('1D-CNN', cnn), ('ChemBERTa', chemberta), ('2D-GNN', graph_2d), ('3D-GNN', graph_3d)]
            labels_v2 = venn.get_labels([sts2[0][1], sts2[1][1], sts2[2][1], sts2[3][1], sts2[4][1]], fill=['number'])
            
            for k, v in labels_v1.items():
                task_score_v1[k] = task_score_v1.get(k, 0) + int(v)
                
            for k, v in labels_v2.items():
                task_score_v2[k] = task_score_v2.get(k, 0) + int(v)
        
        for k, v in task_score_v1.items():
            total_score_v1[k] = total_score_v1.get(k, 0) + v

        for k, v in task_score_v2.items():
            total_score_v2[k] = total_score_v2.get(k, 0) + v
    
    total_score_v1 = {k: v / 5 for k, v in total_score_v1.items()}  
    v1_sum = sum(total_score_v1.values())
    total_score_v1 = {k: float(f"{(v / v1_sum)*100:.2f}") for k, v in total_score_v1.items()}
    
    total_score_v2 = {k: v / 5 for k, v in total_score_v2.items()}
    v2_sum = sum(total_score_v2.values())
    total_score_v2 = {k: float(f"{(v / v2_sum)*100:.2f}") for k, v in total_score_v2.items()}
    
    print(total_score_v1)    
    print(total_score_v2)
    fail = int(f"{sum(fail) / len(fail):.0f}")
    print('Failed:', fail)
    
    success_smi = set.union(*success_smi)
    draw_venn_v3(total_score_v1, ['MACCS', '1D-CNN', 'ChemBERTa', '2D-GNN', '3D-GNN'], f"{dt.upper()}", f'Venn/classification/{dt.upper()}_MACCS.png', test_samples, fail, len(task_num))
    draw_venn_v3(total_score_v2, ['Morgan', '1D-CNN', 'ChemBERTa', '2D-GNN', '3D-GNN'], f"{dt.upper()}", f'Venn/classification/{dt.upper()}_Morgan.png', test_samples, fail, len(task_num))
    # break

print('Done')

bace
{'00001': 33, '00010': 22, '00011': 12, '00100': 7, '00101': 3, '00110': 3, '00111': 5, '01000': 20, '01001': 18, '01010': 9, '01011': 11, '01100': 3, '01101': 4, '01110': 3, '01111': 10, '10000': 20, '10001': 17, '10010': 9, '10011': 6, '10100': 2, '10101': 5, '10110': 6, '10111': 22, '11000': 16, '11001': 15, '11010': 9, '11011': 25, '11100': 8, '11101': 23, '11110': 27, '11111': 296}
--------------------------------
{'00001': 6.6, '00010': 4.4, '00011': 2.4, '00100': 1.4, '00101': 0.6, '00110': 0.6, '00111': 1.0, '01000': 4.0, '01001': 3.6, '01010': 1.8, '01011': 2.2, '01100': 0.6, '01101': 0.8, '01110': 0.6, '01111': 2.0, '10000': 4.0, '10001': 3.4, '10010': 1.8, '10011': 1.2, '10100': 0.4, '10101': 1.0, '10110': 1.2, '10111': 4.4, '11000': 3.2, '11001': 3.0, '11010': 1.8, '11011': 5.0, '11100': 1.6, '11101': 4.6, '11110': 5.4, '11111': 59.2}
133.8
{'00001': 4.93, '00010': 3.29, '00011': 1.79, '00100': 1.05, '00101': 0.45, '00110': 0.45, '00111': 0.75, '01000': 2.99, '01001': 

In [44]:
test_len['bace'] - len(success_smi)

8

In [42]:
fail

16.6

In [33]:
sum(total_score_v1.values()), sum(total_score_v2.values())

(97, 102)

# Regression (corr > heatmap)

In [6]:
dts = ['esol', 'freesolv', 'lipophilicity']
reg_result = {dt: {42: [], 43: [], 44: [], 45: [], 46: []} for dt in dts}
exps = [fd for fd in Path('experiments').glob("*") if fd.stem != '20250710']
for fd in exps:
    feature = fd.stem
    print(feature)
    
    files = [f for f in fd.glob('*.csv') if 'result' not in f.stem]
    for f in files:
        fs = f.stem.split('_')
        dt, seed = fs[0], int(fs[3])
        
        if dt.lower() not in dts:
            continue
        
        if feature == '2D-GNN-tuto':
            feature = '2D-GNN'
        elif feature == 'FP-Morgan':
            feature = 'Morgan'
        elif feature == 'FP-MACCS':
            feature = 'MACCS'
        elif feature == 'CNN':
            feature = '1D-CNN'

        data = pd.read_csv(f).rename(columns={'task_0_yt': 'y', 'task_0_yp': feature})
        reg_result[dt.lower()][seed].append(data)
        
        if len(reg_result[dt.lower()][seed]) == 6:
            from functools import reduce
            reg_result[dt.lower()][seed] = reduce(lambda left, right: pd.merge(left, right, on=['smiles', 'y']), reg_result[dt.lower()][seed])


2D-GNN-tuto
3D-GNN
ChemBERTa
CNN
FP-MACCS
FP-Morgan


In [7]:
reg_avg = {}
for dt in ['esol', 'freesolv', 'lipophilicity']:
    dfs = [reg_result[dt][sd] for sd in [42, 43, 44, 45, 46]]
    mean_values = sum(df[dfs[0].columns[2:]] for df in dfs) / len(dfs)
    reg_avg[dt] = pd.concat([dfs[0][['smiles', 'y']], mean_values], axis=1).reset_index(drop=True)

reg_avg['esol']

Unnamed: 0,smiles,y,2D-GNN,3D-GNN,ChemBERTa,1D-CNN,MACCS,Morgan
0,c1cc2ccc3cccc4ccc(c1)c2c34,-6.176,-7.446912,-6.166356,-5.951371,-6.559772,-5.880040,-3.758644
1,Cc1cc(=O)[nH]c(=S)[nH]1,-2.436,-1.049061,-1.472271,-1.555253,-2.005546,-2.544720,-2.474862
2,Oc1ccc(cc1)C2(OC(=O)c3ccccc23)c4ccc(O)cc4,-2.900,-4.167159,-4.081903,-4.540090,-4.024740,-4.559403,-3.439948
3,c1ccc2c(c1)cc3ccc4cccc5ccc2c3c45,-8.699,-8.027051,-7.756307,-6.662787,-7.640496,-5.880040,-7.077580
4,C1=Cc2cccc3cccc1c23,-3.960,-6.744610,-4.339922,-3.898568,-5.621331,-7.344441,-3.118640
...,...,...,...,...,...,...,...,...
108,ClC4=C(Cl)C5(Cl)C3C1CC(C2OC12)C3C4(Cl)C5(Cl)Cl,-6.290,-6.515555,-5.874908,-5.522565,-4.995311,-6.537668,-4.450267
109,c1ccsc1,-1.330,-0.521195,-0.873121,-1.007786,-1.002033,-1.320460,-2.255828
110,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43,-7.870,-7.306877,-8.250020,-7.153696,-6.912476,-5.880040,-8.272818
111,Cc1occc1C(=O)Nc2ccccc2,-3.300,-1.902024,-2.972885,-4.160515,-3.373551,-3.337472,-3.476054


In [12]:
import pandas as pd
import seaborn as sns
from pathlib import Path
import matplotlib.pyplot as plt

for dt in ['esol', 'freesolv', 'lipophilicity']:
    out = Path(f'Venn/regression')
    out.mkdir(parents=True, exist_ok=True)
    
    df = reg_avg[dt][['y', 'MACCS', 'Morgan', '1D-CNN', 'ChemBERTa', '2D-GNN', '3D-GNN']]
    corr_df = df[df.columns].corr()
    
    if dt == 'freesolv':
        vmin, vmax = 0, 1
    else:
        # vmin, vmax = 0.5, 1
        vmin, vmax = 0, 1

    plt.figure(figsize=(8, 6))
    sns.heatmap(corr_df, annot=True, cmap='coolwarm', vmin=vmin, vmax=vmax)
    plt.title(dt.upper())
    plt.tight_layout()
    plt.savefig(str(out / f'{dt.upper()}.png'), dpi=300)
    plt.close()

print('Done')

Done
