In [1]:
import os
import json

import torch
import tqdm

from transformers import AutoModelForCausalLM, AutoTokenizer

2023-02-05 15:15:51.727901: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
data = json.load(open("Complete_dataset/test.json"))
files = os.listdir("Complete_dataset/CT json/")
files.remove(".DS_Store")

In [3]:
files_data = {file[:-5]:json.load(open(f"Complete_dataset/CT json/{file}")) for file in files}

In [4]:
data_expanded = []
for _id, value in data.items():
    temp = {}
    temp["id"] = _id
    p_nctid = value["Primary_id"]
    s_nctid = value.get("Secondary_id")
    section_id = value["Section_id"]
    statement = value["Statement"]
    primary_evidence = files_data[p_nctid][section_id]
    temp["statement"] = statement
    temp["primary_evidence"] = primary_evidence
#     temp["label"] = value["Label"]
    
    if s_nctid is not None:
        secondary_evidence = files_data[s_nctid][section_id]
        temp["secondary_evidence"] = secondary_evidence
    
    data_expanded.append(temp)

In [5]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-iml-max-30b")

In [6]:
model = AutoModelForCausalLM.from_pretrained("facebook/opt-iml-max-30b", torch_dtype=torch.float16).cuda()

In [None]:
def get_input_text(premise, hypothesis):
    options_prefix = "OPTIONS:\n- "
    separator = "\n- "
    options_ = options_prefix + f"{separator}".join(["Entailment","Contradiction"])
    return f"{premise} \n Question: Does this imply that {hypothesis}? {options_}\n A:"


In [None]:
samples = []
for sample in data_expanded:
    primary_evidence = "".join(sample['primary_evidence'])
    sentence = f"Primary trial evidence are {primary_evidence}"
    secondary_evidence = sample.get("secondary_evidence")
    if secondary_evidence:
        secondary_evidence = "".join(sample['secondary_evidence'])
        sentence = f"{sentence} Secondary trial evidence are {secondary_evidence}"
    input_text = get_input_text(sentence, sample['statement'])
    temp = {"text":input_text, "label":0}
    samples.append(temp)

In [None]:
labels = []
pred = []
with torch.inference_mode():
    for sample in tqdm.tqdm(samples):
        labels.append(sample["label"])
        input_ids = tokenizer(sample["text"], return_tensors="pt",).input_ids.to("cuda")
        outputs = model.generate(input_ids,max_new_tokens=8)
        pred.append(tokenizer.decode(outputs[0]))

Input length of input_ids is 269, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
  0%|          | 1/500 [00:06<50:03,  6.02s/it]Input length of input_ids is 138, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Input length of input_ids is 178, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
  1%|          | 3/500 [00:06<13:27,  1.62s/it]Input length of input_ids is 681, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
  1%|          | 4/500 [00:06<09:32,  1.15s/it]Input length of input_ids is 456, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
  1%|          | 5/500 [00:06<06:50,  1.21it/s]Input length of input_ids is 110, but `max_length` i

In [None]:
pred

['<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradict

In [None]:
pred = [p[5:][:-4].strip() for p in pred]

In [None]:
print(set(pred))
from collections import Counter

{'Contradiction', 'Entailment'}


In [None]:
Counter(pred)

Counter({'Contradiction': 264, 'Entailment': 236})

In [None]:
prediction_dict = {}
for _id,pred_x in zip(data, pred):
    prediction_dict[str(_id)] = {"Prediction":pred_x}

In [None]:
json.dump(prediction_dict, open("results.json",'w'),indent=4)

In [16]:
!zip results_xxl_train_1024_2_7e_6_Dev_repro.zip results.json

  adding: results.json (deflated 73%)


In [17]:
prediction_dict

{'9f978634-637c-472f-a588-6f4bb2fb121f': {'Prediction': 'Contradiction'},
 '20b34e62-97c2-4ca0-bb1d-7824dab0b8bb': {'Prediction': 'Entailment'},
 '893a5337-2aa9-4a87-a020-4c2f03cd4aea': {'Prediction': 'Entailment'},
 'd401affc-f081-4eee-bd61-d109cc88f6de': {'Prediction': 'Entailment'},
 '791790a6-187b-4e4b-be5f-9e5304e9ec2c': {'Prediction': 'Contradiction'},
 'b95b7438-ec16-4d4d-826d-5891e7982b36': {'Prediction': 'Entailment'},
 '4988cb16-7dbb-4847-84e0-4a7957b32c72': {'Prediction': 'Entailment'},
 'e244fc3a-53b3-4158-99c5-a45afc726af6': {'Prediction': 'Contradiction'},
 '56530063-b408-47f2-8421-6be825f5559c': {'Prediction': 'Entailment'},
 'd3379655-55b7-4e58-88c2-c3cd3e8cb557': {'Prediction': 'Contradiction'},
 '7bf988b4-5e6f-41c4-bef6-7b3549dd58d9': {'Prediction': 'Contradiction'},
 '84915e35-a8c9-4d26-ad09-4b1df48a6df8': {'Prediction': 'Entailment'},
 '64df451c-f868-49e6-9b4e-a325e62ce837': {'Prediction': 'Entailment'},
 'dc4765a3-6168-4594-969b-3ce5ea7dc02a': {'Prediction': 'Contr