# T5 Text2Text Generation on Sciq
使用 Sciq dataset訓練 T5 Distractor Generation<br>
直接使用 trainer 訓練 <br>

### GPU

In [1]:
!nvidia-smi

Thu Aug 17 09:00:27 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA TITAN RTX               On  | 00000000:09:00.0 Off |                  N/A |
| 41%   39C    P8              35W / 280W |      1MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA TITAN RTX               On  | 00000000:0A:00.0 Off |  

In [2]:
project_name = "test on T5 with T5"
import os

os.environ["WANDB_PROJECT"] = project_name

### import

In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

  from .autonotebook import tqdm as notebook_tqdm


### Loading the dataset

In [13]:
from datasets import load_dataset

dataset = load_dataset("sciq")

Using custom data configuration default
Reusing dataset sciq (/user_data/.cache/huggingface/datasets/sciq/default/0.1.0/50e5c6e3795b55463819d399ec417bfd4c3c621105e00295ddb5f3633d708493)
100%|██████████| 3/3 [00:00<00:00,  8.98it/s]


In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 11679
    })
    validation: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support'],
        num_rows: 1000
    })
})

In [6]:
dataset['train'][0]

{'question': 'What type of organism is commonly used in preparation of foods such as cheese and yogurt?',
 'distractor3': 'viruses',
 'distractor1': 'protozoa',
 'distractor2': 'gymnosperms',
 'correct_answer': 'mesophilic organisms',
 'support': 'Mesophiles grow best in moderate temperature, typically between 25°C and 40°C (77°F and 104°F). Mesophiles are often found living in or on the bodies of humans or other animals. The optimal growth temperature of many pathogenic mesophiles is 37°C (98°F), the normal human body temperature. Mesophilic organisms have important uses in food preparation, including cheese, yogurt, beer and wine.'}

In [14]:
train = dataset['train']
valid = dataset['validation']
test = dataset['test']

In [15]:
train = list(train)
test = list(test)
valid = list(valid)

In [9]:
len(train), len(valid), len(test)

(11679, 1000, 1000)

In [10]:
train[0]

{'question': 'What type of organism is commonly used in preparation of foods such as cheese and yogurt?',
 'distractor3': 'viruses',
 'distractor1': 'protozoa',
 'distractor2': 'gymnosperms',
 'correct_answer': 'mesophilic organisms',
 'support': 'Mesophiles grow best in moderate temperature, typically between 25°C and 40°C (77°F and 104°F). Mesophiles are often found living in or on the bodies of humans or other animals. The optimal growth temperature of many pathogenic mesophiles is 37°C (98°F), the normal human body temperature. Mesophilic organisms have important uses in food preparation, including cheese, yogurt, beer and wine.'}

### Prepare data

In [16]:
def processData(data):
    
    sentences = []
    labels = []
    answers = []
    for d in data:
        sentence = d['question']
        distractors = [d['distractor1'], d['distractor2'], d['distractor3']]
        answer = d['correct_answer']
        
        # 避免dataset的label有空白
        distractors = [dis.strip() for dis in distractors]
        
        sentences.append(sentence)
        labels.append('_ of distractors are ' + ', '.join(distractors))
        answers.append(answer)
        
    return sentences, answers, labels

In [17]:
train_sent, train_answer, train_label = processData(train)
valid_sent, valid_answer, valid_label = processData(valid)
test_sent, test_answer, test_label = processData(test)

In [None]:
len(train_sent), len(train_answer), len(train_label)

(11679, 11679, 11679)

In [14]:
for idx in range(2):
    print(train_sent[idx])
    print(train_answer[idx])
    print(train_label[idx])
    print()

What type of organism is commonly used in preparation of foods such as cheese and yogurt?
mesophilic organisms
_ of distractors are protozoa, gymnosperms, viruses

What phenomenon makes global winds blow northeast to southwest or the reverse in the northern hemisphere and northwest to southeast or the reverse in the southern hemisphere?
coriolis effect
_ of distractors are muon effect, centrifugal effect, tropical effect



In [18]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

In [19]:
train_encodings = tokenizer(train_sent, train_answer,truncation=True, padding=True)
valid_encodings = tokenizer(valid_sent, valid_answer,truncation=True, padding=True)
test_encodings = tokenizer(test_sent, test_answer,truncation=True, padding=True)

In [17]:
train_encodings.keys()

dict_keys(['input_ids', 'attention_mask'])

In [18]:
print(train_encodings.input_ids[0])

[363, 686, 13, 9329, 19, 5871, 261, 16, 4537, 13, 4371, 224, 38, 3285, 11, 19168, 58, 1, 140, 7, 21144, 447, 9329, 7, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [19]:
tokenizer.decode(train_encodings.input_ids[0])

'What type of organism is commonly used in preparation of foods such as cheese and yogurt?</s> mesophilic organisms</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>'

In [20]:
def add_labels(encodings, distractors):
    
    distractors_encodings = tokenizer(distractors, padding=True)
    labels = []
    for i in range(len(distractors_encodings.input_ids)):
        labels.append(distractors_encodings.input_ids[i])
    
    encodings["labels"] = labels
    return encodings

In [21]:
train_encodings = add_labels(train_encodings, train_label)
valid_encodings = add_labels(valid_encodings, valid_label)
test_encodings = add_labels(test_encodings, test_label)

In [22]:
len(train_encodings.input_ids)

11679

In [22]:
class SciqDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = SciqDataset(train_encodings)
valid_dataset = SciqDataset(valid_encodings)
test_dataset = SciqDataset(test_encodings)

In [24]:
len(train_dataset), len(valid_dataset), len(test_dataset)

(11679, 1000, 1000)

### Fine-tuning

In [25]:
from transformers import T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch

model = T5ForConditionalGeneration.from_pretrained("t5-base")
model.resize_token_embeddings(len(tokenizer))

Embedding(32100, 768)

In [26]:
batch_size = 16
args = Seq2SeqTrainingArguments(
    output_dir = "./results",
    save_strategy = "epoch",
    evaluation_strategy = "epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=50,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="P@1",
    weight_decay=0.01,
    predict_with_generate=True,
    eval_accumulation_steps = 1,
    report_to="wandb" if os.getenv("WANDB_PROJECT") else "none"
)

In [27]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [28]:
import numpy as np
def compute_metrics(p):
    predictions, labels = p
    
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # store all article
    predicted = []
    true_label = []
    
    for k in range(len(decoded_labels)):
        pred = decoded_preds[k]
        label = decoded_labels[k]

        pred_list = pred.split(', ')
        label_list = label.split(', ')
        
        pred_list[0] = pred_list[0].split('are ')[-1]
        label_list[0] = label_list[0].split('are ')[-1]

        predicted.append(pred_list)
        true_label.append(label_list)

    # evaluation metrics
    p1 = 0
    p3 = 0
    r3 = 0
    f3 = 0
    for idx in range(len(true_label)):
        distractors = predicted[idx]
        labels = true_label[idx]

        act_set = set(labels)
        pred1_set = set(distractors[:1])
        pred3_set = set(distractors[:3])

        p_1 = len(act_set & pred1_set) / float(1)
        p_3 = len(act_set & pred3_set) / float(3)
        r_3 = len(act_set & pred3_set) / float(len(act_set))

        if p_3 == 0 and r_3 == 0:
            f1_3 = 0
        else:
            f1_3 = 2 * (p_3 * r_3 / (p_3 + r_3))

        p1+=p_1
        p3+=p_3
        r3+=r_3
        f3+=f1_3

    avg_p1 = p1 / len(true_label)
    avg_p3 = p3 / len(true_label)
    avg_r3 = r3 / len(true_label)
    avg_f3 = f3 / len(true_label)

    result = {'P@1': avg_p1,
              'P@3': avg_p3,
              'R@3': avg_r3,
              'F1@3': avg_f3}
    
    return result

In [29]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [30]:
trainer.train()

***** Running training *****
  Num examples = 11679
  Num Epochs = 50
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 36500
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mms0004284[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,P@1,P@3,R@3,F1@3
1,0.6135,0.510426,0.12,0.054667,0.05431,0.054378
2,0.404,0.50372,0.155,0.076333,0.07631,0.076222
3,0.364,0.494545,0.163,0.086,0.086143,0.085956
4,0.3507,0.494843,0.174,0.091667,0.091976,0.091733
5,0.3267,0.496472,0.188,0.096,0.096333,0.096089
6,0.3133,0.500004,0.195,0.108667,0.109,0.108756
7,0.2942,0.507427,0.187,0.107333,0.107643,0.1074
8,0.2818,0.512803,0.215,0.119,0.1195,0.1192
9,0.2688,0.523899,0.224,0.123667,0.124333,0.123933
10,0.2562,0.527542,0.216,0.124,0.124167,0.123978


***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
  Num examples = 1000
  Batch size = 8
Saving model checkpoint to ./results/checkpoint-730
Configuration saved in ./results/checkpoint-730/config.json
Saving model checkpoint to ./results/checkpoint-730
Configuration saved in ./results/checkpoint-730/config.json
Model weights saved in ./results/checkpoint-730/pytorch_model.bin
tokenizer config file saved in ./results/checkpoint-730/tokenizer_config.json
Special tokens file saved in ./results/checkpoint-730/special_tokens_map.json
Deleting older checkpoint [results/checkpoint-16790] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8
Saving model checkpoint to ./results/checkpoint-1460
Configuration saved in ./results/checkpoint-1460/config.json
Saving model checkpoint to ./results/checkpoint-1460
Configuration saved in ./results/checkpoint-1460/config.

TrainOutput(global_step=36500, training_loss=0.16777305414905286, metrics={'train_runtime': 9610.4423, 'train_samples_per_second': 60.762, 'train_steps_per_second': 3.798, 'total_flos': 7.7093176916736e+16, 'train_loss': 0.16777305414905286, 'epoch': 50.0})

In [31]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8


{'eval_loss': 0.6068299412727356,
 'eval_P@1': 0.225,
 'eval_P@3': 0.14466666666666675,
 'eval_R@3': 0.1448095238095239,
 'eval_F1@3': 0.1446222222222223,
 'eval_runtime': 36.3616,
 'eval_samples_per_second': 27.502,
 'eval_steps_per_second': 1.733,
 'epoch': 50.0}

In [32]:
trainer.save_model('/user_data/CTG/model/t5-base-text2text-sciq-off-shelf')

Saving model checkpoint to /user_data/CTG/model/t5-base-text2text-sciq-8b
Configuration saved in /user_data/CTG/model/t5-base-text2text-sciq-8b/config.json
Model weights saved in /user_data/CTG/model/t5-base-text2text-sciq-8b/pytorch_model.bin
tokenizer config file saved in /user_data/CTG/model/t5-base-text2text-sciq-8b/tokenizer_config.json
Special tokens file saved in /user_data/CTG/model/t5-base-text2text-sciq-8b/special_tokens_map.json


In [33]:
predictions, labels, metrics = trainer.predict(valid_dataset)
print('valid: ')
metrics

***** Running Prediction *****
  Num examples = 1000
  Batch size = 8
  Num examples = 1000
  Batch size = 8


valid: 


{'test_loss': 0.6068299412727356,
 'test_P@1': 0.225,
 'test_P@3': 0.14466666666666675,
 'test_R@3': 0.1448095238095239,
 'test_F1@3': 0.1446222222222223,
 'test_runtime': 36.333,
 'test_samples_per_second': 27.523,
 'test_steps_per_second': 1.734}

In [34]:
predictions, labels, metrics = trainer.predict(test_dataset)
print('test: ')
metrics

***** Running Prediction *****
  Num examples = 1000
  Batch size = 8
  Num examples = 1000
  Batch size = 8


test: 


{'test_loss': 0.7199090719223022,
 'test_P@1': 0.243,
 'test_P@3': 0.1566666666666668,
 'test_R@3': 0.1567539682539684,
 'test_F1@3': 0.1565666666666668,
 'test_runtime': 33.6847,
 'test_samples_per_second': 29.687,
 'test_steps_per_second': 1.87}

In [30]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch

tokenizer = T5Tokenizer.from_pretrained("/user_data/CTG/model/t5-base-text2text-sciq-off-shelf")
model = T5ForConditionalGeneration.from_pretrained("/user_data/CTG/model/t5-base-text2text-sciq-off-shelf")

Didn't find file /user_data/CTG/model/t5-base-text2text-sciq-8b/added_tokens.json. We won't load it.
loading file /user_data/CTG/model/t5-base-text2text-sciq-8b/spiece.model
loading file None
loading file /user_data/CTG/model/t5-base-text2text-sciq-8b/special_tokens_map.json
loading file /user_data/CTG/model/t5-base-text2text-sciq-8b/tokenizer_config.json
loading configuration file /user_data/CTG/model/t5-base-text2text-sciq-8b/config.json
Model config T5Config {
  "_name_or_path": "t5-base",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "d_ff": 3072,
  "d_kv": 64,
  "d_model": 768,
  "decoder_start_token_id": 0,
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "relu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 12,
  "num_heads": 12,
  "num_layers": 12,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  

In [27]:
batch_size = 16
args = Seq2SeqTrainingArguments(
    output_dir = "./results",
    save_strategy = "epoch",
    evaluation_strategy = "epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=50,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="P@1",
    weight_decay=0.01,
    predict_with_generate=True,
    eval_accumulation_steps = 1,
)

In [24]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [25]:
import numpy as np
def compute_metrics(p):
    predictions, labels = p
    
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # store all article
    predicted = []
    true_label = []
    
    for k in range(len(decoded_labels)):
        pred = decoded_preds[k]
        label = decoded_labels[k]

        pred_list = pred.split(', ')
        label_list = label.split(', ')
        
        pred_list[0] = pred_list[0].split('are ')[-1]
        label_list[0] = label_list[0].split('are ')[-1]

        predicted.append(pred_list)
        true_label.append(label_list)

    # evaluation metrics
    p1 = 0
    p3 = 0
    r3 = 0
    f3 = 0
    for idx in range(len(true_label)):
        distractors = predicted[idx]
        labels = true_label[idx]

        act_set = set(labels)
        pred1_set = set(distractors[:1])
        pred3_set = set(distractors[:3])

        p_1 = len(act_set & pred1_set) / float(1)
        p_3 = len(act_set & pred3_set) / float(3)
        r_3 = len(act_set & pred3_set) / float(len(act_set))

        if p_3 == 0 and r_3 == 0:
            f1_3 = 0
        else:
            f1_3 = 2 * (p_3 * r_3 / (p_3 + r_3))

        p1+=p_1
        p3+=p_3
        r3+=r_3
        f3+=f1_3

    avg_p1 = p1 / len(true_label)
    avg_p3 = p3 / len(true_label)
    avg_r3 = r3 / len(true_label)
    avg_f3 = f3 / len(true_label)

    result = {'P@1': avg_p1,
              'P@3': avg_p3,
              'R@3': avg_r3,
              'F1@3': avg_f3}
    
    return result

In [31]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [32]:
test_predictions, test_labels, test_metrics = trainer.predict(test_dataset)
test_metrics

***** Running Prediction *****
  Num examples = 1000
  Batch size = 16


{'test_loss': 0.8025060892105103,
 'test_P@1': 0.195,
 'test_P@3': 0.1366666666666666,
 'test_R@3': 0.1368888888888888,
 'test_F1@3': 0.13659999999999992,
 'test_runtime': 27.5539,
 'test_samples_per_second': 36.292,
 'test_steps_per_second': 1.161}

In [41]:
decoded_preds = tokenizer.batch_decode(test_labels)

In [42]:
for d in decoded_preds:
    print(d)

_ of distractors are antioxidants, Oxygen, residues</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
_ of distractors are adult, male, phenotype</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
_ of distractors are Bones, Muscles, Thumbs</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
_ of distractors are depth, latitude, variation</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
_ of distractors are mountain ranges, fossils, magma</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
_ of distractors are produce hormones, nitrogen hormones, Human Hormones</s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <p

In [43]:
import json
def write_json(data, path):
    
    jsonString = json.dumps(data)
    jsonFile = open(path, "w")
    jsonFile.write(jsonString)
    jsonFile.close()

In [44]:
def save_data(data, predictions, labels, file_name):
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # store all article
    predicted = []
    true_label = []
    
    for k in range(len(decoded_labels)):
        pred = decoded_preds[k]
        label = decoded_labels[k]

        pred_list = pred.split(', ')
        label_list = label.split(', ')
        
        pred_list[0] = pred_list[0].split('are ')[-1]
        label_list[0] = label_list[0].split('are ')[-1]

        predicted.append(pred_list)
        true_label.append(label_list)
    
    
    # evaluation metrics
    for idx in range(len(true_label)):
        distractors = predicted[idx]
        labels = true_label[idx]
        
        data[idx]['pred_distractors'] = distractors

        act_set = set(labels)
        pred1_set = set(distractors[:1])
        pred3_set = set(distractors[:3])

        p_1 = len(act_set & pred1_set) / float(1)
        p_3 = len(act_set & pred3_set) / float(3)
        r_3 = len(act_set & pred3_set) / float(len(act_set))

        if p_3 == 0 and r_3 == 0:
            f1_3 = 0
        else:
            f1_3 = 2 * (p_3 * r_3 / (p_3 + r_3))
            
        data[idx]['metric'] = {'P@1': p_1, 'P@3': p_3, 'R@3': r_3, 'F1@3': f1_3}
        
    write_json(data, file_name)
    print(file_name + ' is saved :)')

In [45]:
test_labels

array([[  3, 834,  13, ...,   0,   0,   0],
       [  3, 834,  13, ...,   0,   0,   0],
       [  3, 834,  13, ...,   0,   0,   0],
       ...,
       [  3, 834,  13, ...,   0,   0,   0],
       [  3, 834,  13, ...,   0,   0,   0],
       [  3, 834,  13, ...,   0,   0,   0]])

In [46]:
save_data(test, test_predictions, test_labels, '/user_data/CTG/test_result/sciq_test_t5_text2text_off_shelf.json')

/user_data/CTG/test_result/sciq_test_t5_text2text_off_shell_8b.json is saved :)


Result

In [47]:
import json
def read_data(path):
    with open(path) as f:
        data = json.load(f)
    return data

In [48]:
test = read_data('/user_data/CTG/test_result/sciq_test_t5_text2text_off_shelf.json')

In [49]:
for i in range(0, 100, 7):
    example = test[i]
    sentence = example['question']
    answer = example['correct_answer']
    distractors = [example['distractor1'], example['distractor2'], example['distractor3']]
    pred_distractors = example['pred_distractors']
    metric = example['metric']
    
    print('question:', sentence.replace('**blank**', '_'))
    print('answer:', answer)
    print('distractors:', distractors)
    print('predict:', pred_distractors)
    print('metric:', metric)
    print()

question: Compounds that are capable of accepting electrons, such as o 2 or f2, are called what?
answer: oxidants
distractors: ['antioxidants', 'Oxygen', 'residues']
predict: ['oxides', 'catalysts', 'erodants']
metric: {'P@1': 0.0, 'P@3': 0.0, 'R@3': 0.0, 'F1@3': 0}

question: Which type of tree is dominant in temperate forests?
answer: deciduous
distractors: ['vines', 'fungus', 'shrubs']
predict: ['perennial', 'conifer', 'evergreen']
metric: {'P@1': 0.0, 'P@3': 0.0, 'R@3': 0.0, 'F1@3': 0}

question: Only about one percent of plants have lost what ability, turning them into consumers and even predators, instead of producers?
answer: photosynthesis
distractors: ['flowering', 'rooting', 'growth']
predict: ['reproduction', 'glycolysis', 'photosynthesis']
metric: {'P@1': 0.0, 'P@3': 0.0, 'R@3': 0.0, 'F1@3': 0}

question: Presence of a cell wall, large central vacuole, and organelles called plastids distinguish what type of cell?
answer: plant
distractors: ['animal', 'reproductive', 'hetero