In [30]:
import os
import logging
import shutil
import pickle
import random
import json
import pathlib
import numpy as np
import datetime
import torch
from numpy.lib.utils import source
import pandas as pd
import logging
from pathlib import Path
from collections import Counter
from sklearn import metrics
from sklearn.base import ClassifierMixin
from sklearn.model_selection import train_test_split
from flair.models.text_classification_model import TARSClassifier
from flair.data import Sentence, Corpus
from flair.datasets import SentenceDataset
from flair.trainers import ModelTrainer

import pandas as pd
import numpy as np
import torch
import pickle
from flair.embeddings import WordEmbeddings, FlairEmbeddings, SentenceTransformerDocumentEmbeddings, TransformerDocumentEmbeddings
from flair.data import Sentence


def vectorize(sentences, embed_type='sentencetransformer'):
    # TODO word / document embedding, check results
    
    if embed_type == 'sentencetransformer':
        embedding = SentenceTransformerDocumentEmbeddings('bert-base-nli-mean-tokens')
    elif embed_type == 'roberta':
        embedding = TransformerDocumentEmbeddings('roberta-base')
    else:
        raise Exception('unknown model - ' + embed_type)
    
    try:
        vectors = []
        for st in sentences:
            st = Sentence(st)

            embedding.embed(st)
            vec = st.embedding
            vectors.append(vec)

    except Exception as e:
        # change pretrained model if fails
        embedding = TransformerDocumentEmbeddings('roberta-base')
        vectors = []
        for st in sentences:
            st = Sentence(st)

            embedding.embed(st)
            vec = st.embedding
            vectors.append(vec)

    vectors = torch.stack(vectors)
    vectors = tensor_to_array(vectors)
    
    return vectors


def tensor_to_array(vectors):
    if torch.is_tensor(vectors):
        if vectors.requires_grad:
            vectors = vectors.detach()
        if vectors.is_cuda:
            vectors = vectors.cpu()
        vectors = vectors.numpy()
    return vectors


def load(filepath):
    with open(filepath, 'rt') as f:
        if filepath.endswith('.csv'):
            df = pd.read_csv(filepath)
        elif filepath.endswith('.tsv'):
            df = pd.read_table(filepath)
        elif filepath.endswith('.xls') or filepath.endswith('.xlsx'):
            df = pd.read_excel(filepath)
        elif filepath.endswith('.txt'):
            texts = [line.strip() for line in f.readlines()]
            df = pd.DataFrame(texts, columns=['text'])
        elif filepath.endswith('.json'):
            df = pd.read_json(f, orient='table')
        elif filepath.endswith('.pkl'):
            df = pickle.load(f)
        else:
            raise Exception('unknown extension')

    return df


def save_output_file(filename, obj):
    workspace = Path(os.environ.get('ACCUTUNING_WORKSPACE'))
    output_path = Path(workspace, 'output')
    fp = pathlib.Path(output_path, filename)
    fp.write_bytes(
        pickle.dumps(obj)
    )
    return str(fp.relative_to(workspace))


logging.basicConfig()
logger = logging.getLogger()


def save_output_file(filename, obj):
    workspace = Path(os.environ.get('ACCUTUNING_WORKSPACE'))
    output_path = Path(workspace, 'output')
    fp = pathlib.Path(output_path, filename)
    fp.write_bytes(
        pickle.dumps(obj)
    )
    logger.critical('Saved output file - {}'.format(filename))
    return str(fp.relative_to(workspace))


def open_output_file(filename, none_if_not_exist=False):
    workspace = Path(os.environ.get('ACCUTUNING_WORKSPACE'))
    output_path = Path(workspace, 'output')
    fp = pathlib.Path(
        output_path,
        filename
    )
    if fp.exists():
        pass
    else:
        if none_if_not_exist:
            return None
        raise Exception('{} does not exist'.format(fp))

    logger.critical('Open output file - {}'.format(filename))
    return pickle.loads(fp.read_bytes())


def zsl(
    data,
    target_column_nm,
    class_nm_list
):
    ret = {}

    logger.critical('Start ZSL')
    start_t =datetime.datetime.now()

    texts = data[target_column_nm].values.tolist()

    #############################################################################
    # vector for sentences
    logger.critical('ZSL - vectorizing texts')
    sentences = []
    vectors = vectorize(texts)
    for i in range(0, len(texts)):
        sentence = Sentence(texts[i])
        sentences.append(sentence)
    ret['vectors'] = save_output_file(
        'vectors.pkl',
        vectors
    )

    logger.critical('ZSL - loading pre-trained model')
    tars = TARSClassifier.load('./resources/taggers/agnews_all/final-model.pt') # Load model

    logger.critical('ZSL - predicting zero shot')
    predictions = []
    for sentence in sentences:
        tars.predict_zero_shot(sentence, class_nm_list)
        number_list = [c['confidence'] for c in sentence.to_dict()['labels']]
        if len(number_list) >  0:
            max_value = max(number_list)
            max_index = number_list.index(max_value) 
            predictions.append(sentence.to_dict()['labels'][max_index]['value'])
        else:
            predictions.append('분류못함')

    ret['labels'] = save_output_file(
        'labels.pkl',
        predictions,
    )

    ret['clusters'] = save_output_file(
        'clusters.pkl',
        list(set(predictions)),
    )

    
    #############################################################################
    # build classifier
    logger.critical('ZSL - building classifier')
    X_train = tensor_to_array(vectors)
    y_train = predictions

    from sklearn.neighbors import KNeighborsClassifier
    classifier = KNeighborsClassifier(
        algorithm='auto', leaf_size=30, metric='euclidean',
        metric_params=None, n_jobs=None, n_neighbors=3, p=2,
        weights='distance')
    classifier.fit(X_train, y_train)
    ret['CLASSIFIER'] = save_output_file(
        'classifier.pkl',
        classifier,
    )
    
    elapsed_time = datetime.datetime.now() - start_t
    logger.critical('ZSL - finished, Elapsed time {}'.format(elapsed_time))

    return ret

def fsl(
    data,
    target_column_nm,
    samples,
    samples_target_column_nm,
    samples_tag_column_nm,
    related_stcs=None,
):
    ret = {}
    logger.critical('Start FSL')

    #############################################################################
    # Load data and Vectorize
    logger.critical('FSL - vectorizing texts')
    texts = data[target_column_nm].values.tolist()
    sentences = []
    vectors = vectorize(texts)
    for i in range(0, len(texts)):
        sentence = Sentence(texts[i])
        sentences.append(sentence)
    ret['vectors'] = save_output_file(
        'vectors.pkl',
        vectors
    )

    logger.critical('ZSL - loading pre-trained model')
    tars = TARSClassifier.load('resources/taggers/agnews_all/final-model.pt') # Load model

    smpl_stcs = samples[samples_target_column_nm].values.tolist()
    smpl_tags = samples[samples_tag_column_nm].values.tolist()

    class_nm_list = list(set(smpl_tags))
    # Split for Train and Test
    #x_train,x_test,y_train,y_test = train_test_split(smpl_stcs, smpl_tags, test_size=0.2, stratify=smpl_tags, random_state=3,shuffle=True)
    
    ## TODO if related_stcs:
    if related_stcs is None:
        logger.critical('FSL - No related sentences, converting samples')
        # convert samples
        tr = []
        for i in range(0, len(smpl_stcs)):
            tr.append(Sentence(smpl_stcs[i]).add_label('fsl', smpl_tags[i]))

        train = SentenceDataset(tr)
        corpus = Corpus(train=train)

        # make the model aware of the desired set of labels from the new corpus
        tars.add_and_switch_to_new_task("fsl", label_dictionary=corpus.make_label_dictionary()) 
        # initialize the text classifier trainer with corpus
        trainer = ModelTrainer(tars, corpus)
    
        # train model
        logger.critical('FSL - Training samples')
        trainer.train(base_path='./resources/taggers/fsl', # path to store the model artifacts
                      learning_rate=0.02, # use very small learning rate
                      mini_batch_size=1, # small mini-batch size since corpus is tiny
                      max_epochs=10, # terminate after 10 epochs
                      train_with_dev=True,
                      )
    
        logger.critical('FSL - Loading the trained model')
        tars = TARSClassifier.load('./resources/taggers/fsl/final-model.pt') # Load model

        logger.critical('FSL - Predicting using the trained model')
        predictions = []
        for sentence in sentences:
            tars.predict_zero_shot(sentence, class_nm_list) # Predict zero-shot
            number_list = [c['confidence'] for c in sentence.to_dict()['labels']]
            if len(number_list) > 0:
                max_value = max(number_list)
                max_index = number_list.index(max_value)
                predictions.append(sentence.to_dict()['labels'][max_index]['value'])
            else :
                predictions.append('분류못함')

        ret['labels'] = save_output_file(
            'labels.pkl',
            predictions,
        )

        ret['clusters'] = save_output_file(
            'clusters.pkl',
            list(set(predictions)),
        )

        #############################################################################
        # build classifier
        logger.critical('FSL - building classifier')
        X_train = tensor_to_array(vectors)
        y_train = predictions

        from sklearn.neighbors import KNeighborsClassifier
        classifier = KNeighborsClassifier(
            algorithm='auto', leaf_size=30, metric='euclidean',
            metric_params=None, n_jobs=None, n_neighbors=3, p=2,
            weights='distance')
        classifier.fit(X_train, y_train)
        ret['CLASSIFIER'] = save_output_file(
            'classifier.pkl',
            classifier,
        )

        return ret


def main(confpath, output_path, logger):
    if not os.path.isfile(confpath):
        raise Exception('no conf file')
    with open(confpath, 'rb') as f:
        conf = pickle.load(f)

    logger.critical(
        {k: conf[k] for k in sorted(conf.keys())}
    )

    worker_type = conf['labeler_worker_type']

    try:
        file = conf['file']
        target_column_nm = conf['target_column_nm']
        class_nm_list = conf.get('class_nm_list', None)
        if worker_type == 'zsl':

            target_df = load(file)
            d = zsl(
                target_df,
                target_column_nm,
                class_nm_list,
            )
            (output_path / 'output.json').write_text(
                json.dumps(d)
            )

        elif worker_type == 'fsl':
            samples_fp = conf['samples_fp']
            samples_target_column_nm = conf.get('samples_target_column_nm', target_column_nm)
            samples_tag_column_nm = conf['samples_tag_column_nm']
            related_stcs = conf.get('related_stcs', None)

            target_df = load(file)
            if samples_fp:
                sample_df = load(samples_fp)
            else:
                sample_df = None
            d = fsl(
                target_df,
                target_column_nm,
                sample_df,
                samples_target_column_nm,
                samples_tag_column_nm,
                related_stcs,
            )
            (output_path / 'output.json').write_text(
                json.dumps(d)
            )

        elif worker_type == "lb_predict":
            logger.critical('AccuTuning Labeler - lb_predict')
            classifier_fp = conf['classifier_fp']
            texts = conf['texts']
            bulk = conf['bulk']

            start_t = datetime.datetime.now()

            workspace_home = Path(os.environ.get('ACCUTUNING_WORKSPACE_ROOT', '/workspace'))

            classifier = pickle.loads((workspace_home / classifier_fp).read_bytes())

            mid_t = datetime.datetime.now()

            pred_list = []
            vectors = vectorize(texts)
            for vector in vectors:
                if torch.is_tensor(vector):
                    vector = np.array(vector.cpu()).reshape(1, -1)
                int_pred = classifier.predict(vector)
                pred_list.append(int_pred)

            end_t = datetime.datetime.now()


            if bulk:
                bulk_output_fp = conf['bulk_output_fp']
                pd.DataFrame({'input': texts, 'output': pred_list}).to_csv(bulk_output_fp, index=False)

            output_path.write_bytes(
                pickle.dumps(
                    dict(
                        pred=pred_list[0],
                        total_duration=(end_t - start_t).microseconds,
                        pred_duration=(end_t - mid_t).microseconds,
                        bulk_output_fp=str(bulk_output_fp.relative_to(workspace_home)) if bulk else ''
                    ),
                )
            )

        else:
            raise Exception('unknown worker_type:' + worker_type)

    except Exception as e:
        raise Exception(str(e))


def evaluate(target_df, pred):
    # stcs = target_df['stcs'].values.tolist()
    tags = target_df['tags'].values.tolist()

    print('분류 못함 ', dict(Counter(pred))['분류못함'], '건 포함 metrics')
    print(metrics.classification_report(tags,pred))
    print('')
    pred_expt = []
    tags_expt = []
    for idx, p in enumerate(pred) :
        if p == '분류못함' :
            pass
        else :
            pred_expt.append(pred[idx])
            tags_expt.append(tags[idx])
    print('분류 못함 ', dict(Counter(pred))['분류못함'], '건 제외 metrics')
    print(metrics.classification_report(tags_expt,pred_expt))


if __name__ == "__main__":
    workspace = os.environ.get('ACCUTUNING_WORKSPACE')
    flagpath = os.path.join(workspace, 'flag')
    if os.path.isdir(flagpath):
        shutil.rmtree(flagpath)
    os.makedirs(flagpath)
    with open(os.path.join(workspace, 'flag', 'STARTED'), 'wt') as f:
        f.write('')

    ###############################################################
    # build directories
    intermediate_path = Path(workspace, 'intermediate')
    output_path = Path(workspace, 'output')
    intermediate_path.mkdir(exist_ok=True)
    output_path.mkdir(exist_ok=True)

    confpath = os.path.join(workspace, 'conf.pkl')

    formatter = logging.Formatter('%(asctime)s %(levelname)s %(name)s - %(message)s')
    ch = logging.FileHandler(
        os.path.join(
            workspace,
            'labeling_fsl.log',
        ),
        mode='w'
    )
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    try:
        if workspace is None:
            raise Exception('no workspace env')
        main(confpath, output_path, logger=logger)
        with open(os.path.join(workspace, 'flag', 'DONE'), 'wt') as f:
            f.write('')
    except Exception as e:
        logging.critical(e, exc_info=True)
        with open(os.path.join(workspace, 'flag', 'ERROR'), 'wt') as f:
            f.write(str(e))
    finally:
        with open(os.path.join(workspace, 'flag', 'FINISHED'), 'wt') as f:
            f.write('')


CRITICAL:root:no conf file
Traceback (most recent call last):
  File "/tmp/ipykernel_539/1674242798.py", line 469, in <module>
    main(confpath, output_path, logger=logger)
  File "/tmp/ipykernel_539/1674242798.py", line 322, in main
    raise Exception('no conf file')
Exception: no conf file


## ZSL

In [31]:
data = pd.read_csv('sources/nnst_lt_1990.csv')
target = 'stcs'
cols = ['생활', '기술', '기타']

In [32]:
zsl(data, target, cols)

CRITICAL:root:Start ZSL
CRITICAL:root:ZSL - vectorizing texts
CRITICAL:root:Saved output file - vectors.pkl
CRITICAL:root:ZSL - loading pre-trained model


2021-07-21 01:37:10,284 loading file ./resources/taggers/agnews_all/final-model.pt


CRITICAL:root:ZSL - predicting zero shot


init TARS


CRITICAL:root:Saved output file - labels.pkl
CRITICAL:root:Saved output file - clusters.pkl
CRITICAL:root:ZSL - building classifier
CRITICAL:root:Saved output file - classifier.pkl
CRITICAL:root:ZSL - finished, Elapsed time 0:04:35.135375


{'vectors': 'output/vectors.pkl',
 'labels': 'output/labels.pkl',
 'clusters': 'output/clusters.pkl',
 'CLASSIFIER': 'output/classifier.pkl'}

In [35]:
vectors = pickle.load(open('output/vectors.pkl', 'rb'))
labels = pickle.load(open('output/labels.pkl', 'rb'))
clusters = pickle.load(open('output/clusters.pkl', 'rb'))

In [19]:
vectors

array([[-0.99578553, -0.00535922,  0.8316863 , ...,  0.19431722,
         0.11335345, -0.61855346],
       [-0.6068382 ,  0.12818483,  0.71900487, ...,  0.32662803,
         0.03273866, -0.59235084],
       [-0.80009407, -0.26659516,  0.8736665 , ...,  0.36003995,
         0.21039687, -0.6081318 ],
       ...,
       [-0.67031485, -0.04569531,  0.9157012 , ..., -0.03234894,
         0.3023741 , -0.47822633],
       [-0.59436864, -0.18450913,  1.0091789 , ...,  0.07618729,
         0.09041867, -0.26631007],
       [-0.6379605 , -0.15387407,  0.83737415, ..., -0.12898003,
        -0.0330771 , -0.38149998]], dtype=float32)

In [20]:
labels

['생활',
 '생활',
 '분류못함',
 '생활',
 '분류못함',
 '기술',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '분류못함',
 '생활',
 '분류못함',
 '분류못함',
 '생활',
 '생활',
 '분류못함',
 '생활',
 '분류못함',
 '생활',
 '분류못함',
 '분류못함',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '분류못함',
 '분류못함',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '생활',
 '분류못함',
 '분류못함',
 '생활',
 '생활',
 '생활',
 '분류못함',
 '분류못함',
 '생활',
 '생활',
 '기술',
 '분류못함',
 '분류못함',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '분류못함',
 '기술',
 '생활',
 '분류못함',
 '생활',
 '분류못함',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '생활',
 '분류못함',
 '분류못함',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '분류못함',
 '생활',
 '분류못함',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '분류못함',
 '분류못함',
 '생활',
 '생활',
 '생활',
 '기술',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '분류못함',
 '생활',
 '기술',
 '기술',
 '분류못함',
 '분류못함',
 '분류못함',
 '생활',
 '분류못함',
 '생활',
 '생활',
 '생

In [21]:
clusters

['분류못함', '생활', '기술', '기타']

In [36]:
from sklearn.metrics import classification_report

print(classification_report(data['tags'], labels))

              precision    recall  f1-score   support

          기술       0.79      0.21      0.33       995
          기타       0.00      0.00      0.00         0
        분류못함       0.00      0.00      0.00         0
          생활       0.78      0.63      0.70       995

    accuracy                           0.42      1990
   macro avg       0.39      0.21      0.26      1990
weighted avg       0.78      0.42      0.52      1990



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## FSL

In [23]:
samples = pd.read_csv('sources/nnst_lt_10.csv')
samples_target_column_nm = 'stcs'
samples_tag_column_nm = 'tags'

fsl(data, target, samples, samples_target_column_nm, samples_tag_column_nm)

CRITICAL:root:Start FSL
CRITICAL:root:FSL - vectorizing texts
CRITICAL:root:Saved output file - vectors.pkl
CRITICAL:root:ZSL - loading pre-trained model


2021-07-21 01:20:26,643 loading file resources/taggers/agnews_all/final-model.pt


CRITICAL:root:FSL - No related sentences, converting samples


init TARS
2021-07-21 01:20:35,711 Computing label dictionary. Progress:


100%|██████████| 9/9 [00:00<00:00, 7368.48it/s]

2021-07-21 01:20:35,716 [b'\xec\x83\x9d\xed\x99\x9c', b'\xea\xb8\xb0\xec\x88\xa0']



CRITICAL:root:FSL - Training samples


2021-07-21 01:20:35,719 ----------------------------------------------------------------------------------------------------
2021-07-21 01:20:35,723 Model: "TARSClassifier(
  (document_embeddings): None
  (decoder): None
  (loss_function): None
  (tars_model): TextClassifier(
    (document_embeddings): TransformerDocumentEmbeddings(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(119547, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in

CRITICAL:root:FSL - Loading the trained model


2021-07-21 01:21:22,720 loading file ./resources/taggers/fsl/final-model.pt


CRITICAL:root:FSL - Predicting using the trained model


init TARS


CRITICAL:root:Saved output file - labels.pkl
CRITICAL:root:Saved output file - clusters.pkl
CRITICAL:root:FSL - building classifier
CRITICAL:root:Saved output file - classifier.pkl


{'vectors': 'output/vectors.pkl',
 'labels': 'output/labels.pkl',
 'clusters': 'output/clusters.pkl',
 'CLASSIFIER': 'output/classifier.pkl'}

In [24]:
vectors = pickle.load(open('output/vectors.pkl', 'rb'))
labels = pickle.load(open('output/labels.pkl', 'rb'))
clusters = pickle.load(open('output/clusters.pkl', 'rb'))

In [25]:
vectors.shape

(1990, 768)

In [26]:
labels

['생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '기술',
 '생활',
 '기술',
 '기술',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '기술',
 '분류못함',
 '기술',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '기술',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '기술',
 '생활',
 '기술',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활',
 '생활',
 '생활',
 '기술',
 '생활',
 '생활

In [27]:
clusters

['기술', '생활', '분류못함']

In [29]:
print(classification_report(data['tags'], labels))

              precision    recall  f1-score   support

          기술       0.80      0.74      0.77       995
        분류못함       0.00      0.00      0.00         0
          생활       0.77      0.81      0.79       995

    accuracy                           0.77      1990
   macro avg       0.52      0.52      0.52      1990
weighted avg       0.78      0.77      0.78      1990



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
