# Bart-DMLM(train-sciq-passage-level) Text2Text Generation on Sciq
使用 Sciq dataset訓練 Bart Distractor Generation<br>
直接使用 trainer 訓練 <br>

### GPU

In [1]:
!nvidia-smi

Fri Jun 16 07:10:37 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |


|   0  NVIDIA TITAN RTX    Off  | 00000000:09:00.0 Off |                  N/A |
| 39%   42C    P0    81W / 280W |      0MiB / 24217MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA TITAN RTX    Off  | 00000000:0A:00.0 Off |                  N/A |
| 33%   42C    P0    61W / 280W |      0MiB / 24220MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                        

In [2]:
project_name = "test on sciq with Bart"
import os

os.environ["WANDB_PROJECT"] = project_name
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

### import

In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration
import torch

  from .autonotebook import tqdm as notebook_tqdm


### Loading the dataset

In [2]:
import json

In [3]:
def read_data(item):
    path = '../../../../data/Sciq/sciq_{}.json'.format(item)
    with open(path) as f:
        data = json.load(f)
    return data

In [4]:
train = read_data('train')
valid = read_data('valid')
test = read_data('test')

In [5]:
train = list(train)
test = list(test)
valid = list(valid)

In [8]:
len(train), len(valid), len(test)

(11679, 1000, 1000)

In [9]:
train[0]

{'question': 'What type of organism is commonly used in preparation of foods such as cheese and yogurt?',
 'distractor3': 'viruses',
 'distractor1': 'protozoa',
 'distractor2': 'gymnosperms',
 'correct_answer': 'mesophilic organisms',
 'support': 'Mesophiles grow best in moderate temperature, typically between 25°C and 40°C (77°F and 104°F). Mesophiles are often found living in or on the bodies of humans or other animals. The optimal growth temperature of many pathogenic mesophiles is 37°C (98°F), the normal human body temperature. Mesophilic organisms have important uses in food preparation, including cheese, yogurt, beer and wine.'}

### Prepare data

In [6]:
def processData(data):
    
    sentences = []
    labels = []
    answers = []
    for d in data:
        sentence = d['question']
        distractors = [d['distractor1'], d['distractor2'], d['distractor3']]
        answer = d['correct_answer']
        
        # 避免dataset的label有空白
        distractors = [dis.strip() for dis in distractors]
        
        sentences.append(sentence)
        labels.append('_ of distractors are ' + ', '.join(distractors))
        answers.append(answer)
        
    return sentences, answers, labels

In [7]:
train_sent, train_answer, train_label = processData(train)
valid_sent, valid_answer, valid_label = processData(valid)
test_sent, test_answer, test_label = processData(test)

In [12]:
print(test_label[0])

_ of distractors are antioxidants, Oxygen, residues


In [13]:
for l in test_label:
    if 'ultraviolet light' in l:
        print(l)

_ of distractors are invisible light, sunlight, ultraviolet light


In [14]:
len(train_sent), len(train_answer), len(train_label)

(11679, 11679, 11679)

In [15]:
for idx in range(2):
    print(train_sent[idx])
    print(train_answer[idx])
    print(train_label[idx])
    print()

What type of organism is commonly used in preparation of foods such as cheese and yogurt?
mesophilic organisms
_ of distractors are protozoa, gymnosperms, viruses

What phenomenon makes global winds blow northeast to southwest or the reverse in the northern hemisphere and northwest to southeast or the reverse in the southern hemisphere?
coriolis effect
_ of distractors are muon effect, centrifugal effect, tropical effect



In [8]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [9]:
train_encodings = tokenizer(train_sent, train_answer,truncation=True, padding=True)
valid_encodings = tokenizer(valid_sent, valid_answer,truncation=True, padding=True)
test_encodings = tokenizer(test_sent, test_answer,truncation=True, padding=True)

In [18]:
train_encodings.keys()

dict_keys(['input_ids', 'attention_mask'])

In [19]:
print(train_encodings.input_ids[0])

[0, 2264, 1907, 9, 33993, 16, 10266, 341, 11, 7094, 9, 6592, 215, 25, 7134, 8, 24351, 116, 2, 2, 12579, 6673, 22586, 28340, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [20]:
tokenizer.decode(train_encodings.input_ids[0])

'<s>What type of organism is commonly used in preparation of foods such as cheese and yogurt?</s></s>mesophilic organisms</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

In [10]:
def add_labels(encodings, distractors):
    
    distractors_encodings = tokenizer(distractors, padding=True)
    labels = []
    for i in range(len(distractors_encodings.input_ids)):
        labels.append(distractors_encodings.input_ids[i])
    
    encodings["labels"] = labels
    return encodings

In [11]:
train_encodings = add_labels(train_encodings, train_label)
valid_encodings = add_labels(valid_encodings, valid_label)
test_encodings = add_labels(test_encodings, test_label)

In [23]:
len(train_encodings.input_ids)

11679

In [12]:
class SciqDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = SciqDataset(train_encodings)
valid_dataset = SciqDataset(valid_encodings)
test_dataset = SciqDataset(test_encodings)

In [25]:
len(train_dataset), len(valid_dataset), len(test_dataset)

(11679, 1000, 1000)

### Fine-tuning

In [26]:
from transformers import BartForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch

model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
model.resize_token_embeddings(len(tokenizer))

Embedding(50265, 768, padding_idx=1)

In [27]:
model_dict = torch.nn.ModuleDict({
    'model': model,
})
checkpoint = torch.load('/user_data/Cloze/dtt_mask_lm_model/bart/sciq_train_3dtt_passage_level_12/checkpoints/epoch=03-dev_loss=0.08.ckpt')
model_dict.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [28]:
batch_size = 16
args = Seq2SeqTrainingArguments(
    output_dir = "./results-3",
    save_strategy = "epoch",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=50,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="P@1",
    weight_decay=0.01,
    predict_with_generate=True,
    eval_accumulation_steps = 1,
    report_to="wandb" if os.getenv("WANDB_PROJECT") else "none"
)

In [29]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [21]:
import numpy as np
def compute_metrics(p):
    predictions, labels = p
    
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # store all article
    predicted = []
    true_label = []
    
    for k in range(len(decoded_labels)):
        pred = decoded_preds[k]
        label = decoded_labels[k]

        pred_list = pred.split(', ')
        label_list = label.split(', ')
        
        pred_list[0] = pred_list[0].split('are ')[-1]
        label_list[0] = label_list[0].split('are ')[-1]

        predicted.append(pred_list)
        true_label.append(label_list)

    # evaluation metrics
    p1 = 0
    p3 = 0
    r3 = 0
    f3 = 0
    for idx in range(len(true_label)):
        distractors = predicted[idx]
        labels = true_label[idx]

        act_set = set(labels)
        pred1_set = set(distractors[:1])
        pred3_set = set(distractors[:3])

        p_1 = len(act_set & pred1_set) / float(1)
        p_3 = len(act_set & pred3_set) / float(3)
        r_3 = len(act_set & pred3_set) / float(len(act_set))

        if p_3 == 0 and r_3 == 0:
            f1_3 = 0
        else:
            f1_3 = 2 * (p_3 * r_3 / (p_3 + r_3))

        p1+=p_1
        p3+=p_3
        r3+=r_3
        f3+=f1_3

    avg_p1 = p1 / len(true_label)
    avg_p3 = p3 / len(true_label)
    avg_r3 = r3 / len(true_label)
    avg_f3 = f3 / len(true_label)

    result = {'P@1': avg_p1,
              'P@3': avg_p3,
              'R@3': avg_r3,
              'F1@3': avg_f3}
    
    return result

In [31]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [32]:
trainer.train()

***** Running training *****
  Num examples = 11679
  Num Epochs = 50
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 18250
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mms0004284[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,P@1,P@3,R@3,F1@3
1,No log,0.672532,0.113,0.091,0.090619,0.090644
2,0.555600,0.662672,0.155,0.103333,0.103119,0.103044
3,0.463800,0.652122,0.163,0.108,0.107619,0.107644
4,0.463800,0.652413,0.174,0.115667,0.115452,0.115378
5,0.429400,0.655532,0.176,0.117333,0.1175,0.117311
6,0.400700,0.657002,0.19,0.121,0.120786,0.120711
7,0.380000,0.665266,0.187,0.128333,0.128119,0.128044
8,0.380000,0.663973,0.178,0.129333,0.1295,0.129311
9,0.357900,0.674554,0.181,0.13,0.129786,0.129711
10,0.338200,0.679074,0.2,0.133,0.132786,0.132711


***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16
Saving model checkpoint to ./results-3/checkpoint-365
Configuration saved in ./results-3/checkpoint-365/config.json
Model weights saved in ./results-3/checkpoint-365/pytorch_model.bin
tokenizer config file saved in ./results-3/checkpoint-365/tokenizer_config.json
Special tokens file saved in ./results-3/checkpoint-365/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16
Saving model checkpoint to ./results-3/checkpoint-730
Configuration saved in ./results-3/checkpoint-730/config.json
Model weights saved in ./results-3/checkpoint-730/pytorch_model.bin
tokenizer config file saved in ./results-3/checkpoint-730/tokenizer_config.json
Special tokens file saved in ./results-3/checkpoint-730/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16
Saving model checkpoint to ./results-3/checkpoint-1095
Configuration saved in ./results-3/checkpoin

TrainOutput(global_step=18250, training_loss=0.2197168498627127, metrics={'train_runtime': 8239.634, 'train_samples_per_second': 70.871, 'train_steps_per_second': 2.215, 'total_flos': 3.5814186809856e+16, 'train_loss': 0.2197168498627127, 'epoch': 50.0})

In [33]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16


{'eval_loss': 0.735399067401886,
 'eval_P@1': 0.21,
 'eval_P@3': 0.14833333333333337,
 'eval_R@3': 0.14866666666666672,
 'eval_F1@3': 0.14837777777777783,
 'eval_runtime': 31.7818,
 'eval_samples_per_second': 31.465,
 'eval_steps_per_second': 1.007,
 'epoch': 50.0}

In [34]:
trainer.save_model('/user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq/bart-base-text2text-sciq-pretrain-on-sciq-train-passage-level-e3')

Saving model checkpoint to /user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq/bart-base-text2text-sciq-pretrain-on-sciq-train-passage-level-e3
Configuration saved in /user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq/bart-base-text2text-sciq-pretrain-on-sciq-train-passage-level-e3/config.json
Model weights saved in /user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq/bart-base-text2text-sciq-pretrain-on-sciq-train-passage-level-e3/pytorch_model.bin
tokenizer config file saved in /user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq/bart-base-text2text-sciq-pretrain-on-sciq-train-passage-level-e3/tokenizer_config.json
Special tokens file saved in /user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq/bart-base-text2text-sciq-pretrain-on-sciq-train-passage-level-e3/special_tokens_map.json


In [35]:
predictions, labels, metrics = trainer.predict(valid_dataset)
print('valid: ')
metrics

***** Running Prediction *****
  Num examples = 1000
  Batch size = 16


valid: 


{'test_loss': 0.735399067401886,
 'test_P@1': 0.21,
 'test_P@3': 0.14833333333333337,
 'test_R@3': 0.14866666666666672,
 'test_F1@3': 0.14837777777777783,
 'test_runtime': 30.2363,
 'test_samples_per_second': 33.073,
 'test_steps_per_second': 1.058}

In [36]:
predictions, labels, metrics = trainer.predict(test_dataset)
print('test: ')
metrics

***** Running Prediction *****
  Num examples = 1000
  Batch size = 16


test: 


{'test_loss': 0.8131418228149414,
 'test_P@1': 0.215,
 'test_P@3': 0.16333333333333347,
 'test_R@3': 0.16358730158730167,
 'test_F1@3': 0.1633000000000001,
 'test_runtime': 29.1814,
 'test_samples_per_second': 34.268,
 'test_steps_per_second': 1.097}

In [37]:
stop

NameError: name 'stop' is not defined

In [13]:
import json
def write_json(data, path):
    
    jsonString = json.dumps(data)
    jsonFile = open(path, "w")
    jsonFile.write(jsonString)
    jsonFile.close()

In [14]:
def save_data(data, predictions, labels, file_name):
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # store all article
    predicted = []
    true_label = []
    
    for k in range(len(decoded_labels)):
        pred = decoded_preds[k]
        label = decoded_labels[k]

        pred_list = pred.split(', ')
        label_list = label.split(', ')
        
        pred_list[0] = pred_list[0].split('are ')[-1]
        label_list[0] = label_list[0].split('are ')[-1]

        predicted.append(pred_list)
        true_label.append(label_list)
    
    
    # evaluation metrics
    for idx in range(len(true_label)):
        distractors = predicted[idx]
        labels = true_label[idx]
        
        data[idx]['pred_distractors'] = distractors

        act_set = set(labels)
        pred1_set = set(distractors[:1])
        pred3_set = set(distractors[:3])

        p_1 = len(act_set & pred1_set) / float(1)
        p_3 = len(act_set & pred3_set) / float(3)
        r_3 = len(act_set & pred3_set) / float(len(act_set))

        if p_3 == 0 and r_3 == 0:
            f1_3 = 0
        else:
            f1_3 = 2 * (p_3 * r_3 / (p_3 + r_3))
            
        data[idx]['metric'] = {'P@1': p_1, 'P@3': p_3, 'R@3': r_3, 'F1@3': f1_3}
        
    write_json(data, file_name)
    print(file_name + ' is saved :)')

In [24]:
save_data(test, test_predictions, test_labels, '/user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq_test_retsult/sciq_test_bart_pretrain_on_sciq_train_passage_level.json')

/user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq_test_retsult/sciq_test_bart_pretrain_on_sciq_train_passage_level.json is saved :)


In [16]:
from transformers import BartTokenizer, BartForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch

tokenizer = BartTokenizer.from_pretrained("/user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq/bart-base-text2text-sciq-pretrain-on-sciq-train-passage-level-e3")
model = BartForConditionalGeneration.from_pretrained("/user_data/CTG/train/DG/Sciq/Bart_sciq_train_passage_level/sciq/bart-base-text2text-sciq-pretrain-on-sciq-train-passage-level-e3")

In [17]:
batch_size = 64
args = Seq2SeqTrainingArguments(
    output_dir = "./results",
    save_strategy = "epoch",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=50,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="P@1",
    weight_decay=0.01,
    predict_with_generate=True,
    eval_accumulation_steps = 1,
)

In [18]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [22]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [23]:
test_predictions, test_labels, test_metrics = trainer.predict(test_dataset)
test_metrics

***** Running Prediction *****
  Num examples = 1000
  Batch size = 64


{'test_loss': 0.8056509494781494,
 'test_P@1': 0.215,
 'test_P@3': 0.16333333333333347,
 'test_R@3': 0.16358730158730167,
 'test_F1@3': 0.1633000000000001,
 'test_runtime': 41.9636,
 'test_samples_per_second': 23.83,
 'test_steps_per_second': 0.191}

Result

In [None]:
import json
def read_data(path):
    with open(path) as f:
        data = json.load(f)
    return data

In [None]:
test = read_data('/user_data/CTG/test_result/sciq_test_t5_text2text_pretrain_on_sciq_training_set_passage_level_600000.json')

In [None]:
for i in range(0, 100, 7):
    example = test[i]
    sentence = example['question']
    answer = example['correct_answer']
    distractors = [example['distractor1'], example['distractor2'], example['distractor3']]
    pred_distractors = example['pred_distractors']
    metric = example['metric']
    
    print('question:', sentence.replace('**blank**', '_'))
    print('answer:', answer)
    print('distractors:', distractors)
    print('predict:', pred_distractors)
    print('metric:', metric)
    print()

question: Compounds that are capable of accepting electrons, such as o 2 or f2, are called what?
answer: oxidants
distractors: ['antioxidants', 'Oxygen', 'residues']
predict: ['oxides', 'carbonates', 'soils']
metric: {'P@1': 0.0, 'P@3': 0.0, 'R@3': 0.0, 'F1@3': 0}

question: Which type of tree is dominant in temperate forests?
answer: deciduous
distractors: ['vines', 'fungus', 'shrubs']
predict: ['perennial', 'conifer', 'annual']
metric: {'P@1': 0.0, 'P@3': 0.0, 'R@3': 0.0, 'F1@3': 0}

question: Only about one percent of plants have lost what ability, turning them into consumers and even predators, instead of producers?
answer: photosynthesis
distractors: ['flowering', 'rooting', 'growth']
predict: ['germination', 'death', 'reproduction']
metric: {'P@1': 0.0, 'P@3': 0.0, 'R@3': 0.0, 'F1@3': 0}

question: Presence of a cell wall, large central vacuole, and organelles called plastids distinguish what type of cell?
answer: plant
distractors: ['animal', 'reproductive', 'heterotroph']
predi