In [1]:
import os

os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#os.environ["CUDA_LAUNCH_BLOCKING"] = '1'
os.environ['WANDB_DISABLED'] = 'True'

In [2]:
import torch
import re
import random
import extendNER_new_bioe_2 as ex #bio or bioe
from transformers import ElectraForTokenClassification, AutoConfig, AutoTokenizer
from transformers import AdamW
import pandas as pd
from torch.utils.data import Dataset as DS
from datasets import Dataset
from transformers import DefaultDataCollator
from transformers import Trainer, TrainingArguments
from seqeval.metrics import classification_report
from sklearn.metrics import f1_score
import numpy as np
from datasets import load_metric
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from seqeval.scheme import IOB1, IOB2, IOE1, IOE2, IOBES, BILOU, Entities, Prefix, Tag
from seqeval.scheme import IOBES

In [4]:
class NERDataset(DS):
    def __init__(self, dataset, tokenizer, label2id, max_length):
        super().__init__()
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_length = max_length
        self.dataset = dataset
        self.data = []
        
        for i in range(len(self.dataset['label'])):
        #for i in range(3):
            text = dataset['label'][i]
            tagged_words = re.findall('<.*?:.*?>', text)
            
            word2ids = dict()
            for tagged_word in tagged_words:
                tag_splited = tagged_word.strip('<>').split(':')
                tag = tag_splited[-1]
                word = ':'.join(tag_splited[:-1])
                    
                word_tok = self.tokenizer.encode(word)[1:-1]
                if word not in word2ids:

                    label_id = [self.label2id['B-'+tag]]
                    if len(word_tok) > 1 :
                        label_id.extend([self.label2id['I-'+tag]] * (len(word_tok)-1))
                    word2ids[word] = {
                        'target_ids': word_tok,
                        'label_id': label_id 
                    }
                    text = text.replace(tagged_word, word)

            tokenized = self.tokenizer(text, truncation=True, max_length=self.max_length, padding='max_length')
            if 0 in tokenized['input_ids']:
                tok_length = tokenized['input_ids'].index(0)
            else:
                tok_length = self.max_length
            label_input = tokenized['input_ids'][:tok_length]
            labels = self._gen_labels(label_input, word2ids)
            labels.insert(0, -100)
            pad = [-100] * (self.max_length - len(labels))
            labels.extend(pad)

            temp = {
                'input_ids' : tokenized['input_ids'],
                'attention_mask' : tokenized['attention_mask'],
                'labels' : labels
            }
            
            self.data.append(temp)
        
    def _gen_labels(self, input_ids, word2ids):
        sequence = input_ids[1:-1]
        labels = [0] * len(sequence)
        
        for v in word2ids.values():
            target_ids = v['target_ids']
            label_id = v['label_id']
            
            i=0
            target_ids_length = len(target_ids)
            
            while i < len(sequence):
                if sequence[i:i + target_ids_length] == target_ids:
                    labels[i:i + target_ids_length] = label_id
                    i = i + target_ids_length
                else:
                    i += 1
                    
        return labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.data[idx]['input_ids'],
            'attention_mask': self.data[idx]['attention_mask'],
            'labels': self.data[idx]['labels'],
        }

In [5]:
class IOBE(IOBES):
    allowed_prefix = Prefix.I | Prefix.O | Prefix.B | Prefix.E
    start_patterns = {
        (Prefix.ANY, Prefix.B, Tag.ANY),
        (Prefix.ANY, Prefix.S, Tag.ANY)
    }
    inside_patterns = {
        (Prefix.B, Prefix.I, Tag.SAME),
        (Prefix.B, Prefix.E, Tag.SAME),
        (Prefix.I, Prefix.I, Tag.SAME),
        (Prefix.I, Prefix.E, Tag.SAME)
    }
    end_patterns = {
        (Prefix.S, Prefix.ANY, Tag.ANY),
        (Prefix.E, Prefix.ANY, Tag.ANY),
        (Prefix.B, Prefix.O, Tag.ANY),
        (Prefix.B, Prefix.I, Tag.DIFF),
        (Prefix.B, Prefix.B, Tag.ANY),
    }

In [6]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    #import pdb;pdb.set_trace()
    predictions = predictions.flatten()
    labels = labels.flatten()
    npre = []
    nlab = []
 
    for i in range(len(labels)):
        if labels[i] != -100:
            npre.append(predictions[i])
            nlab.append(labels[i])
    npre = torch.tensor(npre)
    nlab = torch.tensor(nlab)
    
    label_indices = label_arr.copy()
    npre = [label_indices[pred] for pred in npre]
    nlab = [label_indices[label] for label in nlab]
    del label_indices[label_indices.index("O")]
    entity_level_metrics = classification_report(
        [nlab], [npre], digits=3,
        suffix=False,
        mode= 'strict', scheme=IOBE, 
        zero_division=True, output_dict=True
    )

    metrics = {}
    
    for key in entity_level_metrics.keys():
        if len(key) == 3:
            metrics[key+"_f1"] = entity_level_metrics[key]['f1-score']
            metrics[key+"_recall"] = entity_level_metrics[key]['recall']
            metrics[key+"_precision"] = entity_level_metrics[key]['precision']
            
        if key == 'macro avg':
            metrics["entity_macro_f1"] = entity_level_metrics['macro avg']['f1-score']
            metrics["entity_macro_precision"] = entity_level_metrics['macro avg']['precision']
            metrics["entity_macro_recall"] = entity_level_metrics['macro avg']['recall']
            
    return metrics

In [7]:
def base_train(config):
      
    model = ElectraForTokenClassification.from_pretrained(config['base_model_dir'], num_labels=3)
    tokenizer = AutoTokenizer.from_pretrained(config['base_model_dir'])
    train_file = pd.read_csv(config['train_file'])
    valid_file = pd.read_csv(config['valid_file'], sep='\t') 

    ent = config['new_entity']
    
    label2id = {
        'O' : 0,
        'B-'+ent : 1,
        'I-'+ent : 2
    }

    train_data = NERDataset(train_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    valid_data = NERDataset(valid_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    id2label = {label2id[label] : label for label in label2id.keys()}
    
    model.config.label2id = label2id
    model.config.id2label = id2label

    global label_arr
    
    label_arr = []
    for v in id2label.values():
        label_arr.append(v)

    data_collator = DefaultDataCollator()
    device = torch.device("cuda")
    model.to(device)
     
    training_args = TrainingArguments(

        output_dir=config['output_dir'],
        do_eval = True,
        learning_rate=config['learning_rate'],
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        num_train_epochs=config['train_epoch'],
        weight_decay=0.1,
        save_strategy = 'epoch',
        logging_strategy = 'epoch',
        evaluation_strategy = 'epoch',
        load_best_model_at_end = True,
        label_names = ['labels'],
        metric_for_best_model = 'entity_macro_f1',
        warmup_ratio = 0.05,
        no_cuda = False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset = valid_data,
        data_collator=data_collator,
        tokenizer = tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    trainer.save_model(config['output_dir'] + '/final')
    
    for f_name in os.listdir(config['output_dir']):
        if f_name.startswith('checkpoint'):
            for f in os.listdir(config['output_dir']+'/'+f_name):
                os.remove(config['output_dir']+'/'+f_name+'/'+f)
            os.rmdir(config['output_dir']+'/'+f_name)



In [8]:
# from transformers import TrainerCallback

# class customCallBack(TrainerCallback):
#     def on_epoch_end(self, args, state, control, model, **kwargs):
#         import pdb;pdb.set_trace()
        

In [9]:
def cl_train(config):
    
    teacher = ElectraForTokenClassification.from_pretrained(config['teacher_dir'])
    student = ex.extendNER.from_pretrained(config['teacher_dir'])
    labels = student.num_labels
    student.config.id2label[labels-2] = 'B-' + config['new_entity']
    student.config.id2label[labels-1] = 'I-' + config['new_entity']
    student.config.label2id['B-'+config['new_entity']] = labels-2
    student.config.label2id['I-'+config['new_entity']] = labels-1
    student.ce = config['ce']
    student.kd = config['kd']
    student.T = config['T']
    torch.manual_seed(42)
    student.classifier = torch.nn.Linear(768, labels)
    
    global label_arr
    
    label_arr = []
    for v in student.config.id2label.values():
        label_arr.append(v)
    
    tokenizer = AutoTokenizer.from_pretrained(config['teacher_dir'])
    
    train_file = pd.read_csv(config['train_file'])
    valid_file = pd.read_csv(config['valid_file'], sep='\t')
    
    label2id = student.config.label2id
    
    train_dataset = NERDataset(train_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    valid_dataset = NERDataset(valid_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    
    temp_train = []
    temp_valid = []
    for e in train_dataset:
        temp_train.append(e)
        
    for e in valid_dataset:
        temp_valid.append(e)   
        
    teacher_input = pd.DataFrame(temp_train)
    student_input = teacher_input.copy()
    teacher_input = teacher_input.drop(columns=['labels'])
    teacher_input = Dataset.from_pandas(teacher_input)
    
    data_collator = DefaultDataCollator()
    
    trainer = Trainer(
        model=teacher,
        data_collator=data_collator
    )

    prediction = trainer.predict(teacher_input)
    
    soft_label = prediction[0].tolist()
    student_input['soft_labels'] = soft_label
    student_input = student_input.rename(columns={'labels' :'hard_labels'})

    valid_input = pd.DataFrame(temp_valid)

    student_input = Dataset.from_pandas(student_input)
    valid_input = Dataset.from_pandas(valid_input)
    
    data_collator = DefaultDataCollator()
    
    device = torch.device("cuda")
    student.to(device)
    
    training_args = TrainingArguments(
        output_dir=config['output_dir'],
        do_eval = True,
        learning_rate=config['learning_rate'],
        per_device_train_batch_size=config['batch'],
        per_device_eval_batch_size=config['batch'],
        num_train_epochs=config['train_epoch'],
        weight_decay=0.1,
        save_strategy = 'epoch',
        logging_strategy = 'epoch',
        evaluation_strategy = 'epoch',
        load_best_model_at_end = True,
        label_names = ['labels'],
        metric_for_best_model = 'entity_macro_f1',
        warmup_ratio = 0.05,
        no_cuda = False,
    )

    trainer = Trainer(
        model=student,
        args=training_args,
        train_dataset=student_input,
        eval_dataset = valid_input,
        data_collator=data_collator,
        tokenizer = tokenizer,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    trainer.save_model(config['output_dir'] + '/final')
    
    for f_name in os.listdir(config['output_dir']):
        if f_name.startswith('checkpoint'):
            for f in os.listdir(config['output_dir']+'/'+f_name):
                os.remove(config['output_dir']+'/'+f_name+'/'+f)
            os.rmdir(config['output_dir']+'/'+f_name)
    

In [10]:
def test(config):
    
    model = ElectraForTokenClassification.from_pretrained(config['model'])
    
    global label_arr
    
    label_arr = []
    for v in model.config.id2label.values():
        label_arr.append(v)
    
    tokenizer = AutoTokenizer.from_pretrained(config['model'])
    
    test_file = pd.read_csv(config['test_file'], sep='\t')
    label2id = model.config.label2id
    
    test_dataset = NERDataset(test_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    
    device = torch.device("cuda")
    model.to(device)
    
    data_collator = DefaultDataCollator()
    
    training_args = TrainingArguments(
        output_dir=config['model'],
        per_device_eval_batch_size=32,
    )

    trainer = Trainer(
        args=training_args,
        model=model,
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )
    
    metrics = trainer.evaluate(test_dataset)
    trainer.save_metrics(split='test', metrics=metrics)
    

In [11]:
import os
import shutil

if(__name__=="__main__"):
    
    directory = './ce_full'

    for per_num_int in range(1,7):
        per_num = str(per_num_int)
        permutation_dir = "train_data/perm_" + per_num

        data_file_dict = {
            "train" : [],
            "valid" : []

        }

    #     permutations = [['ORG', 'PER', 'CVL', 'DAT', 'LOC', 'QNT'],
    #                    ['DAT', 'QNT', 'PER', 'LOC', 'ORG', 'CVL'],
    #                    ['CVL', 'LOC', 'ORG', 'QNT', 'DAT', 'PER'],
    #                    ['QNT', 'ORG', 'DAT', 'PER', 'CVL', 'LOC'],
    #                    ['LOC', 'CVL', 'QNT', 'ORG', 'PER', 'DAT'],
    #                    ['PER', 'DAT', 'LOC', 'CVL', 'QNT', 'ORG']]

        for i in range(6):
            for f_name in os.listdir(permutation_dir + '/'):
                if f_name.startswith('d'+str(i)):
                    data_file_dict['train'].append(f_name)
                    break

            for f_name in os.listdir('test_data/perm' + str(per_num)):
                if f_name.startswith('eval_'+str(i+1)):
                    data_file_dict['valid'].append(f_name)
                    break

        for i in range(6):

            new_ent = data_file_dict['train'][i].split('_')[1][:3]
            if i == 0:

                base_config = {
                    'base_model_dir' : 'monologg/koelectra-base-v3-discriminator',
                    'train_file' : permutation_dir + '/' + data_file_dict['train'][0],
                    'valid_file' : 'test_data/perm' + str(per_num) + '/' + data_file_dict['valid'][0],
                    'output_dir' : directory + '/perm_' + per_num + '/step1/',
                    'train_epoch' : 10,
                    'learning_rate' : 5e-05,
                    'new_entity' : new_ent
                }

                base_train(base_config)

            else:

                cl_config = {
                    'teacher_dir' : directory + '/perm_' + per_num + '/step' + str(i) + '/final',
                    'new_entity' : new_ent,
                    'train_file' : permutation_dir + '/' + data_file_dict['train'][i],
                    'valid_file' : 'test_data/perm' + str(per_num) + '/' + data_file_dict['valid'][i],
                    'output_dir' : directory + '/perm_' + per_num + '/step' + str(i+1) + '/',
                    'train_epoch' : 20,
                    'learning_rate' : 5e-05,
                    'batch' : 16,
                    'ce' : 1,
                    'kd' : 1,
                    'T' : 2
                }

                cl_train(cl_config)

            test_config = {
                'model': directory + '/perm_' + per_num + '/step' + str(i+1) + '/final',
                'test_file': 'test_data/perm' + str(per_num) + '/' + data_file_dict['valid'][i]
            }    

            test(test_config)

    # folder = directory
    # for perm in os.listdir(folder):
    #     for step in os.listdir(folder + '/'+perm):
    #         for files in os.listdir(folder+'/'+perm+'/'+step+'/final'):
    #             if files.startswith('test'):
    #                 pass
    #             else:
    #                 if os.path.isdir(folder+'/'+perm+'/'+step+'/final/'+files):
    #                     shutil.rmtree(folder+'/'+perm+'/'+step+'/final/'+files)
    #                 else:
    #                     os.remove(folder+'/'+perm+'/'+step+'/final/'+files) 

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraForTokenClassification: ['discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForTokenClassification were not initialized from the model checkpoint at monologg/koelectra-base-v3-discriminator and are newly initialized: ['classifier

Epoch,Training Loss,Validation Loss,Org F1,Org Recall,Org Precision,Entity Macro F1,Entity Macro Precision,Entity Macro Recall
1,0.2784,0.038075,0.82515,0.87659,0.779412,0.82515,0.779412,0.87659
2,0.0341,0.033878,0.849349,0.871501,0.828295,0.849349,0.828295,0.871501
3,0.0163,0.038848,0.859077,0.888041,0.831943,0.859077,0.831943,0.888041
4,0.008,0.047528,0.840093,0.922392,0.771277,0.840093,0.771277,0.922392
5,0.0054,0.048643,0.850589,0.872774,0.829504,0.850589,0.829504,0.872774
6,0.0033,0.052179,0.843529,0.912214,0.784464,0.843529,0.784464,0.912214
7,0.002,0.055098,0.850602,0.898219,0.80778,0.850602,0.80778,0.898219
8,0.0012,0.056828,0.852816,0.895674,0.813873,0.852816,0.813873,0.895674
9,0.0011,0.058852,0.853865,0.899491,0.812644,0.853865,0.812644,0.899491
10,0.001,0.059077,0.853556,0.908397,0.804961,0.853556,0.804961,0.908397


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Some weights of the model checkpoint at ./ce_full/perm_1/step1/final were not used when initializing extendNER: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing extendNER from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing extendNER from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Epoch,Training Loss,Validation Loss,Org F1,Org Recall,Org Precision,Per F1,Per Recall,Per Precision,Entity Macro F1,Entity Macro Precision,Entity Macro Recall
1,0.5793,0.480743,0.0612,0.031807,0.806452,0.152903,0.918495,0.083393,0.107051,0.444922,0.475151
2,0.0427,0.616746,0.002535,0.001272,0.333333,0.133785,0.835423,0.072715,0.06816,0.203024,0.418348
3,0.0189,0.534238,0.671569,0.522901,0.938356,0.181122,0.926332,0.100374,0.426345,0.519365,0.724617
4,0.0062,0.604311,0.299465,0.178117,0.939597,0.227404,0.802508,0.132471,0.263435,0.536034,0.490312
5,0.0099,0.795538,0.737221,0.651399,0.849088,0.103017,0.703762,0.055576,0.420119,0.452332,0.677581
6,0.0082,0.984511,0.263904,0.153944,0.923664,0.116969,0.778997,0.063232,0.190436,0.493448,0.46647
7,0.0031,0.866057,0.597222,0.437659,0.939891,0.124296,0.830721,0.067174,0.360759,0.503532,0.63419
8,0.0021,0.840348,0.638544,0.491094,0.91253,0.119965,0.84953,0.064539,0.379254,0.488534,0.670312
9,0.0018,0.805771,0.616554,0.464377,0.917085,0.143754,0.84326,0.078575,0.380154,0.49783,0.653818
10,0.0016,0.816168,0.623006,0.47201,0.916049,0.145218,0.840125,0.079478,0.384112,0.497764,0.656068


KeyboardInterrupt: 