In [2]:
import numpy as np
import snorkel
from snorkel.labeling import labeling_function
from snorkel.labeling import LabelingFunction
from snorkel.labeling import PandasLFApplier
from snorkel.labeling import LFAnalysis
import re
from snorkel.labeling.model import MajorityLabelVoter
import json
import pandas as pd
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [3]:
%load_ext autoreload
%autoreload 2
import label_improve as li

In [4]:
# Loading the data 
dataset_name = "chemprot"
idx_to_label = json.load(open(f"../weak_datasets/{dataset_name}/label.json"))
label_to_idx = {l:i for i,l in idx_to_label.items()}
valid_df = li.chemprot_to_df(json.load(open(f"../weak_datasets/{dataset_name}/valid.json", "r")))
train_df = li.chemprot_to_df(json.load(open(f"../weak_datasets/{dataset_name}/train.json", "r")))
test_df = li.chemprot_to_df(json.load(open(f"../weak_datasets/{dataset_name}/test.json", "r")))

# Sample a dev set to help seed ideas for LFs
dev_df = train_df.sample(250, random_state=123)

# Original LFs

In [5]:
# chemprot functions:

ABSTAIN = -1
### Keyword based labeling functions ###

## Part of
#0
@labeling_function()
def lf_amino_acid(x):
    return 0 if 'amino acid' in x.text.lower() else ABSTAIN
#1
@labeling_function()
def lf_replace(x):
    return 0 if 'replace' in x.text.lower() else ABSTAIN
#2
@labeling_function()
def lf_mutant(x):
    return 0 if 'mutant' in x.text.lower() or 'mutat' in x.text.lower() else ABSTAIN
#3
## Regulator
@labeling_function()
def lf_bind(x):
    return 1 if 'bind' in x.text.lower() else ABSTAIN
#4
@labeling_function()
def lf_interact(x):
    return 1 if 'interact' in x.text.lower() else ABSTAIN
#5
@labeling_function()
def lf_affinity(x):
    return 1 if 'affinit' in x.text.lower() else ABSTAIN
#6
## Upregulator
# Activator
@labeling_function()
def lf_activate(x):
    return 2 if 'activat' in x.text.lower() else ABSTAIN
#7
@labeling_function()
def lf_increase(x):
    return 2 if 'increas' in x.text.lower() else ABSTAIN
#8 
@labeling_function()
def lf_induce(x):
    return 2 if 'induc' in x.text.lower() else ABSTAIN
#9
@labeling_function()
def lf_stimulate(x):
    return 2 if 'stimulat' in x.text.lower() else ABSTAIN
#10
@labeling_function()
def lf_upregulate(x):
    return 2 if 'upregulat' in x.text.lower() else ABSTAIN
#11
## Downregulator
@labeling_function()
def lf_downregulate(x):
    return 3 if 'downregulat' in x.text.lower() or 'down-regulat' in x.text.lower() else ABSTAIN
#12
@labeling_function()
def lf_reduce(x):
    return 3 if 'reduc' in x.text.lower() else ABSTAIN
#13
@labeling_function()
def lf_inhibit(x):
    return 3 if 'inhibit' in x.text.lower() else ABSTAIN
#14
@labeling_function()
def lf_decrease(x):
    return 3 if 'decreas' in x.text.lower() else ABSTAIN
    
    
#15
## Agonist
@labeling_function()
def lf_agonist(x):
    return 4 if ' agoni' in x.text.lower() or "\tagoni" in x.text.lower() else ABSTAIN

#16
## Antagonist
@labeling_function()
def lf_antagonist(x):
    return 5 if 'antagon' in x.text.lower() else ABSTAIN

#17
## Modulator
# TODO: Delete this LF, or change this to modulator ??
@labeling_function()
def lf_modulate(x):
    return 6 if 'modulat' in x.text.lower() else ABSTAIN

#18
@labeling_function()
def lf_allosteric(x):
    return 6 if 'allosteric' in x.text.lower() else ABSTAIN
#19
## Cofactor
@labeling_function()
def lf_cofactor(x):
    return 7 if 'cofactor' in x.text.lower() else ABSTAIN
#20
## Substrate/Product
@labeling_function()
def lf_substrate(x):
    return 8 if 'substrate' in x.text.lower() else ABSTAIN
#21
@labeling_function()
def lf_transport(x):
    return 8 if 'transport' in x.text.lower() else ABSTAIN
#22
@labeling_function()
def lf_catalyze(x):
    return 8 if 'catalyz' in x.text.lower() or 'catalys' in x.text.lower() else ABSTAIN
#23
@labeling_function()
def lf_product(x):
    return 8 if "produc" in x.text.lower() else ABSTAIN
#24
@labeling_function()
def lf_convert(x):
    return 8 if "conver" in x.text.lower() else ABSTAIN
#25
## NOT
@labeling_function()
def lf_not(x):
    return 9 if 'not' in x.text.lower() else ABSTAIN

In [6]:
lfs = [lf_amino_acid, lf_replace, lf_mutant, lf_bind, lf_interact, lf_affinity, lf_activate, lf_increase, lf_induce, lf_stimulate, lf_upregulate, lf_downregulate, lf_reduce, lf_inhibit, lf_decrease, lf_agonist, lf_antagonist, lf_modulate, lf_allosteric, lf_cofactor, lf_substrate, lf_transport, lf_catalyze, lf_product, lf_convert, lf_not]

# show the first row of the dataframe
dev_df = li.chemprot_enhanced(dev_df)

In [7]:
li.analysis_LFs(lfs, dev_df,10)

100%|██████████| 250/250 [00:00<00:00, 8422.64it/s]

Test Coverage: 0.888
acuracy for the not abstains
0.6170212765957447
acuracy for all
0.348





Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.,Conflict Ratio
lf_amino_acid,0,[0],0.036,0.028,0.02,4,5,0.444444,0.555556
lf_replace,1,[0],0.004,0.0,0.0,0,1,0.0,0.0
lf_mutant,2,[0],0.016,0.016,0.008,1,3,0.25,0.5
lf_bind,3,[1],0.12,0.096,0.08,20,10,0.666667,0.666667
lf_interact,4,[1],0.028,0.024,0.016,5,2,0.714286,0.571429
lf_affinity,5,[1],0.036,0.024,0.016,5,4,0.555556,0.444444
lf_activate,6,[2],0.096,0.076,0.06,9,15,0.375,0.625
lf_increase,7,[2],0.068,0.052,0.044,5,12,0.294118,0.647059
lf_induce,8,[2],0.228,0.18,0.16,15,42,0.263158,0.701754
lf_stimulate,9,[2],0.048,0.028,0.024,1,11,0.083333,0.5


# Improved LFs

In [8]:
# chemprot functions:
ABSTAIN = -1
### Keyword based labeling functions ###
## Part of
#0
@labeling_function()
def lf_amino_acid(x):
    if "which is present in" in x.text.lower():
        return 0
    return 0 if 'amino acid' in x.text.lower() else ABSTAIN
#1
@labeling_function()
def lf_replace(x):
    return 0 if 'replace' in x.text.lower() else ABSTAIN

#2
@labeling_function()
def lf_mutant(x):
    def find_word_index(words, target):
        for i, word in enumerate(words):
            if target in word:
                return i
        return -1

    words = x.text.lower().split()
    if any(' mutant' in word or ' mutat' in word for word in words):
        # if mutant is between the two entities
        if x.entity1_index == -1 or x.entity2_index == -1:
            return ABSTAIN
        if isinstance(x.entity1_index, int) and isinstance(x.entity2_index, int):
            if x.entity1_index < x.entity2_index:
                if any(' mutant' in word or ' mutat' in word for word in words[x.entity1_index:x.entity2_index]):
                    return 0
            else:
                if any(' mutant' in word or ' mutat' in word for word in words[x.entity2_index:x.entity1_index]):
                    return 0
        # if mutant is close to either of the entities
        mutant_index = find_word_index(words, ' mutant')
        mutat_index = find_word_index(words, ' mutat')
        if (mutant_index != -1 and (abs(x.entity1_index - mutant_index) < 4 or abs(x.entity2_index - mutant_index) < 4)) or \
           (mutat_index != -1 and (abs(x.entity1_index - mutat_index) < 4 or abs(x.entity2_index - mutat_index) < 4)):
            return 0
    return ABSTAIN

#3
## Regulator
@labeling_function()
def lf_bind(x):
    return 1 if 'bind' in x.text.lower() else ABSTAIN
#4
@labeling_function()
def lf_interact(x):
    return 1 if 'interact' in x.text.lower() else ABSTAIN
#5
@labeling_function()
def lf_affinity(x):
    return 1 if 'affinit' in x.text.lower() else ABSTAIN

#6
## Upregulator
# Activator
@labeling_function()
def lf_activate(x):
    return 2 if 'activat' in x.text.lower() else ABSTAIN

#7
@labeling_function()
def lf_increase(x):
    return 2 if 'increas' in x.text.lower() else ABSTAIN

#8 
@labeling_function()
def lf_induce(x):
    return 2 if 'induc' in x.text.lower() else ABSTAIN

#9
@labeling_function()
def lf_stimulate(x):
    # if stimulate is between the two entities
    if x.entity1_index == -1 or x.entity2_index == -1:
        return ABSTAIN
    if isinstance(x.entity1_index, int) and isinstance(x.entity2_index, int):
        if x.entity1_index < x.entity2_index:
            if x.text[x.entity1_index:x.entity2_index].count('stimulat') > 0:
                return 2
        else:
            if x.text[x.entity2_index:x.entity1_index].count('stimulat') > 0:
                return 2
    return ABSTAIN
#10
@labeling_function()
def lf_upregulate(x):
    if ('upregulat' in x.text.lower() or 'up-regulat' in x.text.lower()) and ('downregulat' in x.text.lower() or 'down-regulat' in x.text.lower()):
        entity1_index = x.text.lower().index(x.entity1.lower())
        entity2_index = x.text.lower().index(x.entity2.lower())
        # if up regulate is between the two entities
        if isinstance(entity1_index, int) and isinstance(entity2_index, int):
            if entity1_index < entity2_index:
                if x.text[entity1_index:entity2_index].count('upregulat') > 0 or x.text[entity1_index:entity2_index].count('up-regulat') > 0:
                    return 2
            else:
                if x.text[entity2_index:entity1_index].count('upregulat') > 0 or x.text[entity2_index:entity1_index].count('up-regulat') > 0:
                    return 2
        return ABSTAIN
    else:
        return 2 if 'upregulat' in x.text.lower() or 'up-regulat' in x.text.lower() else ABSTAIN
#11
## Downregulator
@labeling_function()
def lf_downregulate(x):
    if('downregulat' in x.text.lower() or 'down-regulat' in x.text.lower()) and ('upregulat' in x.text.lower() or 'up-regulat' in x.text.lower()):
        if x.entity1_index == -1 or x.entity2_index == -1:
            return ABSTAIN
        # if up regulate is between the two entities
        if isinstance(x.entity1_index, int) and isinstance(x.entity2_index, int):
            if x.entity1_index < x.entity2_index:
                if x.text[x.entity1_index:x.entity2_index].count('downregulat') > 0 or x.text[x.entity1_index:x.entity2_index].count('down-regulat') > 0:
                    return 3
            else:
                if x.text[x.entity2_index:x.entity1_index].count('downregulat') > 0 or x.text[x.entity2_index:x.entity1_index].count('down-regulat') > 0:
                    return 3
        return ABSTAIN
    return 3 if 'downregulat' in x.text.lower() or 'down-regulat' in x.text.lower() else ABSTAIN
#12
@labeling_function()
def lf_reduce(x):
    return 3 if ' reduc' in x.text.lower() else ABSTAIN
#13
@labeling_function()
def lf_inhibit(x):
    return 3 if 'inhibit' in x.text.lower() else ABSTAIN
#14
@labeling_function()
def lf_decrease(x):
    return 3 if 'decreas' in x.text.lower() else ABSTAIN
#15
## Agonist
@labeling_function()
def lf_agonist(x):
    return 4 if ' agoni' in x.text.lower() or "\tagoni" in x.text.lower() else ABSTAIN

#16
## Antagonist
@labeling_function()
def lf_antagonist(x):
    return 5 if ' antagon' in x.text.lower() else ABSTAIN

#17
## Modulator
# Delete this LF, or change this to modulator
@labeling_function()
def lf_modulate(x):
    return 6 if 'modulat' in x.text.lower() else ABSTAIN

#18
@labeling_function()
def lf_allosteric(x):
    return 6 if 'allosteric' in x.text.lower() else ABSTAIN
#19
## Cofactor
@labeling_function()
def lf_cofactor(x):
    return 7 if 'cofactor' in x.text.lower() else ABSTAIN
#20
## Substrate/Product
@labeling_function()
def lf_substrate(x):
    return 8 if 'substrate' in x.text.lower() else ABSTAIN
#21
@labeling_function()
def lf_transport(x):
    return 8 if 'transport' in x.text.lower() else ABSTAIN
#22
@labeling_function()
def lf_catalyze(x):
    if " enzyme" in x.text.lower() or "metabolized" in x.text.lower():
        return 8
    return 8 if 'catalyz' in x.text.lower() or 'catalys' in x.text.lower() else ABSTAIN
#23
@labeling_function()
def lf_product(x):
    return 8 if " produc" in x.text.lower() else ABSTAIN
#24
@labeling_function()
def lf_convert(x):
    return 8 if "conver" in x.text.lower() else ABSTAIN
#25
## NOT
@labeling_function()
def lf_not(x):
    entity1_index = x.text.lower().index(x.entity1.lower())
    entity2_index = x.text.lower().index(x.entity2.lower())
    # if the two entities are close to the word 'not'
    
    if ' not ' in x.text.lower():
        if abs(entity1_index - x.text.lower().index('not')) < 20 or abs(entity2_index - x.text.lower().index('not')) < 20:
            return 9
        # if not is between the two entities
        if abs(entity1_index - x.text.lower().index('not')) < 40 or abs(entity2_index - x.text.lower().index('not')) < 40:
            if entity1_index < entity2_index:
                if x.text[entity1_index:entity2_index].count('not') > 0:
                    return 9
            else:
                if x.text[entity2_index:entity1_index].count('not') > 0:
                    return 9
    return ABSTAIN

# 26 replace the 17 (18)
@labeling_function()
def lf_combined_modulator(x):
    sentence_lower = x.text.lower()

    specific_terms = ['allosteric modulator', 'positive modulator', 'negative modulator', 'non-competitive modulator', 'positive allosteric modulator']
    if any(term in sentence_lower for term in specific_terms):
        return 6

    modulating_terms = ['modulat', 'allosteric', 'potentiate']
    for term in modulating_terms:
        if term in sentence_lower:
            term_index = sentence_lower.index(term)
            if x.entity1_index == -1 or x.entity2_index == -1:
                return ABSTAIN
            if abs(term_index - x.entity1_index) < 20 or abs(term_index - x.entity2_index) < 20:
                return 6
    # check the first indcidences of modulator and positive
    if x.entity1 in sentence_lower and x.entity2 in sentence_lower:
        entity1_index = sentence_lower.index(x.entity1.lower())
        entity2_index = sentence_lower.index(x.entity2.lower())
        between_entities = sentence_lower[min(entity1_index, entity2_index):max(entity1_index, entity2_index)]
        if 'modulate' in between_entities:
            return 6

    if 'positive' in sentence_lower and 'modulator' in sentence_lower:
        pos_indices = [i for i, word in enumerate(sentence_lower.split()) if word == 'positive']
        mod_indices = [i for i, word in enumerate(sentence_lower.split()) if 'modulator' in word]
        if len(pos_indices) == 0 or len(mod_indices) == 0:
            return ABSTAIN
        min_distance = min(abs(p - m) for p in pos_indices for m in mod_indices)
        if min_distance <= 3:
            return 6

    return ABSTAIN

lfs = [lf_amino_acid, lf_replace, lf_mutant, lf_bind, lf_interact, lf_affinity, lf_activate, lf_increase, lf_stimulate, lf_upregulate, lf_downregulate, lf_reduce, lf_inhibit, lf_decrease, lf_agonist, lf_antagonist, lf_combined_modulator, lf_allosteric, lf_cofactor, lf_substrate, lf_transport, lf_catalyze, lf_product, lf_convert, lf_not]

In [9]:
dev_df = li.chemprot_enhanced(dev_df)
li.analysis_LFs(lfs, dev_df,10)

100%|██████████| 250/250 [00:00<00:00, 5958.80it/s]

Test Coverage: 0.832
acuracy for the not abstains
0.7379310344827587
acuracy for all
0.428





Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.,Conflict Ratio
lf_amino_acid,0,[0],0.04,0.02,0.02,5,5,0.5,0.5
lf_replace,1,[0],0.004,0.0,0.0,0,1,0.0,0.0
lf_mutant,2,[],0.0,0.0,0.0,0,0,0.0,
lf_bind,3,[1],0.12,0.076,0.06,20,10,0.666667,0.5
lf_interact,4,[1],0.028,0.02,0.012,5,2,0.714286,0.428571
lf_affinity,5,[1],0.036,0.024,0.016,5,4,0.555556,0.444444
lf_activate,6,[2],0.096,0.06,0.06,9,15,0.375,0.625
lf_increase,7,[2],0.068,0.044,0.044,5,12,0.294118,0.647059
lf_stimulate,8,[],0.0,0.0,0.0,0,0,0.0,
lf_upregulate,9,[2],0.016,0.008,0.008,1,3,0.25,0.5


In [10]:
test_df = li.chemprot_enhanced(test_df)
li.analysis_LFs(lfs, test_df,10)

 38%|███▊      | 616/1607 [00:00<00:00, 6155.72it/s]

100%|██████████| 1607/1607 [00:00<00:00, 6052.44it/s]


Test Coverage: 0.8214063472308649
acuracy for the not abstains
0.7014270032930845
acuracy for all
0.39763534536403233


Unnamed: 0,j,Polarity,Coverage,Overlaps,Conflicts,Correct,Incorrect,Emp. Acc.,Conflict Ratio
lf_amino_acid,0,[0],0.017424,0.012446,0.012446,13,15,0.464286,0.714286
lf_replace,1,[0],0.003111,0.0,0.0,3,2,0.6,0.0
lf_mutant,2,[],0.0,0.0,0.0,0,0,0.0,
lf_bind,3,[1],0.102676,0.071562,0.062228,101,64,0.612121,0.606061
lf_interact,4,[1],0.029247,0.020535,0.018046,20,27,0.425532,0.617021
lf_affinity,5,[1],0.049782,0.031114,0.024269,61,19,0.7625,0.4875
lf_activate,6,[2],0.118233,0.080274,0.071562,75,115,0.394737,0.605263
lf_increase,7,[2],0.097075,0.07094,0.060983,64,92,0.410256,0.628205
lf_stimulate,8,[],0.0,0.0,0.0,0,0,0.0,
lf_upregulate,9,[2],0.014312,0.00809,0.0056,13,10,0.565217,0.391304
