## Implementing Benchmarks

In [23]:
import itertools
import logging
import os
import re
import ssl
#import gensim.downloader as api
import nltk
import pandas as pd
import numpy as np
import warnings


#from fse import SplitIndexedList
#from fse.models import uSIF
from nltk.sentiment.vader import SentimentIntensityAnalyzer
#from textblob import TextBlob

from nltk.tokenize.treebank import TreebankWordTokenizer

from datasets import load_dataset
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import LogisticRegression

import torch
import torch.nn as nn
import torch.utils.data
#from torch_model_base import TorchModelBase
#from torch_rnn_classifier import TorchRNNClassifier, TorchRNNModel
#from torch_shallow_neural_classifier import TorchShallowNeuralClassifier

#import nli
#import utils


def classify_polarity_similarity(df):
    # Load GLOVE, which is necessary for uSIF embeddings
    if not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None):
        ssl._create_default_https_context = ssl._create_unverified_context

    logging.basicConfig(format='%(asctime)s : %(threadName)s : %(levelname)s : %(message)s', level=logging.INFO)
    glove = api.load("glove-wiki-gigaword-300")

    s = SplitIndexedList(input_claims)
    # print(len(s))
    # NOTE, MANY repeats

    # Train the uSIF model
    model = uSIF(glove, workers=2, lang_freq="en")
    model.train(s)
    res_list = model.sv.similar_by_sentence(topic.split(), model=model, indexable=s.items, topn=50)
    results = list(zip(*res_list))
    d = {'claim': results[0], 'similarity_to_topic': results[2], 'drug': [drug] * len(results[0]),
         'topic': [topic] * len(results[0])}

    
    all_claims_df["polarity_vader"] = all_claims_df.apply(lambda row: polarity_v_score(row['claims']), axis=1)
    
    def polarity_v_score(text: str) -> float:
        """
        Calculate polarity of a sentence using Vader.

        :param text: input sentence
        :return: polarity value of sentence. Ranges from -1 (negative) to 1 (positive).
        """
        nltk.download('vader_lexicon')
        vader = SentimentIntensityAnalyzer()
        return vader.polarity_scores(text)['compound']


In [102]:
from collections import Counter
from itertools import product
from nltk.tokenize.treebank import TreebankWordTokenizer
from sklearn.linear_model import LogisticRegression



def classify_negative_parity(premise, hypothesis):
    pass


In [52]:
def load_roam_sep_data(roam_path):
    label_map = {"entailment": 0, "neutral": 1, "contradiction": 2}
    roam_df_list = []
    splits = ["Train", "Val", "Test"]
    for data_split in splits:
        roam_df = pd.read_excel(roam_path, sheet_name=data_split)
        roam_df = roam_df.drop(roam_df.columns[0], axis=1)
        roam_df = roam_df.dropna().reset_index(drop=True)
        roam_df = roam_df.rename(columns={"text1": "sentence1", "text2": "sentence2", "annotation": "label"})
        roam_df = roam_df[roam_df["labels"].isin(label_map.keys())]
        roam_df.replace({"label": label_map})
        roam_df_list.append(roam_df)
    
    return roam_df_list

Unnamed: 0,sentence1,sentence2,labels
0,mortality at 28 days was significantly lower i...,"finally, the most recent and promising researc...",entailment
1,an in vitro study found that remdesivir and ch...,specific therapeutic procedures suggested to i...,entailment
2,an in vitro study found that remdesivir and ch...,in the novel coronavirus pneumonia diagnosis ...,neutral
3,an in vitro study found that remdesivir and ch...,the most prominent finding to emerge from this...,neutral
4,an in vitro study found that remdesivir and ch...,"remdesivir, favipiravir, baricinitib, and anak...",neutral
...,...,...,...
429,"in the dexamethasone group, the incidence of d...",we recommend decreasing the dose of dexamethas...,neutral
430,we report herein our experience regarding the ...,our case also suggests that a brief course of ...,neutral
431,qt prolongation should be considered when usin...,conclusions therapeutic regimens of ifn- + lo...,neutral
432,roads less traveled might also be considered o...,arabi and colleagues initiated a placebo-contr...,neutral


In [1]:


snli = load_dataset("snli")


Downloading builder script:   0%|          | 0.00/3.82k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

Downloading and preparing dataset snli/plain_text (download: 90.17 MiB, generated: 65.51 MiB, post-processed: Unknown size, total: 155.68 MiB) to /Users/dnsosa/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b...


Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Dataset snli downloaded and prepared to /Users/dnsosa/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [54]:
snli


DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 550152
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
})

## CS 224 NLI Baselines

In [100]:
# From NLI
# References: https://github.com/cgpotts/cs224u/blob/afd64b41f845b0f444b152d0f7acf2a45228349a/nli.py#L186
from covid_lit_contra_claims.evaluation.nli_utils import fit_classifier_with_hyperparameter_search

tokenizer = TreebankWordTokenizer()

# Hypothesis only benchmark
def hypothesis_only_unigrams_phi(ex):
    return Counter(tokenizer.tokenize(ex.hypothesis))

def premise_only_unigrams_phi(ex):
    return Counter(tokenizer.tokenize(ex.premise))

def word_overlap_phi(ex):
    words1 = [w.lower() for w in tokenizer.tokenize(ex.premise)]
    words2 = [w.lower() for w in tokenizer.tokenize(ex.hypothesis)]
    return Counter([(w1, w2) for w1, w2 in product(words1, words2)])

def word_cross_product_phi(ex):
    words1 = [w.lower() for w in tokenizer.tokenize(ex.premise)]
    words2 = [w.lower() for w in tokenizer.tokenize(ex.hypothesis)]
    return Counter([(w1, w2) for w1, w2 in product(words1, words2)])


In [98]:
def fit_softmax(X, y):
    mod = LogisticRegression(
        fit_intercept=True,
        solver='liblinear',
        multi_class='ovr')
    mod.fit(X, y)
    return mod

def fit_softmax_with_hyperparameter_search(X, y):
    """
    A MaxEnt model of dataset with hyperparameter cross-validation.

    Parameters
    ----------
    X : 2d np.array
        The matrix of features, one example per row.

    y : list
        The list of labels for rows in `X`.

    Returns
    -------
    sklearn.linear_model.LogisticRegression
        A trained model instance, the best model found.

    """

    mod = LogisticRegression(
        fit_intercept=True,
        max_iter=5,  ## A small number of iterations.
        solver='liblinear',
        multi_class='ovr')

    param_grid = {
        'C': [0.4, 0.6, 0.8, 1.0],
        'penalty': ['l1','l2']}

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        bestmod = fit_classifier_with_hyperparameter_search(
            X, y, mod, param_grid=param_grid, cv=3)

    return bestmod

In [88]:
def calculate_baseline_metrics(dataset, baseline_classifier, hp_opt=False):
    
    if hp_opt:
        baseline_classifier_experiment_xval = nli.experiment(
            train_reader=nli.NLIReader(dataset['train']),
            phi=baseline_classifier,
            train_func=fit_softmax_with_hyperparameter_search,
            assess_reader=None,
            verbose=False)

        optimized_baseline_classifier = baseline_classifier_experiment_xval['model']

        del baseline_classifier_experiment_xval

        def fit_optimized_baseline_classifier(X, y):
            optimized_baseline_classifier.max_iter = 1000 # To convergence in this phase!
            optimized_baseline_classifier.fit(X, y)
            return optimized_baseline_classifier
        
    
        train_func = fit_optimized_baseline_classifier
        
    else: 
        train_func = fit_softmax
    
    baseline_results = nli.experiment(train_reader=nli.NLIReader(dataset['train']),
                                      phi=baseline_classifier,
                                      train_func=train_func,
                                      assess_reader=nli.NLIReader(dataset['val']))

    return baseline_results


### Load Dataset

In [74]:
%%time

from datasets import load_dataset
from covid_lit_contra_claims.evaluation import nli

from covid_lit_contra_claims.data.CreateDatasetUtilities import load_roam_full_data
from covid_lit_contra_claims.data.CreateDataset import create_roam_dataset
from covid_lit_contra_claims.data.constants import *

roam_dataset = create_roam_dataset(ROAM_SEP_PATH)
roam_dataset = roam_dataset.rename_column("labels", "label")
roam_dataset = roam_dataset.rename_column("sentence1", "premise")
roam_dataset = roam_dataset.rename_column("sentence2", "hypothesis")

# snli = load_dataset("snli")

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

CPU times: user 495 ms, sys: 38.2 ms, total: 533 ms
Wall time: 549 ms


### Word Overlap

In [89]:
%%time
word_overlap_results = calculate_baseline_metrics(roam_dataset, word_overlap_phi, hp_opt=True)

Best params: {'C': 0.8, 'penalty': 'l1'}
Best score: 0.482
               precision    recall  f1-score   support

contradiction      0.143     0.073     0.097        41
   entailment      0.292     0.171     0.215        41
      neutral      0.473     0.707     0.567        75

     accuracy                          0.401       157
    macro avg      0.303     0.317     0.293       157
 weighted avg      0.340     0.401     0.352       157

CPU times: user 31.3 s, sys: 1.17 s, total: 32.5 s
Wall time: 13.3 s


### Word Cross-Product

In [90]:
%%time
word_cross_product_results = calculate_baseline_metrics(roam_dataset, word_cross_product_phi, hp_opt=True)

Best params: {'C': 1.0, 'penalty': 'l2'}
Best score: 0.571
               precision    recall  f1-score   support

contradiction      0.000     0.000     0.000        41
   entailment      0.278     0.244     0.260        41
      neutral      0.496     0.800     0.612        75

     accuracy                          0.446       157
    macro avg      0.258     0.348     0.291       157
 weighted avg      0.309     0.446     0.360       157

CPU times: user 32.8 s, sys: 680 ms, total: 33.5 s
Wall time: 11 s


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### Hypothesis- and Premise-Only Unigrams

In [101]:
%%time
hypothesis_unigrams_results = calculate_baseline_metrics(roam_dataset, hypothesis_only_unigrams_phi, hp_opt=False)
premise_unigrams_results = calculate_baseline_metrics(roam_dataset, premise_only_unigrams_phi, hp_opt=False)

               precision    recall  f1-score   support

contradiction      0.400     0.049     0.087        41
   entailment      0.500     0.171     0.255        41
      neutral      0.522     0.960     0.676        75

     accuracy                          0.516       157
    macro avg      0.474     0.393     0.339       157
 weighted avg      0.484     0.516     0.412       157

               precision    recall  f1-score   support

contradiction      0.000     0.000     0.000        41
   entailment      0.229     0.195     0.211        41
      neutral      0.434     0.707     0.538        75

     accuracy                          0.389       157
    macro avg      0.221     0.301     0.250       157
 weighted avg      0.267     0.389     0.312       157

CPU times: user 279 ms, sys: 8.27 ms, total: 287 ms
Wall time: 288 ms


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
