In [709]:
import torch
from torch.nn import functional as F
from matplotlib import pyplot as plt

## Load data

In [None]:
!wget https://raw.githubusercontent.com/ncarkaci/TDKDictionaryCrawler/master/ortak_kelimeler.txt

In [710]:
with open("ortak_kelimeler.txt") as f:
    text = f.read()

In [711]:
words = text.split("\n")

In [712]:
long_text = text.replace("\n","")
ids_list = list(set(long_text))
len(ids_list)

29

In [713]:
ordinal_char_pairs = sorted([(ord(c),c) for c in ids_list])
ordinal_char_pairs

[(97, 'a'),
 (98, 'b'),
 (99, 'c'),
 (100, 'd'),
 (101, 'e'),
 (102, 'f'),
 (103, 'g'),
 (104, 'h'),
 (105, 'i'),
 (106, 'j'),
 (107, 'k'),
 (108, 'l'),
 (109, 'm'),
 (110, 'n'),
 (111, 'o'),
 (112, 'p'),
 (114, 'r'),
 (115, 's'),
 (116, 't'),
 (117, 'u'),
 (118, 'v'),
 (121, 'y'),
 (122, 'z'),
 (231, 'ç'),
 (246, 'ö'),
 (252, 'ü'),
 (287, 'ğ'),
 (305, 'ı'),
 (351, 'ş')]

# Encode Decode

In [714]:
ctoi = {c:i for i,(o,c) in enumerate(ordinal_char_pairs)}
itoc =  {i:c for i,(o,c) in enumerate(ordinal_char_pairs)}

In [715]:
ctoi['.']=len(ctoi)
itoc[len(itoc)]='.'

In [716]:
def encode(input:str):
    return [ctoi[c] for c in input]

def decode(ids):
    return "".join([itoc[i] for i in ids])

In [717]:
assert decode(encode("zemberek"))=='zemberek'

## n-Gram Model

In [718]:
class NGram:
    def __init__(self, n):
        self.n = n
        self.probs = None
        self.all_probs = None

    def train(self, words):
        ids_list = get_data(words, n=self.n)

        counts = count(ids_list, n=self.n)

        counts = counts.float() + 1e-10  # smoothing

        probs = counts / counts.sum(axis=-1, keepdim=True)

        loss = -probs[tuple(ids_list.T)].log().mean()
        print(f"loss: {loss.item()}")

        self.all_probs = [probs]

        counts_ = counts
        for _ in range(self.n-2):
            counts_ = counts_.sum(dim=-1)
            probs_ = counts_ / counts_.sum(dim=-1, keepdim=True)
            self.all_probs.append(probs_)

        self.probs = probs

    def generate(self):
        chars = [ctoi['.']]
        
        while True:
            back_idx = -min(len(chars), len(self.all_probs))
            probs = self.all_probs[back_idx]

            ids = chars[back_idx:]
            p = probs[tuple(ids)]

            assert p.shape == (30,)

            new_char = torch.multinomial(p, 1, replacement=True)

            new_char = new_char.item()

            if new_char == ctoi['.']:
                return decode(chars[1:])
            
            chars.append(new_char)

In [719]:
nn = NGram(2)

In [720]:
nn.train(words)
# loss: 2.522550344467163

loss: 2.522550344467163


In [721]:
assert len(nn.all_probs) == 1
assert nn.all_probs[-1].shape == (30, 30)

In [722]:
n3 = NGram(3)

In [723]:
n3.train(words)
# loss: 2.0314273834228516

loss: 2.0314273834228516


In [724]:
assert len(n3.all_probs) == 2
assert n3.all_probs[-1].shape == (30, 30)

In [725]:
n4 = NGram(4)

In [726]:
n4.train(words)
# loss: 1.5213545560836792

loss: 1.5213545560836792


In [727]:
torch.manual_seed(35)
for i in range(10):
    print(nn.generate())

atettı
ik
öncırısörarsı
kakl
satsabiçinorekom
bil
stlllıdaşapek
k
s
ir


In [728]:
torch.manual_seed(35)
for i in range(10):
    print(n3.generate())

atettı
ik
ölülhası
arsız
akl
satsabiliyonuk
müttırmaklak
düğmek
kor
ih


In [729]:
torch.manual_seed(35)
for i in range(10):
    print(n4.generate())

ateştirik
ölümsesuraksız
aklata
sabuçunutukla
biliktenleşme
pekleşmek
yoğuk
poloji
şantık
çarkeci
