In [1]:
import os

os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#os.environ["CUDA_LAUNCH_BLOCKING"] = '1'
os.environ['WANDB_DISABLED'] = 'True'

In [2]:
import torch
import re
import random
from transformers import ElectraForTokenClassification, AutoConfig, AutoTokenizer
from transformers import AdamW
import pandas as pd
from torch.utils.data import Dataset as DS
from datasets import Dataset
from transformers import DefaultDataCollator
from transformers import Trainer, TrainingArguments
from seqeval.metrics import classification_report
from sklearn.metrics import f1_score
import numpy as np
from datasets import load_metric
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from seqeval.scheme import IOB1, IOB2, IOE1, IOE2, IOBES, BILOU, Entities, Prefix, Tag
from seqeval.scheme import IOBES

In [4]:
permutations = [['ORG', 'PER', 'CVL', 'DAT', 'LOC', 'QNT'],
               ['DAT', 'QNT', 'PER', 'LOC', 'ORG', 'CVL'],
               ['CVL', 'LOC', 'ORG', 'QNT', 'DAT', 'PER'],
               ['QNT', 'ORG', 'DAT', 'PER', 'CVL', 'LOC'],
               ['LOC', 'CVL', 'QNT', 'ORG', 'PER', 'DAT'],
               ['PER', 'DAT', 'LOC', 'CVL', 'QNT', 'ORG']]

In [5]:
class NERDataset(DS):
    def __init__(self, dataset, tokenizer, label2id, max_length):
        super().__init__()
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_length = max_length
        self.dataset = dataset
        self.data = []
        
        for i in range(len(self.dataset['label'])):
            text = dataset['label'][i]
            tagged_words = re.findall('<.*?:.*?>', text)
            
            word2ids = dict()
            for tagged_word in tagged_words:
                tag_splited = tagged_word.strip('<>').split(':')
                tag = tag_splited[-1]
                word = ':'.join(tag_splited[:-1])
                    
                word_tok = self.tokenizer.encode(word)[1:-1]
                if word not in word2ids:

                    label_id = [self.label2id['B-'+tag]]
                    if len(word_tok) > 1 :
                        label_id.extend([self.label2id['I-'+tag]] * (len(word_tok)-1))
                    word2ids[word] = {
                        'target_ids': word_tok,
                        'label_id': label_id 
                    }
                    text = text.replace(tagged_word, word)

            tokenized = self.tokenizer(text, truncation=True, max_length=self.max_length, padding='max_length')
            if 0 in tokenized['input_ids']:
                tok_length = tokenized['input_ids'].index(0)
            else:
                tok_length = self.max_length
            label_input = tokenized['input_ids'][:tok_length]
            labels = self._gen_labels(label_input, word2ids)
            labels.insert(0, -100)
            pad = [-100] * (self.max_length - len(labels))
            labels.extend(pad)

            temp = {
                'input_ids' : tokenized['input_ids'],
                'attention_mask' : tokenized['attention_mask'],
                'labels' : labels
            }
            
            self.data.append(temp)
        
    def _gen_labels(self, input_ids, word2ids):
        sequence = input_ids[1:-1]
        labels = [0] * len(sequence)
        
        for v in word2ids.values():
            target_ids = v['target_ids']
            label_id = v['label_id']
            
            i=0
            target_ids_length = len(target_ids)
            
            while i < len(sequence):
                if sequence[i:i + target_ids_length] == target_ids:
                    labels[i:i + target_ids_length] = label_id
                    i = i + target_ids_length
                else:
                    i += 1
                    
        return labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.data[idx]['input_ids'],
            'attention_mask': self.data[idx]['attention_mask'],
            'labels': self.data[idx]['labels'],
        }

In [6]:
class IOBE(IOBES):
    allowed_prefix = Prefix.I | Prefix.O | Prefix.B | Prefix.E
    start_patterns = {
        (Prefix.ANY, Prefix.B, Tag.ANY),
        (Prefix.ANY, Prefix.S, Tag.ANY)
    }
    inside_patterns = {
        (Prefix.B, Prefix.I, Tag.SAME),
        (Prefix.B, Prefix.E, Tag.SAME),
        (Prefix.I, Prefix.I, Tag.SAME),
        (Prefix.I, Prefix.E, Tag.SAME)
    }
    end_patterns = {
        (Prefix.S, Prefix.ANY, Tag.ANY),
        (Prefix.E, Prefix.ANY, Tag.ANY),
        (Prefix.B, Prefix.O, Tag.ANY),
        (Prefix.B, Prefix.I, Tag.DIFF),
        (Prefix.B, Prefix.B, Tag.ANY),
    }

In [7]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    #import pdb;pdb.set_trace()
    predictions = predictions.flatten()
    labels = labels.flatten()
    npre = []
    nlab = []
 
    for i in range(len(labels)):
        if labels[i] != -100:
            npre.append(predictions[i])
            nlab.append(labels[i])
    npre = torch.tensor(npre)
    nlab = torch.tensor(nlab)
    
    label_indices = label_arr.copy()
    npre = [label_indices[pred] for pred in npre]
    nlab = [label_indices[label] for label in nlab]
    del label_indices[label_indices.index("O")]
    entity_level_metrics = classification_report(
        [nlab], [npre], digits=3,
        suffix=False,
        mode= 'strict', scheme=IOBE, 
        zero_division=True, output_dict=True
    )

    metrics = {}
    #import pdb;pdb.set_trace()
    for key in entity_level_metrics.keys():
        if len(key) == 3:
            metrics[key+"_f1"] = entity_level_metrics[key]['f1-score']
            metrics[key+"_recall"] = entity_level_metrics[key]['recall']
            metrics[key+"_precision"] = entity_level_metrics[key]['precision']
            
        if key == 'macro avg':
            metrics["entity_macro_f1"] = entity_level_metrics['macro avg']['f1-score']
            metrics["entity_macro_precision"] = entity_level_metrics['macro avg']['precision']
            metrics["entity_macro_recall"] = entity_level_metrics['macro avg']['recall']
            
    return metrics

In [8]:
def base_train(config):
    
    model = ElectraForTokenClassification.from_pretrained(config['base_model_dir'], num_labels=13)
    tokenizer = AutoTokenizer.from_pretrained(config['base_model_dir'])
    train_file = pd.read_csv(config['train_file'])
    valid_file = pd.read_csv(config['valid_file'], sep='\t') 
    
    label2id = {'O':0}
    count = 1
    for i in range(6):
        label2id['B-'+permutations[config['perm']][i]] = count
        count += 1
        label2id['I-'+permutations[config['perm']][i]] = count
        count += 1

    train_data = NERDataset(train_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    valid_data = NERDataset(valid_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    id2label = {label2id[label] : label for label in label2id.keys()}
    
    #import pdb;pdb.set_trace()
    model.config.label2id = label2id
    model.config.id2label = id2label

    global label_arr
    
    label_arr = []
    for v in id2label.values():
        label_arr.append(v)
        

    data_collator = DefaultDataCollator()
    device = torch.device("cuda")
    model.to(device)
     
    training_args = TrainingArguments(

        output_dir=config['output_dir'],
        do_eval = True,
        learning_rate=config['learning_rate'],
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=config['train_epoch'],
        weight_decay=0.1,
        save_strategy = 'epoch',
        logging_strategy = 'epoch',
        evaluation_strategy = 'epoch',
        load_best_model_at_end = True,
        label_names = ['labels'],
        metric_for_best_model = 'entity_macro_f1',
        warmup_ratio = 0.05,
        no_cuda = False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset = valid_data,
        data_collator=data_collator,
        tokenizer = tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    trainer.save_model(config['output_dir'] + '/final')
    
    for f_name in os.listdir(config['output_dir']):
        if f_name.startswith('checkpoint'):
            for f in os.listdir(config['output_dir']+'/'+f_name):
                os.remove(config['output_dir']+'/'+f_name+'/'+f)
            os.rmdir(config['output_dir']+'/'+f_name)



In [9]:
def cl_train(config):
    
    model = ElectraForTokenClassification.from_pretrained(config['base_model_dir'])
    tokenizer = AutoTokenizer.from_pretrained(config['base_model_dir'])
    train_file = pd.read_csv(config['train_file'])
    valid_file = pd.read_csv(config['valid_file'], sep='\t') 

    label2id = model.config.label2id
    
    train_data = NERDataset(train_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    valid_data = NERDataset(valid_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    id2label = {label2id[label] : label for label in label2id.keys()}

    global label_arr
    
    label_arr = []
    for v in id2label.values():
        label_arr.append(v)

    data_collator = DefaultDataCollator()
    device = torch.device("cuda")
    model.to(device)
     
    training_args = TrainingArguments(

        output_dir=config['output_dir'],
        do_eval = True,
        learning_rate=config['learning_rate'],
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=config['train_epoch'],
        weight_decay=0.1,
        save_strategy = 'epoch',
        logging_strategy = 'epoch',
        evaluation_strategy = 'epoch',
        load_best_model_at_end = True,
        label_names = ['labels'],
        metric_for_best_model = 'entity_macro_f1',
        warmup_ratio = 0.05,
        no_cuda = False
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset = valid_data,
        data_collator=data_collator,
        tokenizer = tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    trainer.save_model(config['output_dir'] + '/final')
    
    for f_name in os.listdir(config['output_dir']):
        if f_name.startswith('checkpoint'):
            for f in os.listdir(config['output_dir']+'/'+f_name):
                os.remove(config['output_dir']+'/'+f_name+'/'+f)
            os.rmdir(config['output_dir']+'/'+f_name)



In [10]:
def test(config):
    
    model = ElectraForTokenClassification.from_pretrained(config['model'])
    
    global label_arr
    
    label_arr = []
    for v in model.config.id2label.values():
        label_arr.append(v)
    
    tokenizer = AutoTokenizer.from_pretrained(config['model'])
    
    test_file = pd.read_csv(config['test_file'], sep='\t')
    label2id = model.config.label2id
    
    test_dataset = NERDataset(test_file, tokenizer=tokenizer, max_length=300, label2id=label2id)
    
    device = torch.device("cuda")
    model.to(device)
    
    data_collator = DefaultDataCollator()
    
    training_args = TrainingArguments(
        output_dir=config['model'],
        per_device_eval_batch_size=32,
    )

    trainer = Trainer(
        args=training_args,
        model=model,
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )
    
    metrics = trainer.evaluate(test_dataset)
    trainer.save_metrics(split='test', metrics=metrics)
    

In [11]:
import os

if(__name__=="__main__"):
    
    for per_num_int in range(1,7):
        per_num = str(per_num_int)
        permutation_dir = "train_data/perm_" + per_num

        data_file_dict = {
            "train" : [],
            "valid" : []
        }

        for i in range(6):
            for f_name in os.listdir(permutation_dir + '/'):
                if f_name.startswith('d'+str(i+1)):
                    data_file_dict['train'].append(f_name)
                    break

            for f_name in os.listdir('test_data/perm' + str(per_num)):
                if f_name.startswith('eval_'+str(i+1)):
                    data_file_dict['valid'].append(f_name)
                    break

        for i in range(6):
            if i == 0:
                #import pdb;pdb.set_trace()
                base_config = {
                    'base_model_dir' : 'monologg/koelectra-base-v3-discriminator',
                    'train_file' : permutation_dir + '/' + data_file_dict['train'][i],
                    'valid_file' : 'test_data/perm' + str(per_num) + '/' + data_file_dict['valid'][i],
                    'output_dir' : 'transfer/perm_' + per_num + '/step' + str(i+1),
                    'train_epoch' : 10,
                    'learning_rate' : 5e-05,
                    'step' : i+1,
                    'perm': per_num_int-1
                }

                base_train(base_config)
            else:
                base_config = {
                    'base_model_dir' : 'transfer/perm_' + per_num + '/step' + str(i) + '/final',
                    'train_file' : permutation_dir + '/' + data_file_dict['train'][i],
                    'valid_file' : 'test_data/perm' + str(per_num) + '/' + data_file_dict['valid'][i],
                    'output_dir' : 'transfer/perm_' + per_num + '/step' + str(i+1),
                    'train_epoch' : 10,
                    'learning_rate' : 5e-05,
                    'step' : i+1,
                    'perm': per_num_int-1
                }
                
                cl_train(base_config)
                
            
            test_config = {
                'model':'transfer/perm_' + per_num + '/step' + str(i+1) + '/final',
                'test_file': 'test_data/perm' + str(per_num) + '/' + data_file_dict['valid'][i]
            }    

            test(test_config)

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraForTokenClassification: ['discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight']
- This IS expected if you are initializing ElectraForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForTokenClassification were not initialized from the model checkpoint at monologg/koelectra-base-v3-discriminator and are newly initialized: ['classifier

Epoch,Training Loss,Validation Loss,Org F1,Org Recall,Org Precision,Entity Macro F1,Entity Macro Precision,Entity Macro Recall
1,0.4922,0.306961,0.0,0.0,1.0,0.0,1.0,0.0
2,0.1546,0.339264,0.0,0.0,1.0,0.0,1.0,0.0


KeyboardInterrupt: 