In [12]:
import os
import re
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW
from nltk.corpus import stopwords
from rake_nltk import Rake
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader

# promt with text

In [46]:
params_config = {'pad_len': 1024, 
                'train_batch_size': 2,
                'lr': 6.25e-5,
                'chk_path': 'savedir/gpt3full'}

In [52]:
model_name = "sberbank-ai/rugpt3medium_based_on_gpt2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [51]:
text_encoder = GPT2Tokenizer.from_pretrained(model_name, add_prefix_space=True)
text_encoder.add_special_tokens({'bos_token':'_start_',
                                    'cls_token':'_classify_',
                                    'eos_token':'_end_',
                                    'additional_special_tokens': ['_kw_', '_endkw_']
                                })

model = GPT2LMHeadModel.from_pretrained(model_name)
model.resize_token_embeddings(len(text_encoder))

Embedding(50262, 1024)

## prepare texts

In [31]:
def preprocess_texts(path: str, output_path: str, file_names: list, output_file: str, topK: int):
    sentences_to_write = []
    sentences_to_write.append("[KEYWORDS]\t[TEXT]\n")
    
    f_out = open(os.path.join(output_path, output_file), 'w', encoding='utf-8')

    for file in file_names:
        with open(os.path.join(path, file), 'r', encoding='utf-8') as f:
            text = f.read()
            text = text.replace('\u2003', ' ')
            text = text.replace('\n', ' ').replace('  ', ' ').strip()

            try:
                r = Rake(language='russian', stopwords=stopwords.words())
                r.extract_keywords_from_text(text)
                top_features = r.get_ranked_phrases()
                if len(top_features) > topK:
                    top_features = top_features[:topK]

            except Exception:
                print(text[:50])
                continue

            keywordsSTR = "[SEP]".join([kw.strip() for kw in top_features if len(kw.split()) > 2]).strip()

            title = re.sub("[^А-Яа-я]" , " ", file.split('.')[0]).strip()        

            if len(title) > 2:
                title = title.lower().strip()
                keywordsSTR = title + '[SEP]' + keywordsSTR
                if len(keywordsSTR.split(' ')) > 100:
                    keywordsSTR = ' '.join(keywordsSTR.split(' ')[:100]).strip()

            sentences_to_write.append(keywordsSTR + '\t' + text + '\n')

    f_out.writelines(sentences_to_write)
    f_out.close()

In [33]:
path = 'dataset/raw/'
output_path = 'dataset/full/'
if not os.path.exists(output_path):
    os.mkdir(output_path)

input_files = os.listdir(path)

[os.remove(os.path.join(output_path, out_file)) for out_file in os.listdir(output_path)]

# 90/5/5
train, testval = train_test_split(input_files, test_size=0.1)
val, test = train_test_split(testval, test_size=0.5)

preprocess_texts(path, output_path, train, 'train_full', 5)
preprocess_texts(path, output_path, val, 'val_full', 5)
preprocess_texts(path, output_path, test, 'test_full', 5)

## dataset

In [102]:
class FullDataset(Dataset):
    def __init__(self, 
                 data_file: str, 
                 tokenizer: GPT2Tokenizer, 
                 pad_len: int,
                 ):

        with open(data_file, "rb") as f:
            data = f.readlines()

        self.data = []
        for d in range(1, len(data)):
            t = data[d].decode("utf-8", "ignore").strip().split('\t')
            if len(t) == 2:
                self.data.append(t)

        self.tokenizer = tokenizer
        self.pad_len = pad_len
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, indx):
        context, target_txt = self.data[indx]        

        context = self.tokenizer.encode(context)
        target_txt = self.tokenizer.encode(target_txt)
        
        clstok = self.tokenizer.cls_token_id
        keytok = self.tokenizer.convert_tokens_to_ids('_kw_')
        endkeytok = self.tokenizer.convert_tokens_to_ids('_endkw_')

        sample = [self.tokenizer.bos_token_id] + \
                 [keytok] + context + [endkeytok] + \
                 [clstok] + target_txt + \
                 [self.tokenizer.eos_token_id]
        
        if len(sample) <= self.pad_len:            
            mask = [1] * len(sample) + [0] * (self.pad_len - len(sample))
            label = sample + [-100] * (self.pad_len - len(sample))
            sample = sample + [self.tokenizer.bos_token_id] * (self.pad_len - len(sample))
        else:
            sample = sample[:self.pad_len]
            sample[-1] = self.tokenizer.eos_token_id
            mask = [1] * len(sample) + [0] * (self.pad_len - len(sample))
            label = sample + [-100] * (self.pad_len - len(sample))

        sample = torch.tensor(sample)
        mask = torch.tensor(mask)
        label = torch.tensor(label)

        return {
            'sample': sample, 
            'mask': mask, 
            'label': label
        } 

In [103]:
train_dataset = FullDataset(os.path.join(output_path, 'train_full'), text_encoder, params_config['pad_len'])
train_loader = DataLoader(train_dataset, params_config['train_batch_size'], shuffle=True)

val_dataset = FullDataset(os.path.join(output_path, 'val_full'), text_encoder, params_config['pad_len'])
val_loader = DataLoader(val_dataset, params_config['train_batch_size'], shuffle=True)

## train

In [87]:
model.to(device)
optimizer = AdamW(model.parameters(), lr=params_config['lr'])

In [88]:
def train_epoch(model, loader, test_loader, optimizer, epoch_num, device, log_interval=10):
    losses = []
    avg_loss = []
    step = 1
    for batch in loader:
        optimizer.zero_grad()
        input_ids, mask, label = batch['sample'], batch['mask'], batch['label']
        input_ids = input_ids.to(device)
        mask = mask.to(device)
        label = label.to(device)
        outputs = model(input_ids, attention_mask=mask, labels=label)
        loss, _ = outputs[:2]
        avg_loss.append(loss.detach().item())
        loss.backward()
        optimizer.step()
        if step % log_interval == 0:
            val_loss = sum(avg_loss) / len(avg_loss)
            losses.append(val_loss)
            avg_loss = []            
            torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()}, 
            params_config['chk_path'])         
        step += 1
    return losses

In [None]:
EPOCHS = 1
losses = []
for epoch in range(EPOCHS):
    ep_losses = train_epoch(model, train_loader, val_loader, optimizer, epoch, device, log_interval=2)

# evaluate

In [105]:
text_encoder.pad_token_type_id

0

In [106]:
text_encoder.cls_token_id

50258