In [1]:
import os

os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#os.environ["CUDA_LAUNCH_BLOCKING"] = '1'
os.environ['WANDB_DISABLED'] = 'True'

In [2]:
import torch
import re
import random
from transformers import ElectraForTokenClassification, AutoConfig, AutoTokenizer
from transformers import AdamW
import pandas as pd
from torch.utils.data import Dataset as DS
from datasets import Dataset
from transformers import DefaultDataCollator
from transformers import Trainer, TrainingArguments
from seqeval.metrics import classification_report
from sklearn.metrics import f1_score
import numpy as np
from datasets import load_metric

In [3]:
from seqeval.scheme import IOB1, IOB2, IOE1, IOE2, IOBES, BILOU, Entities, Prefix, Tag
from seqeval.scheme import IOBES

In [4]:
permutations = [['ORG', 'PER', 'CVL', 'DAT', 'LOC', 'QNT'],
               ['DAT', 'QNT', 'PER', 'LOC', 'ORG', 'CVL'],
               ['CVL', 'LOC', 'ORG', 'QNT', 'DAT', 'PER'],
               ['QNT', 'ORG', 'DAT', 'PER', 'CVL', 'LOC'],
               ['LOC', 'CVL', 'QNT', 'ORG', 'PER', 'DAT'],
               ['PER', 'DAT', 'LOC', 'CVL', 'QNT', 'ORG']]

In [5]:
class NERDataset(DS):
    def __init__(self, dataset, tokenizer, label2id, max_length):
        super().__init__()
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_length = max_length
        self.dataset = dataset
        self.data = []
        
        for i in range(len(self.dataset['label'])):
            text = dataset['label'][i]
            tagged_words = re.findall('<.*?:.*?>', text)
            
            word2ids = dict()
            for tagged_word in tagged_words:
                tag_splited = tagged_word.strip('<>').split(':')
                tag = tag_splited[-1]
                word = ':'.join(tag_splited[:-1])
                    
                word_tok = self.tokenizer.encode(word)[1:-1]
                if word not in word2ids:

                    label_id = [self.label2id['B-'+tag]]
                    if len(word_tok) > 1 :
                        label_id.extend([self.label2id['I-'+tag]] * (len(word_tok)-1))
                    word2ids[word] = {
                        'target_ids': word_tok,
                        'label_id': label_id 
                    }
                    text = text.replace(tagged_word, word)

            tokenized = self.tokenizer(text, truncation=True, max_length=self.max_length, padding='max_length')
            if 0 in tokenized['input_ids']:
                tok_length = tokenized['input_ids'].index(0)
            else:
                tok_length = self.max_length
            label_input = tokenized['input_ids'][:tok_length]
            labels = self._gen_labels(label_input, word2ids)
            labels.insert(0, -100)
            pad = [-100] * (self.max_length - len(labels))
            labels.extend(pad)

            temp = {
                'input_ids' : tokenized['input_ids'],
                'attention_mask' : tokenized['attention_mask'],
                'labels' : labels
            }
            
            self.data.append(temp)
        
    def _gen_labels(self, input_ids, word2ids):
        sequence = input_ids[1:-1]
        labels = [0] * len(sequence)
        
        for v in word2ids.values():
            target_ids = v['target_ids']
            label_id = v['label_id']
            
            i=0
            target_ids_length = len(target_ids)
            
            while i < len(sequence):
                if sequence[i:i + target_ids_length] == target_ids:
                    labels[i:i + target_ids_length] = label_id
                    i = i + target_ids_length
                else:
                    i += 1
                    
        return labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.data[idx]['input_ids'],
            'attention_mask': self.data[idx]['attention_mask'],
            'labels': self.data[idx]['labels'],
        }

In [6]:
class IOBE(IOBES):
    allowed_prefix = Prefix.I | Prefix.O | Prefix.B | Prefix.E
    start_patterns = {
        (Prefix.ANY, Prefix.B, Tag.ANY),
        (Prefix.ANY, Prefix.S, Tag.ANY)
    }
    inside_patterns = {
        (Prefix.B, Prefix.I, Tag.SAME),
        (Prefix.B, Prefix.E, Tag.SAME),
        (Prefix.I, Prefix.I, Tag.SAME),
        (Prefix.I, Prefix.E, Tag.SAME)
    }
    end_patterns = {
        (Prefix.S, Prefix.ANY, Tag.ANY),
        (Prefix.E, Prefix.ANY, Tag.ANY),
        (Prefix.B, Prefix.O, Tag.ANY),
        (Prefix.B, Prefix.I, Tag.DIFF),
        (Prefix.B, Prefix.B, Tag.ANY),
    }

In [7]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    #import pdb;pdb.set_trace()
    predictions = predictions.flatten()
    labels = labels.flatten()
    npre = []
    nlab = []
 
    for i in range(len(labels)):
        if labels[i] != -100:
            npre.append(predictions[i])
            nlab.append(labels[i])
    npre = torch.tensor(npre)
    nlab = torch.tensor(nlab)
    
    label_indices = label_arr.copy()
    npre = [label_indices[pred] for pred in npre]
    nlab = [label_indices[label] for label in nlab]
    del label_indices[label_indices.index("O")]
    entity_level_metrics = classification_report(
        [nlab], [npre], digits=3,
        suffix=False,
        mode= 'strict', scheme=IOBE, 
        zero_division=True, output_dict=True
    )

    metrics = {}
    #import pdb;pdb.set_trace()
    for key in entity_level_metrics.keys():
        if len(key) == 3:
            metrics[key+"_f1"] = entity_level_metrics[key]['f1-score']
            metrics[key+"_recall"] = entity_level_metrics[key]['recall']
            metrics[key+"_precision"] = entity_level_metrics[key]['precision']
            
        if key == 'macro avg':
            metrics["entity_macro_f1"] = entity_level_metrics['macro avg']['f1-score']
            metrics["entity_macro_precision"] = entity_level_metrics['macro avg']['precision']
            metrics["entity_macro_recall"] = entity_level_metrics['macro avg']['recall']
            
    return metrics

In [8]:
def base_train(config):
    
    model = ElectraForTokenClassification.from_pretrained(config['base_model_dir'], num_labels=13)
    tokenizer = AutoTokenizer.from_pretrained(config['base_model_dir'])
    train_file = pd.read_csv(config['train_file'])
    valid_file = pd.read_csv(config['valid_file'], sep='\t') 
    
    label2id = {'O':0}
    count = 1
    for i in range(6):
        label2id['B-'+permutations[config['perm']][i]] = count
        count += 1
        label2id['I-'+permutations[config['perm']][i]] = count
        count += 1

    train_data = NERDataset(train_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    valid_data = NERDataset(valid_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    id2label = {label2id[label] : label for label in label2id.keys()}
    
    #import pdb;pdb.set_trace()
    model.config.label2id = label2id
    model.config.id2label = id2label

    global label_arr
    
    label_arr = []
    for v in id2label.values():
        label_arr.append(v)
        

    data_collator = DefaultDataCollator()
    device = torch.device("cuda")
    model.to(device)
     
    training_args = TrainingArguments(

        output_dir=config['output_dir'],
        do_eval = True,
        learning_rate=config['learning_rate'],
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=config['train_epoch'],
        weight_decay=0.1,
        save_strategy = 'epoch',
        logging_strategy = 'epoch',
        evaluation_strategy = 'epoch',
        load_best_model_at_end = True,
        label_names = ['labels'],
        metric_for_best_model = 'entity_macro_f1',
        warmup_ratio = 0.05,
        no_cuda = False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset = valid_data,
        data_collator=data_collator,
        tokenizer = tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    trainer.save_model(config['output_dir'] + '/final')
    
    for f_name in os.listdir(config['output_dir']):
        if f_name.startswith('checkpoint'):
            for f in os.listdir(config['output_dir']+'/'+f_name):
                os.remove(config['output_dir']+'/'+f_name+'/'+f)
            os.rmdir(config['output_dir']+'/'+f_name)



In [9]:
def cl_train(config):
    
    #model = ElectraForTokenClassification.from_pretrained(config['base_model_dir'], num_labgels=(config['step']-1)*2+1)
    model = ElectraForTokenClassification.from_pretrained(config['base_model_dir'])
    # new_layer = nn.Linear(512, config['step']*2+1)
    # new_layer.weight.data[:(config['step']-1)*2+1,:] = model.classifier.weight.data
    # model.classifer = new_layer
    tokenizer = AutoTokenizer.from_pretrained(config['base_model_dir'])
    train_file = pd.read_csv(config['train_file'])
    valid_file = pd.read_csv(config['valid_file'], sep='\t') 

    label2id = {'O':0}
    count = 1
    for i in range(6):
        label2id['B-'+permutations[config['perm']][i]] = count
        count += 1
        label2id['I-'+permutations[config['perm']][i]] = count
        count += 1

    
    train_data = NERDataset(train_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    valid_data = NERDataset(valid_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    id2label = {label2id[label] : label for label in label2id.keys()}
    
    # model.config.label2id = label2id
    # model.config.id2label = id2label

    global label_arr
    
    label_arr = []
    for v in id2label.values():
        label_arr.append(v)

    data_collator = DefaultDataCollator()
    device = torch.device("cuda")
    model.to(device)
     
    training_args = TrainingArguments(

        output_dir=config['output_dir'],
        do_eval = True,
        learning_rate=config['learning_rate'],
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=config['train_epoch'],
        weight_decay=0.1,
        save_strategy = 'epoch',
        logging_strategy = 'epoch',
        evaluation_strategy = 'epoch',
        load_best_model_at_end = True,
        label_names = ['labels'],
        metric_for_best_model = 'entity_macro_f1',
        warmup_ratio = 0.05,
        no_cuda = False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset = valid_data,
        data_collator=data_collator,
        tokenizer = tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    trainer.save_model(config['output_dir'] + '/final')
    
    for f_name in os.listdir(config['output_dir']):
        if f_name.startswith('checkpoint'):
            for f in os.listdir(config['output_dir']+'/'+f_name):
                os.remove(config['output_dir']+'/'+f_name+'/'+f)
            os.rmdir(config['output_dir']+'/'+f_name)



In [10]:
def test(config):
    
    model = ElectraForTokenClassification.from_pretrained(config['model'])
    
    global label_arr
    
    label_arr = []
    for v in model.config.id2label.values():
        label_arr.append(v)
    
    tokenizer = AutoTokenizer.from_pretrained(config['model'])
    
    test_file = pd.read_csv(config['test_file'], sep='\t')
    label2id = model.config.label2id
    
    test_dataset = NERDataset(test_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    
    device = torch.device("cuda")
    model.to(device)
    
    data_collator = DefaultDataCollator()
    
    training_args = TrainingArguments(
        output_dir=config['model'],
        per_device_eval_batch_size=32,
    )

    trainer = Trainer(
        args=training_args,
        model=model,
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )
    
    metrics = trainer.evaluate(test_dataset)
    trainer.save_metrics(split='test', metrics=metrics)
    

In [11]:
import os

if(__name__=="__main__"):
    
    for per_num_int in range(6,7):
        per_num = str(per_num_int)
        permutation_dir = "train_data/perm_" + per_num

        data_file_dict = {
            "train" : [],
            "valid" : []
        }

        for i in range(6):
            for f_name in os.listdir(permutation_dir + '/'):
                if f_name.startswith('d'+str(i)):
                    data_file_dict['train'].append(f_name)
                    break

            for f_name in os.listdir('test_data/perm' + str(per_num)):
                if f_name.startswith('eval_'+str(i+1)):
                    data_file_dict['valid'].append(f_name)
                    break

        for i in range(3, 6):
            if i == 0:
                #import pdb;pdb.set_trace()
                base_config = {
                    'base_model_dir' : 'monologg/koelectra-base-v3-discriminator',
                    'train_file' : permutation_dir + '/' + data_file_dict['train'][i],
                    'valid_file' : 'test_data/perm' + str(per_num) + '/' + data_file_dict['valid'][i],
                    'output_dir' : 'transfer/perm_' + per_num + '/step' + str(i+1),
                    'train_epoch' : 10,
                    'learning_rate' : 5e-05,
                    'step' : i+1,
                    'perm': per_num_int-1
                }

                base_train(base_config)
            else:
                base_config = {
                    'base_model_dir' : 'transfer/perm_' + per_num + '/step' + str(i) + '/final',
                    'train_file' : permutation_dir + '/' + data_file_dict['train'][i],
                    'valid_file' : 'test_data/perm' + str(per_num) + '/' + data_file_dict['valid'][i],
                    'output_dir' : 'transfer/perm_' + per_num + '/step' + str(i+1),
                    'train_epoch' : 10,
                    'learning_rate' : 5e-05,
                    'step' : i+1,
                    'perm': per_num_int-1
                }
                
                cl_train(base_config)
                
            
            test_config = {
                'model':'transfer/perm_' + per_num + '/step' + str(i+1) + '/final',
                'test_file': 'test_data/perm' + str(per_num) + '/' + data_file_dict['valid'][i]
            }    

            test(test_config)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Epoch,Training Loss,Validation Loss,Cvl F1,Cvl Recall,Cvl Precision,Dat F1,Dat Recall,Dat Precision,Loc F1,Loc Recall,Loc Precision,Per F1,Per Recall,Per Precision,Entity Macro F1,Entity Macro Precision,Entity Macro Recall
1,0.1491,0.755121,0.831583,0.843497,0.82,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.207896,0.955,0.210874
2,0.0347,0.772937,0.865526,0.91036,0.8249,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.216381,0.956225,0.22759
3,0.0144,0.884716,0.870871,0.852314,0.890253,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.217718,0.972563,0.213079
4,0.0069,0.895565,0.856089,0.937546,0.787654,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.214022,0.946914,0.234386
5,0.0033,0.923057,0.882941,0.908891,0.858432,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.220735,0.964608,0.227223
6,0.0018,0.947555,0.882457,0.918442,0.849185,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.220614,0.962296,0.229611
7,0.0011,0.98687,0.891281,0.912564,0.870968,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.22282,0.967742,0.228141
8,0.0009,0.989217,0.880702,0.922116,0.842848,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.220175,0.960712,0.230529
9,0.0006,0.996134,0.886759,0.917708,0.85783,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.22169,0.964457,0.229427
10,0.0006,0.998169,0.884099,0.919177,0.8516,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.221025,0.9629,0.229794


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Epoch,Training Loss,Validation Loss,Cvl F1,Cvl Recall,Cvl Precision,Dat F1,Dat Recall,Dat Precision,Loc F1,Loc Recall,Loc Precision,Per F1,Per Recall,Per Precision,Qnt F1,Qnt Recall,Qnt Precision,Entity Macro F1,Entity Macro Precision,Entity Macro Recall
1,0.1602,1.163131,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.7557,0.789116,0.725,0.15114,0.945,0.157823
2,0.0125,1.20019,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.759644,0.870748,0.673684,0.151929,0.934737,0.17415
3,0.0073,1.246466,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.822006,0.863946,0.783951,0.164401,0.95679,0.172789
4,0.0032,1.291755,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.768328,0.891156,0.675258,0.153666,0.935052,0.178231
5,0.0022,1.312481,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.823151,0.870748,0.780488,0.16463,0.956098,0.17415
6,0.0011,1.339807,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.809969,0.884354,0.747126,0.161994,0.949425,0.176871
7,0.0007,1.353956,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.836013,0.884354,0.792683,0.167203,0.958537,0.176871
8,0.0004,1.379888,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.845902,0.877551,0.816456,0.16918,0.963291,0.17551
9,0.0005,1.379584,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.825806,0.870748,0.785276,0.165161,0.957055,0.17415
10,0.0003,1.383638,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.825806,0.870748,0.785276,0.165161,0.957055,0.17415


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Epoch,Training Loss,Validation Loss,Cvl F1,Cvl Recall,Cvl Precision,Dat F1,Dat Recall,Dat Precision,Loc F1,Loc Recall,Loc Precision,Org F1,Org Recall,Org Precision,Per F1,Per Recall,Per Precision,Qnt F1,Qnt Recall,Qnt Precision,Entity Macro F1,Entity Macro Precision,Entity Macro Recall
1,0.1822,1.352348,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.771473,0.948784,0.65,0.0,0.0,1.0,0.0,0.0,1.0,0.128579,0.941667,0.158131
2,0.0276,1.582582,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.848635,0.8758,0.823105,0.0,0.0,1.0,0.0,0.0,1.0,0.141439,0.970517,0.145967
3,0.0118,1.685463,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.857508,0.859155,0.855867,0.0,0.0,1.0,0.0,0.0,1.0,0.142918,0.975978,0.143192
4,0.0058,1.757653,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.857316,0.903969,0.815242,0.0,0.0,1.0,0.0,0.0,1.0,0.142886,0.969207,0.150662
5,0.0026,1.889577,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.860149,0.889885,0.832335,0.0,0.0,1.0,0.0,0.0,1.0,0.143358,0.972056,0.148314
6,0.0009,1.900879,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.86284,0.914213,0.816934,0.0,0.0,1.0,0.0,0.0,1.0,0.143807,0.969489,0.152369
7,0.0012,1.9319,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.856802,0.919334,0.802235,0.0,0.0,1.0,0.0,0.0,1.0,0.1428,0.967039,0.153222
8,0.0007,1.933947,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.84852,0.93598,0.776008,0.0,0.0,1.0,0.0,0.0,1.0,0.14142,0.962668,0.155997
9,0.0004,1.959331,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.865817,0.912932,0.823326,0.0,0.0,1.0,0.0,0.0,1.0,0.144303,0.970554,0.152155
10,0.0003,1.970217,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.863527,0.915493,0.817143,0.0,0.0,1.0,0.0,0.0,1.0,0.143921,0.969524,0.152582


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
