<a href="https://colab.research.google.com/github/lazyYC/EasyReadNews/blob/main/load_%26_predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### 工具準備
* clone repository
* install & import 需要用到套件

In [None]:
import os
!git clone 'https://github.com/lazyYC/EasyReadNews.git'
os.chdir('./EasyReadNews')

In [None]:
!pip install -q transformers  rouge-score
!pip3 install newspaper3k
!pip install 'torch>=1.6.0' editdistance matplotlib sacrebleu sacremoses sentencepiece tqdm wandb
!pip install --upgrade jupyter ipywidgets
!git clone https://github.com/pytorch/fairseq.git
!cd fairseq && git checkout 9a1c497
!pip install --upgrade ./fairseq/

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from newspaper import fulltext
import requests, sys, random, re, logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import numpy as np
import tqdm.auto as tqdm
from pathlib import Path
from argparse import Namespace
from fairseq import utils
import matplotlib.pyplot as plt
import sentencepiece as spm
from config import *
from fairseq.tasks.translation import TranslationConfig, TranslationTask
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-cnn-12-6")
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-cnn-12-6")

### 載入transformers, newspaper套件
* 輸入網址透過 newspaper 抓取新聞內容
* 交由 pre-trained model產生summerization
* 產出 txt 檔交由後續 translation model 使用

In [None]:
url = input("Please paste the url of the news to be summarized.")
text = fulltext(requests.get(url).text)
ARTICLE_TO_SUMMARIZE = text
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')

# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=100, early_stopping=True)
output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids][0]
str2 = output.split('.')

with open("summary_out_put.txt","w+") as f:
  for i in range(len(str2)):
    f.write(str2[i])
    if i != len(str2) - 1:
      f.write('\n')
with open("summary_out_put.txt", "r") as f:
  data = f.readlines()
  for i in data:
    print(i)

### 翻譯模型
* 引入所需要的tasks, models, config等等
* 引入預測的函式

In [None]:
## set device
## 如果GPU不能用，就把前兩行markdown
cuda_env = utils.CudaEnvironment()
utils.CudaEnvironment.pretty_print_cuda_env_list([cuda_env])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## setup task
task_cfg = TranslationConfig(
    data=config.datadir,
    source_lang=config.source_lang,
    target_lang=config.target_lang,
    train_subset="train",
    required_seq_len_multiple=8,
    dataset_impl="mmap",
    upsample_primary=1,
)
task = TranslationTask.setup_task(task_cfg)

##logging
logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level="INFO", # "DEBUG" "WARNING" "ERROR"
    stream=sys.stdout,
)
proj = "hw5.seq2seq"
logger = logging.getLogger(proj)
if config.use_wandb:
    import wandb
    wandb.init(project=proj, name=Path(config.savedir).stem, config=config)

## seed
seed = 73
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
np.random.seed(seed)  
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
def load_data_iterator(task, split, epoch=1, max_tokens=4000, num_workers=1, cached=True):
    batch_iterator = task.get_batch_iterator(
        dataset=task.dataset(split),
        max_tokens=max_tokens,
        max_sentences=None,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            max_tokens,
        ),
        ignore_invalid_inputs=True,
        seed=seed,
        num_workers=num_workers,
        epoch=epoch,
        disable_iterator_cache=not cached,
        # Set this to False to speed up. However, if set to False, changing max_tokens beyond 
        # first call of this method has no effect. 
    )
    return batch_iterator

def try_load_checkpoint(model, optimizer=None, name=None):
    name = name if name else "checkpoint_last.pt"
    checkpath = Path(config.savedir)/name
    if checkpath.exists():
        check = torch.load(checkpath)# （取消comment時記得把左邊括號刪掉）, map_location=torch.device('cpu'))
        model.load_state_dict(check["model"])
        stats = check["stats"]
        step = "unknown"
        if optimizer != None:
            optimizer._step = step = check["optim"]["step"]
        logger.info(f"loaded checkpoint {checkpath}: step={step} loss={stats['loss']} bleu={stats['bleu']}")
    else:
        logger.info(f"no checkpoints found at {checkpath}!")

from NoamOpt import *

## set args
arch_args = Namespace(
    encoder_embed_dim=256,
    encoder_ffn_embed_dim=1024,
    encoder_layers=4,
    decoder_embed_dim=256,
    decoder_ffn_embed_dim=1024,
    decoder_layers=4,
    share_decoder_input_output_embed=True,
    dropout=0.3,
)
def add_transformer_args(args):
    args.encoder_attention_heads=4
    args.encoder_normalize_before=True
    
    args.decoder_attention_heads=4
    args.decoder_normalize_before=True
    
    args.activation_fn="relu"
    args.max_source_positions=1024
    args.max_target_positions=1024
    
    # 補上我們沒有設定的Transformer預設參數
    from fairseq.models.transformer import base_architecture 
    base_architecture(arch_args)

add_transformer_args(arch_args)

In [None]:
from classSeq2Seq import *
model = build_model(arch_args, task)
logger.info(model)

optimizer = NoamOpt(
    model_size=arch_args.encoder_embed_dim, 
    factor=config.lr_factor, 
    warmup=config.lr_warmup, 
    optimizer=torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.0001))
plt.plot(np.arange(1, 100000), [optimizer.rate(i) for i in range(1, 100000)])
plt.legend([f"{optimizer.model_size}:{optimizer.warmup}"])

try_load_checkpoint(model, optimizer, name=config.resume)

In [None]:
torch.load('./checkpoints/transformer/checkpoint13.pt')#（刪掉左邊括號）, map_location=torch.device('cpu'))
model.to(device)

In [None]:
## 刪掉上次predict可能殘存的test檔
to_del_raw = ['test.en', 'test.raw.en', 'test.raw.clean.en']
for f in to_del_raw:
  if Path(f'./DATA/rawdata/ted2020/{f}').exists():
    !rm ./DATA/rawdata/ted2020/{f}
    print(f'{f} is deleted now.')
  else:
    print(f'{f} does not exists.')
    
## 這邊上傳要預測的文字檔（summarized）
!cp './summary_out_put.txt' './DATA/rawdata/ted2020/test.raw.en' 

* 把摘要模型的output，複製一份作為翻譯模型的input

In [None]:
!cp './summary_out_put.txt' './DATA/rawdata/ted2020/test.raw.en' 

In [None]:
# 把test.zh改成跟test.en一樣多行
with open('./DATA/rawdata/ted2020/test.raw.en', 'r') as en:
  length = len(en.readlines())
  tmp = open('./DATA/rawdata/ted2020/test.zh', 'r')
  repeat = tmp.readline()
  tmp.close()
  with open('./DATA/rawdata/ted2020/test.zh', 'w') as zh:
    for i in range(length):
      zh.write(repeat)    

### 欲翻譯文本的前處理
* 清除多餘標點符號
* 斷成subwords，以孑孓常遇到未登錄詞的問題
* binarize

In [None]:
from cleanse import *
clean_corpus('./DATA/rawdata/ted2020/test.raw', 'en', 'zh', ratio=-1, min_len=-1, max_len=-1)
spm_model = spm.SentencePieceProcessor(model_file=str('./DATA/rawdata/ted2020/spm8000.model'))

# input的資料
if Path('./DATA/rawdata/ted2020/test.en').exists():
  print('資料已經轉成subwords，跳過此步驟')
else:
  with open('./DATA/rawdata/ted2020/test.en', 'w+') as out_f:
    with open('./DATA/rawdata/ted2020/test.raw.clean.en' ,'r') as f:
      for line in f:
        line = line.strip()
        tok = spm_model.encode(line, out_type=str)
        print(' '.join(tok), file=out_f)
# !head {'./DATA/rawdata/ted2020/test.en'} -n 10

## 把上次binarize的bin清除
p = './DATA/data-bin/ted2020/'
ToDelEveryPred = ['test.en-zh.en.bin', 'test.en-zh.en.idx', 'test.en-zh.zh.bin', 'test.en-zh.zh.idx']
for f in ToDelEveryPred:
  if Path(p + f).exists():
    !rm {p}{f}
    print(f'{f} is deleted now')

## binarize
binpath = Path('./DATA/data-bin', 'ted2020')
!python -m fairseq_cli.preprocess \
    --source-lang 'en'\
    --target-lang 'zh'\
    --testpref './DATA/rawdata/ted2020/test'\
    --destdir {binpath}\
    --srcdict './DATA/data-bin/ted2020/dict.en.txt' \
    --tgtdict './DATA/data-bin/ted2020/dict.zh.txt' \
    --workers 2

#################################################################

sequence_generator = task.build_generator([model], config)

def decode(toks, dictionary):
    # 從 Tensor 轉成人看得懂的句子
    s = dictionary.string(
        toks.int().cpu(),
        config.post_process,
    )
    return s if s else "<unk>"

def inference_step(sample, model):
    gen_out = sequence_generator.generate([model], sample)
    srcs = []
    hyps = []
    refs = []
    for i in range(len(gen_out)):
        # 對於每個 sample, 收集輸入，輸出和參考答案，稍後計算 BLEU
        srcs.append(decode(
            utils.strip_pad(sample["net_input"]["src_tokens"][i], task.source_dictionary.pad()), 
            task.source_dictionary,
        ))
        hyps.append(decode(
            gen_out[i][0]["tokens"], # 0 代表取出 beam 內分數第一的輸出結果
            task.target_dictionary,
        ))
        refs.append(decode(
            utils.strip_pad(sample["target"][i], task.target_dictionary.pad()), 
            task.target_dictionary,
        ))
    return srcs, hyps, refs


def generate_prediction(model, task, split="test", outfile="./prediction.txt"):    
    task.load_dataset(split=split, epoch=1)
    itr = load_data_iterator(task, split, 1, config.max_tokens, config.num_workers).next_epoch_itr(shuffle=False)
    idxs = []
    hyps = []

    model.eval()
    progress = tqdm.tqdm(itr, desc=f"prediction")
    with torch.no_grad():
        for i, sample in enumerate(progress):
            # validation loss
            sample = utils.move_to_cuda(sample, device=device)

            # 進行推論
            s, h, r = inference_step(sample, model)
            
            hyps.extend(h)
            idxs.extend(list(sample['id']))
            
    # 根據 preprocess 時的順序排列
    hyps = [x for _,x in sorted(zip(idxs,hyps))]
    print(hyps)
    with open(outfile, "w") as f:
            for i in hyps:
              f.write(i+"\n")

### 進行預測

In [None]:
generate_prediction(model, task)
p = './DATA/data-bin/ted2020/'
ToDelEveryPred = ['test.en-zh.en.bin', 'test.en-zh.en.idx', 'test.en-zh.zh.bin', 'test.en-zh.zh.idx']
for f in ToDelEveryPred:
  if Path(p + f).exists():
    !rm {p}{f}
    print(f'{f} is deleted now')