-
Notifications
You must be signed in to change notification settings - Fork 6
/
model-full.py
96 lines (78 loc) · 3.3 KB
/
model-full.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pandas as pd
import os
import random
import numpy as np
import torch
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping
import transformers
from transformers import TFAutoModel, AutoTokenizer
from tqdm.notebook import tqdm
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors, SentencePieceBPETokenizer
from tensorflow.keras import backend as K
def set_seed(seed=1):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
def regular_encode(texts, tokenizer, maxlen=512):
enc_di = tokenizer.batch_encode_plus(
texts,
return_attention_masks=False,
return_token_type_ids=False,
pad_to_max_length=True,
max_length=maxlen
)
return np.array(enc_di['input_ids'])
def f1(y_true, y_pred):
def recall(y_true, y_pred):
"""Recall metric.
Only computes a batch-wise average of recall.
Computes the recall, a metric for multi-label classification of
how many relevant items are selected.
"""
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
recall = true_positives / (possible_positives + K.epsilon())
return recall
def precision(y_true, y_pred):
"""Precision metric.
Only computes a batch-wise average of precision.
Computes the precision, a metric for multi-label classification of
how many selected items are relevant.
"""
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
precision = true_positives / (predicted_positives + K.epsilon())
return precision
precision = precision(y_true, y_pred)
recall = recall(y_true, y_pred)
return 2*((precision*recall)/(precision+recall+K.epsilon()))
def build_model(transformer, learning_rate=1e-5, max_len=512):
input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
sequence_output = transformer(input_word_ids)[0]
cls_token = sequence_output[:, 0, :]
cls_token = Dropout(0.2)(cls_token)
out = Dense(1, activation='sigmoid')(cls_token)
model = Model(inputs=input_word_ids, outputs=out)
model.compile(Adam(lr=learning_rate), loss='binary_crossentropy', metrics=[f1, 'accuracy'])
return model
def callback():
cb = []
reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss',
factor=0.5, patience=0,
verbose=1, mode='min',
epsilon=0.0001, min_lr=0,
restore_best_weights=True)
cb.append(reduceLROnPlat)
log = CSVLogger('log.csv')
cb.append(log)
es = EarlyStopping(monitor='val_loss', patience=4, verbose=0,
mode='min', restore_best_weights=True)
cb.append(es)
return cb