In [1]:
import re
import pickle
import math
import os
import pandas as pd
import rouge
import codecs

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AdamW
from tqdm import tqdm

from src.model.logger import Logger
from src.model.data_full import RawFilesDataset
from src.model.loss import ParagraphLoss
from src.model.generate_utils import toks_to_str

from rake_nltk import Rake
from nltk.corpus import stopwords

# from src.model.generate_utils  import generate_paragraph
# from src.model.eval_utils import evaluate_doc_model
# from src.model.model import GPT2BaseModel


[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\edbon\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\edbon\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


# generate

In [None]:
from tokenizers.decoders import ByteLevel
decoder = ByteLevel()


In [None]:
str_rep = []
end_tok = encoder.convert_tokens_to_ids('_end_')

for token in sample_output[0]:
    print(token.item(), repr(decoder.decode([ encoder.convert_ids_to_tokens(token.item(), skip_special_tokens=True)])))
    if token.item() == end_tok : #or token.item() == 0:# or x.item() == end_idx:
        break        
    str_rep.append(encoder.convert_ids_to_tokens(token.item()))

str_rep = encoder.convert_tokens_to_string(str_rep)

# This makes sure rouge scorers doesn't complain about no sentences
if not str_rep:
    str_rep = "unk."
elif "." not in str_rep:
    str_rep += "."

print(encoder.decode(sample_output[0], skip_special_tokens=False, clean_up_tokenization_spaces=False))
print("-"*50)
print(str_rep)

# evaluate doc

In [None]:
class Config:
    repeattheta = 1.5
    output_attentions = True

In [None]:
args = Config()

In [None]:
vocab = len(encoder)

In [None]:
doc_model = GPT2BaseModel(args, vocab=vocab, n_ctx=config['n_ctx'], gen_len=401, lastidx=encoder.eos_token_id, includeprev=False, device='cpu')

In [None]:
evaluate_doc_model(model=doc_model, val_loader=val_loader, text_encoder=encoder, device='cpu', beam=0, gen_len=401, k=0, p=90, save_file='out', max_len=512, gen_dir=None, tgt_dir=None, min_len=100)

In [None]:
import json

In [None]:
with open('text.txt', 'w', encoding='utf-8') as f:
    json.dump("Моя строка", f, ensure_ascii=False)

In [None]:
df = pd.read_csv('generated/test.gens.tsv', sep='\t', header=None, names=['id', 'plot', 'context', 'part', 'text'])
df.head()

In [None]:
df.text[0]

# Generate

## rake

* для токенизаторов в исходном коде не применяется язык
* англ токенизатор предложений лучше делит, к примеру русский не смог разделить 'Король дал за дочкой богатое приданое, наградил зятя большим чином и задал пир на весь мир.\nЖивут молодые месяц, и два, и три.'

In [None]:
from rake_nltk import Metric, Rake

In [None]:
story = '111 Волшебное кольцо.txt'
path = 'dataset/raw'
with open(os.path.join(path, story), 'r', encoding='utf-8') as f:
    text =  f.read()
    text = re.sub('\.\.\.', '.', text)
    text = re.sub('—', '-', text)

## metrics

In [None]:
def rouge_scores(hyps, refs):       
    rouge_scorer = rouge.Rouge()
    averaged_scores = rouge_scorer.get_scores(hyps, refs, avg=True)
    return averaged_scores

## samples

In [2]:
text_encoder = GPT2Tokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2", add_prefix_space=True)
text_encoder.add_special_tokens({'bos_token': '<s>',                                     
                                    'eos_token': '</s>',
                                    'additional_special_tokens': ['[SEP]', '_kw_', '_endkw_']
                                })

3

In [3]:
with open('savedir/savedir/s_all_nodiscourse/checkpoints/checkpoint.pt', 'rb') as f:
    model = torch.load(f, map_location=torch.device('cpu'))

with open('savedir/savedir/s_all_nodiscourse/test_dataset', 'rb') as f:
    test = pickle.load(f)

In [4]:
test_dataset = RawFilesDataset(test, text_encoder, 2048, n_ctx=70)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [5]:
model.eval()
batch = next(iter(test_loader))

In [6]:
septok = text_encoder.convert_tokens_to_ids('[SEP]')
endtok = text_encoder.eos_token_id
input_ids, mask = batch['sample'], batch['mask']

sep_idx = torch.where(input_ids[0] == septok)[0].item()
eos_idx = torch.where(input_ids[0] == endtok)[0].item()
context = input_ids[:, :sep_idx+1]
target_txt = input_ids[:, sep_idx+1:eos_idx+1]

In [7]:
context_txt = text_encoder.decode(context[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)
# context_txt = re.sub('—|\.\.\.', ' ', context_txt)
# context = text_encoder.encode(context_txt, return_tensors='pt')

refs = text_encoder.decode(target_txt[0], skip_special_tokens=False, clean_up_tokenization_spaces=False)

In [8]:
len(context_txt.split())

54

In [9]:
print(context_txt)

<s> царь тотчас снял кольцо _kw_ кот васька говорит _kw_ кот взял кольцо _kw_ днем королевна носит кольцо _kw_ кот васька сел собаке _kw_ дня сидел кот васька _kw_ нам помогут кольцо найти _kw_ достать чудодейное кольцо _kw_ добывать чудодейное кольцо _kw_ кот васька цап _kw_ задумали кот васька _kw_ чудодейное кольцо _endkw_ [SEP]


In [None]:
sample_output = model.generate(
    context, 
    # attention_mask=mask,
    max_length=512, 
    do_sample=True,
    num_beams = 20,  # https://arxiv.org/pdf/2108.03502.pdf 
    top_p=0.95, # https://arxiv.org/pdf/2108.03502.pdf 
    top_k=3, # https://arxiv.org/pdf/2108.03502.pdf
    eos_token_id=endtok,
    bos_token_id=text_encoder.bos_token_id,
    decoder_start_token_id = septok,
    min_length = 100,
    num_return_sequences=1, 
    temperature=1.0, # https://arxiv.org/pdf/2108.03502.pdf
    repetition_penalty=2.0,  # https://arxiv.org/pdf/2108.03502.pdf
    no_repeat_ngram_size=3, # https://arxiv.org/pdf/2108.03502.pdf
    forced_eos_token_id = endtok,
    early_stopping=True  # https://arxiv.org/pdf/2108.03502.pdf
)

In [None]:
hyps1 = text_encoder.decode(sample_output[0][sep_idx+1:], skip_special_tokens=False, clean_up_tokenization_spaces=False)


In [None]:
print(rouge_scores(hyps1, refs))

In [None]:
print(hyps1)