In [2]:
import numpy as np
import snorkel
from snorkel.labeling import labeling_function
from snorkel.labeling import LabelingFunction
from snorkel.labeling import PandasLFApplier
from snorkel.labeling import LFAnalysis
import re
from snorkel.labeling.model import MajorityLabelVoter
import json
import pandas as pd
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [3]:
%load_ext autoreload
%autoreload 2
import label_improve as li

In [4]:
# Loading the data 
dataset_name = "chemprot"
idx_to_label = json.load(open(f"../weak_datasets/{dataset_name}/label.json"))
label_to_idx = {l:i for i,l in idx_to_label.items()}
valid_df = li.chemprot_to_df(json.load(open(f"../weak_datasets/{dataset_name}/valid.json", "r")))
train_df = li.chemprot_to_df(json.load(open(f"../weak_datasets/{dataset_name}/train.json", "r")))
test_df = li.chemprot_to_df(json.load(open(f"../weak_datasets/{dataset_name}/test.json", "r")))

# Sample a dev set to help seed ideas for LFs
dev_df = train_df.sample(250, random_state=123)

In [5]:
dev_df

Unnamed: 0,text,label,entity1,entity2,span1,span2,weak_labels
10408,"Unlike OFQ II(1-17), high concentrations of it...",1,mu (mu) opioid receptor,naloxone,"[124, 147]","[196, 204]","[-1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1..."
12234,Photolytic release of free alanine results in ...,1,ASCT2,alanine,"[136, 141]","[27, 34]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -..."
14360,Our results showed that bone remodeling was si...,5,CCR1,Met,"[166, 170]","[128, 131]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -..."
10387,Experimental evidence from the use of agents w...,1,BuChE,MF-8622,"[75, 80]","[103, 110]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -..."
2750,"PURPOSE: Dasatinib (BMS-354825), a potent oral...",3,ABL,BMS-354825,"[99, 102]","[20, 30]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -..."
...,...,...,...,...,...,...,...
1763,This study demonstrates enhanced cardiostimula...,1,beta(1)-adrenoceptors,CGP 12177A,"[141, 162]","[54, 64]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -..."
12127,Thymidylate synthase and thymidine kinase are ...,8,thymidine kinase,pyrimidine nucleotide,"[25, 41]","[107, 128]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -..."
3952,CONCLUSIONS A diet that partially replaces car...,2,insulin,carbohydrate,"[89, 96]","[43, 55]","[-1, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1..."
1186,Bovine tracheal smooth muscle strips were incu...,4,muscarinic receptor,4-(m-chlorophenyl-carbamoyloxy)-2-butynyltrime...,"[299, 318]","[327, 385]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -..."


In [8]:
# chemprot functions:
ABSTAIN = -1
### Keyword based labeling functions ###
## Part of
#0
@labeling_function()
def lf_amino_acid(x):
    if "which is present in" in x.text.lower():
        return 0
    return 0 if 'amino acid' in x.text.lower() else ABSTAIN
#1
@labeling_function()
def lf_replace(x):
    return 0 if 'replace' in x.text.lower() else ABSTAIN

#2 TODO: 0.1988
@labeling_function()
def lf_mutant(x):
    def find_word_index(words, target):
        for i, word in enumerate(words):
            if target in word:
                return i
        return -1

    words = x.text.lower().split()
    if any(' mutant' in word or ' mutat' in word for word in words):
        # if mutant is between the two entities
        if x.entity1_index == -1 or x.entity2_index == -1:
            return ABSTAIN
        if isinstance(x.entity1_index, int) and isinstance(x.entity2_index, int):
            if x.entity1_index < x.entity2_index:
                if any(' mutant' in word or ' mutat' in word for word in words[x.entity1_index:x.entity2_index]):
                    return 0
            else:
                if any(' mutant' in word or ' mutat' in word for word in words[x.entity2_index:x.entity1_index]):
                    return 0
        # if mutant is close to either of the entities
        mutant_index = find_word_index(words, ' mutant')
        mutat_index = find_word_index(words, ' mutat')
        if (mutant_index != -1 and (abs(x.entity1_index - mutant_index) < 4 or abs(x.entity2_index - mutant_index) < 4)) or \
           (mutat_index != -1 and (abs(x.entity1_index - mutat_index) < 4 or abs(x.entity2_index - mutat_index) < 4)):
            return 0
    return ABSTAIN

#3
## Regulator
@labeling_function()
def lf_bind(x):
    return 1 if 'bind' in x.text.lower() else ABSTAIN
#4
@labeling_function()
def lf_interact(x):
    return 1 if 'interact' in x.text.lower() else ABSTAIN
#5
@labeling_function()
def lf_affinity(x):
    return 1 if 'affinit' in x.text.lower() else ABSTAIN
#6 TODO: 0.3578
## Upregulator
# Activator
@labeling_function()
def lf_activate(x):
    return 2 if 'activat' in x.text.lower() else ABSTAIN
#7
@labeling_function()
def lf_increase(x):
    return 2 if 'increas' in x.text.lower() else ABSTAIN
#8 TODO: 
@labeling_function()
def lf_induce(x):
    return 2 if 'induc' in x.text.lower() else ABSTAIN
#9 TODO: 
@labeling_function()
def lf_stimulate(x):
    # if stimulate is between the two entities
    if x.entity1_index == -1 or x.entity2_index == -1:
        return ABSTAIN
    if isinstance(x.entity1_index, int) and isinstance(x.entity2_index, int):
        if x.entity1_index < x.entity2_index:
            if x.text[x.entity1_index:x.entity2_index].count('stimulat') > 0:
                return 2
        else:
            if x.text[x.entity2_index:x.entity1_index].count('stimulat') > 0:
                return 2
    return ABSTAIN
#10
@labeling_function()
def lf_upregulate(x):
    if ('upregulat' in x.text.lower() or 'up-regulat' in x.text.lower()) and ('downregulat' in x.text.lower() or 'down-regulat' in x.text.lower()):
        entity1_index = x.text.lower().index(x.entity1.lower())
        entity2_index = x.text.lower().index(x.entity2.lower())
        # if up regulate is between the two entities
        if isinstance(entity1_index, int) and isinstance(entity2_index, int):
            if entity1_index < entity2_index:
                if x.text[entity1_index:entity2_index].count('upregulat') > 0 or x.text[entity1_index:entity2_index].count('up-regulat') > 0:
                    return 2
            else:
                if x.text[entity2_index:entity1_index].count('upregulat') > 0 or x.text[entity2_index:entity1_index].count('up-regulat') > 0:
                    return 2
        return ABSTAIN
    else:
        return 2 if 'upregulat' in x.text.lower() or 'up-regulat' in x.text.lower() else ABSTAIN
#11
## Downregulator
@labeling_function()
def lf_downregulate(x):
    if('downregulat' in x.text.lower() or 'down-regulat' in x.text.lower()) and ('upregulat' in x.text.lower() or 'up-regulat' in x.text.lower()):
        if x.entity1_index == -1 or x.entity2_index == -1:
            return ABSTAIN
        # if up regulate is between the two entities
        if isinstance(x.entity1_index, int) and isinstance(x.entity2_index, int):
            if x.entity1_index < x.entity2_index:
                if x.text[x.entity1_index:x.entity2_index].count('downregulat') > 0 or x.text[x.entity1_index:x.entity2_index].count('down-regulat') > 0:
                    return 3
            else:
                if x.text[x.entity2_index:x.entity1_index].count('downregulat') > 0 or x.text[x.entity2_index:x.entity1_index].count('down-regulat') > 0:
                    return 3
        return ABSTAIN
    return 3 if 'downregulat' in x.text.lower() or 'down-regulat' in x.text.lower() else ABSTAIN
#12
@labeling_function()
def lf_reduce(x):
    return 3 if ' reduc' in x.text.lower() else ABSTAIN
#13
@labeling_function()
def lf_inhibit(x):
    return 3 if 'inhibit' in x.text.lower() else ABSTAIN
#14
@labeling_function()
def lf_decrease(x):
    return 3 if 'decreas' in x.text.lower() else ABSTAIN
#15
## Agonist
@labeling_function()
def lf_agonist(x):
    return 4 if ' agoni' in x.text.lower() or "\tagoni" in x.text.lower() else ABSTAIN

#16
## Antagonist
@labeling_function()
def lf_antagonist(x):
    return 5 if ' antagon' in x.text.lower() else ABSTAIN

#17
## Modulator
# TODO: Delete this LF, or change this to modulator ??
@labeling_function()
def lf_modulate(x):
    return 6 if 'modulat' in x.text.lower() else ABSTAIN

#18
@labeling_function()
def lf_allosteric(x):
    return 6 if 'allosteric' in x.text.lower() else ABSTAIN
#19
## Cofactor
@labeling_function()
def lf_cofactor(x):
    return 7 if 'cofactor' in x.text.lower() else ABSTAIN
#20
## Substrate/Product
@labeling_function()
def lf_substrate(x):
    return 8 if 'substrate' in x.text.lower() else ABSTAIN
#21
@labeling_function()
def lf_transport(x):
    return 8 if 'transport' in x.text.lower() else ABSTAIN
#22
@labeling_function()
def lf_catalyze(x):
    if " enzyme" in x.text.lower() or "metabolized" in x.text.lower():
        return 8
    return 8 if 'catalyz' in x.text.lower() or 'catalys' in x.text.lower() else ABSTAIN
#23
@labeling_function()
def lf_product(x):
    return 8 if " produc" in x.text.lower() else ABSTAIN
#24
@labeling_function()
def lf_convert(x):
    return 8 if "conver" in x.text.lower() else ABSTAIN
#25
## NOT
@labeling_function()
def lf_not(x):
    entity1_index = x.text.lower().index(x.entity1.lower())
    entity2_index = x.text.lower().index(x.entity2.lower())
    # if the two entities are close to the word 'not'
    
    if ' not ' in x.text.lower():
        if abs(entity1_index - x.text.lower().index('not')) < 20 or abs(entity2_index - x.text.lower().index('not')) < 20:
            return 9
        # if not is between the two entities
        if abs(entity1_index - x.text.lower().index('not')) < 40 or abs(entity2_index - x.text.lower().index('not')) < 40:
            if entity1_index < entity2_index:
                if x.text[entity1_index:entity2_index].count('not') > 0:
                    return 9
            else:
                if x.text[entity2_index:entity1_index].count('not') > 0:
                    return 9
    return ABSTAIN

# 26 replace the 17 (18)
@labeling_function()
def lf_combined_modulator(x):
    sentence_lower = x.text.lower()

    specific_terms = ['allosteric modulator', 'positive modulator', 'negative modulator', 'non-competitive modulator', 'positive allosteric modulator']
    if any(term in sentence_lower for term in specific_terms):
        return 6

    modulating_terms = ['modulat', 'allosteric', 'potentiate']
    for term in modulating_terms:
        if term in sentence_lower:
            term_index = sentence_lower.index(term)
            if x.entity1_index == -1 or x.entity2_index == -1:
                return ABSTAIN
            if abs(term_index - x.entity1_index) < 20 or abs(term_index - x.entity2_index) < 20:
                return 6
    # check the first indcidences of modulator and positive
    if x.entity1 in sentence_lower and x.entity2 in sentence_lower:
        entity1_index = sentence_lower.index(x.entity1.lower())
        entity2_index = sentence_lower.index(x.entity2.lower())
        between_entities = sentence_lower[min(entity1_index, entity2_index):max(entity1_index, entity2_index)]
        if 'modulate' in between_entities:
            return 6

    if 'positive' in sentence_lower and 'modulator' in sentence_lower:
        pos_indices = [i for i, word in enumerate(sentence_lower.split()) if word == 'positive']
        mod_indices = [i for i, word in enumerate(sentence_lower.split()) if 'modulator' in word]
        if len(pos_indices) == 0 or len(mod_indices) == 0:
            return ABSTAIN
        min_distance = min(abs(p - m) for p in pos_indices for m in mod_indices)
        if min_distance <= 3:
            return 6

    return ABSTAIN

lfs = [lf_amino_acid, lf_replace, lf_mutant, lf_bind, lf_interact, lf_affinity, lf_activate, lf_increase, lf_stimulate, lf_upregulate, lf_downregulate, lf_reduce, lf_inhibit, lf_decrease, lf_agonist, lf_antagonist, lf_combined_modulator, lf_allosteric, lf_cofactor, lf_substrate, lf_transport, lf_catalyze, lf_product, lf_convert, lf_not]

In [9]:
dev_df = train_df.sample(250, random_state=123)

In [10]:
# show the first row of the dataframe
dev_df = dev_df
dev_df = li.chemprot_enhanced(dev_df)
L_dev = li.apply_LFs(lfs, dev_df)

  0%|          | 0/1607 [00:00<?, ?it/s]

100%|██████████| 1607/1607 [00:00<00:00, 6300.93it/s]


In [12]:
print("Test Coverage:", li.calc_coverage(L_dev))
lf_analysis = LFAnalysis(L_dev, lfs = lfs).lf_summary(Y = dev_df.label.values)

# Calculates how many of an LFs votes result in conflicts (helpful signal for debugging LFs)
lf_analysis['Conflict Ratio'] = lf_analysis['Conflicts'] / lf_analysis['Coverage']
lf_analysis

Test Coverage: 0.8214063472308649


Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.,Conflict Ratio
lf_amino_acid,0,[0],0.017424,0.012446,0.012446,13,15,0.464286,0.714286
lf_replace,1,[0],0.003111,0.0,0.0,3,2,0.6,0.0
lf_mutant,2,[],0.0,0.0,0.0,0,0,0.0,
lf_bind,3,[1],0.102676,0.071562,0.062228,101,64,0.612121,0.606061
lf_interact,4,[1],0.029247,0.020535,0.018046,20,27,0.425532,0.617021
lf_affinity,5,[1],0.049782,0.031114,0.024269,61,19,0.7625,0.4875
lf_activate,6,[2],0.118233,0.080274,0.071562,75,115,0.394737,0.605263
lf_increase,7,[2],0.097075,0.07094,0.060983,64,92,0.410256,0.628205
lf_stimulate,8,[],0.0,0.0,0.0,0,0,0.0,
lf_upregulate,9,[2],0.014312,0.00809,0.0056,13,10,0.565217,0.391304


In [13]:
# Calculate accuracy on the validation set (Ideally do this only at the end)
majority_model = MajorityLabelVoter(10)
preds_valid = majority_model.predict(L=L_dev)

print("acuracy for the not abstains")
print((preds_valid[preds_valid != -1] == dev_df[preds_valid != -1].label.values).mean())
print("acuracy for all")
print((preds_valid == dev_df.label.values).mean())

acuracy for the not abstains
0.7014270032930845
acuracy for all
0.39763534536403233
Number of absent predictions 696


100%|██████████| 187/187 [00:00<00:00, 6387.92it/s]


In [14]:
new_train = li.chemprot_df_with_new_lf(train_df, lfs)
chemprot = li.df_to_chemprot(new_train)
li.save_dataset(chemprot, "../weak_datasets/chemprot2/train.json")
new_test = li.chemprot_df_with_new_lf(test_df, lfs)
chemprot = li.df_to_chemprot(new_test)
li.save_dataset(chemprot, "../weak_datasets/chemprot2/test.json")
new_valid = li.chemprot_df_with_new_lf(valid_df, lfs)
chemprot = li.df_to_chemprot(new_valid)
li.save_dataset(chemprot, "../weak_datasets/chemprot2/valid.json")

100%|██████████| 12861/12861 [00:02<00:00, 6238.01it/s]
100%|██████████| 1607/1607 [00:00<00:00, 6347.90it/s]
100%|██████████| 1607/1607 [00:00<00:00, 6379.41it/s]


# Original label functions

In [206]:
# chemprot functions:

ABSTAIN = -1
### Keyword based labeling functions ###

## Part of
#0
@labeling_function()
def lf_amino_acid(x):
    return 0 if 'amino acid' in x.text.lower() else ABSTAIN
#1
@labeling_function()
def lf_replace(x):
    return 0 if 'replace' in x.text.lower() else ABSTAIN
#2
@labeling_function()
def lf_mutant(x):
    return 0 if 'mutant' in x.text.lower() or 'mutat' in x.text.lower() else ABSTAIN
#3
## Regulator
@labeling_function()
def lf_bind(x):
    return 1 if 'bind' in x.text.lower() else ABSTAIN
#4
@labeling_function()
def lf_interact(x):
    return 1 if 'interact' in x.text.lower() else ABSTAIN
#5
@labeling_function()
def lf_affinity(x):
    return 1 if 'affinit' in x.text.lower() else ABSTAIN
#6
## Upregulator
# Activator
@labeling_function()
def lf_activate(x):
    return 2 if 'activat' in x.text.lower() else ABSTAIN
#7
@labeling_function()
def lf_increase(x):
    return 2 if 'increas' in x.text.lower() else ABSTAIN
#8 
@labeling_function()
def lf_induce(x):
    return 2 if 'induc' in x.text.lower() else ABSTAIN
#9
@labeling_function()
def lf_stimulate(x):
    return 2 if 'stimulat' in x.text.lower() else ABSTAIN
#10
@labeling_function()
def lf_upregulate(x):
    return 2 if 'upregulat' in x.text.lower() else ABSTAIN
#11
## Downregulator
@labeling_function()
def lf_downregulate(x):
    return 3 if 'downregulat' in x.text.lower() or 'down-regulat' in x.text.lower() else ABSTAIN
#12
@labeling_function()
def lf_reduce(x):
    return 3 if 'reduc' in x.text.lower() else ABSTAIN
#13
@labeling_function()
def lf_inhibit(x):
    return 3 if 'inhibit' in x.text.lower() else ABSTAIN
#14
@labeling_function()
def lf_decrease(x):
    return 3 if 'decreas' in x.text.lower() else ABSTAIN
    
    
#15
## Agonist
@labeling_function()
def lf_agonist(x):
    return 4 if ' agoni' in x.text.lower() or "\tagoni" in x.text.lower() else ABSTAIN

#16
## Antagonist
@labeling_function()
def lf_antagonist(x):
    return 5 if 'antagon' in x.text.lower() else ABSTAIN

#17
## Modulator
# TODO: Delete this LF, or change this to modulator ??
@labeling_function()
def lf_modulate(x):
    return 6 if 'modulat' in x.text.lower() else ABSTAIN

#18
@labeling_function()
def lf_allosteric(x):
    return 6 if 'allosteric' in x.text.lower() else ABSTAIN
#19
## Cofactor
@labeling_function()
def lf_cofactor(x):
    return 7 if 'cofactor' in x.text.lower() else ABSTAIN
#20
## Substrate/Product
@labeling_function()
def lf_substrate(x):
    return 8 if 'substrate' in x.text.lower() else ABSTAIN
#21
@labeling_function()
def lf_transport(x):
    return 8 if 'transport' in x.text.lower() else ABSTAIN
#22
@labeling_function()
def lf_catalyze(x):
    return 8 if 'catalyz' in x.text.lower() or 'catalys' in x.text.lower() else ABSTAIN
#23
@labeling_function()
def lf_product(x):
    return 8 if "produc" in x.text.lower() else ABSTAIN
#24
@labeling_function()
def lf_convert(x):
    return 8 if "conver" in x.text.lower() else ABSTAIN
#25
## NOT
@labeling_function()
def lf_not(x):
    return 9 if 'not' in x.text.lower() else ABSTAIN

In [207]:
lfs = [lf_amino_acid, lf_replace, lf_mutant, lf_bind, lf_interact, lf_affinity, lf_activate, lf_increase, lf_induce, lf_stimulate, lf_upregulate, lf_downregulate, lf_reduce, lf_inhibit, lf_decrease, lf_agonist, lf_antagonist, lf_modulate, lf_allosteric, lf_cofactor, lf_substrate, lf_transport, lf_catalyze, lf_product, lf_convert, lf_not]

# show the first row of the dataframe
train_dev = dev_df
train_dev = li.chemprot_enhanced(train_dev)
L_dev2 = li.apply_LFs(lfs, train_dev)
L_dev2

  0%|          | 0/1607 [00:00<?, ?it/s]

100%|██████████| 1607/1607 [00:00<00:00, 8560.79it/s]


array([[-1, -1, -1, ..., -1, -1, -1],
       [-1, -1, -1, ...,  8, -1, -1],
       [-1, -1, -1, ..., -1, -1, -1],
       ...,
       [-1, -1, -1, ..., -1, -1, -1],
       [-1, -1, -1, ..., -1, -1, -1],
       [-1, -1, -1, ..., -1, -1, -1]])

In [209]:
# Calculate accuracy on the validation set (Ideally do this only at the end)
majority_model = MajorityLabelVoter(10)
preds_valid = majority_model.predict(L=L_dev2)
print("Test Coverage:", li.calc_coverage(L_dev2))
lf_analysis = LFAnalysis(L_dev2, lfs = lfs).lf_summary(Y = train_dev.label.values)
# Calculates how many of an LFs votes result in conflicts (helpful signal for debugging LFs)
lf_analysis['Conflict Ratio'] = lf_analysis['Conflicts'] / lf_analysis['Coverage']
lf_analysis

Test Coverage: 0.8637212196639701


Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.,Conflict Ratio
lf_amino_acid,0,[0],0.016801,0.01369,0.012446,12,15,0.444444,0.740741
lf_replace,1,[0],0.003111,0.000622,0.000622,3,2,0.6,0.2
lf_mutant,2,[0],0.032981,0.023024,0.02178,7,46,0.132075,0.660377
lf_bind,3,[1],0.102676,0.080896,0.073429,101,64,0.612121,0.715152
lf_interact,4,[1],0.029247,0.023024,0.02178,20,27,0.425532,0.744681
lf_affinity,5,[1],0.049782,0.033603,0.02738,61,19,0.7625,0.55
lf_activate,6,[2],0.118233,0.091475,0.075296,75,115,0.394737,0.636842
lf_increase,7,[2],0.097075,0.07654,0.065961,64,92,0.410256,0.679487
lf_induce,8,[2],0.141879,0.112632,0.099564,56,172,0.245614,0.701754
lf_stimulate,9,[2],0.04107,0.034225,0.029247,16,50,0.242424,0.712121


In [210]:
print("acurracy: for the not abstain")
print((preds_valid[preds_valid != -1] == train_dev[preds_valid != -1].label.values).mean())
print("acurracy: for all the data")
print((preds_valid == train_dev.label.values).mean())
incorrect_indices = np.where(preds_valid[preds_valid != -1] != train_dev[preds_valid != -1].label.values)[0]
incorrect_predictions_df = train_dev.iloc[incorrect_indices]

acurracy: for the not abstain
0.6515679442508711
acurracy: for all the data
0.3490976975731176
