# Dependencies

In [1]:
!pip install striprtf

Collecting striprtf
  Downloading striprtf-0.0.29-py3-none-any.whl.metadata (2.3 kB)
Downloading striprtf-0.0.29-py3-none-any.whl (7.9 kB)
Installing collected packages: striprtf
Successfully installed striprtf-0.0.29


In [2]:
import striprtf
import re
from striprtf.striprtf import rtf_to_text
import string

import nltk
from nltk.tokenize import word_tokenize
nltk.download('punkt')
nltk.download('punkt_tab')

import matplotlib.pyplot as plt
import numpy as np
import re, string
from collections import Counter

import math
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, random_split
import random
from tqdm.auto import tqdm

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /usr/share/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [3]:
import collections 

# Load data

In [4]:
def load_convert_clean(file_path, cleaner_func):
    """Loads an RTF file, converts it to plain text, and cleans it."""
    try:
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            rtf_content = f.read()

        # Convert RTF to plain text
        plain_text = rtf_to_text(rtf_content)

        # Apply the cleaning function
        text_prepared = cleaner_func(plain_text)
        
        return text_prepared

    except FileNotFoundError:
        print(f"Error: File not found at {file_path}. Returning empty string.")
        return ""
    except Exception as e:
        print(f"An error occurred while processing {file_path}: {e}")
        return ""

In [5]:
def clean_text(text):
    """
    Cleans the input text by removing non-Cyrillic characters, digits, and extra whitespace.

    :param text: The original text to be cleaned.
    :return: A cleaned string containing only Cyrillic characters, with digits and extra whitespace removed.
    """
    cleaned_text = re.sub(r"[^\u0400-\u04FF\s]", "", text)  
    cleaned_text = re.sub(r"\d+", "", cleaned_text) 
    cleaned_text = cleaned_text.lower()
    cleaned_text = re.sub(r"[\n\t]", " ", cleaned_text)  
    cleaned_text = re.sub(r"\s+", " ", cleaned_text) 
    return cleaned_text.strip()

In [6]:
file_path_1 = '/kaggle/input/dune-frank-herbert/dune/dune.rtf'
file_path_2 = '/kaggle/input/dune-frank-herbert/dune/dune-messiah.rtfd/TXT.rtf'

text_dune = load_convert_clean(file_path_1, clean_text)
text_dune_messiah = load_convert_clean(file_path_2, clean_text)

separator = "\n\n--- END OF DUNE / START OF DUNE MESSIAH ---\n\n"
combined_text = text_dune + separator + text_dune_messiah

In [7]:
len(combined_text)

1422330

# Prepare data

In [8]:
word_list = combined_text.split()
word_count = len(word_list)

print(f" Total number of words in the prepared text: {word_count}")

 Total number of words in the prepared text: 224040


In [9]:
MAX_VOCAB = 20000
word_list = combined_text.split()
counter = collections.Counter(word_list)
specials = ["[PAD]", "[UNK]"]
most_common = counter.most_common(MAX_VOCAB - len(specials))
itos = specials + [w for w, _ in most_common]       
stoi = {w: i for i, w in enumerate(itos)}           
PAD_IDX = stoi["[PAD]"]
UNK_IDX = stoi["[UNK]"]

print(f"Vocabulary Size: {len(stoi)}")
print(f"Example UNK Index (for 'pad'): {stoi.get('pad', UNK_IDX)}")

Vocabulary Size: 20000
Example UNK Index (for 'pad'): 1


Let's encode text and add padding.

In [10]:
tokens = [word_tokenize(t) for t in combined_text]

In [11]:
def encode(tokens):
    return [stoi.get(t, UNK_IDX) for t in tokens]

encoded_texts = [encode(tok_list) for tok_list in tokens]
print(encoded_texts[0][:20])

[1772]


In [12]:
MAX_LEN = 100  
def pad_sequence(seq):
    seq = seq[:MAX_LEN] + [PAD_IDX] * max(0, MAX_LEN - len(seq))
    return torch.tensor(seq, dtype=torch.long)

X = torch.stack([pad_sequence(seq) for seq in encoded_texts])
print(X.shape)

torch.Size([1422330, 100])


Let's create target: next tokens.

In [13]:
X_src = X[:, :-1].clone()
Y_tgt = X[:, 1:].clone()

In [14]:
MAX_LEN_TRAINING = X_src.size(1)

print(f"Original X shape: {X.shape}")
print(f"Input (X_src) shape: {X_src.shape}")
print(f"Target (Y_tgt) shape: {Y_tgt.shape}")

Original X shape: torch.Size([1422330, 100])
Input (X_src) shape: torch.Size([1422330, 99])
Target (Y_tgt) shape: torch.Size([1422330, 99])


# Model

In [15]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]
        return x

In [16]:
class CausalTransformerDecoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        emb_dim,
        pad_idx,
        n_heads=4,
        n_layers=2,
        dim_feedforward=256,
        dropout=0.1,
        max_len=5000,
    ):
        super().__init__()
        self.pad_idx = pad_idx
        self.emb_dim = emb_dim
        self.max_len = max_len

        # Word Embedding
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=emb_dim,
            padding_idx=pad_idx
        )

        # Positional Encoding
        self.pos_encoding = PositionalEncoding(d_model=emb_dim, max_len=max_len)

        # Transformer Decoder Layers Stack.
        decoder_layer = nn.TransformerEncoderLayer(
            d_model=emb_dim,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.transformer_decoder = nn.TransformerEncoder(
            decoder_layer,
            num_layers=n_layers
        )

        # Final Head: Output logits for the entire vocabulary at each position
        self.fc_out = nn.Linear(emb_dim, vocab_size)

        self.dropout = nn.Dropout(dropout)
    
    def _generate_causal_mask(self, size):
        mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
        return mask

    def forward(self, x):
        seq_len = x.size(1)

        # Embedding + Scaling
        x = self.embedding(x) #* math.sqrt(self.emb_dim)  # [B, T, D]

        # Positional Encoding
        x = self.pos_encoding(x)  # [B, T, D]
        x = self.dropout(x)

        # Causal Mask for Autoregressive Generation
        causal_mask = self._generate_causal_mask(seq_len).to(x.device)
        
        # Padding Mask
        src_key_padding_mask = (x.squeeze(0).sum(dim=-1) == 0) # [B, T] - if you re-pad, but based on your input 'X', this might be (x == self.pad_idx)

        # Transformer Decoder Stack
        decoder_output = self.transformer_decoder(
            x,
            mask=causal_mask,              # Causal Mask (T, T)
            src_key_padding_mask=(x.mean(dim=2) == self.embedding.padding_idx) # Better padding mask: [B, T] 
        ) # [B, T, D]

        # Final Logits: predict next token for every position
        logits = self.fc_out(decoder_output)  # [B, T, V]
        return logits

# Setup

In [17]:
# Model Hyperparameters
EMB_DIM = 256
N_HEADS = 8
N_LAYERS = 4
DIM_FEEDFORWARD = 1024
DROPOUT = 0.1

# Training Parameters
LEARNING_RATE = 3e-5
BATCH_SIZE = 64
N_EPOCHS = 10

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Training

Stuff data into dataset.

In [18]:
total_samples = X_src.size(0)
train_size = int(0.9 * total_samples)
val_size = total_samples - train_size

In [19]:
full_dataset = TensorDataset(X_src, Y_tgt)
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

In [20]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

Let's initialize model.

In [21]:
model = CausalTransformerDecoder(
    vocab_size=MAX_VOCAB,
    emb_dim=EMB_DIM,
    pad_idx=PAD_IDX,
    n_heads=N_HEADS,
    n_layers=N_LAYERS,
    dim_feedforward=DIM_FEEDFORWARD,
    dropout=DROPOUT,
    max_len=MAX_LEN_TRAINING + 1
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"Model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.")

Model initialized with 13,419,040 trainable parameters.


In [22]:
model.float()

CausalTransformerDecoder(
  (embedding): Embedding(20000, 256, padding_idx=0)
  (pos_encoding): PositionalEncoding()
  (transformer_decoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc_out): Linear(in_features=256, out_features=20000, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

Training loop:

In [23]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    epoch_iterator = tqdm(dataloader, desc="Training Batch", leave=False)
    
    for batch_idx, (src, tgt) in enumerate(epoch_iterator):
        src, tgt = src.to(device), tgt.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass: Get logits [B, T, V]
        logits = model(src)

        # Calculate Loss (Reshape for CrossEntropyLoss)
        logits = logits.view(-1, logits.size(-1))
        tgt = tgt.view(-1)
        
        loss = criterion(logits, tgt)

        # Backward pass and optimization
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 
        optimizer.step()

        current_loss = loss.item()
        total_loss += current_loss
        
        epoch_iterator.set_postfix(loss=f'{current_loss:.4f}')

    return total_loss / len(dataloader)

In [24]:
def evaluate_epoch(model, dataloader, criterion, device):
    model.eval() 
    total_loss = 0
    
    val_iterator = tqdm(dataloader, desc="Validation Batch", leave=False)
    
    with torch.no_grad():
        for src, tgt in val_iterator:
            src, tgt = src.to(device), tgt.to(device)

            # Forward pass
            logits = model(src)

            # Calculate Loss (Reshape for CrossEntropyLoss)
            logits = logits.view(-1, logits.size(-1))
            tgt = tgt.view(-1)
            
            loss = criterion(logits, tgt)

            total_loss += loss.item()
            val_iterator.set_postfix(val_loss=f'{loss.item():.4f}')

    return total_loss / len(dataloader)

In [25]:
#N_EPOCHS = 10 
N_EPOCHS = 2
print("\nStarting training...")

for epoch in range(N_EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{N_EPOCHS} ---")
    
    avg_train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Epoch {epoch+1} Training Complete. Average Training Loss: {avg_train_loss:.4f}")

    avg_val_loss = evaluate_epoch(model, val_loader, criterion, device)
    print(f"Epoch {epoch+1} Complete. Average Validation Loss: {avg_val_loss:.4f}")

print("\nTraining complete!")


Starting training...

--- Epoch 1/2 ---


Training Batch:   0%|          | 0/20002 [00:00<?, ?it/s]

Epoch 1 Training Complete. Average Training Loss: nan


Validation Batch:   0%|          | 0/2223 [00:00<?, ?it/s]

Epoch 1 Complete. Average Validation Loss: nan

--- Epoch 2/2 ---


Training Batch:   0%|          | 0/20002 [00:00<?, ?it/s]

Epoch 2 Training Complete. Average Training Loss: nan


Validation Batch:   0%|          | 0/2223 [00:00<?, ?it/s]

Epoch 2 Complete. Average Validation Loss: nan

Training complete!


# Inference

In [26]:
def indices_to_text(indices, itos, stop_token=None):
    words = []
    for idx in indices:
        word = itos[idx] 
        if word == stop_token:
            break
        words.append(word)
    return " ".join(words)

In [27]:
def generate_text_greedy(
    model,
    prompt_text,
    stoi,
    itos,
    max_new_tokens=50,
    max_seq_len=100,
    pad_idx=PAD_IDX,
    device=device
):
    """
    Generates text autoregressively using greedy decoding.

    Args:
        model: Trained CausalTransformerDecoder instance.
        prompt_text: The starting string for generation.
        stoi: Word-to-index mapping.
        itos: Index-to-word mapping.
        max_new_tokens: Maximum number of tokens to generate.
        max_seq_len: The maximum sequence length the model was trained with (MAX_LEN - 1).
        pad_idx: The index used for padding.
        device: The device to run on.
    """
    model.eval()
    
    # Prepare the initial prompt 
    prompt_tokens = [stoi.get(t, UNK_IDX) for t in word_tokenize(prompt_text)]
    
    # Convert prompt to a tensor [1, T_prompt]
    input_ids = torch.tensor(prompt_tokens, dtype=torch.long, device=device).unsqueeze(0)
    
    print(f"Starting generation from prompt: '{prompt_text}'")

    # Autoregressive Generation Loop
    for _ in range(max_new_tokens):
        current_len = input_ids.size(1)
        if current_len > max_seq_len:
            input_ids = input_ids[:, -max_seq_len:]
            current_len = max_seq_len

        # Forward Pass
        with torch.no_grad():
            output_logits = model(input_ids) 
        
        # Select the prediction for the *last* token
        logits_last_token = output_logits[:, -1, :] 
        
        # Greedy selection: take the index with the highest probability
        next_token_idx = torch.argmax(logits_last_token, dim=-1).unsqueeze(0)
        
        # Check for stop condition
        if next_token_idx.item() == pad_idx or current_len + 1 > max_seq_len:
            break
            
        # Append the new token to the sequence for the next step
        input_ids = torch.cat([input_ids, next_token_idx], dim=-1)

    # Final Conversion
    full_indices = input_ids[0].cpu()
    generated_indices = full_indices[len(prompt_tokens):] 
    
    # Decode the full sequence for printing
    generated_text = indices_to_text(generated_indices, itos, stop_token="[PAD]")
    
    return generated_text

In [28]:
MAX_SEQ_LEN_MODEL = 99
model.to(device)

CausalTransformerDecoder(
  (embedding): Embedding(20000, 256, padding_idx=0)
  (pos_encoding): PositionalEncoding()
  (transformer_decoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc_out): Linear(in_features=256, out_features=20000, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [29]:
prompt = "Імператор взяв свій меч"

generated_continuation = generate_text_greedy(
    model=model,
    prompt_text=prompt,
    stoi=stoi,
    itos=itos,
    max_new_tokens=30,
    max_seq_len=MAX_SEQ_LEN_MODEL,
    pad_idx=PAD_IDX,
    device=device
)

print("-" * 50)
print(f"Prompt: {prompt}")
print(f"Generated Text: {prompt} {generated_continuation}")
print("-" * 50)

Starting generation from prompt: 'Імператор взяв свій меч'
--------------------------------------------------
Prompt: Імператор взяв свій меч
Generated Text: Імператор взяв свій меч слухаєш зниження розлетілися ставку відчувалися зниження зниження візуального візуального задніх думкою задніх думкою задніх розклав розклав розклав розклав розклав расової задніх підтримувало залишки тканину задніх повернути рідній вільних задніх периметру
--------------------------------------------------
