In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from einops import rearrange, repeat
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class xLSTMLayer(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size        
        self.input_linear = nn.Linear(input_size, hidden_size * 4)  # для i, f, o, g        
        nn.init.orthogonal_(self.input_linear.weight)
        nn.init.zeros_(self.input_linear.bias)
        
    def forward(self, x, hidden_states):
        h_prev, c_prev = hidden_states
        gates = self.input_linear(x)
        i, f, o, g = gates.chunk(4, dim=-1)        
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        o = torch.sigmoid(o)
        g = torch.tanh(g)        
        c = f * c_prev.unsqueeze(1) + i * g        
        h = o * torch.tanh(c)
        
        return h, (h[:, -1, :], c[:, -1, :])

In [3]:
class PoetryxLSTM(nn.Module):
    def __init__(self, vocab_size, hidden_size=128, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.layers = nn.ModuleList([
            xLSTMLayer(
                input_size=hidden_size, 
                hidden_size=hidden_size
            ) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(hidden_size)
        self.head = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
    def init_hidden(self, batch_size, device):
        return (torch.zeros(batch_size, self.hidden_size).to(device),
                torch.zeros(batch_size, self.hidden_size).to(device))
        
    def forward(self, x, hidden_states=None):
        batch_size, seq_len = x.size()
        device = x.device
        
        if hidden_states is None:
            hidden_states = [self.init_hidden(batch_size, device) for _ in range(self.num_layers)]
        
        x = self.embedding(x)
        
        new_hidden_states = []
        for i, layer in enumerate(self.layers):
            h_prev, c_prev = hidden_states[i]
            x, (h_new, c_new) = layer(x, (h_prev, c_prev))
            new_hidden_states.append((h_new, c_new))
            
        x = self.norm(x)
        logits = self.head(x)
        
        return logits, new_hidden_states
    
    def generate(self, tokenizer, prompt, max_length=30, temperature=0.7, device='cpu'):
        self.eval()
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
        hidden = None
        
        for _ in range(max_length):
            with torch.no_grad():
                logits, hidden = self(input_ids[:, -1:], hidden)
            
            logits = logits[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            input_ids = torch.cat([input_ids, next_token], dim=-1)
            
            if next_token.item() == tokenizer.eos_token_id:
                break
        
        return tokenizer.decode(input_ids[0], skip_special_tokens=True)

In [4]:
class PoetryDataset(Dataset):
    def __init__(self, filepath, tokenizer, max_length=64):
        self.df = pd.read_csv(filepath)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.df = self.df.dropna(subset=['text'])
        self.texts = self.df['text'].tolist()
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        encoded = self.tokenizer.encode_plus(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return encoded['input_ids'].squeeze(0)

tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
tokenizer.add_special_tokens({'pad_token': '[PAD]', 'eos_token': '</s>'})

# Загрузка данных
dataset = PoetryDataset("/home/lad1chka/russianPoetryWithTheme.csv", tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = tokenizer.vocab_size + len(tokenizer.added_tokens_encoder)
model = PoetryxLSTM(vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
epochs = 5

def train(model, dataloader, epochs):
    model.train()
    losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
        hidden = None
        
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
            input_ids = batch.to(device)
            batch_size = input_ids.size(0)
            hidden = [model.init_hidden(batch_size, device) for _ in range(model.num_layers)]
            optimizer.zero_grad()
            inputs = input_ids[:, :-1]
            targets = input_ids[:, 1:]
            logits, hidden = model(inputs, hidden)
            loss = criterion(
                logits.view(-1, logits.size(-1)), 
                targets.contiguous().view(-1)
            )
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(dataloader)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
        
        # Пример генерации
        sample = model.generate(tokenizer, "Белая береза", device=device)
        print(f"Пример генерации: {sample}")
    
    return losses

loss_history = train(model, dataloader, epochs)

Epoch 1: 100%|████████████████████████████████| 522/522 [28:47<00:00,  3.31s/it]


Epoch 1, Loss: 8.3141
Пример генерации: Белая береза. С..... И. Р В. За. в з. И, В и вм о не ки, се


Epoch 2: 100%|████████████████████████████████| 522/522 [28:17<00:00,  3.25s/it]


Epoch 2, Loss: 6.3885
Пример генерации: Белая береза весе. - Б Я востся - во, я, не вту.. И не в бтя, на гл И


Epoch 3: 100%|████████████████████████████████| 522/522 [28:22<00:00,  3.26s/it]


Epoch 3, Loss: 6.0477
Пример генерации: Белая березачью и д, у. Все. В. Не поэт и ст цвет воз свидй стрной - и сл увык


Epoch 4: 100%|████████████████████████████████| 522/522 [28:19<00:00,  3.26s/it]


Epoch 4, Loss: 5.7812
Пример генерации: Белая береза. Пчий... И Он м, в у, с по умс в перкаллаья!ки


Epoch 5: 100%|████████████████████████████████| 522/522 [29:07<00:00,  3.35s/it]


Epoch 5, Loss: 5.5831
Пример генерации: Белая береза нож, м с м на ка собе мне,! Не во мне со жить, на в во не тре.. Н
