1
+ from datasets import load_metric
2
+ from transformers import GPTNeoXTokenizerFast , GPTNeoXForCausalLM
3
+ import numpy as np
4
+ from rouge import Rouge
5
+ from statistics import geometric_mean
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+
10
+ def train_compute_metrics (pred ):
11
+
12
+ model = GPTNeoXForCausalLM .from_pretrained ('nlpai-lab/kullm-polyglot-12.8b-v2' )
13
+
14
+ logits = torch .tensor (pred .predictions .argmax (- 1 ).flatten (), dtype = torch .int64 )
15
+ logits = logits .unsqueeze (0 ) # torch.Size([1, 35200])
16
+
17
+ max_length = model .config .max_position_embeddings # 2048
18
+ stride = 1024
19
+ seq_len = logits .size (1 )
20
+
21
+ nlls = []
22
+ prev_end_loc = 0
23
+ for begin_loc in tqdm (range (0 , seq_len , stride )):
24
+ end_loc = min (begin_loc + max_length , seq_len )
25
+ trg_len = end_loc - prev_end_loc # 마지막 루프의 스트라이드 값과 다를 수 있음
26
+ input_ids = logits [:, begin_loc :end_loc ]
27
+ target_ids = input_ids .clone ()
28
+ target_ids [:, :- trg_len ] = - 100
29
+
30
+ with torch .no_grad ():
31
+ outputs = model (input_ids , labels = target_ids )
32
+
33
+ # 손실은 모든 유효한 레이블에 대한 평균값을 구하는 교차 엔트로피(cross entropy)로 계산됩니다.
34
+ # 나이브 베이지안 모델은 내부적으로 레이블을 왼쪽으로 1개씩 밀기 때문에, (타켓 - 1)개 만큼의 레이블에 대해 손실을 계산합니다.
35
+ neg_log_likelihood = outputs .loss
36
+
37
+ nlls .append (neg_log_likelihood )
38
+ prev_end_loc = end_loc
39
+ if end_loc == seq_len :
40
+ break
41
+
42
+ ppl = torch .exp (torch .stack (nlls ).mean ())
43
+
44
+ return {'perplexity' :ppl }
45
+
46
+
47
+ def test_compute_metrics (pred ):
48
+ tokenizer = GPTNeoXTokenizerFast .from_pretrained ('nlpai-lab/kullm-polyglot-12.8b-v2' )
49
+
50
+ # 사용할 metric을 불러옵니다.
51
+ metric_bleu = load_metric ("sacrebleu" )
52
+ metric_meteor = load_metric ("meteor" )
53
+ metric_rouge = Rouge (metrics = ["rouge-1" , "rouge-2" , "rouge-3" , "rouge-4" , "rouge-5" , "rouge-l" ])
54
+ metric_bertscore = load_metric ("bertscore" )
55
+
56
+ # 학습에서 산출된 pred를 preds(모델이 생성)와 label(정답 데이터)로 분리합니다.
57
+ preds = pred .predictions .argmax (- 1 )
58
+ labels = pred .label_ids
59
+ labels = np .where (pred .label_ids != - 100 , labels , tokenizer .pad_token_id )
60
+
61
+ scores = {
62
+ 'sacre_bleu' : [],
63
+ 'meteor' : [],
64
+ 'rouge_l_f1' : [],
65
+ 'bert_score_f1' : [],
66
+ }
67
+
68
+ for i in range (len (preds )):
69
+ decoded_preds = tokenizer .decode (preds [i ], skip_special_tokens = True )
70
+ decoded_labels = tokenizer .decode (labels [i ], skip_special_tokens = True )
71
+ if "### 응답:" in decoded_preds :
72
+ decoded_preds = decoded_preds .split ('### 응답:\n ' )[1 ][:- 1 ]
73
+
74
+ bleu_score = metric_bleu .compute (predictions = [decoded_preds ], references = [[decoded_labels ]])["score" ]
75
+ meteor_score = metric_meteor .compute (predictions = [decoded_preds ], references = [decoded_labels ])["meteor" ]
76
+ rouge_scores = metric_rouge .get_scores (decoded_preds , decoded_labels , avg = True )["rouge-l" ]['f' ]
77
+ bert_score = metric_bertscore .compute (predictions = [decoded_preds ], references = [decoded_labels ], lang = 'ko' )["f1" ][0 ]
78
+
79
+ scores ['sacre_bleu' ].append (bleu_score / 100 )
80
+ scores ['meteor' ].append (meteor_score )
81
+ scores ['rouge_l_f1' ].append (rouge_scores )
82
+ scores ['bert_score_f1' ].append (bert_score )
83
+
84
+ scores = {k : geometric_mean (v ) for k , v in scores .items ()}
85
+
86
+ return {k : round (v , 5 ) for k , v in scores .items ()}
0 commit comments