In [1]:
import os
import json

import torch
import tqdm

from transformers import T5Tokenizer, T5ForConditionalGeneration

2023-01-31 14:07:58.011991: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
data = json.load(open("Complete_dataset/test.json"))
files = os.listdir("Complete_dataset/CT json/")
files.remove(".DS_Store")

In [3]:
files_data = {file[:-5]:json.load(open(f"Complete_dataset/CT json/{file}")) for file in files}

In [4]:
data_expanded = []
for _id, value in data.items():
    temp = {}
    temp["id"] = _id
    p_nctid = value["Primary_id"]
    s_nctid = value.get("Secondary_id")
    section_id = value["Section_id"]
    statement = value["Statement"]
    primary_evidence = files_data[p_nctid][section_id]
    temp["statement"] = statement
    temp["primary_evidence"] = primary_evidence
#     temp["label"] = value["Label"]
    
    if s_nctid is not None:
        secondary_evidence = files_data[s_nctid][section_id]
        temp["secondary_evidence"] = secondary_evidence
    
    data_expanded.append(temp)

In [5]:
tokenizer = T5Tokenizer.from_pretrained("allenai/tk-instruct-11b-def")

In [6]:
model = T5ForConditionalGeneration.from_pretrained("allenai/tk-instruct-11b-def", device_map="auto", torch_dtype=torch.bfloat16)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [13]:
def get_input_text(premise, hypothesis):
    return f"Definition: In this task, you are given a premise and hypothesis. The task is to classify them into two categories: 'Entailment' if the hypothesis supports the premise, 'Contradiction' if it opposes the premise. \nNow complete the following example -\nInput: {premise} \n Question: Does this imply that {hypothesis}?  Entailment or Contradiction?"


In [14]:
samples = []
for sample in data_expanded:
    primary_evidence = "".join(sample['primary_evidence'])
    sentence = f"Primary trial evidence are {primary_evidence}"
    secondary_evidence = sample.get("secondary_evidence")
    if secondary_evidence:
        secondary_evidence = "".join(sample['secondary_evidence'])
        sentence = f"{sentence} Secondary trial evidence are {secondary_evidence}"
    input_text = get_input_text(sentence, sample['statement'])
    temp = {"text":input_text, "label":0}
    samples.append(temp)

In [15]:
labels = []
pred = []
for sample in tqdm.tqdm(samples):
    labels.append(sample["label"])
    input_ids = tokenizer(sample["text"], return_tensors="pt",).input_ids.to("cuda")
    outputs = model.generate(input_ids)
    pred.append(tokenizer.decode(outputs[0]))

100%|██████████| 500/500 [02:34<00:00,  3.24it/s]


In [16]:
pred

['<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Entailment</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction</s>',
 '<pad> Contradiction<

In [17]:
pred = [p[5:][:-4].strip() for p in pred]

In [18]:
print(set(pred))
from collections import Counter

{'Entailment', 'Contradiction'}


In [13]:
pred = [p.capitalize() for p in pred]

In [19]:
Counter(pred)

Counter({'Entailment': 137, 'Contradiction': 363})

In [20]:
prediction_dict = {}
for _id,pred_x in zip(data, pred):
    prediction_dict[str(_id)] = {"Prediction":pred_x}

In [21]:
json.dump(prediction_dict, open("results.json",'w'),indent=4)

In [22]:
!zip zeroshot_instruct_11b_2.zip results.json

  adding: results.json (deflated 74%)


In [17]:
prediction_dict

{'9f978634-637c-472f-a588-6f4bb2fb121f': {'Prediction': 'Contradiction'},
 '20b34e62-97c2-4ca0-bb1d-7824dab0b8bb': {'Prediction': 'Contradiction'},
 '893a5337-2aa9-4a87-a020-4c2f03cd4aea': {'Prediction': 'Entailment'},
 'd401affc-f081-4eee-bd61-d109cc88f6de': {'Prediction': 'Entailment'},
 '791790a6-187b-4e4b-be5f-9e5304e9ec2c': {'Prediction': 'Contradiction'},
 'b95b7438-ec16-4d4d-826d-5891e7982b36': {'Prediction': 'Entailment'},
 '4988cb16-7dbb-4847-84e0-4a7957b32c72': {'Prediction': 'Contradiction'},
 'e244fc3a-53b3-4158-99c5-a45afc726af6': {'Prediction': 'Contradiction'},
 '56530063-b408-47f2-8421-6be825f5559c': {'Prediction': 'Entailment'},
 'd3379655-55b7-4e58-88c2-c3cd3e8cb557': {'Prediction': 'Contradiction'},
 '7bf988b4-5e6f-41c4-bef6-7b3549dd58d9': {'Prediction': 'Contradiction'},
 '84915e35-a8c9-4d26-ad09-4b1df48a6df8': {'Prediction': 'Entailment'},
 '64df451c-f868-49e6-9b4e-a325e62ce837': {'Prediction': 'Contradiction'},
 'dc4765a3-6168-4594-969b-3ce5ea7dc02a': {'Prediction

: 