# Test set inference notebook for 2nd place solution

- Refer to the discussion post [2nd place solution overview](https://www.kaggle.com/c/coleridgeinitiative-show-us-the-data/discussion/248296) for an explanation of the approach
- @leecming if you have any questions about my approach
- Other than the competition dataset, the only external dataset used is a [HuggingFace fine-tuned model/tokenizer](https://www.kaggle.com/leecming/robertalabelclassifierrawipcc)

Broadly, the notebook does the following :-
1. Search all test documents for strings in the form "LONG-NAME (ACRONYM)" and create mappings of documents to these LONG-NAMES/ACRONYMS
   - Accept acronyms only if they're longer than 3 characters 
2. Classify LONG-NAMES as datasets using a fine-tuned HuggingFace Transformer binary classifier 
   - Accept candidates only if they have probability above MIN_PROB
3. Generate a dynamic lookup table for LONG-NAMES that exceed HIGH_FREQ document frequency + definite labels from the training dataset
4. Search the cleaned text of all documents for the LONG-NAMES from step 3. and generate a combined ID-to-candidate table
5. Collate a final set of predictions 
  - Remove candidates too similar to definite training labels using the fuzzywuzzy string comparison functionality
  - Accept a candidate if either they exceed HIGH_FREQ doc frequency OR regex match ([A-Z][a-z]+ )+(Study|Survey)$ OR regex match (Study|Survey) of
  - Add a candidate's acronym if it is present in the RAW text 
  - For any document that has no suitable candidates, predict empty string
   



In [1]:
from collections import defaultdict, Counter
from fuzzywuzzy import fuzz
import json
import logging
import multiprocessing as mp
import numpy as np
import pandas as pd
import os
import re
import regex
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import nltk
from nltk.tokenize import sent_tokenize

# Model settings and hyper-parameters

In [2]:
JSON_FEATURE_DIR = '../input/coleridgeinitiative-show-us-the-data/test/' # Path to the competition test-set
PRETRAINED_CLASSIFIER = '../input/robertalabelclassifierrawipcc/' # Path to HuggingFace model/tokenizer; also found at https://www.kaggle.com/leecming/robertalabelclassifierrawipcc
MIN_PROB = 0.9 # Threshold probability for accepting a dataset label prediction
HIGH_FREQ = 50  # Threshold document frequency (i.e., # of docs that contain the candidate string)
MATCHING_THRESHOLD = 90  # Threshold similarity for removing candidate strings too similar to definite training labels

In [3]:
# Preprocessing text cleaning code provided by organizers
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower()).strip()

# Schwartz-Hearst code 

Schwartz-Hearst algorithm as described in [A Simple Algorithm for Identifying Abbreviation Definitions in Biomedical Text](https://psb.stanford.edu/psb-online/proceedings/psb03/schwartz.pdf) and Python code adapted from [here](https://github.com/philgooch/abbreviation-extraction)

The long code chunk below is hidden for readability but can be interpreted as an advanced regex search algorithm used to search for strings in the form of LONG-NAME (ACRONYM) e.g., "Baltimore Longitudinal Study of Aging (BLSA)"

In [4]:
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
log = logging.getLogger(__name__)


class Candidate(str):
    def __init__(self, value):
        super().__init__()
        self.start = 0
        self.stop = 0

    def set_position(self, start, stop):
        self.start = start
        self.stop = stop


def yield_lines_from_file(file_path):
    with open(file_path, 'rb') as f:
        for line in f:
            try:
                line = line.decode('utf-8')
            except UnicodeDecodeError:
                line = line.decode('latin-1').encode('utf-8').decode('utf-8')
            line = line.strip()
            yield line


def yield_lines_from_doc(doc_text):
    for line in doc_text.split("\n"):
        yield line.strip()


def best_candidates(sentence):
    """
    :param sentence: line read from input file
    :return: a Candidate iterator
    """

    if '(' in sentence:
        # Check some things first
        if sentence.count('(') != sentence.count(')'):
            raise ValueError("Unbalanced parentheses: {}".format(sentence))

        if sentence.find('(') > sentence.find(')'):
            raise ValueError("First parentheses is right: {}".format(sentence))

        close_index = -1
        while 1:
            # Look for open parenthesis. Need leading whitespace to avoid matching mathematical and chemical formulae
            open_index = sentence.find(' (', close_index + 1)

            if open_index == -1: break

            # Advance beyond whitespace
            open_index += 1

            # Look for closing parentheses
            close_index = open_index + 1
            open_count = 1
            skip = False
            while open_count:
                try:
                    char = sentence[close_index]
                except IndexError:
                    # We found an opening bracket but no associated closing bracket
                    # Skip the opening bracket
                    skip = True
                    break
                if char == '(':
                    open_count += 1
                elif char in [')', ';', ':']:
                    open_count -= 1
                close_index += 1

            if skip:
                close_index = open_index + 1
                continue

            # Output if conditions are met
            start = open_index + 1
            stop = close_index - 1
            candidate = sentence[start:stop]

            # Take into account whitespace that should be removed
            start = start + len(candidate) - len(candidate.lstrip())
            stop = stop - len(candidate) + len(candidate.rstrip())
            candidate = sentence[start:stop]

            if conditions(candidate):
                new_candidate = Candidate(candidate)
                new_candidate.set_position(start, stop)
                yield new_candidate


def conditions(candidate):
    """
    Based on Schwartz&Hearst

    2 <= len(str) <= 10
    len(tokens) <= 2
    re.search(r'\p{L}', str)
    str[0].isalnum()

    and extra:
    if it matches (\p{L}\.?\s?){2,}
    it is a good candidate.

    :param candidate: candidate abbreviation
    :return: True if this is a good candidate
    """
    viable = True
    if regex.match(r'(\p{L}\.?\s?){2,}', candidate.lstrip()):
        viable = True
    if len(candidate) < 2 or len(candidate) > 10:
        viable = False
    if len(candidate.split()) > 2:
        viable = False
    if not regex.search(r'\p{L}', candidate):
        viable = False
    if not candidate[0].isalnum():
        viable = False

    return viable


def get_definition(candidate, sentence):
    """
    Takes a candidate and a sentence and returns the definition candidate.

    The definition candidate is the set of tokens (in front of the candidate)
    that starts with a token starting with the first character of the candidate

    :param candidate: candidate abbreviation
    :param sentence: current sentence (single line from input file)
    :return: candidate definition for this abbreviation
    """
    # Take the tokens in front of the candidate
    tokens = regex.split(r'[\s\-]+', sentence[:candidate.start - 2].lower())
    # the char that we are looking for
    key = candidate[0].lower()

    # Count the number of tokens that start with the same character as the candidate
    first_chars = [t[0] for t in filter(None, tokens)]

    definition_freq = first_chars.count(key)
    candidate_freq = candidate.lower().count(key)

    # Look for the list of tokens in front of candidate that
    # have a sufficient number of tokens starting with key
    if candidate_freq <= definition_freq:
        # we should at least have a good number of starts
        count = 0
        start = 0
        start_index = len(first_chars) - 1
        while count < candidate_freq:
            if abs(start) > len(first_chars):
                raise ValueError("candidate {} not found".format(candidate))
            start -= 1
            # Look up key in the definition
            try:
                start_index = first_chars.index(key, len(first_chars) + start)
            except ValueError:
                pass

            # Count the number of keys in definition
            count = first_chars[start_index:].count(key)

        # We found enough keys in the definition so return the definition as a definition candidate
        start = len(' '.join(tokens[:start_index]))
        stop = candidate.start - 1
        candidate = sentence[start:stop]

        # Remove whitespace
        start = start + len(candidate) - len(candidate.lstrip())
        stop = stop - len(candidate) + len(candidate.rstrip())
        candidate = sentence[start:stop]

        new_candidate = Candidate(candidate)
        new_candidate.set_position(start, stop)
        return new_candidate

    else:
        raise ValueError('There are less keys in the tokens in front of candidate than there are in the candidate')


def select_definition(definition, abbrev):
    """
    Takes a definition candidate and an abbreviation candidate
    and returns True if the chars in the abbreviation occur in the definition

    Based on
    A simple algorithm for identifying abbreviation definitions in biomedical texts, Schwartz & Hearst
    :param definition: candidate definition
    :param abbrev: candidate abbreviation
    :return:
    """

    if len(definition) < len(abbrev):
        raise ValueError('Abbreviation is longer than definition')

    if abbrev in definition.split():
        raise ValueError('Abbreviation is full word of definition')

    s_index = -1
    l_index = -1

    while 1:
        try:
            long_char = definition[l_index].lower()
        except IndexError:
            raise

        short_char = abbrev[s_index].lower()

        if not short_char.isalnum():
            s_index -= 1

        if s_index == -1 * len(abbrev):
            if short_char == long_char:
                if l_index == -1 * len(definition) or not definition[l_index - 1].isalnum():
                    break
                else:
                    l_index -= 1
            else:
                l_index -= 1
                if l_index == -1 * (len(definition) + 1):
                    raise ValueError("definition {} was not found in {}".format(abbrev, definition))

        else:
            if short_char == long_char:
                s_index -= 1
                l_index -= 1
            else:
                l_index -= 1

    new_candidate = Candidate(definition[l_index:len(definition)])
    new_candidate.set_position(definition.start, definition.stop)
    definition = new_candidate

    tokens = len(definition.split())
    length = len(abbrev)

    if tokens > min([length + 5, length * 2]):
        raise ValueError("did not meet min(|A|+5, |A|*2) constraint")

    # Do not return definitions that contain unbalanced parentheses
    if definition.count('(') != definition.count(')'):
        raise ValueError("Unbalanced parentheses not allowed in a definition")

    return definition


def extract_abbreviation_definition_pairs(file_path=None,
                                          doc_text=None,
                                          most_common_definition=False,
                                          first_definition=False):
    abbrev_map = dict()
    list_abbrev_map = defaultdict(list)
    counter_abbrev_map = dict()
    omit = 0
    written = 0
    if file_path:
        sentence_iterator = enumerate(yield_lines_from_file(file_path))
    elif doc_text:
        sentence_iterator = enumerate(yield_lines_from_doc(doc_text))
    else:
        return abbrev_map

    collect_definitions = False
    if most_common_definition or first_definition:
        collect_definitions = True

    for i, sentence in sentence_iterator:
        # Remove any quotes around potential candidate terms
        clean_sentence = regex.sub(r'([(])[\'"\p{Pi}]|[\'"\p{Pf}]([);:])', r'\1\2', sentence)
        try:
            for candidate in best_candidates(clean_sentence):
                try:
                    definition = get_definition(candidate, clean_sentence)
                except (ValueError, IndexError) as e:
                    log.debug("{} Omitting candidate {}. Reason: {}".format(i, candidate, e.args[0]))
                    omit += 1
                else:
                    try:
                        definition = select_definition(definition, candidate)
                    except (ValueError, IndexError) as e:
                        log.debug("{} Omitting definition {} for candidate {}. Reason: {}".format(i, definition, candidate, e.args[0]))
                        omit += 1
                    else:
                        # Either append the current definition to the list of previous definitions ...
                        if collect_definitions:
                            list_abbrev_map[candidate].append(definition)
                        else:
                            # Or update the abbreviations map with the current definition
                            abbrev_map[candidate] = definition
                        written += 1
        except (ValueError, IndexError) as e:
            log.debug("{} Error processing sentence {}: {}".format(i, sentence, e.args[0]))
    log.debug("{} abbreviations detected and kept ({} omitted)".format(written, omit))

    # Return most common definition for each term
    if collect_definitions:
        if most_common_definition:
            # Return the most common definition for each term
            for k,v in list_abbrev_map.items():
                counter_abbrev_map[k] = Counter(v).most_common(1)[0][0]
        else:
            # Return the first definition for each term
            for k, v in list_abbrev_map.items():
                counter_abbrev_map[k] = v[0]
        return counter_abbrev_map

    # Or return the last encountered definition for each term
    return abbrev_map

# Dataset binary classifier 
Self-contained code to classify a list of input strings (as datasets or not); returns a list of probabilities

In [5]:
def classify_labels(string_list):
    classifier = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_CLASSIFIER).half().cuda()
    tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_CLASSIFIER)
    classifier.eval()

    label_preds = []
    with torch.no_grad():
        for batch_idx_start in range(0, len(string_list), 64):
            batch_idx_end = min(batch_idx_start + 64, len(string_list))
            current_batch = string_list[batch_idx_start:batch_idx_end]
            batch_features = tokenizer(current_batch,
                                       truncation=True,
                                       max_length=64,
                                       padding='max_length',
                                       add_special_tokens=True,
                                       return_tensors='pt')
            batch_features = {k: v.cuda() for k,v in batch_features.items()}
            model_output = classifier(**batch_features,
                                      return_dict=True)
            batch_preds = torch.nn.Softmax(-1)(model_output['logits'])[:, 1].cpu().numpy()
            label_preds.append(batch_preds)
            
    return np.concatenate(label_preds)

# Document text search 
Load a given document's json file, tokenize into sentences, and search for LONG-NAME (ACRONYMS) 

In [6]:
def get_abbrev_dict(jsonfile):
    curr_id = jsonfile.split('.')[0]
    with open(os.path.join(JSON_FEATURE_DIR, jsonfile)) as f:
        raw_json = json.load(f)
        raw_text = list(map(lambda x: ' '.join(x.values()), raw_json))
        raw_text = ' '.join(raw_text)
        raw_text = ' '.join(raw_text.split())
    sentences = '\n'.join(sent_tokenize(raw_text))
    return curr_id, extract_abbreviation_definition_pairs(doc_text=sentences)

In [7]:
%%time
# Run the document text search across the test set using multiprocessing
with mp.Pool(12) as pooler:
    list_dicts = list(pooler.map(get_abbrev_dict, os.listdir(JSON_FEATURE_DIR)))

CPU times: user 11.5 ms, sys: 59 ms, total: 70.5 ms
Wall time: 405 ms


In [8]:
%%time
# Generate classifier predictions for all labels
candidate_labels = list(set(long for _, curr_dict in list_dicts for long in curr_dict.values()))
print('# candidate labels: {}'.format(len(candidate_labels)))
candidate_preds = classify_labels(candidate_labels)
candidate_prob_mapping = dict(zip(candidate_labels, candidate_preds))

# candidate labels: 69
CPU times: user 5.44 s, sys: 1.53 s, total: 6.97 s
Wall time: 17.4 s


In [9]:
# Filter for labels that meet the classifier prob req
long_id_mapping = defaultdict(set) # Mapping of candidate to document ID
clean_raw_mapping = dict()  # Mapping of cleaned candidate string to raw candidate string
long_short_mapping = dict()  # Mapping of LONG-NAME candidate to its acronym

for curr_id, curr_dict in list_dicts:
    for short, long in curr_dict.items():
        if candidate_prob_mapping[long] > MIN_PROB: # Only accept labels that meet the minimum probabiity
            cleaned_form = clean_text(long)
            long_id_mapping[cleaned_form].add(curr_id)
            clean_raw_mapping[cleaned_form] = long

            # Store only acronyms that are longer than 3 characters
            if len(short) > 3:
                long_short_mapping[cleaned_form] = short

print('# All associated IDs: {}'.format(len(set().union(*[x for x in long_id_mapping.values()]))))

# All associated IDs: 4


# Threshold and propagate
- Create a set of high frequency labels (high_prob_freq_labels) including the definite training labels
- Search through the cleaned texts of all documents for matches and add them as candidates for the given docs

In [10]:
high_prob_freq_labels = [k for k, v in long_id_mapping.items() if len(v) > HIGH_FREQ]
print('# high prob/high freq labels: {}'.format(len(high_prob_freq_labels)))
print('# of associated IDs: {}'.format(len(set().union(*[v for k, v in long_id_mapping.items() if k in high_prob_freq_labels]))))
print(high_prob_freq_labels)

# high prob/high freq labels: 0
# of associated IDs: 0
[]


In [11]:
# Throw in external training labels
def_labels = set(pd.read_csv('../input/coleridgeinitiative-show-us-the-data/train.csv')['cleaned_label'].drop_duplicates())
high_prob_freq_labels = set(high_prob_freq_labels).union(def_labels)
print('# combined w/ training labels: {}'.format(len(high_prob_freq_labels)))

# combined w/ training labels: 130


In [12]:
%%time
# Code to generate both raw text and cleaned text for given test json file
def get_text_for_jsonid(json_file):
    file_id = json_file.split('.')[0]
    with open(os.path.join(JSON_FEATURE_DIR, json_file)) as f:
        raw_json = json.load(f)
        raw_text = list(map(lambda x: ' '.join(x.values()), raw_json))
        raw_text = ' '.join(raw_text)
        raw_text = ' '.join(raw_text.split())
    return file_id, raw_text, clean_text(raw_text)

id_text_tuple = list(map(get_text_for_jsonid, os.listdir(JSON_FEATURE_DIR)))

id_to_raw_text = {x[0]:x[1] for x in id_text_tuple}
id_to_clean_text = {x[0]:x[2] for x in id_text_tuple}

CPU times: user 33.2 ms, sys: 1.36 ms, total: 34.5 ms
Wall time: 41.6 ms


In [13]:
# Perform the search of docs
for curr_label in high_prob_freq_labels:
    for curr_id in id_to_clean_text:
        if curr_label in id_to_clean_text[curr_id]:
            long_id_mapping[curr_label].add(curr_id)
            
print('# of associated IDs: {}'.format(len(set().union(*[v for k, v in long_id_mapping.items() if k in high_prob_freq_labels]))))

# of associated IDs: 4


In [14]:
# Generate ID to prediction mappings
id_to_pred_mapping = defaultdict(list)
for curr_label in long_id_mapping:
    for curr_id in long_id_mapping[curr_label]:
        id_to_pred_mapping[curr_id].append(curr_label)

# Collate at the ID level
For each candidate dataset of a given document :-
1. Remove if too similar to a definite training label (as assessed by FuzzyWuzzy string match algo)
2. Accept if 

    a. Meets HIGH_FREQ document frequency threshold OR 
    
    b. Regex matches ([A-Z][a-z]+ )+(Study|Survey)$ OR 
    
    c. Regex matches (Study|Survey) of
    
3. If the candidate is accepted, accept its acronym if it's present in the RAW document text 


In [15]:
for curr_id, curr_pred_list in id_to_pred_mapping.items():
    # Sort in following descending priority (a definite training label, doc frequency, length of string)
    curr_pred_list = sorted(curr_pred_list,
                            key=lambda x:(x in def_labels,len(long_id_mapping[x]), 1./len(x)), reverse=True)
    sieved_pred_list = []
    for curr_pred in curr_pred_list:
        match_found = False
        for other_pred in sieved_pred_list:
            # Check if a candidate is too similar to a definite training label prediction
            if fuzz.token_set_ratio(curr_pred, other_pred) > MATCHING_THRESHOLD and curr_pred not in def_labels and other_pred in def_labels:
                match_found = True
                break
                
        if not match_found and (len(long_id_mapping[curr_pred]) > HIGH_FREQ or curr_pred in def_labels
                                or re.search('([A-Z][a-z]+ )+(Study|Survey)$', clean_raw_mapping[curr_pred])
                                or re.search('(Study|Survey) of', clean_raw_mapping[curr_pred])):
            sieved_pred_list.append(curr_pred)
            
            # Add acronym as prediction if present in raw document text
            if curr_pred in long_short_mapping and re.search(r' {} '.format(long_short_mapping[curr_pred]),
                                                             id_to_raw_text[curr_id]):
                sieved_pred_list.append((clean_text(long_short_mapping[curr_pred])))

    id_to_pred_mapping[curr_id] = set(sieved_pred_list)

In [16]:
# Add dummy labels for missing IDs (if any)
all_ids = set(list(map(lambda x: x.split('.')[0], os.listdir(JSON_FEATURE_DIR))))
pred_ids = set(id_to_pred_mapping)
missing_ids = all_ids - pred_ids

for curr_missing_id in missing_ids:
    id_to_pred_mapping[curr_missing_id].append('')

print('Added dummy labels for #{} missing IDs'.format(len(missing_ids)))

Added dummy labels for #0 missing IDs


In [17]:
pred_df = pd.DataFrame(id_to_pred_mapping.items())
pred_df.columns = ['Id', 'PredictionString']
pred_df['PredictionString'] = pred_df['PredictionString'].apply(lambda x: '|'.join(sorted(x)))
pred_df = pred_df.sort_values(['Id', 'PredictionString'])
pred_df.to_csv('submission.csv', index=False)