Подготим сначала датасет по текстовому файлу. Создадим для этого специальный класс

In [23]:
import sentencepiece as spm
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer
import numpy as np
import tensorflow as tf

class TextDataset:
    
    def __init__(self, data_file, sp_model_prefix, vocab_size = 2000, normalization_rule_name = 'nmt_nfkc_cf',
                 model_type = 'bpe', max_length = 128):

        SentencePieceTrainer.train(input=data_file, vocab_size=vocab_size, model_type=model_type, model_prefix=sp_model_prefix,
            normalization_rule_name=normalization_rule_name, pad_id=0, bos_id=1, eos_id=2, unk_id=3)
        
        self.sp_model = SentencePieceProcessor(model_file=sp_model_prefix + '.model')
        
        with open(data_file, 'r', encoding='utf-8') as file:
            self.texts = [line.strip() for line in file.readlines()]
        
        self.max_length = max_length
        self.vocab_size = self.sp_model.vocab_size()
        self._encoded_cache = None
        
    def encode_texts(self):
        if self._encoded_cache is not None:
            return self._encoded_cache
        sequences = []
        targets = []
        for text in self.texts:
            encoded = self.sp_model.encode(text)
            if len(encoded) > self.max_length - 2:
                encoded = encoded[:self.max_length - 2]
            sequence = [self.sp_model.bos_id()] + encoded + [self.sp_model.eos_id()]
            if len(sequence) < self.max_length:
                sequence = sequence + [self.sp_model.pad_id()] * (self.max_length - len(sequence))
            else:
                sequence = sequence[:self.max_length]
            target = sequence[1:] + [self.sp_model.pad_id()]
            sequences.append(sequence)
            targets.append(target)
        self._encoded_cache = (np.array(sequences), np.array(targets))
        return self._encoded_cache
    
    def create_tf_dataset(self, batch_size = 32, shuffle = True):
        sequences, targets = self.encode_texts()
        dataset = tf.data.Dataset.from_tensor_slices((sequences, targets))
        if shuffle:
            dataset = dataset.shuffle(buffer_size=len(sequences))
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        return dataset

Загрузим датасет из файла

In [None]:
dataset = TextDataset(data_file="dataset.txt", sp_model_prefix="jokes_spm", vocab_size=2000, max_length=100)

X, y = dataset.encode_texts()
tf_dataset = dataset.create_tf_dataset(batch_size=32, shuffle=True)

Создадим и обучим рекуррентную сеть на основе LSTM

In [12]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, Bidirectional

model = Sequential([
    Embedding(input_dim=dataset.vocab_size, output_dim=128, input_length=dataset.max_length, mask_zero=True),
    Bidirectional(LSTM(256, return_sequences=True)),
    Dropout(0.3),
    LSTM(256, return_sequences=True),
    Dropout(0.3),
    Dense(dataset.vocab_size, activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(tf_dataset, epochs=10, verbose=1)

Epoch 1/10
[1m3774/3774[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m107s[0m 28ms/step - accuracy: 0.0853 - loss: 4.8587
Epoch 2/10
[1m3774/3774[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m104s[0m 28ms/step - accuracy: 0.3270 - loss: 0.7651
Epoch 3/10
[1m3774/3774[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m104s[0m 28ms/step - accuracy: 0.3885 - loss: 0.1167
Epoch 4/10
[1m3774/3774[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 28ms/step - accuracy: 0.3954 - loss: 0.0297
Epoch 5/10
[1m3774/3774[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 28ms/step - accuracy: 0.3968 - loss: 0.0135
Epoch 6/10
[1m3774/3774[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 28ms/step - accuracy: 0.3957 - loss: 0.0081
Epoch 7/10
[1m3774/3774[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 28ms/step - accuracy: 0.3973 - loss: 0.0055
Epoch 8/10
[1m3774/3774[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 28ms/step - accuracy: 0.3976 - loss: 0.0041


<keras.src.callbacks.history.History at 0x7c1a3f419450>

Создадим класс для генерации с помощью обученной модели

In [19]:
class TextGenerator:
    
    def __init__(self, model, sp_model, max_length, temperature):

        self.model = model
        self.sp_model = sp_model
        self.max_length = max_length
        self.temperature = temperature
        self.bos_id = sp_model.bos_id()
        self.eos_id = sp_model.eos_id()
        self.pad_id = sp_model.pad_id()
        self.unk_id = sp_model.unk_id()
        
    def preprocess_text(self, text):
        encoded = self.sp_model.encode(text)
        if len(encoded) > self.max_length - 2:
            encoded = encoded[:self.max_length - 2]
        sequence = encoded
        if len(sequence) < self.max_length:
            sequence = sequence + [self.pad_id] * (self.max_length - len(sequence))
        else:
            sequence = sequence[:self.max_length]
        return np.array([sequence])
    
    def sample_next_token(self, logits):
        logits = logits / self.temperature
        probs = np.exp(logits) / np.sum(np.exp(logits))
        return np.random.choice(len(probs), p=probs)
    
    def generate_sequence(self, prompt, max_gen_length = 50):
        input_sequence = self.preprocess_text(prompt)
        current_length = np.sum(input_sequence[0] != self.pad_id)
        generated_tokens = []
        
        for i in range(self.max_length):
            if current_length >= self.max_length:
                break
            predictions = self.model.predict(input_sequence, verbose=0)
            last_token_logits = predictions[0, current_length - 1, :]
            next_token = self.sample_next_token(last_token_logits)
            input_sequence[0, current_length] = next_token
            generated_tokens.append(next_token)
            current_length += 1
        
        if prompt:
            generated_text = self.sp_model.decode(generated_tokens)
            return prompt + generated_text
        else:
            all_tokens = input_sequence[0][:current_length].tolist()
            if all_tokens and all_tokens[0] == self.bos_id:
                all_tokens = all_tokens[1:]
            return self.sp_model.decode(all_tokens)

Сгенерируем последовательность по промпту

In [20]:
generator = TextGenerator(model=model, sp_model=dataset.sp_model, max_length=100, temperature=0.8)
prompt = "Анекдот"
generated = generator.generate_sequence(prompt=prompt)
generated

'Анекдотак сказал зап закры запа каждо клаить государ заня тутбыного ино занима случатный твойвом стоит мар.- выпу ту против блиндешь папактор крожил уби рабо этой подходит соност человек зво соба,паатдом расвтоняхшать думаю?денок, эвить! апские кла работатьльный ша оргасемскогоэ сразу ту которая амери некостит владими когда прекра длянымичился сто говор человего девушки постели слулет студен шиниид конекро'

In [22]:
with open('generate_result.txt', 'w') as file:
    file.write(generated)