## Load

In [None]:
!pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

In [None]:
!pip install pyrouge --upgrade
!pip install https://github.com/bheinzerling/pyrouge/archive/master.zip
!pip install pyrouge
!pip show pyrouge
!git clone https://github.com/andersjo/pyrouge.git
from pyrouge import Rouge155
!pyrouge_set_rouge_path 'pyrouge/tools/ROUGE-1.5.5'

In [None]:
!pip install transformers
!pip install tensorboardX
!pip install easydict

## Code

In [None]:
import os

os.chdir('KorBertSum/src')

In [None]:
import torch
import numpy as np
from models import data_loader, model_builder
from models.model_builder import Summarizer
from others.logging import logger, init_logger
from models.data_loader import load_dataset
from transformers import BertConfig, BertTokenizer
from tensorboardX import SummaryWriter
from models.reporter import ReportMgr
from models.stats import Statistics
import easydict
from multiprocessing.dummy import Pool as ThreadPool

In [None]:
import multiprocessing

NUM_CORES = multiprocessing.cpu_count()

In [None]:
print(NUM_CORES) # 32

## Preprocessing

In [None]:
def _tally_parameters(model):
    n_params = sum([p.nelement() for p in model.parameters()])
    return n_params

def build_trainer(args, device_id, model,
                  optim):
    """
    Simplify `Trainer` creation based on user `opt`s*
    Args:
        opt (:obj:`Namespace`): user options (usually from argument parsing)
        model (:obj:`onmt.models.NMTModel`): the model to train
        fields (dict): dict of fields
        optim (:obj:`onmt.utils.Optimizer`): optimizer used during training
        data_type (str): string describing the type of data
            e.g. "text", "img", "audio"
        model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object
            used to save the model
    """
    device = "cpu" if args.visible_gpus == '-1' else "cuda"


    grad_accum_count = args.accum_count
    n_gpu = args.world_size

    if device_id >= 0:
        gpu_rank = int(args.gpu_ranks[device_id])
    else:
        gpu_rank = 0
        n_gpu = 0

    print('gpu_rank %d' % gpu_rank)

    tensorboard_log_dir = args.model_path

    writer = SummaryWriter(tensorboard_log_dir, comment="Unmt")

    report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer)

    trainer = Trainer(args, model, optim, grad_accum_count, n_gpu, gpu_rank, report_manager)

    # print(tr)
    if (model):
        n_params = _tally_parameters(model)
        logger.info('* number of parameters: %d' % n_params)

    return trainer

class Trainer(object):
    """
    Class that controls the training process.

    Args:
            model(:py:class:`onmt.models.model.NMTModel`): translation model
                to train
            train_loss(:obj:`onmt.utils.loss.LossComputeBase`):
               training loss computation
            valid_loss(:obj:`onmt.utils.loss.LossComputeBase`):
               training loss computation
            optim(:obj:`onmt.utils.optimizers.Optimizer`):
               the optimizer responsible for update
            trunc_size(int): length of truncated back propagation through time
            shard_size(int): compute loss in shards of this size for efficiency
            data_type(string): type of the source input: [text|img|audio]
            norm_method(string): normalization methods: [sents|tokens]
            grad_accum_count(int): accumulate gradients this many times.
            report_manager(:obj:`onmt.utils.ReportMgrBase`):
                the object that creates reports, or None
            model_saver(:obj:`onmt.models.ModelSaverBase`): the saver is
                used to save a checkpoint.
                Thus nothing will be saved if this parameter is None
    """

    def __init__(self,  args, model,  optim,
                  grad_accum_count=1, n_gpu=1, gpu_rank=1,
                  report_manager=None):
        # Basic attributes.
        self.args = args
        self.save_checkpoint_steps = args.save_checkpoint_steps
        self.model = model
        self.optim = optim
        self.grad_accum_count = grad_accum_count
        self.n_gpu = n_gpu
        self.gpu_rank = gpu_rank
        self.report_manager = report_manager

        self.loss = torch.nn.BCELoss(reduction='none')
        assert grad_accum_count > 0
        # Set model in training mode.
        if (model):
            self.model.train()
            
    def summ(self, test_iter, step, cal_lead=False, cal_oracle=False):
          """ Validate model.
              valid_iter: validate data iterator
          Returns:
              :obj:`nmt.Statistics`: validation loss statistics
          """
          # Set model in validating mode.
          def _get_ngrams(n, text):
              ngram_set = set()
              text_length = len(text)
              max_index_ngram_start = text_length - n
              for i in range(max_index_ngram_start + 1):
                  ngram_set.add(tuple(text[i:i + n]))
              return ngram_set

          def _block_tri(c, p):
              tri_c = _get_ngrams(3, c.split())
              for s in p:
                  tri_s = _get_ngrams(3, s.split())
                  if len(tri_c.intersection(tri_s))>0:
                      return True
              return False

          if (not cal_lead and not cal_oracle):
              self.model.eval()
          stats = Statistics()

          with torch.no_grad():
              for batch in test_iter:
                  src = batch.src
                  labels = batch.labels
                  segs = batch.segs
                  clss = batch.clss
                  mask = batch.mask
                  mask_cls = batch.mask_cls

                  if (cal_lead):
                      selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                  elif (cal_oracle):
                      selected_ids = [[j for j in range(batch.clss.size(1)) if labels[i][j] == 1] for i in
                                      range(batch.batch_size)]
                  else:
                      sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)
                      sent_scores = sent_scores + mask.float()
                      sent_scores = sent_scores.cpu().data.numpy()
                      selected_ids = np.argsort(-sent_scores, 1)
          return selected_ids

    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            if self.grad_accum_count == 1:
                self.model.zero_grad()

            src = batch.src
            labels = batch.labels
            segs = batch.segs
            clss = batch.clss
            mask = batch.mask
            mask_cls = batch.mask_cls

            sent_scores, mask = self.model(src, segs, clss, mask, mask_cls)

            loss = self.loss(sent_scores, labels.float())
            loss = (loss*mask.float()).sum()
            (loss/loss.numel()).backward()
            # loss.div(float(normalization)).backward()

            batch_stats = Statistics(float(loss.cpu().data.numpy()), normalization)


            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [p.grad.data for p in self.model.parameters()
                             if p.requires_grad
                             and p.grad is not None]
                    distributed.all_reduce_and_rescale_tensors(
                        grads, float(1))
                self.optim.step()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [p.grad.data for p in self.model.parameters()
                         if p.requires_grad
                         and p.grad is not None]
                distributed.all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()
            
    def _save(self, step):
        real_model = self.model
        # real_generator = (self.generator.module
        #                   if isinstance(self.generator, torch.nn.DataParallel)
        #                   else self.generator)

        model_state_dict = real_model.state_dict()
        # generator_state_dict = real_generator.state_dict()
        checkpoint = {
            'model': model_state_dict,
            # 'generator': generator_state_dict,
            'opt': self.args,
            'optim': self.optim,
        }
        checkpoint_path = os.path.join(self.args.model_path, 'model_step_%d.pt' % step)
        logger.info("Saving checkpoint %s" % checkpoint_path)
        # checkpoint_path = '%s_step_%d.pt' % (FLAGS.model_path, step)
        if (not os.path.exists(checkpoint_path)):
            torch.save(checkpoint, checkpoint_path)
            return checkpoint, checkpoint_path

    def _start_report_manager(self, start_time=None):
        """
        Simple function to start report manager (if any)
        """
        if self.report_manager is not None:
            if start_time is None:
                self.report_manager.start()
            else:
                self.report_manager.start_time = start_time

    def _maybe_gather_stats(self, stat):
        """
        Gather statistics in multi-processes cases

        Args:
            stat(:obj:onmt.utils.Statistics): a Statistics object to gather
                or None (it returns None in this case)

        Returns:
            stat: the updated (or unchanged) stat object
        """
        if stat is not None and self.n_gpu > 1:
            return Statistics.all_gather_stats(stat)
        return stat

    def _maybe_report_training(self, step, num_steps, learning_rate,
                               report_stats):
        """
        Simple function to report training stats (if report_manager is set)
        see `onmt.utils.ReportManagerBase.report_training` for doc
        """
        if self.report_manager is not None:
            return self.report_manager.report_training(
                step, num_steps, learning_rate, report_stats,
                multigpu=self.n_gpu > 1)
        
    def _report_step(self, learning_rate, step, train_stats=None,
                     valid_stats=None):
        """
        Simple function to report stats (if report_manager is set)
        see `onmt.utils.ReportManagerBase.report_step` for doc
        """
        if self.report_manager is not None:
            return self.report_manager.report_step(
                learning_rate, step, train_stats=train_stats,
                valid_stats=valid_stats)

    def _maybe_save(self, step):
        """
        Save the model if a model saver is set
        """
        if self.model_saver is not None:
            self.model_saver.maybe_save(step)

class BertData():
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
        self.sep_vid = self.tokenizer.vocab['[SEP]']
        self.cls_vid = self.tokenizer.vocab['[CLS]']
        self.pad_vid = self.tokenizer.vocab['[PAD]']

    def preprocess(self, src):

        if (len(src) == 0):
            return None

        original_src_txt = [' '.join(s) for s in src]
        idxs = [i for i, s in enumerate(src) if (len(s) > 1)]

        src = [src[i][:2000] for i in idxs]
        src = src[:1000]

        if (len(src) < 3):
            return None

        src_txt = [' '.join(sent) for sent in src]
        text = ' [SEP] [CLS] '.join(src_txt)
        src_subtokens = self.tokenizer.tokenize(text)
        src_subtokens = src_subtokens[:510]
        src_subtokens = ['[CLS]'] + src_subtokens + ['[SEP]']

        src_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(src_subtokens)
        _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == self.sep_vid]
        segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
        segments_ids = []
        for i, s in enumerate(segs):
            if (i % 2 == 0):
                segments_ids += s * [0]
            else:
                segments_ids += s * [1]
        cls_ids = [i for i, t in enumerate(src_subtoken_idxs) if t == self.cls_vid]
        labels = None
        src_txt = [original_src_txt[i] for i in idxs]
        tgt_txt = None
        return src_subtoken_idxs, labels, segments_ids, cls_ids, src_txt, tgt_txt
    
def _lazy_dataset_loader(pt_file):
  yield  pt_file

## Params

In [None]:
args = easydict.EasyDict({
    "encoder":'classifier',
    "mode":'test',
    "bert_data_path":'../bert_data/korean',
    "model_path":'../models/bert_classifier',
    "result_path":'../results',
    "temp_dir":'../temp',
    "batch_size":1000,
    "use_interval":True,
    "hidden_size":128,
    "ff_size":512,
    "heads":4,
    "inter_layers":2,
    "rnn_size":512,
    "param_init":0,
    "param_init_glorot":True,
    "dropout":0.1,
    "optim":'adam',
    "lr":2e-3,
    "report_every":1,
    "save_checkpoint_steps":5,
    "block_trigram":True,
    "recall_eval":False,
    
    "accum_count":1,
    "world_size":1,
    "visible_gpus":'-1', # cpu
    "gpu_ranks":'0',
    "log_file":'../logs/train.txt',
    "test_from":'../models/bert_classifier/model_step_25000.pt'
})
model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers','encoder','ff_actv', 'use_interval','rnn_size']

## Test code

In [None]:
# test(args, input_data, -1, '', None)
pt = ''
step = None

init_logger(args.log_file)
device = "cpu" if args.visible_gpus == '-1' else "cuda"
device_id = 0 if device == "cuda" else -1

cp = args.test_from
try:
    step = int(cp.split('.')[-2].split('_')[-1])
except:
    step = 0

device = "cpu" if args.visible_gpus == '-1' else "cuda"
if (pt != ''):
    test_from = pt
else:
    test_from = args.test_from
logger.info('Loading checkpoint from %s' % test_from)
checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
opt = vars(checkpoint['opt'])
for k in opt.keys():
    if (k in model_flags):
        setattr(args, k, opt[k])

config = BertConfig.from_pretrained('bert-base-multilingual-cased')
model = Summarizer(args, device, load_pretrained_bert=False, bert_config = config)
model.load_cp(checkpoint)
model.eval()

In [None]:
# def test(args, input_list, device_id, pt, step):
def test(args, input_list):
  test_iter = data_loader.Dataloader(args, _lazy_dataset_loader(input_list),
                                args.batch_size, device,
                                shuffle=False, is_test=True)
  trainer = build_trainer(args, device_id, model, None)
  result = trainer.summ(test_iter, step)
  return result, input_list

# args.gpu_ranks = [int(i) for i in args.gpu_ranks.split(',')]
# os.environ["CUDA_VISIBLE_DEVICES"] = args.visible_gpus

In [None]:
def txt2input(text):
    data = list(filter(None, text.split('\n')))
    bertdata = BertData()
    txt_data = bertdata.preprocess(data)
    data_dict = {"src":txt_data[0],
               "labels":[0,1,2],
               "segs":txt_data[2],
               "clss":txt_data[3],
               "src_txt":txt_data[4],
               "tgt_txt":None}
    input_data = []
    input_data.append(data_dict)
    return input_data

## Load data

In [None]:
import pandas as pd

df_covid = pd.read_csv("../한겨레_코로나_content.csv")
df_covid.head()

In [None]:
from tqdm import tqdm

idxs = list(range(27237))
texts = df_covid['content'].tolist()
idx_text = list(zip(idxs, texts))

## Data processing

In [None]:
from tqdm import tqdm
import re

for i in tqdm(range(len(idx_text))):
    idx_text[i] = list(idx_text[i])
    if (idx_text[i][1] != ""):
        if (isinstance(idx_text[i][1], str)):
            text_list = idx_text[i][1].split('\n')
            re_text = ''

            for text in (text_list):
                if len(text) > 50:
                    text = re.sub(r"([\w\.-]+)@([\w\.-]+)(\.[\w\.]+)", "", text) # 이메일 검사

                    if " 기자" in text: # 기자 검사
                        repoter_check = text.split(" ")
                        while "기자" in repoter_check:
                            index = repoter_check.index("기자")
                            del repoter_check[index]
                            del repoter_check[index-1]
                        text = ' '.join(r for r in repoter_check)

                    re_text += (text + '\n')

            idx_text[i][1] = re_text
        
    else:
        idx_text[i][1] = ''

## Test

In [None]:
import time

def inference(texts):
    global idx_text, df_covid
    time.sleep(1)
    idx = texts[0]
    text = texts[1]
    if text == "" or pd.isna(text):
        new_list.append(
            {
                "index" : idx,
                "summary" : ""
            }
        )
        df_covid = df_covid.drop(idx, axis=0)
#         idx_text.remove(texts)
        return ''
    else:
        if len(text.split('\n')) <= 3: # 원문 기사가 짧은 경우. txt2input에서 none 타입이 됨
            new_list.append(
                    {
                        "index" : idx,
                        "summary" : text
                    }
            )
        else:
            input_data = txt2input(text)
            sum_list = test(args, input_data)
            result = [list(filter(None, text.split('\n')))[i] for i in sum_list[0][0][:2]]
            summary = (result[0] + " " + result[1])
            new_list.append(
                    {
                        "index" : idx,
                        "summary" : summary
                    }
            )
        df_covid = df_covid.drop(idx, axis=0)
#         idx_text.remove(texts)
        return ''

In [None]:
pool = ThreadPool(100)

In [None]:
new_list = []

try:
    for _ in tqdm(pool.imap_unordered(inference, idx_text[0:1000]), total=len(idx_text[0:1000])):
        pass
except RuntimeError:
    pool.close()
    pool.terminate()
    pool.join()
    
    df_covid.to_csv("../한겨레_코로나_content_1.csv")
    
    col_name = ["index", "summary"]
    news_df = pd.DataFrame(new_list, columns=col_name)

    csv_file_name = "../한겨레_covid_summary_1.csv"

    news_df.to_csv(csv_file_name)
    news_df.head()

## Test example

In [None]:
# text = '''
# 일본 정부가 문재인 대통령이 도쿄올림픽을 계기로 일본을 방문하는 방향으로 한일 양국이 조율하고 있다는 15일 요미우리신문의 보도를 부인했다.

# 정부 대변인인 가토 가쓰노부(加藤勝信) 관방장관은 이날 오전 정례 기자회견에서 관련 질문에 “말씀하신 보도와 같은 사실이 없는 것으로 안다”고 밝혔다.

# 앞서 요미우리는 한국 측이 도쿄올림픽을 계기로 한 문 대통령의 방일을 타진했고, 일본 측은 수용하는 방향이라고 이날 보도했다.

# 한국 측은 문 대통령의 방일 때 스가 요시히데(菅義偉) 총리와 처음으로 정상회담을 하겠다는 생각이라고 요미우리는 전했다.

# 가토 장관은 한일 정상회담에 대한 일본 정부의 자세에 대해 “그런 ㅌ사실이 없기 때문에 가정의 질문에 대해 답하는 것을 삼가겠다”고 말했다.

# 그는 한국 측의 독도방어훈련에 ‘어떤 대항 조치를 생각하고 있느냐’는 질문에는 “한국 해군의 훈련에 대해 정부로서는 강한 관심을 가지고 주시하는 상황이어서 지금 시점에선 논평을 삼가겠다”고 말을 아꼈다.

# 가토 장관은 “다케시마(竹島·일본이 주장하는 독도의 명칭)는 역사적 사실에 비춰봐도, 국제법상으로도 명백한 일본 고유의 영토”라며 독도 영유권 주장을 되풀이했다.

# 그러면서 “다케시마 문제에 대해서는 계속 우리나라의 영토, 영해, 영공을 단호히 지키겠다는 결의로 냉정하고 의연하게 대응해갈 생각”이라고 밝혔다.
# '''

In [None]:
# input_data = txt2input(text)
# sum_list = test(args, input_data)
# sum_list[0]

In [None]:
# [list(filter(None, text.split('\n')))[i] for i in sum_list[0][0][:3]][0]