In [1]:
import datasets
import re

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast

from tokenizers import Tokenizer
from tokenizers.models import *
from tokenizers.trainers import *
from tokenizers.pre_tokenizers import *
from tokenizers.processors import TemplateProcessing
from tokenizers.normalizers import Lowercase

import numpy as np
import pandas as pd

from tqdm import tqdm

In [2]:
if torch.cuda.is_available():
    [print(f"Device {i}: {torch.cuda.get_device_properties(i)}") for i in range(torch.cuda.device_count())]

Device 0: _CudaDeviceProperties(name='NVIDIA GeForce RTX 3070 Ti Laptop GPU', major=8, minor=6, total_memory=7982MB, multi_processor_count=46)


## Loading dataset

In [7]:
# dataset = datasets.load_dataset("cnn_dailymail", '3.0.0')
dataset = datasets.load_dataset("mlsum", 'ru')
dataset

Found cached dataset mlsum (/home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'summary', 'topic', 'url', 'title', 'date'],
        num_rows: 25556
    })
    validation: Dataset({
        features: ['text', 'summary', 'topic', 'url', 'title', 'date'],
        num_rows: 750
    })
    test: Dataset({
        features: ['text', 'summary', 'topic', 'url', 'title', 'date'],
        num_rows: 757
    })
})

In [8]:
dataset = dataset.remove_columns(['topic', 'url', 'title', 'date'])
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'summary'],
        num_rows: 25556
    })
    validation: Dataset({
        features: ['text', 'summary'],
        num_rows: 750
    })
    test: Dataset({
        features: ['text', 'summary'],
        num_rows: 757
    })
})

In [9]:
def clean_text(row):
    text_cleaner = lambda x: re.sub(r"\s+", " ",
                               re.sub("[^А-Яа-яЁёa-zA-Z0-9 -$?!,“”.%&\"\'=+*^<>\[\]]", "",
                                      re.sub("'", " ' ",
                                          re.sub("<[^>]+>", "",
                                                 re.sub("@\S+", "[REF]",
                                                        re.sub("https?:\/\/.*[\r\n]*", "[URL]", x))
                                                 )
                                          )
                                      )
                               ).strip()
    return {
        'text': [text_cleaner(x) for x in row['text']],
        'summary': [text_cleaner(x) for x in row['summary']],
    }

In [10]:
cleaned_dataset = dataset.map(clean_text, batch_size=100, batched=True, num_proc=10)
cleaned_dataset

            

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-2db87da7ba5a2c35.arrow
Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-cac6670a808bf77e.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-b92895f714dedce9.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-dba0b6cbe919a37e.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-303ec6742961adb8.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-67749293cd2eec98.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-1aea483462929a7c.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-7222a720cff0f6e9.arrow


  

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-f97a3f8478cb3859.arrow
Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-dfb8da3da708fb24.arrow


           

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-879b0b657d9f9de3.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-e52a161f5404eaff.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-54ad44783beca6be.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-899bdbc1b6134a24.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-ee5c53387e0780e9.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-d4c29d2976f14535.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-2ad12b4b052e0750.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-cd4fda0a196ac2aa.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-52a832c1b08ba2fa.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-7a905a8c5d0ce574.arrow


           

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-9a5213556b0212ed.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-99137d5328fa28bc.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-eb0a9561d0004c1e.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-db9fd8a74e89fe4a.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-938dc6765aea3b4c.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-8e5d34aff7414212.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-73d2040bb74196f2.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-1a2e6098cb4a0a12.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-8ba8ad64aa1aa110.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-0c782645405d06eb.arrow


DatasetDict({
    train: Dataset({
        features: ['text', 'summary'],
        num_rows: 25556
    })
    validation: Dataset({
        features: ['text', 'summary'],
        num_rows: 750
    })
    test: Dataset({
        features: ['text', 'summary'],
        num_rows: 757
    })
})

In [11]:
filtered_dataset = cleaned_dataset.filter(lambda row: len(row['text']) > 10 and len(row['summary']) > 10, num_proc=5)
filtered_dataset

 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-d52a46ca47f3c888_00000_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-d52a46ca47f3c888_00001_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-d52a46ca47f3c888_00002_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-d52a46ca47f3c888_00003_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-d52a46ca47f3c888_00004_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-e881add5aabdb6cb_00000_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-e881add5aabdb6cb_00001_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-e881add5aabdb6cb_00002_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-e881add5aabdb6cb_00003_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-e881add5aabdb6cb_00004_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-8be6550d478f32b4_00000_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-8be6550d478f32b4_00001_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-8be6550d478f32b4_00002_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-8be6550d478f32b4_00003_of_00005.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-8be6550d478f32b4_00004_of_00005.arrow


DatasetDict({
    train: Dataset({
        features: ['text', 'summary'],
        num_rows: 25556
    })
    validation: Dataset({
        features: ['text', 'summary'],
        num_rows: 750
    })
    test: Dataset({
        features: ['text', 'summary'],
        num_rows: 757
    })
})

## Custom tokenizer trainer

In [68]:
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
trainer = WordLevelTrainer(vocab_size=8_000,
                     min_frequency=1,
                     show_progress=True,
                     special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
tokenizer.pre_tokenizer = Sequence([Punctuation(), Whitespace()])
tokenizer.normalizer = Lowercase()

In [69]:
%%time
tokenizer.train_from_iterator(filtered_dataset['train']['text'], trainer)

CPU times: user 28.7 s, sys: 99.6 ms, total: 28.8 s
Wall time: 28.8 s


In [70]:
tokenizer.get_vocab_size()

8000

In [71]:
tokenizer.get_vocab()

{'пропал': 7211,
 'говорили': 652,
 'президенту': 1697,
 'готовности': 5174,
 'поста': 5297,
 'научного': 7260,
 '1015': 4402,
 'подписал': 4271,
 'инструменты': 7790,
 'неоднократно': 2089,
 'творчества': 4394,
 'высказывания': 7106,
 'голосование': 4904,
 'дабы': 5697,
 'приезда': 7433,
 'активный': 7441,
 'оружия': 1546,
 'процедуры': 3492,
 'уехать': 3976,
 'женщин': 880,
 'будущее': 1447,
 'запах': 4742,
 'оружием': 3499,
 'вроде': 544,
 'центры': 3395,
 'партию': 2635,
 'страны': 178,
 'этот': 105,
 'кабинете': 4014,
 'источников': 4262,
 'работала': 1829,
 'стало': 252,
 'музыки': 2628,
 'московские': 3174,
 'отрасли': 2754,
 'думала': 3065,
 'сутки': 1851,
 'главный': 480,
 'милиции': 1737,
 'рубль': 3111,
 'спорт': 1540,
 'позвонили': 4121,
 'для': 26,
 'решит': 5783,
 'футбольного': 6515,
 'онк': 7206,
 'партнерами': 7503,
 'полагают': 7379,
 'танцевать': 7686,
 'постепенно': 1603,
 'вошла': 7562,
 '36': 2761,
 'владимир': 288,
 'готов': 860,
 'вред': 4226,
 'официальные': 47

In [72]:
tokenizer.enable_truncation(512)
tokenizer.enable_padding(direction="right", pad_id=tokenizer.token_to_id("[PAD]"), length=512)

In [73]:
tokenizer.post_processor = TemplateProcessing(
    single="[CLS] $A [SEP]",
    pair="[CLS] $A [SEP] $B:1 [SEP]:1",
    special_tokens=[
        ("[CLS]", tokenizer.token_to_id("[CLS]")),
        ("[SEP]", tokenizer.token_to_id("[SEP]")),
    ],
)

In [74]:
tokenizer.save(f"ru_word_tokenizer_{tokenizer.get_vocab_size()}.json")

## Loading tokenizer

In [12]:
# tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
# tokenizer.model_max_length = 1024
# tokenizer.vocab_size

In [13]:
tokenizer = Tokenizer.from_file("ru_word_tokenizer_8000.json")

In [14]:
tokenizer.enable_truncation(512)
tokenizer.enable_padding(direction="right", pad_id=tokenizer.token_to_id("[PAD]"), length=512)

In [15]:
def tokenize_function(row):
    return {
        'text': [x.ids for x in tokenizer.encode_batch(row['text'])],
        'summary': [x.ids for x in tokenizer.encode_batch(row['summary'])]
    }

In [16]:
tokenized_dataset = filtered_dataset.map(tokenize_function, batch_size=100, batched=True, num_proc=6)
tokenized_dataset

        

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-b32c8073d41bd42f.arrow
Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-87203aa6216fd3d1.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-37f580ce6d1f93cb.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-bc80142709b1541e.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-962bc637dfc5fd3b.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-2774228a1aa4981e.arrow


       

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-0bc354b14c44103b.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-6cdbcca7ac9e2ea2.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-5b93ebc7a0e91a3c.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-08d0ed7f48e9dc82.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-da3e878a6d2ab042.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-088c78827e050124.arrow


       

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-1af9224433812071.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-a2c66bb8e63b2da0.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-3cbdb262b21b37d1.arrow


 

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-d6a19241a48bc1b5.arrow


  

Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-dd9d12ea5a744a14.arrow
Loading cached processed dataset at /home/hivaze/.cache/huggingface/datasets/mlsum/ru/1.0.0/033c69bbbf1eb198d444f668be75f297cb86251c0671a3d063d1c53c2f231076/cache-85efff09b58822c4.arrow


DatasetDict({
    train: Dataset({
        features: ['text', 'summary'],
        num_rows: 25556
    })
    validation: Dataset({
        features: ['text', 'summary'],
        num_rows: 750
    })
    test: Dataset({
        features: ['text', 'summary'],
        num_rows: 757
    })
})

In [17]:
tokenized_dataset.set_format('torch')

## Modeling

In [18]:
class SimpleLSTMWithEmbedding(nn.Module):
    def __init__(self, vocab_size, hidden_dim, output_dim, num_layers, num_heads=4, inner_dropout=0.1):
        super(SimpleLSTMWithEmbedding, self).__init__()

        self.embed = nn.Embedding(vocab_size, hidden_dim)

        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=num_heads,
                                               batch_first=True, bias=False,
                                               dropout=inner_dropout)

        self.lstm = nn.LSTM(input_size=hidden_dim,
                            hidden_size=hidden_dim,
                            batch_first=True,
                            dropout=inner_dropout,
                            num_layers=num_layers,
                            bidirectional=False)

        self.layer_norm = nn.LayerNorm(hidden_dim)

        self.out_proj = nn.Linear(hidden_dim, output_dim, bias=False)

    def forward(self, x):

        x = self.embed(x)
        # attn = self.attention(x, x, x, need_weights=False)[0]
        x, hidden = self.lstm(x)

        x = self.layer_norm(x)
        x = self.out_proj(x)

        return x, hidden, None

    @torch.no_grad()
    def sample(
            self,
            prompt: str,
            tokenizer: Tokenizer,
            num_steps: int = 10,
            temperature: float = 1.0
        ):
        prompt_encoding = tokenizer.encode(f"[CLS] {prompt}", add_special_tokens=False)
        token_ids = torch.tensor(prompt_encoding.ids, device=self.embed.weight.device)
        num_tokens = (1 - np.array(prompt_encoding.special_tokens_mask)).sum()
        for t in tqdm(range(num_steps), desc=f"Sampling {num_steps} steps.."):
            logits = self.forward(token_ids)[0][t + num_tokens].softmax(-1)
            logits_t = logits / temperature
            # p_wt = torch.distributions.Categorical(logits=logits_t)
            p_wt = torch.distributions.Categorical(probs=logits_t)
            tokens_t = p_wt.sample()
            token_ids[t + num_tokens] = tokens_t.item()
        return token_ids.detach()

In [None]:
model = SimpleLSTMWithEmbedding(vocab_size=tokenizer.get_vocab_size(),
                                hidden_dim=128,
                                output_dim=tokenizer.get_vocab_size(),
                                num_layers=3).cuda()
model

In [None]:
num_epochs = 6
batch_size = 16
clip_grad = 0.25

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

train_loader = DataLoader(tokenized_dataset['train'], batch_size=batch_size, pin_memory=True)
valid_loader = DataLoader(tokenized_dataset['validation'], batch_size=batch_size, pin_memory=True)

# Track loss
training_loss, validation_loss = [], []

for i in range(num_epochs):

    print(f'Starting epoch {i}...')

    epoch_training_loss = 0
    epoch_validation_loss = 0

    model.eval()

    with torch.no_grad():

        for batch in tqdm(valid_loader, f"Validation batch", total=len(valid_loader)):
            batch_size = batch['text'].shape[0]
            model_out, _, _ = model(batch['text'].cuda())

            one_hot_target = torch.zeros(model_out.shape, dtype=torch.float, device='cuda')
            target = F.one_hot(batch['text'])
            one_hot_target[:, :, :target.shape[2]] = target

            loss = F.cross_entropy(input=model_out, target=one_hot_target)

            epoch_validation_loss += loss.detach().cpu().numpy()

            del model_out, target, one_hot_target

        validation_loss.append(epoch_validation_loss / len(valid_loader))

    model.train()

    for batch in tqdm(train_loader, f"Train batch", total=len(train_loader)):
        batch_size = batch['text'].shape[0]
        model_out, _, _ = model(batch['text'].cuda())

        one_hot_target = torch.zeros(model_out.shape, dtype=torch.float, device='cuda')
        target = F.one_hot(batch['text'])
        one_hot_target[:, :, :target.shape[2]] = target

        loss = F.cross_entropy(input=model_out, target=one_hot_target)

        optimizer.zero_grad()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        optimizer.step()

        epoch_training_loss += loss.detach().cpu().numpy()

        if torch.rand(1).item() < 0.75 * (1 / len(train_loader)):
            print(f"Batch loss {loss.detach().item()}")

        del model_out, target, one_hot_target

    training_loss.append(epoch_training_loss / len(train_loader))

    print(f'Epoch {i} finidhed, training loss: {training_loss[-1]}, validation loss: {validation_loss[-1]}')

In [18]:
model.cpu()

SimpleLSTMWithEmbedding(
  (embed): Embedding(8000, 128)
  (attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=False)
  )
  (lstm): LSTM(128, 128, num_layers=3, batch_first=True, dropout=0.1)
  (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (out_proj): Linear(in_features=128, out_features=8000, bias=False)
)

In [37]:
tokenizer.decode(model.sample("Do you speak", tokenizer, temperature=0.8, num_steps=10).cpu().numpy())

Sampling 10 steps..: 100%|██████████| 10/10 [00:00<00:00, 16.53it/s]


'Do you speak welcomed scientists phone wearing heads scientists welcomed Age They scientists'

In [34]:
tokenizer.decode(model(tokenized_dataset['test']['document'][-3])[0].argmax(-1).detach().numpy(), skip_special_tokens=False)

'[CLS] CNN A federal criminal investigation into a deadly [UNK] boat [UNK] on a ended reports was [UNK] after the US facing whom determined the 17 deaths resulted from " misconduct , [UNK] or [UNK] to the duties " by the captain of the [UNK] boat , according to a court theyre filed Wednesday . The new investigation also is looking at another [UNK] boat captain and officials at the company that kept the tourist efforts , the court document says . There are several investigations into the July 19 incident , including the state of ended , which is also looking into criminal currently , and the National Transportation ones Board , which is trying to determine what caused the [UNK] . [UNK] of some of the people who died are , in four cases , suing [UNK] Entertainment , which runs the [UNK] boat [UNK] called [UNK] the [UNK] [UNK] . On Wednesday , the US government attached a motion to the civil cases , asking that a court rule that federal investigators be allowed first to talk to the [UNK] 

In [31]:
tokenizer.decode(tokenized_dataset['test']['summary'][-2].numpy())

'Note to tweeting politicians Watch what you post , because will remember it forever . The website is politicians \' deleted tweets , the rest of us to or over them at our , The Atlantic reports . The site \' s current includes a few , including John McCain Vladimir Putin \' s tears and Rep . Jeff Miller posting a link to a poll that asked , " Was Obama born in the United States ? " A few are more odd than obvious , us to ask what politicians were thinking . Why , for example , did Rep . Tom remove a tweet about going out one night with his wife ? Or Rep . delete one about her visit to a cancer ? Perhaps Rep . Stephen \' s tweet comparing The to The Games is a more obvious case , but the online of a politician \' s mind can be indeed .'

In [19]:
torch.rand(1).item()

0.9066071510314941

In [22]:
F.cross_entropy(input=model_out, target=torch.randint(0, 2, [32, 2048, 5000]).float())

tensor(7811.1294)

In [24]:
m = torch.zeros([32, 2048, 5000])
target = F.one_hot(batch['summary'])
m[:, :, :target.shape[2]] = target

In [29]:
m.shape

torch.Size([32, 2048, 5000])

In [63]:
F.cross_entropy(input=model_out, target=F.one_hot(batch['summary']))

RuntimeError: Expected target size [32, 5000], got [32, 2048, 4998]

In [61]:
F.one_hot(batch['summary']).float().shape

torch.Size([32, 2048, 4998])

In [34]:
tokenizer.enable_padding(length=)

{'length': None,
 'pad_to_multiple_of': None,
 'pad_id': 0,
 'pad_token': '[PAD]',
 'pad_type_id': 0,
 'direction': 'right'}

In [32]:
model_out.shape

torch.Size([32, 2048, 5000])

In [19]:
model(torch.randint(0, 100, [10, 10]).cuda())

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [21]:
torch.randint(0, 100, [10, 10]).cpu()

tensor([[90, 27, 63,  8, 96, 20, 92, 38, 69, 71],
        [55, 67, 17, 30, 83, 26, 93, 56, 62, 55],
        [ 8, 43, 45, 43, 52, 28, 21, 74, 39, 17],
        [58, 82, 20, 57,  1, 54,  2, 58, 28, 69],
        [60,  5, 94, 98, 70, 19, 86, 38, 10,  1],
        [65, 59, 44, 79, 39, 10, 13, 46, 45, 62],
        [53, 73, 97,  7, 53, 45, 83, 67, 63, 10],
        [90, 19,  7, 54, 85, 56,  6, 21, 78, 68],
        [91, 67, 38, 68, 44, 90, 11,  2, 34, 33],
        [97, 86, 48, 34, 93, 20,  5, 23, 85, 63]])

In [68]:
model(tokenized_dataset['train']['document'][0])

tensor([[ 0.0238,  0.0274, -0.0123,  ...,  0.0540, -0.0579, -0.0230],
        [ 0.0052,  0.0147, -0.0199,  ...,  0.0334, -0.0693, -0.0249],
        [-0.0004,  0.0063, -0.0295,  ...,  0.0413, -0.0693, -0.0271],
        ...,
        [ 0.0615,  0.0317,  0.0262,  ...,  0.0457, -0.0094,  0.0165],
        [ 0.0615,  0.0317,  0.0262,  ...,  0.0457, -0.0094,  0.0165],
        [ 0.0615,  0.0317,  0.0262,  ...,  0.0457, -0.0094,  0.0165]],
       grad_fn=<AddmmBackward0>)

In [19]:
target = tokenized_dataset['validation'][:10]['document']
target.shape

torch.Size([10, 2048])

In [20]:
target.unsqueeze(1).shape

torch.Size([10, 1, 2048])

In [21]:
model_out = model(target)
model_out.shape

torch.Size([10, 2048, 5000])

In [22]:
model_result = nn.LogSoftmax(dim=2)(model_out).argmax(-1)
model_result.shape

torch.Size([10, 2048])

In [23]:
model_result

tensor([[2746, 3379, 2746,  ..., 1612, 4819, 2746],
        [3921, 3921, 4347,  ..., 3379, 3379, 3379],
        [2746, 3332, 3379,  ..., 2343, 2746, 2746],
        ...,
        [2746, 2746, 2746,  ..., 3332, 2619, 3332],
        [2746, 2746, 2746,  ..., 2746, 2746, 2746],
        [3379, 2746, 2746,  ..., 3379, 3379, 3379]])

In [24]:
torch.empty(3, dtype=torch.long).random_(5).shape

torch.Size([3])

In [2]:
F.cross_entropy(input=model_out, target=F.one_hot(target).float())

NameError: name 'F' is not defined

In [112]:
tokenized_dataset['validation'].shuffle()[:10]['document'].__len__()

10

In [53]:
tokenizer.vocab_size

30522

In [55]:
nn.Embedding(tokenizer.vocab_size, 10)(torch.randint(0, 4858, size=[2, 1024])).shape

torch.Size([2, 1024, 10])

In [56]:
nn.Embedding(num_embeddings=tokenizer.vocab_size, embedding_dim=100)(tokenized_dataset['train']['document'][:2]).shape

torch.Size([2, 1024, 100])

In [47]:
tokenized_dataset['train']['document'][:2].shape

torch.Size([2, 1024])

In [84]:
nn.LSTM(input_size=1024, hidden_size=100, batch_first=True, num_layers=1, bidirectional=False)(torch.rand(size=[1, 1024]))[0].shape

torch.Size([1, 100])

In [66]:
torch.rand(size=[2, 1024, 100]).view(2, -1).shape

torch.Size([2, 102400])

In [74]:
nn.Softmax(dim=2)(torch.rand(size=[2, 1024, tokenizer.vocab_size])).shape

torch.Size([2, 1024, 30522])