In [1]:
import os
import sys
import numpy as np
from tqdm import tqdm
import time
import pickle

In [2]:
import logging

In [3]:
logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s [%(levelname)s] %(message)s", 
                    level=logging.INFO,
                    handlers=[
                        logging.FileHandler(os.path.join("./subtext_test_result.log")),
                        logging.StreamHandler()
                    ])

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

## Call model (BertSum/SubtextDivider)

### (1) BertSum

In [6]:
# path 추가
sys.path.append('/home/sks/korea_univ/21_1/TA/project/video_summarizer/Youtube-Summarizer/src/bertsum')

In [7]:
from models.data_loader import TextLoader, load_dataset
from bertsum import args, ExtTransformerEncoder, ExtSummarizer, WindowEmbedder

2021-05-02 19:55:37,975 [INFO] PyTorch version 1.1.0 available.


In [8]:
# Settings
device = "cpu" if args.visible_gpus == -1 else "cuda"
loader = TextLoader(args, device)

# model setting
ckpt_path = '/home/sks/korea_univ/21_1/TA/project/video_summarizer/Youtube-Summarizer/src/bertsum/checkpoint/model_step_24000.pt'
checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
bert_model = ExtSummarizer(args, device, checkpoint)
bert_model.eval()

using cached model
using cached model
using cached model
using cached model


2021-05-02 19:55:41,562 [INFO] loading configuration file ./tmp/kobert_from_pretrained/config.json
2021-05-02 19:55:41,564 [INFO] Model config BertConfig {
  "architectures": [
    "BertModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "type_vocab_size": 2,
  "vocab_size": 8002
}

2021-05-02 19:55:41,564 [INFO] loading weights file ./tmp/kobert_from_pretrained/pytorch_model.bin


using cached model


2021-05-02 19:55:42,850 [INFO] All model checkpoint weights were used when initializing BertModel.

2021-05-02 19:55:42,850 [INFO] All the weights of BertModel were initialized from the model checkpoint at ./tmp/kobert_from_pretrained.
If your task is similar to the task the model of the ckeckpoint was trained on, you can already use BertModel for predictions without further training.


ExtSummarizer(
  (bert): Bert(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(8004, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm(torch.Size([768]),

In [9]:
embedder = WindowEmbedder(model=bert_model, text_loader=loader)

### (2) SubtextDivider

In [10]:
from model import ChunkClassifier

subtext_model = ChunkClassifier(x_features=768)

model_path = '/home/sks/korea_univ/21_1/TA/project/video_summarizer/Youtube-Summarizer/src/subtext/chunk_nn/ckpt/chunk_model_ckpt.pt'
subtext_model.load_state_dict(torch.load(model_path))

subtext_model.eval()

ChunkClassifier(
  (block1): Sequential(
    (0): Conv1d(768, 128, kernel_size=(6,), stride=(1,))
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (block2): Sequential(
    (0): Linear(in_features=384, out_features=16, bias=True)
    (1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=16, out_features=1, bias=True)
  )
)

### (3) Score Calculator

In [11]:
def get_divscore(src_doc=[], embedder=None, divider=None):
    embedding = embedder.get_embeddings(src_doc).transpose(1, 0).unsqueeze(0)
    score = divider(embedding).item()
    return score

# Dataset for evaluation

In [12]:
import json
def load_jsonl(input_path) -> list:
    """
    Read list of objects from a JSON lines file.
    """
    data = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.rstrip('\n|\r')))
    print('Loaded {} records from {}'.format(len(data), input_path))
    return data

In [13]:
news_df = load_jsonl('/home/sks/korea_univ/21_1/TA/project/video_summarizer/Youtube-Summarizer/src/bertsum/dataset/train.jsonl')

Loaded 260697 records from /home/sks/korea_univ/21_1/TA/project/video_summarizer/Youtube-Summarizer/src/bertsum/dataset/train.jsonl


In [14]:
# 전처리
# (1) 글자 개수가 너무 작은 경우 없애기 (30글자 이상)
# (2) 문장이 적은 경우 해당 기사 없애기 (10문장 이상)
news_clean = []
for news in news_df:
    news_article = news['article_original']
    if len(news_article) >= 10:
        article_clean = [sent for sent in news_article if len(sent) >= 30]
        news_clean.append(article_clean)

In [15]:
import random

def make_mixed_doc(news_dataset=None, max_num=1000):
    mixed_doc_set = []
    for i in range(max_num):
        lh_count = min(random.randint(7, 10), len(news_dataset[i]))
        rh_count = min(random.randint(7, 10), len(news_dataset[i+1]))

        lh_news = news_dataset[i][:lh_count]
        rh_news = news_dataset[i+1][:rh_count]
        
        gt = lh_count - 1

        src_doc = '\n'.join((lh_news + rh_news))
        mixed_doc_set.append((src_doc, gt))
        
    return mixed_doc_set

In [16]:
random.seed(2020011135)
mixed_doc_list = make_mixed_doc(news_dataset=news_clean, max_num=1000)

## Evaluation

In [17]:
# Settings
loader = TextLoader(args, device)
window_size = 4

using cached model
using cached model


In [18]:
err_cnt = 0
acc_cnt = 0
ws = 4

div_result = []
for i, a_set in enumerate(mixed_doc_list):
    
    if (i+1) % 20 == 0:
        logger.info(f"working on {i+1}th doc: Accuracy so far is {acc_cnt/(acc_cnt+err_cnt)*100:.2f}%")
        
    src_doc = a_set[0].split('\n')
    gt = a_set[1]
    
    cands = [src_doc[i:i+ws*2] for i, _ in enumerate(src_doc) if i <= len(src_doc) - ws*2]
    
    # 가끔 한문장이 너무길어서 잘리는 경우가 있음..
    try:
        div_scores = [get_divscore(src_doc=cand, embedder=embedder, divider=subtext_model) for cand in cands]
        div_point = div_scores.index(max(div_scores)) + ws - 1

        if div_point == gt:
            acc_cnt += 1
        else:
            err_cnt += 1
            
    except RuntimeError as e:
        print(f"Error occurred at {i}th article")
    
#     sents = [sent for sent in src_doc.split('\n') if sent]
#     lh_sent, rh_sent = [], []
#     for i, sent in enumerate(sents):
#         if i <= div_point:
#             lh_sent.append(sent)
#         else:
#             rh_sent.append(sent)
            
#     result_sents = lh_sent + ["----------------[DIV]---------------"] + rh_sent
#     div_result.append((result_sents, div_scores, div_point, gt))

2021-05-02 19:57:07,751 [INFO] working on 20th doc: Accuracy so far is 73.68%


Error occurred at 30th article
Error occurred at 32th article
Error occurred at 33th article


2021-05-02 19:57:52,749 [INFO] working on 40th doc: Accuracy so far is 77.78%
2021-05-02 19:58:43,279 [INFO] working on 60th doc: Accuracy so far is 80.36%


Error occurred at 66th article
Error occurred at 67th article


2021-05-02 19:59:27,099 [INFO] working on 80th doc: Accuracy so far is 75.68%
2021-05-02 20:00:19,010 [INFO] working on 100th doc: Accuracy so far is 79.79%
2021-05-02 20:01:05,288 [INFO] working on 120th doc: Accuracy so far is 79.82%
2021-05-02 20:01:53,274 [INFO] working on 140th doc: Accuracy so far is 78.36%


Error occurred at 156th article


2021-05-02 20:02:39,330 [INFO] working on 160th doc: Accuracy so far is 79.74%
2021-05-02 20:03:28,138 [INFO] working on 180th doc: Accuracy so far is 79.77%


Error occurred at 191th article


2021-05-02 20:04:10,577 [INFO] working on 200th doc: Accuracy so far is 80.21%


Error occurred at 206th article
Error occurred at 207th article


2021-05-02 20:04:55,522 [INFO] working on 220th doc: Accuracy so far is 80.00%


Error occurred at 221th article
Error occurred at 222th article


2021-05-02 20:05:41,893 [INFO] working on 240th doc: Accuracy so far is 79.39%


Error occurred at 248th article


2021-05-02 20:06:30,241 [INFO] working on 260th doc: Accuracy so far is 79.35%


Error occurred at 265th article
Error occurred at 266th article


2021-05-02 20:07:16,536 [INFO] working on 280th doc: Accuracy so far is 78.87%


Error occurred at 291th article


2021-05-02 20:08:05,110 [INFO] working on 300th doc: Accuracy so far is 76.76%


Error occurred at 300th article
Error occurred at 312th article
Error occurred at 315th article
Error occurred at 316th article


2021-05-02 20:08:52,520 [INFO] working on 320th doc: Accuracy so far is 77.33%
2021-05-02 20:09:38,502 [INFO] working on 340th doc: Accuracy so far is 77.12%


Error occurred at 338th article


2021-05-02 20:10:25,756 [INFO] working on 360th doc: Accuracy so far is 77.29%
2021-05-02 20:11:14,110 [INFO] working on 380th doc: Accuracy so far is 77.16%


Error occurred at 387th article


2021-05-02 20:12:00,515 [INFO] working on 400th doc: Accuracy so far is 77.25%
2021-05-02 20:12:48,850 [INFO] working on 420th doc: Accuracy so far is 77.89%
2021-05-02 20:13:37,001 [INFO] working on 440th doc: Accuracy so far is 78.95%


Error occurred at 446th article
Error occurred at 447th article


2021-05-02 20:14:19,661 [INFO] working on 460th doc: Accuracy so far is 79.82%
2021-05-02 20:15:05,472 [INFO] working on 480th doc: Accuracy so far is 79.39%
2021-05-02 20:15:58,735 [INFO] working on 500th doc: Accuracy so far is 79.41%
2021-05-02 20:16:53,206 [INFO] working on 520th doc: Accuracy so far is 79.84%
2021-05-02 20:17:42,199 [INFO] working on 540th doc: Accuracy so far is 79.65%


Error occurred at 546th article


2021-05-02 20:18:28,825 [INFO] working on 560th doc: Accuracy so far is 79.63%
2021-05-02 20:19:19,594 [INFO] working on 580th doc: Accuracy so far is 79.82%
2021-05-02 20:20:10,408 [INFO] working on 600th doc: Accuracy so far is 80.00%


Error occurred at 599th article
Error occurred at 600th article


2021-05-02 20:20:55,960 [INFO] working on 620th doc: Accuracy so far is 79.93%


Error occurred at 631th article


2021-05-02 20:21:42,848 [INFO] working on 640th doc: Accuracy so far is 80.39%


Error occurred at 641th article
Error occurred at 648th article
Error occurred at 649th article


2021-05-02 20:22:23,885 [INFO] working on 660th doc: Accuracy so far is 80.60%


Error occurred at 663th article
Error occurred at 664th article
Error occurred at 667th article


2021-05-02 20:23:05,795 [INFO] working on 680th doc: Accuracy so far is 80.50%
2021-05-02 20:23:52,135 [INFO] working on 700th doc: Accuracy so far is 80.33%
2021-05-02 20:24:43,570 [INFO] working on 720th doc: Accuracy so far is 80.47%


Error occurred at 726th article


2021-05-02 20:25:32,328 [INFO] working on 740th doc: Accuracy so far is 80.57%
2021-05-02 20:26:20,843 [INFO] working on 760th doc: Accuracy so far is 80.28%
2021-05-02 20:27:06,800 [INFO] working on 780th doc: Accuracy so far is 80.54%
2021-05-02 20:27:59,666 [INFO] working on 800th doc: Accuracy so far is 80.50%


Error occurred at 798th article
Error occurred at 801th article


2021-05-02 20:28:44,063 [INFO] working on 820th doc: Accuracy so far is 80.46%
2021-05-02 20:29:38,363 [INFO] working on 840th doc: Accuracy so far is 80.45%


Error occurred at 854th article
Error occurred at 855th article


2021-05-02 20:30:22,763 [INFO] working on 860th doc: Accuracy so far is 80.27%


Error occurred at 869th article
Error occurred at 870th article


2021-05-02 20:31:04,383 [INFO] working on 880th doc: Accuracy so far is 80.33%
2021-05-02 20:31:47,949 [INFO] working on 900th doc: Accuracy so far is 79.98%
2021-05-02 20:32:33,937 [INFO] working on 920th doc: Accuracy so far is 79.64%
2021-05-02 20:33:25,227 [INFO] working on 940th doc: Accuracy so far is 79.76%
2021-05-02 20:34:13,123 [INFO] working on 960th doc: Accuracy so far is 79.33%


Error occurred at 967th article
Error occurred at 968th article


2021-05-02 20:34:59,979 [INFO] working on 980th doc: Accuracy so far is 79.51%


Error occurred at 983th article
Error occurred at 984th article


2021-05-02 20:35:49,014 [INFO] working on 1000th doc: Accuracy so far is 79.27%


In [22]:
print(f"{acc_cnt / (acc_cnt + err_cnt)*100:.2f}%")

79.29%


In [26]:
tmp_src = mixed_doc_list[100][0].split('\n')
tmp_cands = [tmp_src[i:i+ws*2] for i, _ in enumerate(tmp_src) if i <= len(tmp_src) - ws*2]

In [30]:
idx = [get_divscore(src_doc=cand, embedder=embedder, divider=subtext_model) for cand in tmp_cands]

In [32]:
idx

[-6.0954909324646,
 -7.948309421539307,
 -9.31032657623291,
 -7.533929347991943,
 1.1708508729934692,
 3.9606902599334717,
 5.838548183441162,
 3.081315755844116,
 -5.48937463760376,
 -11.055376052856445]

In [31]:
idx.index(max(idx))

6

In [33]:
tmp_cands[6]

['또 2016년에는 당진시 청소년 문화시설에 대한 실태조사를 통해 청소년 문화시설 만족도를 파악하고, 청소년들의 요구가 반영된 다양한 문화활동을 지원할 수 있도록 노력했다.',
 '올해에는 당진시 청소년이 바라는 진로교육 및 체험 내용과 형태를 파악하고 진로·직업 교육에 대한 의견실태조사가 진행됐다.',
 '또 다양한 진로교육 및 체험의 기회를 제공하고자 청소년어울림마당과 연계해 진로체험부스를 운영해 3년 연속 1위를 차지했다.',
 '한편 당진시청소년참여위원회는 다음달 22일 아동·청소년시설 인프라 확충방안을 주제로 한 토론회에서 당진시 청소년들의 입장을 발표할 예정이다.',
 "지난해 5월 발생한 '해외 유령주식' 사건과 같은 해외주식 결제 관련 사고의 재발을 방지하기 위한 대책이다.",
 '이병래 한국예탁결제원 사장(사진)은 27일 서울 여의도에서 열린 기자간담회에서 이 같은 내용을 포함한 하반기 주요 사업 계획을 발표했다.',
 '우선 예탁원은 해외주식 거래와 관련해 외국 보관기관에 과실 책임이 있을 때 배상을 요구할 수 있도록 특약을 체결하기로 했다.',
 '예탁원은 "유상증자, 무상증자, 액면분할 등과 관련해 권리행사가 필요한 경우 주식 보관기관이 이를 통지해줘야 하는데, 외국 기관이 국내에 이 정보를 제때 전달하지 않아 투자자에게 손해가 발생하면 해당 기관에 책임을 묻겠다는 것"이라고 설명했다.']