# Does resampling experiment help with predicting DaG sentences?

In [1]:
from pathlib import Path
import warnings

import numpy as np
import pandas as pd
import plydata as ply
from sqlalchemy import create_engine

from snorkel.labeling.analysis import LFAnalysis
from snorkeling_helper.generative_model_helper import (
    sample_lfs,
    run_generative_label_function_sampler,
)

warnings.filterwarnings("ignore")

In [2]:
username = "danich1"
password = "snorkel"
dbname = "pubmed_central_db"
database_str = (
    f"postgresql+psycopg2://{username}:{password}@/{dbname}?host=/var/run/postgresql"
)
conn = create_engine(database_str)

## Load the data

In [3]:
label_candidates_dir = Path("../label_candidates/output")
notebook_output_dir = Path("../generative_model_training/output/DaG")

In [4]:
L_abstracts = pd.read_csv(
    str(label_candidates_dir / Path("dg_abstract_train_candidates_resampling.tsv")),
    sep="\t",
)

print(L_abstracts.shape)
L_abstracts.head().T

(1539670, 105)


Unnamed: 0,0,1,2,3,4
LF_HETNET_DISEASES,-1,-1,-1,-1,-1
LF_HETNET_DOAF,-1,-1,-1,-1,-1
LF_HETNET_DisGeNET,-1,-1,-1,1,-1
LF_HETNET_GWAS,-1,-1,-1,-1,-1
LF_HETNET_DaG_ABSENT,0,0,0,-1,0
...,...,...,...,...,...
LF_GG_BICLUSTER_INCREASES_EXPRESSION,-1,-1,-1,-1,-1
LF_GG_BICLUSTER_SIGNALING,-1,-1,1,-1,-1
LF_GG_BICLUSTER_IDENTICAL_PROTEIN,-1,-1,-1,-1,-1
LF_GG_BICLUSTER_CELL_PRODUCTION,-1,-1,-1,-1,-1


In [5]:
L_dev = pd.read_csv(
    str(label_candidates_dir / Path("dg_dev_test_candidates_resampling.tsv")), sep="\t"
) >> ply.query("split==1")
print(L_dev.shape)
L_dev.head().T

(975, 107)


Unnamed: 0,0,8,25,33,50
LF_HETNET_DISEASES,-1.0,1.0,1.0,1.0,1.0
LF_HETNET_DOAF,-1.0,1.0,-1.0,-1.0,-1.0
LF_HETNET_DisGeNET,1.0,1.0,-1.0,-1.0,-1.0
LF_HETNET_GWAS,-1.0,-1.0,-1.0,-1.0,-1.0
LF_HETNET_DaG_ABSENT,-1.0,-1.0,-1.0,-1.0,-1.0
...,...,...,...,...,...
LF_GG_BICLUSTER_IDENTICAL_PROTEIN,-1.0,-1.0,-1.0,-1.0,-1.0
LF_GG_BICLUSTER_CELL_PRODUCTION,-1.0,-1.0,-1.0,-1.0,-1.0
split,1.0,1.0,1.0,1.0,1.0
document_id,23520.0,629602.0,993337.0,1434797.0,1350353.0


In [6]:
L_test = pd.read_csv(
    str(label_candidates_dir / Path("dg_dev_test_candidates_resampling.tsv")), sep="\t"
) >> ply.query("split==2")
print(L_test.shape)
L_test.head().T

(1000, 107)


Unnamed: 0,1,2,3,4,5
LF_HETNET_DISEASES,1.0,-1.0,1.0,-1.0,-1.0
LF_HETNET_DOAF,-1.0,-1.0,-1.0,-1.0,-1.0
LF_HETNET_DisGeNET,-1.0,-1.0,-1.0,-1.0,-1.0
LF_HETNET_GWAS,-1.0,-1.0,-1.0,-1.0,-1.0
LF_HETNET_DaG_ABSENT,-1.0,0.0,-1.0,0.0,0.0
...,...,...,...,...,...
LF_GG_BICLUSTER_IDENTICAL_PROTEIN,-1.0,-1.0,0.0,-1.0,-1.0
LF_GG_BICLUSTER_CELL_PRODUCTION,-1.0,-1.0,0.0,0.0,-1.0
split,2.0,2.0,2.0,2.0,2.0
document_id,217570.0,31141.0,266209.0,201631.0,394935.0


## Filter Candidates Based on Document ID

In [7]:
# Grab the document ids for resampling
sql = """
select dg_candidates.sentence_id, document_id, dg_candidates.candidate_id from sentence
inner join (
  select candidate.candidate_id, disease_gene.sentence_id from disease_gene
  inner join candidate on candidate.candidate_id=disease_gene.candidate_id
  ) as dg_candidates
on sentence.sentence_id = dg_candidates.sentence_id
"""
candidate_doc_df = pd.read_sql(sql, database_str)
candidate_doc_df.head()

Unnamed: 0,sentence_id,document_id,candidate_id
0,577814033,8168034,3623
1,592544979,12960042,11812
2,298670465,26635731,33796
3,409575187,23452434,51810
4,588478777,18324346,57316


In [8]:
dev_test_ids = (
    L_dev >> ply.select("document_id") >> ply.distinct() >> ply.pull("document_id")
)

filtered_candidate_id = (
    candidate_doc_df
    >> ply.query(f"document_id in {list(dev_test_ids)}")
    >> ply.pull("candidate_id")
)

In [9]:
sorted_train_df = pd.read_csv(
    str(notebook_output_dir / Path("dag_dataset_mapper.tsv")), sep="\t"
)
sorted_train_df.head()

Unnamed: 0,document_id,dataset
0,8168034,train
1,12960042,train
2,26635731,train
3,23452434,tune
4,18324346,train


In [10]:
trained_documents = (
    sorted_train_df
    >> ply.inner_join(candidate_doc_df, on="document_id")
    >> ply.query("dataset=='train'")
    >> ply.pull("candidate_id")
)

In [11]:
filtered_L_abstracts = L_abstracts >> ply.query(
    f"candidate_id in {list(trained_documents)}"
)
print(filtered_L_abstracts.shape)
filtered_L_abstracts.head()

(1076965, 105)


Unnamed: 0,LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_DisGeNET,LF_HETNET_GWAS,LF_HETNET_DaG_ABSENT,LF_DG_IS_BIOMARKER,LF_DaG_ASSOCIATION,LF_DaG_WEAK_ASSOCIATION,LF_DaG_NO_ASSOCIATION,LF_DaG_CELLULAR_ACTIVITY,...,LF_GG_NO_VERB,LF_GG_BICLUSTER_BINDING,LF_GG_BICLUSTER_ENHANCES,LF_GG_BICLUSTER_ACTIVATES,LF_GG_BICLUSTER_AFFECTS_EXPRESSION,LF_GG_BICLUSTER_INCREASES_EXPRESSION,LF_GG_BICLUSTER_SIGNALING,LF_GG_BICLUSTER_IDENTICAL_PROTEIN,LF_GG_BICLUSTER_CELL_PRODUCTION,candidate_id
0,-1,-1,-1,-1,0,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,121
1,-1,-1,-1,-1,0,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,122
2,-1,-1,-1,-1,0,-1,-1,-1,-1,-1,...,-1,1,-1,-1,-1,-1,1,-1,-1,124
3,-1,-1,1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,148
4,-1,-1,-1,-1,0,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,190


## Construct the Grid Search

In [12]:
# Global Grid
epochs_grid = [500]
l2_param_grid = np.linspace(0.01, 5, num=5)
lr_grid = [1e-2]
grid = list(
    zip(epochs_grid * len(l2_param_grid), l2_param_grid, lr_grid * len(l2_param_grid))
)

# Abstracts

In [13]:
analysis_module = LFAnalysis(
    filtered_L_abstracts >> ply.select("candidate_id", drop=True)
)

abstract_lf_summary = analysis_module.lf_summary()
abstract_lf_summary.index = (
    filtered_L_abstracts >> ply.select("candidate_id", drop=True)
).columns.tolist()
abstract_lf_summary

Unnamed: 0,Polarity,Coverage,Overlaps,Conflicts
LF_HETNET_DISEASES,[1],0.325606,0.325606,0.325606
LF_HETNET_DOAF,[1],0.145826,0.145826,0.145826
LF_HETNET_DisGeNET,[1],0.299600,0.299600,0.299600
LF_HETNET_GWAS,[1],0.036476,0.036476,0.036476
LF_HETNET_DaG_ABSENT,[0],0.574512,0.574512,0.574512
...,...,...,...,...
LF_GG_BICLUSTER_AFFECTS_EXPRESSION,[1],0.020184,0.020184,0.020184
LF_GG_BICLUSTER_INCREASES_EXPRESSION,[0],0.042169,0.042169,0.042169
LF_GG_BICLUSTER_SIGNALING,[1],0.046558,0.046558,0.046558
LF_GG_BICLUSTER_IDENTICAL_PROTEIN,[0],0.023896,0.023896,0.023896


# Set up For Resampling

In [14]:
lf_columns_base = list(L_abstracts.columns[0:5])
candidate_id_field = list(L_abstracts.columns[-1:])
dev_column_base = ["split", "curated_dsh", "document_id"]
data_columns = []

# Abstracts - All Label Functions

In [15]:
dag_start = 5
dag_end = 104

# Spaced out number of sampels including total
size_of_samples = [1, 33, 65, 97, dag_end - dag_start]
number_of_samples = 50
dag_lf_range = range(dag_start, dag_end)

In [16]:
sampled_lfs_dict = {
    sample_size: (
        sample_lfs(
            list(dag_lf_range),
            len(list(dag_lf_range)),
            sample_size,
            number_of_samples,
            random_state=100,
        )
    )
    for sample_size in size_of_samples
}

In [17]:
data_columns += run_generative_label_function_sampler(
    filtered_L_abstracts,
    L_dev,
    L_test,
    sampled_lfs_dict,
    lf_columns_base=lf_columns_base,
    grid_param=grid,
    marginals_df_file="",
    curated_label="curated_dsh",
    entity_label="ALL",
    data_source="abstract",
)

100%|██████████| 50/50 [03:22<00:00,  4.05s/it]
100%|██████████| 50/50 [11:12<00:00, 13.46s/it]
100%|██████████| 50/50 [21:08<00:00, 25.37s/it]
100%|██████████| 50/50 [4:56:01<00:00, 355.23s/it]  
100%|██████████| 50/50 [50:12<00:00, 60.25s/it]


# Write Performance to File

In [18]:
performance_df = pd.DataFrame.from_records(data_columns)
performance_df

Unnamed: 0,lf_num,auroc,aupr,bce_loss,sampled_lf_name,label_source,data_source,model,epochs,l2_param,lr_param
0,1,0.491089,0.401374,2.412427,"LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_Di...",ALL,abstract,tune,500,1.2575,0.01
1,1,0.560246,0.437185,1.846836,"LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_Di...",ALL,abstract,test,500,1.2575,0.01
2,1,0.491389,0.398175,2.412709,"LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_Di...",ALL,abstract,tune,500,1.2575,0.01
3,1,0.562735,0.422589,1.851670,"LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_Di...",ALL,abstract,test,500,1.2575,0.01
4,1,0.506065,0.404286,2.217284,"LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_Di...",ALL,abstract,tune,500,1.2575,0.01
...,...,...,...,...,...,...,...,...,...,...,...
495,99,0.702126,0.533069,1.565527,"LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_Di...",ALL,abstract,test,500,1.2575,0.01
496,99,0.720316,0.584358,1.508298,"LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_Di...",ALL,abstract,tune,500,1.2575,0.01
497,99,0.702126,0.533069,1.565527,"LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_Di...",ALL,abstract,test,500,1.2575,0.01
498,99,0.720316,0.584358,1.508298,"LF_HETNET_DISEASES,LF_HETNET_DOAF,LF_HETNET_Di...",ALL,abstract,tune,500,1.2575,0.01


In [19]:
(
    performance_df
    >> ply.call(
        "to_csv",
        str(Path("output") / Path("ALL_DaG_performance.tsv")),
        index=False,
        sep="\t",
    )
)