In [1]:
import torch
from torch import nn, tensor, randn, optim
import pandas
# import nltk # See https://www.nltk.org/data.html
from nltk import tokenize
from torch.utils.data import Dataset, DataLoader
import random

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
# Токенизация

char_tokenize = lambda text: tokenize.simple.CharTokenizer().tokenize(text)
word_lower_tokenize = lambda text: [t.lower() for t in tokenize.WordPunctTokenizer().tokenize(text)]
word_tokenize = lambda text: tokenize.WordPunctTokenizer().tokenize(text)

def byte_pair_encode(data, merge_count, chars_to_ignore=None):
    token_set = set(char_tokenize(data))
    return byte_pair_encode_continue(data, token_set.copy(), merge_count, chars_to_ignore)

import threading

def _bpe_count_pairs(token, token_set, pair_counts, data):
    for B in token_set:
        pair = token + B
        if pair in token_set: continue
        pair_counts.append((pair, data.count(pair)))


default_chars_to_ignore = {' ', '.', ',', '\n'}

def byte_pair_encode_continue(data, token_set, merge_count, chars_to_ignore=None):
    # print('starts with ', token_set)
    token_set -= (chars_to_ignore or default_chars_to_ignore)
    while merge_count != 0:
        new_token_set = token_set.copy()
        pair_counts = []
        found_pairs = set()
        for A in token_set:
            for B in token_set:
                pair = A + B
                if pair in new_token_set: continue
                if pair in found_pairs: continue
                pair_counts.append((pair, data.count(pair)))
                found_pairs.add(pair)
                max_freq = 0
                most_frequent = []
        for v in pair_counts:
            if v[1] > max_freq:
                max_freq = v[1]
                most_frequent = [v[0]]
            elif v[1] == max_freq:
                most_frequent.append(v[0])
        for t in most_frequent:
            new_token_set.add(t)
            merge_count -= 1
            if merge_count == 0: break
        token_set = new_token_set
    return token_set | (chars_to_ignore or default_chars_to_ignore)

def bpe_tokenize(text):
    result = []
    words = text.split(' ')
    for w in words:
        # print('# tokenize word\t', w)
        while len(w) > 0:
            # print('#  step with\t', w)
            for i in range(len(w)):
                i = len(w) - i
                # print('##\t', w[:i])
                if w[:i] in token_to_id:
                    # print('# found ', w[:i])
                    result.append(w[:i])
                    w = w[i:]
                    # print(w, 'still')
                    break
        result.append(' ')
    return result

# T -- Time
get_x = lambda data, T, pos: data[pos : pos + T]
get_y = lambda data, T, pos: data[pos + 1 : pos + T + 1]

# B -- batch size
def get_batch(data, B, T):
    positions = torch.randint(len(data) - T, (B,))
    # positions = torch.randint(len(data), (B,))
    x = torch.stack([get_x(data, T, pos) for pos in positions])
    y = torch.stack([get_y(data, T, pos) for pos in positions]) 
    # x = torch.stack([torch.zeros(T, device=device) + get_x(data[pos], T, 0) for pos in positions])
    # y = torch.stack([torch.zeros(T, device=device) + get_y(data[pos], T, 0) for pos in positions])
    return x, y

eval_iters = 80

@torch.no_grad()
def estimate_loss():
    out = []
    model.eval()
    for data in [train_data, validation_data]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(data, batch_len, block_len)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out.append(losses.mean())
    model.train()
    return out

In [3]:
# Создание файла данных
with open("data/glados-portal2.csv", "r", encoding='utf-16') as f:
    strings = f.readlines()
with open("data/extension.csv", "r", encoding='utf-16') as f:
    strings += f.readlines()
random.shuffle(strings)
with open("data/data.csv", "w", encoding='utf-16') as f:
    f.writelines(strings)

In [None]:
# Вывести случайные строки оригинала
with open("data/glados-portal2.csv", "r", encoding='utf-16') as f:
    strings = f.readlines()
random.shuffle(strings)
for s in strings[:80]: print(s[:-1])

In [5]:
with open("data/data.csv", "r", encoding='utf-16') as f:
    text = f.read()

In [4]:
# Загрузить токены КПБ
import pickle
bpe_token_file = open(r'bpe_tokens.pkl', 'rb')
bpe_token_set = pickle.load(bpe_token_file)
bpe_token_file.close()
list(bpe_token_set)[:10]

['ri', 'O', 'es', 'ut', 'K', 'one', 'n', 'b', 'ex', ' ']

In [9]:
# Загрузить, дополнить и сохранить токены КПБ
import pickle
bpe_token_file = open(r'bpe_tokens.pkl', 'rb')
bpe_token_set = pickle.load(bpe_token_file)
bpe_token_file.close()
bpe_token_set_old = bpe_token_set
# bpe_token_set |= set(char_tokenize(text))
bpe_token_set = byte_pair_encode_continue(text, bpe_token_set, 100)
bpe_token_file = open(r'bpe_tokens.pkl', 'wb')
pickle.dump(bpe_token_set, bpe_token_file)
bpe_token_file.close()
print('+', bpe_token_set - bpe_token_set_old - default_chars_to_ignore)

+ set()


In [15]:
# Сгенерировать и сохранить **новый** набор токенов кодированием пар байтов *вместо имеющегося*
# import pickle
# bpe_token_file = open(r'bpe_tokens.pkl', 'wb')
# bpe_token_set = byte_pair_encode(text, 40)
# pickle.dump(bpe_token_set, bpe_token_file)
# bpe_token_file.close()

In [11]:
print(bpe_tokenize("100%!\n"))

['1', '0', '0', '%', '!', '\n', ' ']


In [10]:
# tokens = word_tokenize(text)
# tokens = list(bpe_token_set | set(word_tokenize(text)))
tokens = list(bpe_token_set)
vocab =  sorted(tokens)
vocab = ['<PAD>', '<UNK>'] + vocab
vocab_size = len(vocab)

token_to_id = {t:i for i,t in enumerate(vocab)}
id_to_token = {i:t for t,i in token_to_id.items()}

encode = lambda data: [token_to_id[t if t in token_to_id else '<UNK>'] for t in data]
decode = lambda data: ''.join([id_to_token[i] for i in data])
# decode = lambda data: ' '.join([id_to_token[i] for i in data])

torch.manual_seed(0)

tokenize_func = bpe_tokenize
# tokenize_func = word_tokenize
data = tensor(encode(tokenize_func(text)), device=device)
# data = [tensor(encode(tokenize_func(string)), device=device) for string in text.split('\n')]
train_data = data[:(len(data)//5)]
validation_data = data[(len(data)//5):]

In [12]:
block_len = 48  # T (time) -- Размер контекстного окна
channel_len = 192 # C -- размер эмбеддинга токена, вектора, хранящего его информацию
batch_len = 64  # B -- количество матриц (T, C) в пакете
head_count = 6
block_count = 6
dropout_rate = 0.2

class Head(nn.Module):
    def __init__(self, length):
        super().__init__()
        self.key   = nn.Linear(channel_len, length, bias=False)
        self.query = nn.Linear(channel_len, length, bias=False)
        self.value = nn.Linear(channel_len, length, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_len, block_len))) # https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * C**(-0.5)
        # wei = q @ k.transpose(-2,-1) * k.shape[-1]**(-0.5)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = nn.functional.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHead(nn.Module):
    def __init__(self, head_count, head_len):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_len) for _ in range(head_count)])
        self.proj = nn.Linear(head_len * head_count, channel_len)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(self.proj(out))
        return out

class LinearReLU(nn.Module):
    def __init__(self, channel_len):
        super().__init__()
        self.nn = nn.Sequential(
            nn.Linear(channel_len, 4 * channel_len),
            nn.ReLU(),
            nn.Linear(4 * channel_len, channel_len),
            nn.Dropout())
    def forward(self, x):
        return self.nn(x)

class Block(nn.Module):
    def __init__(self, channel_len, head_count):
        super().__init__()
        head_len = channel_len // head_count
        self.sa = MultiHead(head_count, head_len)
        self.ffwd = LinearReLU(channel_len)
        self.ln1 = nn.LayerNorm(channel_len)
        self.ln2 = nn.LayerNorm(channel_len)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class TransformerLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, channel_len)
        self.position_embedding_table = nn.Embedding(block_len, channel_len)
        self.blocks = nn.Sequential(*[Block(channel_len, head_count) for _ in range(block_count)])
        self.layer_norm = nn.LayerNorm(channel_len)
        self.head = nn.Linear(channel_len, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        x = self.token_embedding_table(idx) + self.position_embedding_table(torch.arange(T, device=device))
        x = self.blocks(x)
        logits = self.head(x)
        if targets is None: return logits

        B, T, C = logits.shape
        logits = logits.view(B * T, C)
        targets = targets.view(B * T)
        loss = nn.functional.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_end = idx[:, -block_len:]
            logits = self(idx_end)
            logits = logits[:, -1, :] # Берём последний токен в каждом слое порции: (B, T, C) -> (B, C)
            probs = nn.functional.softmax(logits, dim=-1) # Оценки в вероятности
            # Выбор числа от 0 до channel_len по данным вероятностям для каждого слоя порции
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1) # Добавляем в конец порции результаты по измерению T
        return idx

model = TransformerLM()
model.to(device)


TransformerLM(
  (token_embedding_table): Embedding(320, 192)
  (position_embedding_table): Embedding(48, 192)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHead(
        (heads): ModuleList(
          (0-5): 6 x Head(
            (key): Linear(in_features=192, out_features=32, bias=False)
            (query): Linear(in_features=192, out_features=32, bias=False)
            (value): Linear(in_features=192, out_features=32, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (ffwd): LinearReLU(
        (nn): Sequential(
          (0): Linear(in_features=192, out_features=768, bias=True)
          (1): ReLU()
          (2): Linear(in_features=768, out_features=192, bias=True)
          (3): Dropout(p=0.5, inplace=False)
        )
      )
      (ln1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (ln

In [13]:
# Обучение

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002)

def train(model, epochs=10):
    for epoch in range(epochs):
        if epoch % 500 == 0 and epoch != 0 or epoch == epochs - 1:
            losses = estimate_loss()
            print(f"Epoch {epoch}/{epochs}\ttrain loss {losses[0]:.4f}, validation loss {losses[1]:.4f}")

        xb, yb = get_batch(train_data, batch_len, block_len)
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        if epoch % 100 == 0 and epoch % 500 != 0:
            print(f"Epoch {epoch}/{epochs}")

In [None]:
# Загрузка состояния модели
checkpoint = torch.load('text_transformer.pth')
model.load_state_dict(checkpoint['model_state'])
token_to_id = checkpoint['token_to_id']
id_to_token = checkpoint['id_to_token']

In [23]:
train(model, 200)

Epoch 100/200
Epoch 199/200	train loss 0.7159, validation loss 3.4884


In [24]:
encoded_input = tensor([encode(tokenize_func("You are"))], device=device)
print("Сгенерированный текст:", decode(model.generate(encoded_input, 80)[0].tolist()))

Сгенерированный текст: You are stilling for  Reulty are operation, we of test wors requires a surprtemance, my are worn you aslet's difushed a man's stertle prov


In [25]:
# Сохранение модели
torch.save({
    'model_state': model.state_dict(),
    'token_to_id': token_to_id,
    'id_to_token': id_to_token
}, 'text_transformer.pth')