In [2]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt

In [5]:
mishnah_path = r'/Users/lee/Judaism/Sefaria_Fun/sefaria_txt/Mishnah/'
gemara_path = r'/Users/lee/Judaism/Sefaria_Fun/sefaria_txt/Talmud/Bavli/'

## Clean Gemara.

In [6]:
example_mishnah = open(mishnah_path + 'Seder Zeraim/Mishnah Berakhot/Hebrew/merged.txt', 'rb').read().decode(encoding='utf-8')
example_gemara = open(gemara_path + 'Seder Nezikin/Sanhedrin/English/merged.txt', 'rb').read().decode(encoding='utf-8')

In [3]:
#print(example_gemara)

In [7]:
headerless_gemara = example_gemara[example_gemara.find('\n\n\n'):][2:]

In [8]:
dafless_gemara = re.sub(r'\n\nDaf .*\n', '\n', headerless_gemara)[2:]
dafless_gemara = re.sub(r'\n\n\n\n', '', dafless_gemara)

In [9]:
#print(dafless_gemara)

In [10]:
clean_gemara = re.sub(r'<.*?>', '', dafless_gemara)

In [11]:
#print(clean_gemara)

## Remove extra lines.

In [12]:
english_gemara = re.sub('\n\n', '\n', clean_gemara)

## Iterate through all files.

In [12]:
directories = [x[0] for x in os.walk(gemara_path)]

In [13]:
english = [i for i in directories if 'English' in i and not 'Commentary' in i] 
hebrew = [i for i in directories if 'Hebrew' in i and not 'Commentary' in i] 

In [14]:
for dir in english:
    gemara = open(dir + '/merged.txt', 'rb').read().decode(encoding='utf-8')
    gemara = gemara[gemara.find('\n\n\n\n'):][2:]
    gemara = re.sub(r'\n\nDaf .*\n', '\n', gemara)[2:]
    gemara = re.sub(r'\n\n\n\n', '', gemara)
    gemara = re.sub(r'<.*?>', '', gemara)
    gemara = re.sub('\n\n', '\n', gemara)
    gemara += '\n\n'
    
    text_file = open(dir + '/clean.txt', 'wt')
    n = text_file.write(gemara)
    text_file.close()

In [15]:
for dir in hebrew:
    gemara = open(dir + '/merged.txt', 'rb').read().decode(encoding='utf-8')
    gemara = gemara[gemara.find('\n\n\n\n'):][2:]
    gemara = re.sub(r'\n\nDaf .*\n', '\n', gemara)[2:]
    gemara = re.sub(r'\n\n\n\n', '', gemara)
    gemara = re.sub(r'<.*?>', '', gemara)
    gemara = re.sub('\n\n', '\n', gemara)
    gemara += '\n\n'
    
    text_file = open(dir + '/clean.txt', 'wt')
    n = text_file.write(gemara)
    text_file.close()

In [16]:
directories = [x[0] for x in os.walk(gemara_path)]
english = [i for i in directories if 'English' in i and not 'Commentary' in i] 
hebrew = [i for i in directories if 'Hebrew' in i and not 'Commentary' in i] 

gemara_english = ''
gemara_hebrew = ''

for direct in english:
    gemara = open(direct + '/clean.txt', 'rb').read().decode(encoding='utf-8')
    gemara_english += gemara

for direct in hebrew:
    gemara = open(direct + '/clean.txt', 'rb').read().decode(encoding='utf-8')
    gemara_hebrew += gemara

In [17]:
with open("gemara_english.txt", "w") as text_file:
    text_file.write(gemara_english)

with open("gemara_hebramaic.txt", "w") as text_file:
    text_file.write(gemara_hebrew)

## Tokenize and prepare for training.

In [None]:
class TalmudTokenizer:
    def __init__(self, vocab_size: int = 16000):
        self.vocab_size = vocab_size
        self.vocab: Dict[str, int] = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3}
        self.inverse_vocab: Dict[int, str] = {v: k for k, v in self.vocab.items()}
        self.merges: Dict[Tuple[str, str], str] = {}
        self.space_prefix = 'Ä '

    def _get_stats(self, vocab):
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[symbols[i], symbols[i + 1]] += freq
        return pairs

    def _merge_vocab(self, pair, v_in):
        v_out = {}
        bigram = re.escape(' '.join(pair))
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        for word in v_in:
            w_out = p.sub(''.join(pair), word)
            v_out[w_out] = v_in[word]
        return v_out

    def train(self, text: str):
        print("Starting tokenizer training...")
        
        # Preprocess text to add space prefix, including for the first word
        words = [self.space_prefix + word for word in text.split()]
        
        # Also add non-prefixed versions of words to the vocabulary
        non_prefixed_words = text.split()
        
        # Initialize vocab with character tokens
        chars = set(''.join(words + non_prefixed_words))
        for char in chars:
            if char not in self.vocab:
                self.vocab[char] = len(self.vocab)
                self.inverse_vocab[len(self.vocab) - 1] = char

        print(f"Initial vocabulary size: {len(self.vocab)}")
        
        # Convert words to space-separated character sequences
        vocab = Counter(' '.join(word) for word in words)
        vocab.update(' '.join(word) for word in non_prefixed_words)
        
        num_merges = self.vocab_size - len(self.vocab)
        for i in range(num_merges):
            pairs = self._get_stats(vocab)
            if not pairs:
                print(f"No more pairs to merge after {i} iterations")
                break
            
            best = max(pairs, key=pairs.get)
            vocab = self._merge_vocab(best, vocab)
            self.merges[best] = ''.join(best)
            new_token = ''.join(best)
            
            if new_token not in self.vocab:
                self.vocab[new_token] = len(self.vocab)
                self.inverse_vocab[len(self.vocab) - 1] = new_token
            
            if len(self.vocab) >= self.vocab_size:
                print(f"Reached target vocabulary size after {i+1} iterations")
                break
            
            if i % 100 == 0:
                print(f"Completed {i} merges. Current vocab size: {len(self.vocab)}")

        print(f"Final vocabulary size: {len(self.vocab)}")
        print(f"Number of merges: {len(self.merges)}")

    def _tokenize_word(self, word: str) -> List[str]:
        if word in self.vocab:
            return [word]
        
        word = ' '.join(word)
        tokens = []
        while len(word) > 0:
            subword = word
            while len(subword) > 0:
                if subword in self.vocab:
                    tokens.append(subword)
                    word = word[len(subword):].lstrip()
                    break
                subword = subword[:-1]
            if len(subword) == 0:
                tokens.append(word[0])
                word = word[1:].lstrip()
        return tokens

    def tokenize(self, text: str) -> List[int]:
        words = text.split()
        tokens = []
        for i, word in enumerate(words):
            if i == 0 or word.startswith(self.space_prefix):
                tokens.extend(self._tokenize_word(word))
            else:
                tokens.extend(self._tokenize_word(self.space_prefix + word))
        return [self.vocab.get(token, self.vocab["<UNK>"]) for token in tokens]

    def decode(self, token_ids: List[int]) -> str:
        tokens = [self.inverse_vocab.get(id, "<UNK>") for id in token_ids]
        text = ''.join(tokens).replace(self.space_prefix, ' ')
        return text.strip()

    def save(self, path: str):
        os.makedirs(path, exist_ok=True)
        with open(os.path.join(path, 'vocab.json'), 'w', encoding='utf-8') as f:
            json.dump(self.vocab, f, ensure_ascii=False, indent=2)
        with open(os.path.join(path, 'merges.json'), 'w', encoding='utf-8') as f:
            json.dump({' '.join(k): v for k, v in self.merges.items()}, f, ensure_ascii=False, indent=2)
        with open(os.path.join(path, 'config.json'), 'w', encoding='utf-8') as f:
            json.dump({'vocab_size': self.vocab_size, 'space_prefix': self.space_prefix}, f, indent=2)

    @classmethod
    def load(cls, path: str):
        with open(os.path.join(path, 'config.json'), 'r', encoding='utf-8') as f:
            config = json.load(f)
        tokenizer = cls(vocab_size=config['vocab_size'])
        tokenizer.space_prefix = config['space_prefix']
        
        with open(os.path.join(path, 'vocab.json'), 'r', encoding='utf-8') as f:
            tokenizer.vocab = json.load(f)
        tokenizer.inverse_vocab = {int(v): k for k, v in tokenizer.vocab.items()}
        
        with open(os.path.join(path, 'merges.json'), 'r', encoding='utf-8') as f:
            merges = json.load(f)
            tokenizer.merges = {tuple(k.split()): v for k, v in merges.items()}
        
        return tokenizer

In [None]:
# Train and save the tokenizer.
tokenizer = TalmudTokenizer(vocab_size=16000)
tokenizer.train(gemara)
tokenizer.save("talmud_tokenizer")
print("Tokenizer saved.")

# Load the saved tokenizer
loaded_tokenizer = TalmudTokenizer.load("talmud_tokenizer")
print("Tokenizer loaded.")

# Test the loaded tokenizer
test_sentence = "Rav Pappa said to Rabbi Akiva, from where do we learn about sandwiches?"
encoded = loaded_tokenizer.tokenize(test_sentence)
decoded = loaded_tokenizer.decode(encoded)

print(f"\nTest sentence: {test_sentence}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")

# Verify that the loaded tokenizer produces the same results as the original
original_encoded = loaded_tokenizer.tokenize(test_sentence)
print(f"\nOriginal encoded: {original_encoded}")
print(f"Loaded tokenizer encoded: {encoded}")
print(f"Encodings match: {original_encoded == encoded}")

# Print some statistics
print(f"\nVocabulary size: {len(loaded_tokenizer.vocab)}")
print(f"Number of merges: {len(loaded_tokenizer.merges)}")

In [None]:
class SequenceDataset(Dataset):
    def __init__(self, tokens, sequence_length):
        self.tokens = tokens
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.tokens) - self.sequence_length

    def __getitem__(self, idx):
        chunk = self.tokens[idx:idx + self.sequence_length + 1]
        return t.tensor(chunk[:-1], dtype=t.long), t.tensor(chunk[1:], dtype=t.long)

def prepare_data_for_training(tokens, sequence_length, batch_size, val_split=0.1):
    dataset = SequenceDataset(tokens, sequence_length)
    
    # Split into train and validation sets
    val_size = int(val_split * len(dataset))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader