## Main Analysis Notebook

For all pre-processing of the raw data, and analysis with HuggingFace

### Pure HF Training

In [None]:
#!conda install -c conda-forge datasets evaluate ipykernel jupyter jupyterlab keras nb_conda_kernels openpyxl pytorch scikit-learn transformers tqdm wandb
#!ipython kernel install --user --name=cc2
#!pip install transformers -U
#!pip install tokenizers==0.12.1 #maybe

In [20]:
!conda install -c conda-forge openpyxl --yes

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /Users/dnsosa/opt/miniconda3/envs/cc37

  added / updated specs:
    - openpyxl


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    et_xmlfile-1.0.1           |          py_1001          11 KB  conda-forge
    openpyxl-3.0.10            |   py37h8052db5_1         537 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         548 KB

The following NEW packages will be INSTALLED:

  et_xmlfile         conda-f

In [None]:
# TODO 1: Implement negex benchmark 
# TODO 2: Implement word overlap benchmark
# TODO 3: Implement benchmarks based on Vader (polarity detection)

# TODO 4: Implement final evaluation (all test stuff)


In [215]:
# Pre-processing functions

from datasets import ClassLabel

label_map = {"entailment": 0, "neutral": 1, "contradiction": 2}
ClassLabels = ClassLabel(num_classes=len(label_map), names=list(label_map.keys()))


def label_str_to_num(example):
    #example['labels'] = label_map[example['labels']]
    example['labels'] = ClassLabels.str2int([example['labels']])
    return example

def label_num_to_str(example):
    num_to_str = {v: k for k,v in label_map.items()}
    example['labels'] = num_to_str[int(example['labels'])]
    return example


        
def preprocess_nli_corpus_for_pytorch(corpus_id, tokenizer, truncation, mancon_neutral_frac=1, mancon_train_frac=0.67, SEED=42):
    if corpus_id == "multinli":
        raw_dataset = create_multinli_dataset(SEED=SEED)
    
    elif corpus_id == "mednli":
        raw_dataset = create_mednli_dataset(mednli_train_path, mednli_dev_path, mednli_test_path)

    elif corpus_id == "manconcorpus":
        raw_dataset = create_mancon_dataset(mancon_xml_path, mancon_neutral_frac, mancon_train_frac, SEED)
        
    elif corpus_id == "roam":      
        raw_dataset = create_roam_dataset(roam_path)
        
    else:
        print("Invalid corpus ID. Pre-processing failed. ")
        return None
    
    old_column_names = raw_dataset['train'].column_names
    old_column_names.remove('labels')
    
    def tokenize_data(example, tokenizer=tokenizer):
        return tokenizer(example["sentence1"], example["sentence2"], truncation=truncation)
    
    tokenized_datasets = raw_dataset.map(tokenize_data, batched=True, remove_columns=old_column_names)

    return tokenized_datasets


In [216]:
import pandas as pd

tokenized_datasets_multi = preprocess_nli_corpus_for_pytorch("multinli", tokenizer=tokenizer, truncation=config['truncation'])
tokenized_datasets_med = preprocess_nli_corpus_for_pytorch("mednli", tokenizer=tokenizer, truncation=config['truncation'])
tokenized_datasets_man = preprocess_nli_corpus_for_pytorch("manconcorpus", tokenizer=tokenizer, truncation=config['truncation'])
tokenized_datasets_roam = preprocess_nli_corpus_for_pytorch("roam", tokenizer=tokenizer, truncation=config['truncation'])


Using custom data configuration default
Found cached dataset multi_nli (/Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-cf746adc131b9a82.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-219eae8003b88bad.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-5f6c1b53460da3cd.arrow
Loading cached split indices for dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-9b83d1256632f1d6.arrow and /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-e0c94d6a3ca27823.arrow


  0%|          | 0/393 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

Using custom data configuration default-7d9106e9c4160845
Found cached dataset json (/Users/dnsosa/.cache/huggingface/datasets/json/default-7d9106e9c4160845/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-7d9106e9c4160845/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-0cd8e3899082131a.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-7d9106e9c4160845/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-af3ebcd197c3c424.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-7d9106e9c4160845/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-e3b18813c149330b.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-7d9106e9c4160845/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-31f2acf7f6950d78.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-7d9106e9c4160845/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad937

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [208]:
tokenized_datasets_man

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 1566
    })
    val: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 755
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 501
    })
})

In [217]:
train_dataset_list = [tokenized_datasets_multi, tokenized_datasets_med, tokenized_datasets_man, tokenized_datasets_roam]
train_subset_list = [dataset["train"] for dataset in train_dataset_list]
for dataset in train_subset_list:
    print(dataset.features)

{'labels': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradiction'], id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}
{'labels': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradiction'], id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}
{'labels': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradiction'], id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}
{'labels': ClassLabel(num_classes=3, names=['entailment', 'neutral', 'contradiction'], id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence

wandb: Network error (ConnectTimeout), entering retry loop.


In [161]:
import datasets
train_subset_list_combined = datasets.concatenate_datasets(train_subset_list).shuffle()
train_subset_list_combined

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 406060
})

In [36]:
tokenized_datasets_med

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 11232
    })
    val: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 1395
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 1422
    })
})

In [37]:
tokenized_datasets_man

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 1692
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 565
    })
    val: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 565
    })
})

In [38]:
tokenized_datasets_roam

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 434
    })
    val: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 157
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 187
    })
})

In [4]:
import os

import torch
import wandb

from datasets import load_dataset, Dataset, DatasetDict
from torch.utils.data import DataLoader
from transformers import AdamW, AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, get_scheduler,  Trainer, TrainingArguments
from tqdm.notebook import tqdm

import evaluate


print("Packages loaded.")

SEED = 42

root_dir = os.path.abspath("..")
mednli_train_path = os.path.join(root_dir, 'input/mednli/mli_train_v1.jsonl')
mednli_dev_path = os.path.join(root_dir, 'input/mednli/mli_dev_v1.jsonl')
mednli_test_path = os.path.join(root_dir, 'input/mednli/mli_test_v1.jsonl')
mancon_xml_path = os.path.join(root_dir, 'input/manconcorpus/ManConCorpus.xml')
roam_path = os.path.join(root_dir, 'input/cord-training/Roam_annotations_trainvaltest_split_V2.xlsx')

in_dataset = "mednli"
val_set_name = "val"
#val_set_mapper[{"multinli": "validation_matched"}]

config = dict(
    truncation = True,
    mancon_neutral_frac = 1,
    train_val_frac = 0.8,
    num_epochs = 8,
    batch_size = 8,
    wandb_log_interval = 10,
    dataset = in_dataset,
    learning_rate = 3e-5
)

wandb.init(project='Contra Claims 10_22', config=config)
%env "WANDB_NOTEBOOK_NAME" "Main CC Pipeline Analysis Notebook"

print("WandB initialized.")


    
checkpoint = "allenai/biomed_roberta_base"
#checkpoint = "bert-base-uncased"
#checkpoint = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
#checkpoint = "gsarti/biobert-nli"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

print("Tokenizer loaded.")

tokenized_datasets = preprocess_nli_corpus_for_pytorch(config['dataset'], tokenizer=tokenizer, truncation=config['truncation'])

print(f"{in_dataset} tokenized.")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# NOTE: Change from 100
train_dataloader = DataLoader(
    tokenized_datasets["train"].select(range(1000)), shuffle=True, batch_size=config['batch_size'], collate_fn=data_collator
)
#eval_dataloader = DataLoader(
#    tokenized_datasets[val_set_name], batch_size=config['batch_size'], collate_fn=data_collator
#)
eval_dataloader = train_dataloader

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3)
optimizer = AdamW(model.parameters(), lr=config['learning_rate'])
wandb.watch(model, log_freq=100)


print("Model loaded.")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

print(f"Using device {device}.")

num_training_steps = config['num_epochs'] * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

#progress_bar = tqdm(range(num_training_steps))

print("Beginning training...")
print(f"# Epochs: {config['num_epochs']}")
model.train()

#for epoch in range(config['num_epochs']):
for epoch in tqdm(range(config['num_epochs'])):
    for batch_idx, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        #progress_bar.update(1)
        
        if batch_idx % config['wandb_log_interval'] == 0:
            wandb.log({"epoch": epoch, "training_loss": loss})

print("Training complete.")
print("Beginning evaluation...")

acc_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1', average='macro')
precision_metric = evaluate.load('precision', average='macro')
recall_metric = evaluate.load('recall', average='macro')

model.eval()
#for batch_idx, batch in enumerate(eval_dataloader):
for batch_idx, batch in enumerate(tqdm(eval_dataloader)):
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    for metric in [acc_metric, f1_metric, precision_metric, recall_metric]:
        metric.add_batch(predictions=predictions, references=batch["labels"])

results = acc_metric.compute()
for metric in [f1_metric, precision_metric, recall_metric]:
    results.update(metric.compute(average='macro'))
    
wandb.log(results)
#torch.onnx.export(model, batch, "model.onnx")
#wandb.save("model.onnx")

print(f"Results: {results}")
print("Evaluation complete.")



Packages loaded.


[34m[1mwandb[0m: Currently logged in as: [33mdnsosa[0m. Use [1m`wandb login --relogin`[0m to force relogin


env: "WANDB_NOTEBOOK_NAME"="Main CC Pipeline Analysis Notebook"
WandB initialized.
Tokenizer loaded.


Using custom data configuration default-7d9106e9c4160845
Found cached dataset json (/Users/dnsosa/.cache/huggingface/datasets/json/default-7d9106e9c4160845/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/11232 [00:00<?, ?ex/s]

  0%|          | 0/1395 [00:00<?, ?ex/s]

  0%|          | 0/1422 [00:00<?, ?ex/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

mednli tokenized.


Some weights of the model checkpoint at allenai/biomed_roberta_base were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at allenai/biomed_roberta_base and are newly initialized: ['classi

Model loaded.
Using device cpu.
Beginning training...
# Epochs: 8




  0%|          | 0/8 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [255]:
# Processing the Roam annotated corpus again

import os
import pandas as pd

root_dir = os.path.join(os.getcwd(),  "../") # MIGHT NEED TO CHANGE LATER
input_dir = os.path.join(root_dir, "input")
roam_path = os.path.join(input_dir, "Coronawhy-Contra-Claims-Scaling-v2-annotated-2020-10-21.xlsx")
active_sheet = "Docs"
roam_data = pd.read_excel(roam_path, sheet_name=active_sheet)

def process_text(s):
    s = s.replace("Claim1:\n\n", "")
    return s.split("\n\nClaim2:\n\n")

def normalize_tags(s):
    s = s.replace("STRICT_", "")
    return s.lower()

##df['example'] = df['text'].apply(process_text)

roam_data = roam_data.dropna()
roam_data['label'] = roam_data['tags'].apply(normalize_tags)
roam_data = roam_data[~roam_data['label'].str.contains("question|duplicate")]



def splitter(in_str, index):
    text1 = in_str.rstrip().split("\n\n")[index]
    return text1

roam_data["claim1"] = roam_data.text.transform(lambda x: splitter(x, 1))
roam_data["claim2"] = roam_data.text.transform(lambda x: splitter(x, 3))


## df = pd.read_csv("Demo-Annotations-R1-2020-08-21.csv")


# Get the context term set (e.g. Tabula) hits for a PMID where at least k occur
drug_list = ["hydroxychloroquine", " chloroquine", "tocilizumab", "remdesivir", "vitamin d", "lopinavir", "dexamethasone"]
def identify_drugs_mentioned(claim):
    
    found_drugs = []
    for drug in drug_list:
        if drug in claim:
            found_drugs.append(drug)
    
    return found_drugs

        
roam_data["claim1_drugs"] = roam_data['claim1'].apply(identify_drugs_mentioned)
roam_data["claim2_drugs"] = roam_data['claim2'].apply(identify_drugs_mentioned)

In [294]:
pd.set_option('display.max_columns', None)
pd.set_option('display.expand_frame_repr', False)
pd.set_option('max_colwidth', -1)

roam_data.head(3)

  This is separate from the ipykernel package so we can avoid doing imports until


Unnamed: 0,docnum,tags,source,text,row_id,paper1_cord_uid,paper2_cord_uid,label,claim1,claim2,claim1_drugs,claim2_drugs
0,0,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\nchloroquine has been recommended by some authors to be used for the treatment of patients infected with this virus however chloroquine may have side effects and drug resistance problems.\n\nClaim2:\n\non the basis of hydroxychloroquine's superior antiviral and prophylactic activity, as well as its more tolerable safety profile in comparison to chloroquine, we believe that hydroxychloroquine may be a promising drug for the treatment of sars-cov-2 infection [24] .",323,rc5bn6jc,sdij1d90,neutral,chloroquine has been recommended by some authors to be used for the treatment of patients infected with this virus however chloroquine may have side effects and drug resistance problems.,"on the basis of hydroxychloroquine's superior antiviral and prophylactic activity, as well as its more tolerable safety profile in comparison to chloroquine, we believe that hydroxychloroquine may be a promising drug for the treatment of sars-cov-2 infection [24] .",[ chloroquine],"[hydroxychloroquine, chloroquine]"
1,1,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\n15 our regression model identified age as a determinant in responsiveness to lopinavir/ritonavir, with efficacy being related to younger ages.\n\nClaim2:\n\nthese findings formed the basis of a recent randomized clinical treatment trial which showed that the triple combination of antiviral therapy with ifn -1b, lopinavirritonavir, and ribavirin is safe and highly effective in shortening the duration of virus shedding, decreasing cytokine responses, alleviating symptoms, and facilitating the discharge of patients with mild to moderate covid-19 (47) .",413,36amafub,rirbffi6,neutral,"15 our regression model identified age as a determinant in responsiveness to lopinavir/ritonavir, with efficacy being related to younger ages.","these findings formed the basis of a recent randomized clinical treatment trial which showed that the triple combination of antiviral therapy with ifn -1b, lopinavirritonavir, and ribavirin is safe and highly effective in shortening the duration of virus shedding, decreasing cytokine responses, alleviating symptoms, and facilitating the discharge of patients with mild to moderate covid-19 (47) .",[lopinavir],[lopinavir]
2,2,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\nn/a fax: +90 322 458 88 54 inhibition of the raas, weight loss, vitamin d supplementation, management of osa as well as prevention of sarcopenia/frailty.\n\nClaim2:\n\nwe observed weak but beneficial class effects of -blockers, mtor/pi3k inhibitors and vitamin d analogues and a mild amplification of the viral phenotype with -agonists.",431,1emlkii0,27f9241x,neutral,"n/a fax: +90 322 458 88 54 inhibition of the raas, weight loss, vitamin d supplementation, management of osa as well as prevention of sarcopenia/frailty.","we observed weak but beneficial class effects of -blockers, mtor/pi3k inhibitors and vitamin d analogues and a mild amplification of the viral phenotype with -agonists.",[vitamin d],[vitamin d]


In [225]:
len(roam_data[roam_data.claim1_drugs == roam_data.claim2_drugs])

556

In [226]:
len(roam_data)

987

In [228]:
roam_data[roam_data.claim1_drugs == roam_data.claim2_drugs].head(30)

Unnamed: 0,docnum,tags,source,text,row_id,paper1_cord_uid,paper2_cord_uid,label,claim1,claim2,claim1_drugs,claim2_drugs
1,1,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\n15 our regression model identified age as a determinant in responsiveness to lopinavir/ritonavir, with efficacy being related to younger ages.\n\nClaim2:\n\nthese findings formed the basis of a recent randomized clinical treatment trial which showed that the triple combination of antiviral therapy with ifn -1b, lopinavirritonavir, and ribavirin is safe and highly effective in shortening the duration of virus shedding, decreasing cytokine responses, alleviating symptoms, and facilitating the discharge of patients with mild to moderate covid-19 (47) .",413,36amafub,rirbffi6,neutral,"15 our regression model identified age as a determinant in responsiveness to lopinavir/ritonavir, with efficacy being related to younger ages.","these findings formed the basis of a recent randomized clinical treatment trial which showed that the triple combination of antiviral therapy with ifn -1b, lopinavirritonavir, and ribavirin is safe and highly effective in shortening the duration of virus shedding, decreasing cytokine responses, alleviating symptoms, and facilitating the discharge of patients with mild to moderate covid-19 (47) .",[lopinavir],[lopinavir]
2,2,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\nn/a fax: +90 322 458 88 54 inhibition of the raas, weight loss, vitamin d supplementation, management of osa as well as prevention of sarcopenia/frailty.\n\nClaim2:\n\nwe observed weak but beneficial class effects of -blockers, mtor/pi3k inhibitors and vitamin d analogues and a mild amplification of the viral phenotype with -agonists.",431,1emlkii0,27f9241x,neutral,"n/a fax: +90 322 458 88 54 inhibition of the raas, weight loss, vitamin d supplementation, management of osa as well as prevention of sarcopenia/frailty.","we observed weak but beneficial class effects of -blockers, mtor/pi3k inhibitors and vitamin d analogues and a mild amplification of the viral phenotype with -agonists.",[vitamin d],[vitamin d]
4,4,ENTAILMENT,Demo-Annotations_43573ac3-R1,"Claim1:\n\nfurthermore, despite the favorable outcomes of high dose dexamethasone treatment for the acute respiratory distress syndrome 20 , the role of intravenous steroids in ktr with covid-19 remains unknown.\n\nClaim2:\n\nit is important to underline that the immunomodulating therapies used in our patient (metilprednisolone, igiv, plasma-exchange) have not aggravated the sars-cov-2 infection, and besides having a positive effect on the management of the anti-nmda encephalitis, they could also have been helpful in the treatment of covid-19 due to immunosuppression/anti-inflammation properties, and among steroids especially dexamethasone [8].",14,v17l6t5u,kmzum2a9,entailment,"furthermore, despite the favorable outcomes of high dose dexamethasone treatment for the acute respiratory distress syndrome 20 , the role of intravenous steroids in ktr with covid-19 remains unknown.","it is important to underline that the immunomodulating therapies used in our patient (metilprednisolone, igiv, plasma-exchange) have not aggravated the sars-cov-2 infection, and besides having a positive effect on the management of the anti-nmda encephalitis, they could also have been helpful in the treatment of covid-19 due to immunosuppression/anti-inflammation properties, and among steroids especially dexamethasone [8].",[dexamethasone],[dexamethasone]
5,5,CONTRADICTION,Demo-Annotations_43573ac3-R1,"Claim1:\n\nword count: 248 take home: this study demonstrates that the use of hydroxychloroquine with or without azithromycin might have benefits in positive-to-negative conversion of sars-cov-2 and reduction of progression rate, but was associated with increased mortality in covid-19.\n\nClaim2:\n\nsimilarly, a report from france with 181 patients with covid-19 pneumonia found no difference in terms of mortality and icu admission in patients treated with hydroxychloroquine compared to the standard of care [5] .",967,2f6nj4to,n6juf8tw,contradiction,"word count: 248 take home: this study demonstrates that the use of hydroxychloroquine with or without azithromycin might have benefits in positive-to-negative conversion of sars-cov-2 and reduction of progression rate, but was associated with increased mortality in covid-19.","similarly, a report from france with 181 patients with covid-19 pneumonia found no difference in terms of mortality and icu admission in patients treated with hydroxychloroquine compared to the standard of care [5] .",[hydroxychloroquine],[hydroxychloroquine]
7,7,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\nto date, data about the use of tocilizumab in the treatment of acute lung injury in patients\n\nClaim2:\n\nwe report herein our experience regarding the off-label treatment with tocilizumab of the first 51 patients with severe and critical forms of covid-19.",594,3od9m8gh,e5hi63rm,neutral,"to date, data about the use of tocilizumab in the treatment of acute lung injury in patients",we report herein our experience regarding the off-label treatment with tocilizumab of the first 51 patients with severe and critical forms of covid-19.,[tocilizumab],[tocilizumab]
9,9,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\nthe role of convalescent plasma transfusion and recent antiviral agents such as ivermectin and remdesivir, in improving covid-19 prognosis in high-risk patients remains to be demonstrated.\n\nClaim2:\n\nremdesivir, an adenosine analog that incorporates into rna and causes premature termination, is a promising antiviral agent, with preliminary reports suggesting clinical benefit in adult patients.",790,6kn4mr04,t6w7p90m,neutral,"the role of convalescent plasma transfusion and recent antiviral agents such as ivermectin and remdesivir, in improving covid-19 prognosis in high-risk patients remains to be demonstrated.","remdesivir, an adenosine analog that incorporates into rna and causes premature termination, is a promising antiviral agent, with preliminary reports suggesting clinical benefit in adult patients.",[remdesivir],[remdesivir]
10,10,ENTAILMENT,Demo-Annotations_43573ac3-R1,"Claim1:\n\n12, 13 the findings of our study are similar to the observational study from a new york hospital which reported no beneficial effect of hydroxychloroquine treatment on respiratory failure or mortality in patients hospitalized with covid-19.\n\nClaim2:\n\nour findings suggest that patients treated by association of hydroxychloroquine andazithromycinare at greater risk of mortality compared with the 'neither drug' group.",371,yxuzc18x,7qdjea6f,entailment,"12, 13 the findings of our study are similar to the observational study from a new york hospital which reported no beneficial effect of hydroxychloroquine treatment on respiratory failure or mortality in patients hospitalized with covid-19.",our findings suggest that patients treated by association of hydroxychloroquine andazithromycinare at greater risk of mortality compared with the 'neither drug' group.,[hydroxychloroquine],[hydroxychloroquine]
12,12,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\nthe improvement in oxygenation and inflammatory markers with no increase secondary infections or intra-hospital mortality suggests that tocilizumab could be an option for patients with progressive covid-19 after initiation of systemic corticosteroids.\n\nClaim2:\n\nthe mortality rate over a 30-days observation period was 27%, including 5/6 patients (83%) who were receiving invasive ventilation at the time of tocilizumab treatment and 9 of 45 (20%) who were receiving noninvasive oxygen support.",393,f7omk8ut,e5hi63rm,neutral,the improvement in oxygenation and inflammatory markers with no increase secondary infections or intra-hospital mortality suggests that tocilizumab could be an option for patients with progressive covid-19 after initiation of systemic corticosteroids.,"the mortality rate over a 30-days observation period was 27%, including 5/6 patients (83%) who were receiving invasive ventilation at the time of tocilizumab treatment and 9 of 45 (20%) who were receiving noninvasive oxygen support.",[tocilizumab],[tocilizumab]
14,14,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\nwe used tocilizumab in both of our cases to counteract the hyperinflammatory state and cytokine storm associated with increased mortality in covid 19 patients in china.\n\nClaim2:\n\nthe mortality rate over a 30-days observation period was 27%, including 5/6 patients (83%) who were receiving invasive ventilation at the time of tocilizumab treatment and 9 of 45 (20%) who were receiving noninvasive oxygen support.",146,l0rk37zb,e5hi63rm,neutral,we used tocilizumab in both of our cases to counteract the hyperinflammatory state and cytokine storm associated with increased mortality in covid 19 patients in china.,"the mortality rate over a 30-days observation period was 27%, including 5/6 patients (83%) who were receiving invasive ventilation at the time of tocilizumab treatment and 9 of 45 (20%) who were receiving noninvasive oxygen support.",[tocilizumab],[tocilizumab]
19,19,NEUTRAL,Demo-Annotations_43573ac3-R1,"Claim1:\n\nto date, data about the use of tocilizumab in the treatment of acute lung injury in patients\n\nClaim2:\n\nin addition, this case demonstrates the effective use of tocilizumab for the treatment of sars-cov-2 and suggests the superiority of tocilizumab over lenzilumab in the management of this cytokine-mediated syndrome.",595,3od9m8gh,2y8x05a1,neutral,"to date, data about the use of tocilizumab in the treatment of acute lung injury in patients","in addition, this case demonstrates the effective use of tocilizumab for the treatment of sars-cov-2 and suggests the superiority of tocilizumab over lenzilumab in the management of this cytokine-mediated syndrome.",[tocilizumab],[tocilizumab]


In [289]:
len(set(roam_data.claim2))

201

In [286]:
roam_all_claims = set(roam_data.claim1).union(set(roam_data.claim2))
len(roam_all_claims)

229

In [277]:
import numpy as np

roam_all_claims = set(roam_data.claim1).union(roam_data.claim2)
n_claims = len(roam_all_claims)
selected_claims = np.random.choice(list(roam_all_claims), size=n_claims, replace=False)
roam_ph = pd.DataFrame({"sentence1": selected_claims,
                        "sentence2": selected_claims,
                        "labels": "ENTAILMENT"})

ph_train_frac = 0.75
SEED = 42
train_ph, valtest_ph = train_test_split(roam_ph, test_size=(1 - ph_train_frac), shuffle=True, random_state=SEED)
val_ph, test_ph = train_test_split(valtest_ph, test_size=0.5, shuffle=True, random_state=SEED)
raw_roam_ph_df_dict = {"train": train_ph.reset_index(drop=True), 
                       "val": val_ph.reset_index(drop=True), 
                       "test": test_ph.reset_index(drop=True)}

raw_roam_ph_df_dict



{'train':                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               

In [278]:
raw_roam_ph_df_dict['test']

Unnamed: 0,sentence1,sentence2,labels
0,"the observation that in vitro vitamin d 3 was much more effective than either 25(oh)d 3 or 1,25(oh) 2 d 3 in stabilizing endothelial membranes thereby reducing inflammation may help explain the interesting clinical observations that extremely high doses of vitamin d have been effective in treating or at least reducing symptoms of some autoimmune disorders including psoriasis, vitiligo, and multiple sclerosis [37, 85] .","the observation that in vitro vitamin d 3 was much more effective than either 25(oh)d 3 or 1,25(oh) 2 d 3 in stabilizing endothelial membranes thereby reducing inflammation may help explain the interesting clinical observations that extremely high doses of vitamin d have been effective in treating or at least reducing symptoms of some autoimmune disorders including psoriasis, vitiligo, and multiple sclerosis [37, 85] .",ENTAILMENT
1,"the role of convalescent plasma transfusion and recent antiviral agents such as ivermectin and remdesivir, in improving covid-19 prognosis in high-risk patients remains to be demonstrated.","the role of convalescent plasma transfusion and recent antiviral agents such as ivermectin and remdesivir, in improving covid-19 prognosis in high-risk patients remains to be demonstrated.",ENTAILMENT
2,this short write-up explores the potential efficacy and established safety of chloroquine in covid-19.,this short write-up explores the potential efficacy and established safety of chloroquine in covid-19.,ENTAILMENT
3,"an in vitro study found that remdesivir and chloroquine inhibit viral infection, but further study is required [84,85].","an in vitro study found that remdesivir and chloroquine inhibit viral infection, but further study is required [84,85].",ENTAILMENT
4,as very recent studies conducted on remdesivir and hydroxychloroquine (two among the most promising treatments) failed to demonstrate efficacy in patients hospitalized for a documented sars-cov-2 pneumonia [45 47]; the findings of the present analysis could be considered as particularly critical for defining new approaches for the battle against this major endemic disease.,as very recent studies conducted on remdesivir and hydroxychloroquine (two among the most promising treatments) failed to demonstrate efficacy in patients hospitalized for a documented sars-cov-2 pneumonia [45 47]; the findings of the present analysis could be considered as particularly critical for defining new approaches for the battle against this major endemic disease.,ENTAILMENT
5,"it is important to underline that the immunomodulating therapies used in our patient (metilprednisolone, igiv, plasma-exchange) have not aggravated the sars-cov-2 infection, and besides having a positive effect on the management of the anti-nmda encephalitis, they could also have been helpful in the treatment of covid-19 due to immunosuppression/anti-inflammation properties, and among steroids especially dexamethasone [8].","it is important to underline that the immunomodulating therapies used in our patient (metilprednisolone, igiv, plasma-exchange) have not aggravated the sars-cov-2 infection, and besides having a positive effect on the management of the anti-nmda encephalitis, they could also have been helpful in the treatment of covid-19 due to immunosuppression/anti-inflammation properties, and among steroids especially dexamethasone [8].",ENTAILMENT
6,"word count: 248 take home: this study demonstrates that the use of hydroxychloroquine with or without azithromycin might have benefits in positive-to-negative conversion of sars-cov-2 and reduction of progression rate, but was associated with increased mortality in covid-19.","word count: 248 take home: this study demonstrates that the use of hydroxychloroquine with or without azithromycin might have benefits in positive-to-negative conversion of sars-cov-2 and reduction of progression rate, but was associated with increased mortality in covid-19.",ENTAILMENT
7,"after delivery, there was a clinical deterioration; she was treated with lopinavir-ritonavir, oseltamivir, hydroxychloroquine, meropenem, and vancomycin and received corticosteroid pulse therapy, emergency plasmapheresis, and invasive ventilation.","after delivery, there was a clinical deterioration; she was treated with lopinavir-ritonavir, oseltamivir, hydroxychloroquine, meropenem, and vancomycin and received corticosteroid pulse therapy, emergency plasmapheresis, and invasive ventilation.",ENTAILMENT
8,"it was indicated for acute coivd-19 patients when there is no response to the commonly used and recommended treatment with hydroxychloroquine and azithromycin [5, 14] .","it was indicated for acute coivd-19 patients when there is no response to the commonly used and recommended treatment with hydroxychloroquine and azithromycin [5, 14] .",ENTAILMENT
9,"the third patient, who had started receiving hydroxychloroquine 5 days after the admission, was transferred to intensive care 2 days later and was then prescribed an off-label therapy with ritonavir and lopinavir for severe acute respiratory syndrome coronavirus 2 (sars-cov-2) pneumonia; this patient developed a left bundle branch block due to hydroxychloroquine on hospital day 8 [not all durations of treatments to reactions onsets stated].","the third patient, who had started receiving hydroxychloroquine 5 days after the admission, was transferred to intensive care 2 days later and was then prescribed an off-label therapy with ritonavir and lopinavir for severe acute respiratory syndrome coronavirus 2 (sars-cov-2) pneumonia; this patient developed a left bundle branch block due to hydroxychloroquine on hospital day 8 [not all durations of treatments to reactions onsets stated].",ENTAILMENT


In [319]:
def drug_swap(row):
    c_drug = row["claim_drugs"][0].strip()
    other_drugs = [drug for drug in drug_list if drug != c_drug]
    new_drug = np.random.choice(other_drugs)
    
    return row["claim"].replace(c_drug, new_drug)

roam_claims = pd.concat([pd.DataFrame({"claim": roam_data.claim1, "claim_drugs": roam_data.claim1_drugs}),
                         pd.DataFrame({"claim": roam_data.claim2, "claim_drugs": roam_data.claim2_drugs})])
roam_claims["n_drugs"] = roam_claims["claim_drugs"].apply(lambda x: len(x))
roam_claims["claim_drugs"] = roam_claims["claim_drugs"].apply(lambda x: tuple(x))
roam_claims_all = roam_claims.copy()
roam_claims = roam_claims[roam_claims.n_drugs == 1].drop_duplicates()
#roam_claims = roam_claims.drop(columns=["claim_drugs"]).drop_duplicates()
roam_claims["swapped_claim1"] = roam_claims.apply(drug_swap, axis=1)
roam_claims["swapped_claim2"] = roam_claims.apply(drug_swap, axis=1)

# On multiple drugs mentioned at the same time. What would the swap procedure be? Replace the names of all the drugs with the resampled drug (or drugs?)?

print(len(roam_claims))
roam_claims.head()

roam_dd_ph = pd.concat([pd.DataFrame({"sentence1": roam_claims.claim, "sentence2": roam_claims.swapped_claim1}),
                        pd.DataFrame({"sentence1": roam_claims.swapped_claim2, "sentence2": roam_claims.claim})]).sample(frac=1, random_state=SEED)
roam_dd_ph["labels"] = "NEUTRAL"
train_dd_ph, valtest_dd_ph = train_test_split(roam_dd_ph, test_size=(1 - ph_train_frac), shuffle=True, random_state=SEED)
val_dd_ph, test_dd_ph = train_test_split(valtest_dd_ph, test_size=0.5, shuffle=True, random_state=SEED)
raw_roam_dd_ph_df_dict = {"train": train_dd_ph.reset_index(drop=True), 
                          "val": val_dd_ph.reset_index(drop=True), 
                          "test": test_dd_ph.reset_index(drop=True)}



176


In [None]:
roam_claims.head()



In [320]:
roam_claims_drugs_dict = dict(zip(roam_claims_all.claim, roam_claims_all.claim_drugs))

In [322]:
len(roam_claims_drugs_dict)

229

In [333]:
n_dd = 400
dd_train_frac = 0.75
all_claims = list(roam_claims_drugs_dict.keys())
first_claims = np.random.choice(all_claims, size=n_dd, replace=True)

claim_pairs = []

for claim in first_claims:
    claim_drugs = set(roam_claims_drugs_dict[claim])
    found_compatible_claim = False
    while not found_compatible_claim:
        candidate_claim = np.random.choice(all_claims)
        candidate_claim_drugs = set(roam_claims_drugs_dict[candidate_claim])
        if len(claim_drugs.intersection(candidate_claim_drugs)) == 0:
            found_compatible_claim = True
    claim_pairs.append([claim, candidate_claim])

roam_dd = pd.DataFrame(claim_pairs, columns=["sentence1", "sentence2"])
roam_dd["labels"] = "NEUTRAL"
train_dd, valtest_dd = train_test_split(roam_dd, test_size=(1 - dd_train_frac), shuffle=True, random_state=SEED)
val_dd, test_dd = train_test_split(valtest_dd, test_size=0.5, shuffle=True, random_state=SEED)
raw_roam_dd_df_dict = {"train": train_dd.reset_index(drop=True), 
                       "val": val_dd.reset_index(drop=True), 
                       "test": test_dd.reset_index(drop=True)}


In [331]:
roam_dd_ph.head(3)

Unnamed: 0,sentence1,sentence2,labels
120,"conclusion: further research is needed to confirm the correlation between latitude and covid-19 fatalities, and to determine the optimum amounts of safe sunlight exposure and/or tocilizumab oral supplementation to reduce covid-19 fatalities in populations that are at high risk for tocilizumab deficiency.","conclusion: further research is needed to confirm the correlation between latitude and covid-19 fatalities, and to determine the optimum amounts of safe sunlight exposure and/or vitamin d oral supplementation to reduce covid-19 fatalities in populations that are at high risk for vitamin d deficiency.",NEUTRAL
420,"to assess the overall effect of vitamin d supplementation on risk of acute respiratory infection (ari), and to identify factors modifying this effect.","to assess the overall effect of dexamethasone supplementation on risk of acute respiratory infection (ari), and to identify factors modifying this effect.",NEUTRAL
129,"secondary outcomes were incidence of uri and lri, analysed separately; incidence of emergency department attendance and/or hospital admission for ari; death due to ari or respiratory failure; use of antibiotics to treat an ari; absence from work or school due to ari; incidence of serious adverse events; death due to any cause; and incidence of potential adverse reactions to vitamin d (hypercalcaemia and renal stones).","secondary outcomes were incidence of uri and lri, analysed separately; incidence of emergency department attendance and/or hospital admission for ari; death due to ari or respiratory failure; use of antibiotics to treat an ari; absence from work or school due to ari; incidence of serious adverse events; death due to any cause; and incidence of potential adverse reactions to hydroxychloroquine (hypercalcaemia and renal stones).",NEUTRAL


In [332]:
roam_ph.head(3)

Unnamed: 0,sentence1,sentence2,labels
0,"randomized clinical trials are needed to identify safe and effective treatments for covid-19, including those that definitively delineate the incidence of adverse effects and efficacy of hydroxychloroquine in hospitalized patients.","randomized clinical trials are needed to identify safe and effective treatments for covid-19, including those that definitively delineate the incidence of adverse effects and efficacy of hydroxychloroquine in hospitalized patients.",ENTAILMENT
1,"on the other hand, despite the efficacy of hydroxychloroquine in covid-19 treatment is unclears9,s10, dose reduction could be associated with the decreased potency of hydroxychloroquine for covid-19 treatment.","on the other hand, despite the efficacy of hydroxychloroquine in covid-19 treatment is unclears9,s10, dose reduction could be associated with the decreased potency of hydroxychloroquine for covid-19 treatment.",ENTAILMENT
2,"clinicians should monitor covid-19 patients when treating them with chloroquine or other qt-prolonging drugs, with special attention to females, patients with structural heart disease, baseline qt interval on ecg, concomitant use of other qt-prolonging medications, potassium or magnesium abnormalities and bradycardia.","clinicians should monitor covid-19 patients when treating them with chloroquine or other qt-prolonging drugs, with special attention to females, patients with structural heart disease, baseline qt interval on ecg, concomitant use of other qt-prolonging medications, potassium or magnesium abnormalities and bradycardia.",ENTAILMENT


In [1]:
import covid_lit_contra_claims as clcc
from covid_lit_contra_claims.data.constants import model_id_mapper
from covid_lit_contra_claims.data.DataLoader import load_train_datasets, load_additional_eval_datasets
from covid_lit_contra_claims.data.DataExperiments import prepare_training_data

from transformers import AutoTokenizer


#out_dir = 
model = "biobert"
train_datasets = "multinli_mednli_mancon_roam_roamAll_roamPH_roamDD_roamDDPH"
eval_datasets = train_datasets
truncation = True
train_prep_experiment = "shuffled"
data_ratios = 2
SEED = 42

# Loading tokenizer here because needed in data loading and model loading
checkpoint = model_id_mapper[model]
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# Load training and evaluation datasets
train_dataset_dict, val_dataset_dict, test_dataset_dict = load_train_datasets(train_datasets, tokenizer,
                                                                              truncation=truncation,
                                                                              SEED=SEED)

# Two versions of CovidNLI: One where test is a separate network from train
eval_dataset_dict = load_additional_eval_datasets(eval_datasets, tokenizer,
                                                  truncation=truncation,
                                                  SEED=SEED)

# Conduct any input preprocessing for various experiments
# Note currently only using data_ratio parameter for training data, NOT val data.


====Creating multinli Dataset object for train/val/test...====


Using custom data configuration default
Found cached dataset multi_nli (/Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-cf746adc131b9a82.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-219eae8003b88bad.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-5f6c1b53460da3cd.arrow
Loading cached split indices for dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-9b83d1256632f1d6.arrow and /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-e0c94d6a3ca27823.arrow
Loading cached processed dataset at /Users/d

====...done.====
====Creating mednli Dataset object for train/val/test...====


Using custom data configuration default-e5247ea137d095d5
Found cached dataset json (/Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-b4834287848785dd.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-7f0f4a316a8e5bb3.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-af90fb7c154f5da6.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-e40fa9fcc0faf7b1.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad937

====...done.====
====Creating mancon Dataset object for train/val/test...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roam Dataset object for train/val/test...====


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roamAll Dataset object for train/val/test...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roamPH Dataset object for train/val/test...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roamDD Dataset object for train/val/test...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roamDDPH Dataset object for train/val/test...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating multinli Dataset object for evaluation only...====


Using custom data configuration default
Found cached dataset multi_nli (/Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-cf746adc131b9a82.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-219eae8003b88bad.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-5f6c1b53460da3cd.arrow
Loading cached split indices for dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-9b83d1256632f1d6.arrow and /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-e0c94d6a3ca27823.arrow
Loading cached processed dataset at /Users/d

====...done.====
====Creating mednli Dataset object for evaluation only...====


Using custom data configuration default-e5247ea137d095d5
Found cached dataset json (/Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-b4834287848785dd.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-7f0f4a316a8e5bb3.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-af90fb7c154f5da6.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-e40fa9fcc0faf7b1.arrow
Loading cached processed dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad937

====...done.====
====Creating mancon Dataset object for evaluation only...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roam Dataset object for evaluation only...====


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roamAll Dataset object for evaluation only...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roamPH Dataset object for evaluation only...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roamDD Dataset object for evaluation only...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

====...done.====
====Creating roamDDPH Dataset object for evaluation only...====


Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Loading cached shuffled indices for dataset at /Users/dnsosa/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39/cache-4e70c3b4af46282f.arrow
Loading cached shuffled indices for dataset at /Users/dnsosa/.cache/huggingface/datasets/json/default-e5247ea137d095d5/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab/cache-c9999423c7edd30e.arrow


====...done.====


In [2]:
train_dataset_dict

OrderedDict([('multinli',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 392702
              })),
             ('mednli',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 11232
              })),
             ('mancon',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 1396
              })),
             ('roam',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 434
              })),
             ('roamAll',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 740
              })),
             ('roamPH',
              Dataset({
     

In [3]:
prepared_train_dataset_dict.keys()

odict_keys(['roamPH', 'roam', 'roamAll', 'mednli', 'mancon', 'roamDD', 'roamDDPH', 'multinli'])

In [7]:
prepared_train_dataset_dict = prepare_training_data(train_dataset_dict, train_prep_experiment, SEED=SEED, data_ratios=data_ratios)
prepared_train_dataset_dict

OrderedDict([('roam',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 434
              })),
             ('roamAll',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 740
              })),
             ('roamDD',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 300
              })),
             ('roamDDPH',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 264
              })),
             ('mancon',
              Dataset({
                  features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
                  num_rows: 1000
              })),
             ('roamPH',
              Dataset({
          

In [14]:
big_datasets = ["multinli", "mednli", "mancon"]
data_ratio = 4

ratio_multiplier = len(set(big_datasets).intersection(train_dataset_dict.keys()))

big_datasets_new_counts = {}
for big_dataset in big_datasets:
    if big_dataset in train_dataset_dict:
        big_dataset_count = 500 * data_ratio ** ratio_multiplier
        big_datasets_new_counts[big_dataset] = min(big_dataset_count, train_dataset_dict[big_dataset].num_rows)
        ratio_multiplier -= 1

print(big_datasets_new_counts)

{'multinli': 32000, 'mednli': 8000, 'mancon': 1738}


In [17]:
for big_dataset, new_count in big_datasets_new_counts.items():
    zz = train_dataset_dict[big_dataset].shuffle(seed=42).select(range(new_count))
    print(zz)


Dataset({
    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 32000
})
Dataset({
    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 8000
})
Dataset({
    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1738
})


In [35]:
from collections import OrderedDict

xx = OrderedDict()
#xx = OrderedDict({'a': 1, 'b': 2, 'c': 3})
xx["test"] = 5
xx["8"] = "pizza"
xx["french"] = "fry"

In [38]:
for k,v in xx.items():
    print(k,v)

test 5
8 pizza
french fry


In [55]:
import random

random.seed(42)

xx_items = list(xx.items())
random.shuffle(xx_items)
yy = OrderedDict(xx_items)

for k,v in yy.items():
    print(k,v)

8 pizza
test 5
french fry


In [56]:
xx_items

[('8', 'pizza'), ('test', 5), ('french', 'fry')]

In [60]:
from collections import OrderedDict
from datasets import concatenate_datasets

combined = concatenate_datasets(list(train_dataset_dict.values()))

In [58]:
train_dataset_dict

{'roam': Dataset({
     features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 434
 }),
 'roamAll': Dataset({
     features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 740
 }),
 'roamPH': Dataset({
     features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 171
 }),
 'roamDD': Dataset({
     features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 300
 }),
 'roamDDPH': Dataset({
     features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 264
 })}

In [66]:
combined.shuffle().select(range(3))[:]

{'labels': [1, 0, 0],
 'input_ids': [[101,
   1103,
   1329,
   1104,
   1142,
   1207,
   7606,
   1107,
   4612,
   1114,
   1168,
   3252,
   6665,
   1176,
   1103,
   7991,
   1104,
   1957,
   5557,
   1216,
   1112,
   177,
   19694,
   16844,
   1732,
   10885,
   1186,
   12934,
   1162,
   1137,
   1231,
   1306,
   4704,
   11083,
   3161,
   117,
   1336,
   1129,
   6315,
   1107,
   1884,
   18312,
   118,
   1627,
   4420,
   119,
   102,
   1145,
   117,
   3209,
   5557,
   2345,
   1107,
   1952,
   122,
   117,
   1216,
   1112,
   1231,
   1306,
   4704,
   11083,
   3161,
   117,
   1120,
   10961,
   3906,
   25740,
   117,
   21718,
   12934,
   21704,
   1197,
   117,
   1105,
   1532,
   26950,
   4063,
   117,
   1105,
   1106,
   6617,
   2646,
   10337,
   1918,
   1830,
   1169,
   1129,
   2234,
   1112,
   14115,
   1111,
   1884,
   18312,
   118,
   1627,
   1191,
   1152,
   5424,
   1106,
   1129,
   3903,
   1107,
   3724,
   1105,
   7300,
   2527,


In [69]:
'_'.join(list(train_dataset_dict.keys()))

'roam_roamAll_roamPH_roamDD_roamDDPH'