In [None]:
batch_size = 32
collator = DataCollatorWithPadding(tokenizer, return_tensors = "tf")

def tokenizing(inputs, prem_max_len, hypo_max_len, training) :
    model_inputs = tokenizer(inputs['premise_chunk'], inputs["hypo_chunk"])
    
    tokenized = model_inputs.input_ids

    indices = []
    
    for t in tokenized :
        t = np.array(t)
        indices.append(np.where((t != 0) & (t != 2) & (t != 32000))[0])
    
    chunk_index = []
    
    for idx in indices :
        temp = [1]
        for i in range(len(idx) - 1) :
            if idx[i] + 1 != idx[i + 1] :
                temp.append(idx[i])
                temp.append(idx[i + 1])
        temp.append(idx[-1])
        temp = temp + [-1 for i in range(idx_max_len * 2 - len(temp))]
        temp  = np.array(temp).reshape(-1, 2).tolist()
        chunk_index.append(temp)
    
    model_inputs["chunk_index"] = chunk_index
    
    prem_gov_idx = []
    prem_tag_info = []
    
    for i in range(len(inputs["premise_rel_gov_idx"])) :
        idx = inputs["premise_rel_gov_idx"][i]
        tag = inputs["premise_chunk_tag"][i]
        
        prem_gov_idx.append(idx + [-1 for idx in range(prem_max_len - len(idx))])
        temp = tag + ["padding" for tag in range(prem_max_len - len(tag))]
        prem_tag_info.append(tag_labelr.transform(temp))

    model_inputs["prem_gov_idx"] = prem_gov_idx
    model_inputs["prem_tag"] = prem_tag_info
    
    
    hypo_gov_idx = []
    hypo_tag_info = []
    
    for i in range(len(inputs["hypo_rel_gov_idx"])) :
        idx = inputs["hypo_rel_gov_idx"][i]
        tag = inputs["hypo_chunk_tag"][i]
        
        hypo_gov_idx.append(idx + [-1 for idx in range(hypo_max_len - len(idx))])
        temp = tag + ["padding" for tag in range(hypo_max_len - len(tag))]
        hypo_tag_info.append(tag_labelr.transform(temp))

    model_inputs["hypo_gov_idx"] = hypo_gov_idx
    model_inputs["hypo_tag"] = hypo_tag_info
    
    if training :
        model_inputs["labels"] = inputs["label"]
        

    return model_inputs

def get_dataset(inputs, collator, batch_size, idx_max_len, training) :
    inputs = datasets.Dataset.from_pandas(inputs)
    tokenized_inputs = inputs.map(tokenizing,
                                  batched = True,
                                  fn_kwargs = {"training" : training,
                                               "prem_max_len" : premise_max_len,
                                               "hypo_max_len" : hypo_max_len})

    columns = ["input_ids", "attention_mask", "token_type_ids", "chunk_index", "prem_gov_idx", "hypo_gov_idx", "prem_tag", "hypo_tag"]
    
    if training :
        inputs_data = tokenized_inputs.to_tf_dataset(
            batch_size = batch_size,
            columns = columns,
            shuffle = True,
            collate_fn = collator,
            label_cols = "labels",
            drop_remainder = False
        )
    else :
        inputs_data = tokenized_inputs.to_tf_dataset(
            batch_size = batch_size,
            columns = columns,
            shuffle = True,
            collate_fn = collator,
            drop_remainder = False
        )
        
    return inputs_data




def chunk_mean(premise_len, hypo_len, x, chunk_idx, sep_idx) :
        premise = tf.TensorArray(tf.float32, size = 0, dynamic_size = True)
        hypothesis = tf.TensorArray(tf.float32, size = 0, dynamic_size = True)
        for batch in range(len(chunk_idx)) :
            temp_premise = tf.TensorArray(tf.float32, size = 0, dynamic_size = True)
            temp_hypothesis = tf.TensorArray(tf.float32, size = 0, dynamic_size = True)
            for idx in chunk_idx[batch] :
                if tf.reduce_sum(idx) > 0 :
                    chunks = x[batch, idx[0] : idx[1], :] if idx[0] != idx[1] else tf.expand_dims(x[batch, idx[0], :], axis = 0)
                    if sep_idx[batch, 1] > tf.cast(idx[1], tf.int64) :
                        temp_premise = temp_premise.write(temp_premise.size(), tf.reduce_mean(chunks, axis = 0))
                    else :
                        temp_hypothesis = temp_hypothesis.write(temp_hypothesis.size(), tf.reduce_mean(chunks, axis = 0))
                else :
                    curr_len = temp_premise.size()
                    for i in range(premise_len - curr_len) :
                        temp_premise = temp_premise.write(temp_premise.size(), tf.zeros_like(x[batch, 0, :]))

                    curr_len = temp_hypothesis.size()
                    for i in range(hypo_len - curr_len) :
                        temp_hypothesis = temp_hypothesis.write(temp_hypothesis.size(), tf.zeros_like(x[batch, 0, :]))

            premise = premise.write(premise.size(), temp_premise.stack())
            hypothesis = hypothesis.write(hypothesis.size(), temp_hypothesis.stack())
        premise = premise.stack()
        hypothesis = hypothesis.stack()
        return premise, hypothesis
        
def syntax_struct_info_vec(w, gov_idx, tag) :
    t = tf.TensorArray(tf.float32, size = 0, dynamic_size = True)
    for batch in range(len(gov_idx)) :
        temp = tf.TensorArray(tf.float32, size = 0, dynamic_size = True)
        for idx in range(len(gov_idx[batch])) :
            if gov_idx[batch][idx] >= 0 :
                temp = temp.write(temp.size(), w[batch][idx] + w[batch][gov_idx[batch][idx] + idx] + tag[batch][idx])                    
            else :
                pad = tf.zeros_like(w[batch][idx])
                temp = temp.write(temp.size(), pad)
        temp = temp.stack()
        t = t.write(t.size(), temp)
    t = t.stack()
    return t

        
class SyntaxStructureEmbedding(keras.layers.Layer) :
    def __init__(self, input_dim, **kwargs) :
        super(SyntaxStructureEmbedding, self).__init__()
        self.embedder = keras.layers.Embedding(input_dim = input_dim, output_dim = 768)
        
    def call(self, x) :
        return self.embedder(x)
    

class InfoConnectingLayer(keras.layers.Layer) :
    def __init__(self, input_shape, rnn_size, **kwargs) :
        super(InfoConnectingLayer, self).__init__()
        
        self.bias_shape = input_shape
        self.u = keras.layers.Dense(self.bias_shape, activation = None, use_bias = False)
        self.v_w = keras.layers.Dense(768, activation = None, use_bias = False)
        self.v_t = keras.layers.Dense(768, activation = None, use_bias = False)
        
        self.bi_lstm = keras.layers.Bidirectional(keras.layers.LSTM(rnn_size))
    
    def build(self, input_shape) :
        self.bias = self.add_weight("bias",
                                    shape = (self.bias_shape, 768))
    
    def call(self, w, t) :
        t = self.u(t)
        term_1 = tf.matmul(t, w)
        term_2 = self.v_w(w)
        term_3 = self.v_t(t)
        
        b = term_1 + term_2 + term_3 + self.bias
        
        return self.bi_lstm(b)
    
class ClassifierLayer(keras.layers.Layer) :
    def __init__(self, rnn_size, n_class, **kwargs) :
        super(ClassifierLayer, self).__init__()
        self.u = keras.layers.Dense(rnn_size, activation = None, use_bias = False)
        self.v_prem = keras.layers.Dense(rnn_size, activation = None, use_bias = False)
        self.v_hypo = keras.layers.Dense(rnn_size, activation = None, use_bias = False)
        self.outputs = keras.layers.Dense(n_class, activation = "softmax", use_bias = False)
        
    def call(self, h_prem, h_hypo) :
        term_1 = self.u(tf.expand_dims(h_prem, 2))
        term_1 = tf.matmul(term_1, tf.expand_dims(h_hypo, 2))
        term_1 = keras.layers.Flatten()(term_1)
        term_2 = self.v_prem(h_prem)
        term_3 = self.v_hypo(h_hypo)
        l = term_1 + term_2 + term_3
        return self.outputs(l)
    
class SyntaxRoBERTa(keras.models.Model) :
    def __init__(self, backbone, rnn_size, prem_max_len, hypo_max_len, **kargs) :
        super(SyntaxRoBERTa, self).__init__()
        
        self.backbone = backbone
        self.prem_tag_embedder = SyntaxStructureEmbedding(prem_max_len)
        self.hypo_tag_embedder = SyntaxStructureEmbedding(hypo_max_len)
        self.premise_info = InfoConnectingLayer(prem_max_len, rnn_size)
        self.hypo_info = InfoConnectingLayer(hypo_max_len, rnn_size)
        self.classifier = ClassifierLayer(rnn_size * 2, 3)
        
        self.prem_len = prem_max_len
        self.hypo_len = hypo_max_len
        
    def call(self, x) :
        backbone_input = {k : v for k, v in x.items() if k in ["input_ids", "token_type_ids", "attention_mask"]}
        chunk_idx = x["chunk_index"]
        prem_gov_idx = x["prem_gov_idx"]
        hypo_gov_idx = x["hypo_gov_idx"]
        prem_tag = x["prem_tag"]
        hypo_tag = x["hypo_tag"]
        
        prem_tag_emb = self.prem_tag_embedder(prem_tag)
        hypo_tag_emb = self.hypo_tag_embedder(hypo_tag)
        
        seq_emb = self.backbone(backbone_input).last_hidden_state
        p, h = chunk_mean(self.prem_len, self.hypo_len, seq_emb, chunk_idx, tf.where(x["input_ids"] == 2)[0::2])

        p_t = syntax_struct_info_vec(p, prem_gov_idx, prem_tag_emb)
        h_t = syntax_struct_info_vec(h, hypo_gov_idx, hypo_tag_emb)
        
        p_i = self.premise_info(p, p_t)
        h_i = self.hypo_info(h, h_t)
        res = self.classifier(p_i, h_i)
        
        return res
    
batch_size = 2

train_data = get_dataset(chunk_data, collator, batch_size, idx_max_len, True)
train_data = train_data.map(lambda x, y : ({k : tf.cast(v, tf.int32) for k, v in x.items()}, tf.cast(y, tf.int32)),
                            num_parallel_calls = tf.data.experimental.AUTOTUNE).prefetch(tf.data.experimental.AUTOTUNE)


roberta = TFRobertaModel.from_pretrained("klue/roberta-base", from_pt = True)
roberta.resize_token_embeddings(len(tokenizer))
    
model = SyntaxRoBERTa(roberta, 64, premise_max_len, hypo_max_len)
loss_fn = keras.losses.sparse_categorical_crossentropy
epochs = 5

optimizer = keras.optimizers.Adam(learning_rate = 1e-5)

total_loss = []
total_acc = []

################## 훈련 루프 ###################

for epoch in range(epochs) :
        
    cum_loss = deque(maxlen = 20)
    cum_acc = deque(maxlen = 20)
    
    batch_loss = []
    batch_acc = []
    
    with tqdm(train_data, unit = "batch") as tepoch :
        for step, (x, y) in enumerate(tepoch) :
            tepoch.set_description(f"Epoch {epoch}")
            with tf.GradientTape() as t :
                y_hat = model(x)
                loss = loss_fn(y, y_hat)
            dz_dx = t.gradient(loss, model.trainable_weights,
                               unconnected_gradients = tf.UnconnectedGradients.ZERO)
            optimizer.apply_gradients(zip(dz_dx, model.trainable_weights))
            
            curr_loss = float(tf.reduce_mean(loss))
            curr_acc = float(keras.metrics.categorical_accuracy(y, tf.argmax(y_hat, axis = 1)))
            
            batch_loss.append(curr_loss)
            batch_acc.append(curr_acc)
            
            cum_loss.append(curr_loss)
            cum_acc.append(curr_acc)
            
            tepoch.set_postfix(loss = sum(cum_loss) / len(cum_loss),
                               accuracy = sum(cum_acc) / len(cum_acc))      
    total_loss.append(batch_loss)
    total_acc.append(batch_acc)

In [None]:
parser = Parser(API.KKMA)
def make_chunk_pair(parser, sentence, chunk_token = "[WORD]") :
    analyzed = parser(sentence)
    
    chunk = []
    tag_info = []
    n_words = []
    gov_id = []
    chunk_length = []

    rel_gov_idx = []

    for an in analyzed :
        tmp_chunk = []
        tmp_gov_id = []
        tmp_idx = []
        if len(an.words) > 1 :
            words_a = an.words[:-1]
            words_b = an.words[1:]

            for f, s in zip(words_a, words_b) :
                if (f.governorEdge.depType in ["CMP", "MOD", "AJT"]) and (s.governorEdge.depType in ["CMP", "MOD", "AJT"]) or \
                   (f.governorEdge.type == s.governorEdge.type)  and (f.governorEdge.type != 'X' or s.governorEdge.type != 'X') :
                    tmp_chunk.append(f.surface)
                else :
                    tmp_chunk.append(f.surface)
                    tmp_chunk.append(chunk_token)
                    tmp_gov_id.append(f.governorEdge if f.governorEdge.src else 0)
                    tag_info.append(f.governorEdge.type + '-' + f.governorEdge.depType if f.governorEdge.depType else f.governorEdge.type)

            tmp_chunk.append(s.surface)
            tmp_chunk.append(chunk_token)
            tag_info.append(s.governorEdge.type + '-' + s.governorEdge.depType)
            tmp_gov_id.append(s.governorEdge if s.governorEdge.src else 0)

        else :
            tmp_chunk.append(an.words[0].surface)
            tmp_chunk.append(chunk_token)
            tmp_gov_id.append(0)
            tag_info.append(an.words[0].governorEdge.type + '-' + an.words[0].governorEdge.depType if an.words[0].governorEdge.depType else an.words[0].governorEdge.type)

        tmp_words = np.cumsum(np.array(tmp_chunk) == chunk_token).tolist()
        tmp_words = [tmp_words[i] for i in range(len(tmp_words)) if tmp_chunk[i] != chunk_token]

        chunk_length.append(len(chunk[-1]) if chunk else 0)
        n_words.append(tmp_words)
        chunk.append(tmp_chunk)
        to_get_rel_idx = tmp_gov_id.copy()

        for idx in range(len(to_get_rel_idx)) :
            if to_get_rel_idx[idx] != 0 :
                edge = to_get_rel_idx[idx].src.id
                tmp_gov_id[idx] = edge + tmp_words[edge] + sum(chunk_length)
                tmp_idx.append(tmp_words[edge] - tmp_words[to_get_rel_idx[idx].dest.id])
            else :
                tmp_idx.append(0)
        gov_id.append(tmp_gov_id)
        rel_gov_idx.append(tmp_idx)

    gov_id = [ids for sentence in gov_id for ids in sentence]
    n_words = [n for sentence in n_words for n in sentence]
    chunk = [c for sentence in chunk for c in sentence]
    rel_gov_idx = [g for count in rel_gov_idx for g in count]
    
    result = {"chunk" : chunk,
              "chunk_tag" : tag_info,
              "gov_idx" : gov_id,
              "rel_gov_idx" : rel_gov_idx}
    
    return result

In [None]:
#parser = Parser(API.KKMA)
from konlpy.tag import Okt
sentence = original.premise[41]
analyzed = parser(sentence)
chunk_token = '[WORD]'
    
chunk = []
tag_info = []
n_words = []
gov_id = []
chunk_length = []

rel_gov_idx = []

for an in analyzed :
    tmp_chunk = []
    tmp_gov_id = []
    tmp_idx = []



    if len(an.words) > 1 :
        words_a = an.words[:-1]
        words_b = an.words[1:]

        for f, s in zip(words_a, words_b) :
            if (f.governorEdge.depType in ["CMP", "MOD", "AJT"]) and (s.governorEdge.depType in ["CMP", "MOD", "AJT"]) or \
               (f.governorEdge.type == s.governorEdge.type)  and (f.governorEdge.type != 'X' or s.governorEdge.type != 'X') :
                tmp_chunk.append(f.surface)
            else :
                tmp_chunk.append(f.surface)
                tmp_chunk.append(chunk_token)
                tmp_gov_id.append(f.governorEdge if f.governorEdge.src else 0)
                tag_info.append(f.governorEdge.type + '-' + f.governorEdge.depType if f.governorEdge.depType else f.governorEdge.type)

        tmp_chunk.append(s.surface)
        tmp_chunk.append(chunk_token)
        tag_info.append(s.governorEdge.type + '-' + s.governorEdge.depType)
        tmp_gov_id.append(s.governorEdge if s.governorEdge.src else 0)

    else :
        tmp_chunk.append(an.words[0].surface)
        tmp_chunk.append(chunk_token)
        tmp_gov_id.append(0)
        tag_info.append(an.words[0].governorEdge.type + '-' + an.words[0].governorEdge.depType if an.words[0].governorEdge.depType else an.words[0].governorEdge.type)


    tmp_words = np.cumsum(np.array(tmp_chunk) == chunk_token).tolist()
    tmp_words = [tmp_words[i] for i in range(len(tmp_words)) if tmp_chunk[i] != chunk_token]

    chunk_length.append(len(chunk[-1]) if chunk else 0)
    n_words.append(tmp_words)
    chunk.append(tmp_chunk)
    to_get_rel_idx = tmp_gov_id.copy()


    for idx in range(len(to_get_rel_idx)) :
        if to_get_rel_idx[idx] != 0 :
            edge = to_get_rel_idx[idx].src.id
            tmp_gov_id[idx] = edge + tmp_words[edge] + sum(chunk_length)
            tmp_idx.append(tmp_words[edge] - tmp_words[to_get_rel_idx[idx].dest.id])
        else :
            tmp_idx.append(0)
    gov_id.append(tmp_gov_id)
    rel_gov_idx.append(tmp_idx)

gov_id = [ids for sentence in gov_id for ids in sentence]
n_words = [n for sentence in n_words for n in sentence]
chunk = [c for sentence in chunk for c in sentence]
rel_gov_idx = [g for count in rel_gov_idx for g in count]


    #####추가#####
tagger = Okt()
okt = tagger.morphs(sentence)
copy = chunk.copy()

ori_list = np.where(np.array(chunk) == "[WORD]")[0]
res_list = ori_list - 1


chunk_mark = []
idx_mark = []
cur_n_word = 0


for j in range(len(res_list)) :
    temp = []
    cur_n_word -= 1
    
    if cur_n_word < 1 :
        list_test.pop()
        
        cur_idx = j
        cur_n_word = rel_gov_idx[j]
            
    for i in range(len(okt)) :
        if okt[i] in chunk[res_list[j]] :
            temp.append(1)
        else : temp.append(0)
        
    if sum(temp) == 0 :
        mark.append(ori_list[j])
        idx_mark.append(j)
        
        rel_gov_idx[cur_idx] -= rel_gov_idx[j]

rel_gov_idx = [rel_gov_idx[i] for i in range(len(rel_gov_idx)) if i not in idx_mark]
tag_info = [tag_info[i] for i in range(len(tag_info)) if i not in idx_mark]
chunk = [chunk[i] for i in range(len(chunk)) if i not in chunk_mark]
        [idx_mark.append(l[0]) for l in list_test]
