In [None]:
import os
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import json
from sentence_splitter import split_text_into_sentences
from transformers import RobertaForSequenceClassification, RobertaTokenizer

In [None]:
# load jsonl from ../data/out/medical/datasets/train.jsonl
with open('../data/out/medical/datasets/train.jsonl', 'r') as f:
    train = [json.loads(line) for line in f]

# load jsonl from ../data/out/medical/mistral-medical/selected/wikidoc/top_p0.1.jsonl
with open('../data/out/medical/mistral-medical/selected/wikidoc/top_p0.1.jsonl', 'r') as f:
    selected = [json.loads(line) for line in f]

# load jsonl from ../data/out/medical/datasets/medical_meadow_wikidoc.jsonl
with open('../data/out/medical/datasets/medical_meadow_wikidoc.jsonl', 'r') as f:
    wikidoc = [json.loads(line) for line in f]

In [None]:
# load roberta classifier
model = RobertaForSequenceClassification.from_pretrained(
    "../data/generics/roberta_generics_classifier", 
    num_labels=1,
    device_map="auto"
)
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
1+1

In [None]:
try:
    with open('../data/out/medical/datasets/sentences.json', 'r') as f:
        sentences = json.load(f)
except:
    def extract_sentences_from_jsonl(jsonl):
        sentences = []
        for item in tqdm(jsonl):
            sentences.append(item["prompt"])
            sentences.extend(split_text_into_sentences(item["completion"], language="en"))
        return sentences
    sentences = {
        "train": extract_sentences_from_jsonl(train),
        "selected": extract_sentences_from_jsonl(selected),
        "wikidoc": extract_sentences_from_jsonl(wikidoc)
    }
    with open('../data/out/medical/datasets/sentences.json', 'w') as f:
        json.dump(sentences, f)

In [None]:
def is_generic(batch):
    inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
    inputs = {key: inputs[key].cuda() for key in inputs}
    outputs = model(**inputs)
    outputs = outputs.logits.squeeze().cpu().detach().numpy()
    outputs = outputs > 0.8
    return outputs

def sentence_to_generics(sentences, batch_size = 64):
    generics = []
    for i in tqdm(range(0, len(sentences), batch_size)):
        batch = sentences[i:i+batch_size]
        are_generics = is_generic(batch)
        generics.extend(are_generics)
    return generics

In [None]:
are_generics={
    "train": sentence_to_generics(sentences["train"]),
    "selected": sentence_to_generics(sentences["selected"]),
    "wikidoc": sentence_to_generics(sentences["wikidoc"])
}

In [None]:
#print the percentage of generics on each key
for key in are_generics:
    print(f"{key}: {sum(are_generics[key]) / len(are_generics[key])}")

In [None]:
import re
are_questions = {
    "train": ["?" in s for s in sentences["train"]],
    "selected": ["?" in s for s in sentences["selected"]],
    "wikidoc": ["?" in s for s in sentences["wikidoc"]]
}
#print the percentage of questions on each key
for key in are_questions:
    print(f"{key}: {sum(are_questions[key]) / len(are_questions[key])}")

In [None]:
import re
are_conditionals = {
    "train": ["if" in s.lower() for s in sentences["train"]],
    "selected": ["if" in s.lower() for s in sentences["selected"]],
    "wikidoc": ["if" in s.lower() for s in sentences["wikidoc"]]
}
#print the percentage of questions on each key
for key in are_conditionals:
    print(f"{key}: {sum(are_conditionals[key]) / len(are_conditionals[key])}")

In [None]:
import re
are_negatives = {
    "train": [" no " in s.lower() or "No " in s for s in sentences["train"]],
    "selected": [" no " in s.lower() or "No " in s for s in sentences["selected"]],
    "wikidoc": [" no " in s.lower() or "No " in s for s in sentences["wikidoc"]]
}
#print the percentage of questions on each key
for key in are_negatives:
    print(f"{key}: {sum(are_negatives[key]) / len(are_negatives[key])}")

In [None]:
# print the sentences for which are_generics is true for each key
for key in ["wikidoc"]:
    print(key)
    for i, (sentence, is_generic) in enumerate(zip(sentences[key], are_generics[key])):
        if is_generic:
            print(f"{i}: {sentence}")
    print()