In [1]:
import math
import pytorch_lightning as pl
import torch
torch.manual_seed(42)
import torch.nn as nn
from torch.nn import Linear
from torch.nn import functional as F
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from tokenizers import Tokenizer
from torch.nn.utils.rnn import pad_sequence
import string
import pdb
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
from torch.utils.data import DataLoader, WeightedRandomSampler
from functools import partial

2023-12-28 18:00:27.619593: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-12-28 18:00:27.650609: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-28 18:00:27.650634: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-28 18:00:27.651394: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-28 18:00:27.656189: I tensorflow/core/platform/cpu_feature_guar

In [2]:
MAX_LEN = 256

In [3]:
def gen_trg_mask(length, device):
    return torch.triu(
        torch.ones(length, length, device=device) * float("-inf"), diagonal=1
    )

def create_padding_mask(tensor, pad_idx):
    padding_mask = (tensor == pad_idx).transpose(0, 1)

    return padding_mask


def masked_accuracy(y_true: torch.Tensor, y_pred: torch.Tensor, pad_idx):
    mask = y_true != pad_idx
    y_true = torch.masked_select(y_true, mask)
    y_pred = torch.masked_select(y_pred, mask)
    acc = (y_true == y_pred).double().mean()
    return acc


class PositionalEncoding(nn.Module):
    #  https://pytorch.org/tutorials/beginner/transformer_tutorial.html

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 256):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """

        x = x + self.pe[: x.size(0)]

        return self.dropout(x)


class TokenEmbedding(nn.Module):
    #  https://pytorch.org/tutorials/beginner/translation_transformer.html
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: torch.Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


class Seq2Seq(pl.LightningModule):
    def __init__(
        self,
        out_vocab_size,
        pad_idx,
        tokenizer,
        channels=256,
        dropout=0.1
    ):
        super().__init__()
        self.pad_idx = pad_idx
        self.dropout = dropout
        self.out_vocab_size = out_vocab_size

        self.embeddings = TokenEmbedding(
            vocab_size=self.out_vocab_size, emb_size=channels
        )

        self.pos_encoder = PositionalEncoding(d_model=channels, dropout=dropout)

        self.transformer = torch.nn.Transformer(
            d_model=channels,
            nhead=4,
            num_encoder_layers=6,
            num_decoder_layers=6,
            dim_feedforward=1024,
            dropout=dropout,
        )

        self.linear = Linear(channels, out_vocab_size)

        self.do = nn.Dropout(p=self.dropout)
        self.tokenizer = tokenizer

    def init_weights(self) -> None:
        init_range = 0.1
        self.embeddings.weight.data.uniform_(-init_range, init_range)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-init_range, init_range)

    def encode_src(self, src):
        #pdb.set_trace()
        src = src.permute(1, 0)
        src_pad_mask = create_padding_mask(src, self.pad_idx)
        src = self.embeddings(src)
        src = self.pos_encoder(src)
        src = self.transformer.encoder(src, src_key_padding_mask=src_pad_mask)
        #src = self.pos_encoder(src)
        return src

    def decode_trg(self, trg, memory):
        trg = trg.permute(1, 0)
        out_sequence_len, batch_size = trg.size(0), trg.size(1)
        trg_pad_mask = create_padding_mask(trg, self.pad_idx)
        trg = self.embeddings(trg)
        trg = self.pos_encoder(trg)
        trg_mask = gen_trg_mask(out_sequence_len, trg.device)
        out = self.transformer.decoder(
            tgt=trg, memory=memory, tgt_mask=trg_mask, tgt_key_padding_mask=trg_pad_mask
        )
        out = out.permute(1, 0, 2)
        out = self.linear(out)
        return out

    def forward(self, x):
        #pdb.set_trace()
        src, trg = x
        src = self.encode_src(src)
        out = self.decode_trg(trg=trg, memory=src)
        return out

    def training_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, name="train")

    def on_validation_epoch_start(self):
        self.acc_sum = 0
        self.acc_num = 0
    
    def validation_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, name="valid")

    def on_validation_epoch_end(self):
        avg_acc = (self.acc_sum / self.acc_num).item()
        self.log("val_avg_acc", avg_acc)
        print("Epoch, accuracy:", self.current_epoch, round(avg_acc, 4))
    
    def test_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, name="test")

    def ids_to_str(self, ids):        
        is_end = np.nonzero(ids == tokenizer.token_to_id("[SEP]"))[0]
        if len(is_end) > 0:
            end = is_end[0]
            ids = ids[:end]     
            
        return self.tokenizer.decode(ids)
    
 
    def generate(self, input_str):
        #pdb.set_trace()
        input_tokens = torch.tensor(self.tokenizer.encode(input_str).ids, dtype=torch.long, device=self.device).unsqueeze(0)
        src = self.encode_src(input_tokens)

        outputs = torch.zeros((MAX_LEN, src.size(1)), dtype=torch.long, device=src.device)
        outputs[0] = self.tokenizer.token_to_id("[CLS]")

        for i in range(1, MAX_LEN):
            out = self.decode_trg(outputs[:i].T, memory=src)
            _, next_token = torch.max(out, 2)
            next_token = next_token[0,-1]
            outputs[i] = next_token
            if next_token == self.tokenizer.token_to_id("[SEP]"):
                break

        output_ids = list(outputs.squeeze().cpu().numpy())
        return self.tokenizer.decode(output_ids)

    def _step(self, batch, batch_idx, name="train"):
        src, trg = batch
        #pdb.set_trace()
        trg_in, trg_out = trg[:, :-1], trg[:, 1:]
        y_hat_orig = self((src, trg_in))
        y_hat = y_hat_orig.view(-1, y_hat_orig.size(2))
        y = trg_out.contiguous().view(-1)
        
        if batch_idx==0 and name=="valid":
            tb_logger = None
            for logger in self.trainer.loggers:
                if isinstance(logger, TensorBoardLogger):
                    tb_logger = logger.experiment
                    break
            
            _, this_pred = torch.max(y_hat_orig, 2)
            for i in range(16):                
                text_truth = self.ids_to_str(trg_out[i].cpu().numpy())
                text_corrupted = self.ids_to_str(src[i].cpu().numpy())
                text_pred = self.ids_to_str(this_pred[i].cpu().numpy())
                
                output = "Corrupted: {}<br>Corrected: {}<br>GrndTruth: {}".format(text_corrupted, text_pred, text_truth)
                tb_logger.add_text(f'Validation #{i}, target: {text_truth}', output, self.global_step)
                
        
        loss = F.cross_entropy(y_hat, y, ignore_index=self.pad_idx)
        _, predicted = torch.max(y_hat, 1)
        acc = masked_accuracy(y, predicted, pad_idx=self.pad_idx)

        #pdb.set_trace()
        
        self.log(f"{name}_loss", loss)
        self.log(f"{name}_acc", acc)
        
        self.acc_sum += acc
        self.acc_num += 1

        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=3e-5)

In [4]:
def apply_perturbation_to_text(text):
    #substitute e for c or c for e
    #t for r
    #m for ni, m for ui, n for ii, n for u, n for ii, u for ii, ium for mm, f for s
    #cut out the middles of words, not whole words
    
    #There's a chance it'll be perfect
    if np.random.uniform() < 0.1:
        return text
    
    all_characters = string.ascii_lowercase + ' '
    words = text.split()
    if np.random.uniform() < 0.25 and len(words) > 1:
        #Get rid of word
        idx = np.random.randint(len(words))
        words.pop(idx)                
    
    #Change word endings
    for i in range(len(words)):
        if words[i].endswith("ium") and np.random.uniform() < 0.2:
            words[i] = words[i][:-3] + "ius"
        elif words[i].endswith("ius") and np.random.uniform() < 0.2:
            words[i] = words[i][:-3] + "ium"
        elif words[i].endswith("us") and np.random.uniform() < 0.2:
            words[i] = words[i][:-2] + "i"     
        elif words[i].endswith("nem") and np.random.uniform() < 0.2:
            words[i] = words[i][:-3] + "nus"
    
    curr_str = " ".join(words)
    new_str = ""
    
    for i in range(len(curr_str)):
        if np.random.uniform() < 0.02:
            new_str += np.random.choice(list(all_characters))
        else:
            new_str += curr_str[i]
            
        #There's a chance for a new character to pop up
        if np.random.uniform() < 0.02:
            new_str += np.random.choice(list(all_characters))
            
    return new_str

def generate_batch(data_batch, pad_idx):
    src, trg = [], []
    for (src_item, trg_item) in data_batch:
        src.append(src_item)
        trg.append(trg_item)
    src = pad_sequence(src, padding_value=pad_idx, batch_first=True)
    trg = pad_sequence(trg, padding_value=pad_idx, batch_first=True)
    return src, trg



class Dataset(torch.utils.data.Dataset):
    def __init__(self, samples, hf_tokenizer):
        self.samples = samples
        self.n_samples = len(self.samples)
        self.hf_tokenizer = hf_tokenizer

    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        text = self.samples[idx]
        perturbed_text = apply_perturbation_to_text(text)

        x = self.hf_tokenizer.encode(perturbed_text).ids
        y = self.hf_tokenizer.encode(text).ids
        #print(text, len(x), len(y))
        #assert(len(x) < MAX_LEN)
        #assert(len(y) < MAX_LEN)
        x = torch.tensor(x, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.long)

        return x, y


In [5]:
def split_text(text):
    words = text.split()
    i = 0
    lines = []
    while True:
        num_words = np.random.randint(2, 30)
        lines.append(" ".join(words[i:i+num_words]))
        i += num_words
        if i >= len(words):
            break
            
    return lines

lines = []
weights = []

for filename in ["decretum.txt", "corpus_thomisticum.txt", "misc_medieval.txt"]:
    with open(filename) as f:
        data = f.read()
        new_lines = split_text(data)
        lines += new_lines
        weights += len(new_lines) * [1./len(new_lines)]

with open("cases_training_lines.txt") as f:
    new_lines = f.readlines()
    lines += new_lines
    weights += len(new_lines) * [1./len(new_lines)]
    
lines = np.array(lines)
weights = np.array(weights)
    
tokenizer = Tokenizer.from_file("latin_tokenizer.json")    
train_inds, val_inds = train_test_split(np.arange(len(lines)), test_size=0.1, random_state=1337)

train_data = Dataset(samples=lines[train_inds], hf_tokenizer=tokenizer)
val_data = Dataset(samples=lines[val_inds], hf_tokenizer=tokenizer)

In [6]:
def generate_batch(data_batch, pad_idx):
    src, trg = [], []
    for (src_item, trg_item) in data_batch:
        src.append(src_item)
        trg.append(trg_item)
    src = pad_sequence(src, padding_value=pad_idx, batch_first=True)
    trg = pad_sequence(trg, padding_value=pad_idx, batch_first=True)
    return src, trg

train_loader = DataLoader(
    train_data,
    sampler=WeightedRandomSampler(weights[train_inds], num_samples=len(train_inds), replacement=True), 
    batch_size=128,
    num_workers=4,
    collate_fn=partial(generate_batch, pad_idx=tokenizer.token_to_id("[PAD]")),
)
val_loader = DataLoader(
    val_data,
    sampler=WeightedRandomSampler(weights[val_inds], num_samples=len(val_inds), replacement=True), 
    batch_size=128,
    num_workers=4,
    collate_fn=partial(generate_batch, pad_idx=tokenizer.token_to_id("[PAD]")),
)

print("len(train_data)", len(train_data))
print("len(val_data)", len(val_data))

len(train_data) 705769
len(val_data) 78419


In [7]:
model = Seq2Seq.load_from_checkpoint("checker-v20.ckpt", out_vocab_size=tokenizer.get_vocab_size(),
    pad_idx=tokenizer.token_to_id("[PAD]"),
    tokenizer=tokenizer,
    dropout=0.1)

# model = Seq2Seq(
#     out_vocab_size=tokenizer.get_vocab_size(),
#     pad_idx=tokenizer.token_to_id("[PAD]"),
#     tokenizer=tokenizer,
#     dropout=0.1
# )
#model.load_state_dict(torch.load("autocorrect.pt"))

checkpoint_callback = ModelCheckpoint(
    monitor="valid_acc", mode="max", dirpath="./", filename="checker"
)

logger = TensorBoardLogger(
        save_dir="./",
        name="autocorrect_logs",
    )

trainer = pl.Trainer(
    max_epochs=2000,
    logger=logger,
    callbacks=[checkpoint_callback],
)

trainer.fit(model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params
---------------------------------------------------
0 | embeddings  | TokenEmbedding     | 1.3 M 
1 | pos_encoder | PositionalEncoding | 0     
2 | transformer | Transformer        | 11.1 M
3 | linear      | Linear             | 1.3 M 
4 | do          | Dropout            | 0     
------------------------------------

Sanity Checking: 0it [00:00, ?it/s]

Epoch, accuracy: 0 0.9121




Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
#torch.save(model.state_dict(), "autocorrect_special_tokens.pt")

In [20]:
# #Test the model we just trained

# model = Seq2Seq(
#     out_vocab_size=tokenizer.get_vocab_size(),
#     pad_idx=tokenizer.token_to_id("[PAD]"),
#     tokenizer=tokenizer,
#     dropout=0.1
# )
# model.load_state_dict(torch.load("autocorrect_special_tokens.pt"))
# #model.generate("dignitatm divinae Scripturae pertinet, ut sub una littera multos sensus contineat, ut sic et diversis intellectibi rhominum conveniat")

# input_str = "probatorem et clericum conuictum et eunt ab eius duas vaccas et sex bonuculos pro xxx. S.. vnde idem Galfridus"
# input_tokens = torch.tensor(model.tokenizer.encode(input_str).ids, dtype=torch.long).unsqueeze(0)
# src = model.encode_src(input_tokens)

# outputs = torch.zeros((MAX_LEN, src.size(1)), dtype=torch.long, device=src.device)
# outputs[0] = model.tokenizer.token_to_id("[CLS]")

# for i in range(1, MAX_LEN):
#     out = model.decode_trg(outputs[:i].T, memory=src)
#     _, next_token = torch.max(out, 2)
#     next_token = next_token[0,-1]
#     outputs[i] = next_token
#     if next_token == model.tokenizer.token_to_id("[SEP]"):
#         break
        
# output_ids = list(outputs.squeeze().numpy())
# print(model.tokenizer.decode(output_ids))
# print(model.tokenizer.decode(output_ids[:input_tokens.size(1)-1]))

In [56]:
#Test the model
model = Seq2Seq.load_from_checkpoint("checker-v20.ckpt", out_vocab_size=tokenizer.get_vocab_size(),
    pad_idx=tokenizer.token_to_id("[PAD]"),
    tokenizer=tokenizer,
    dropout=0.1)
model.load_state_dict(torch.load("autocorrect_special_tokens.pt"))
model.to("cuda:0")
#model.generate("Sidure fecit capitus dominus feodi mediatur et inmedatur essendi hic etc. sei etc. Et iuratores veniunt qui dicunt super sacr in")

model.generate("iurlepartcipant, nonne magis nos? sedea non sumus usi hac potestate, sed omnia tolleramus, ne quod inpedntum demus evangelio Christi.")

'iure participant, nonne magis nos? sed ea non sumus usi hac potestate, sed omnia tolleramus, ne quod inpentum demus evangelio Christi.'