In [12]:
import torch

class NGramModel:
    def __init__(self, N, words):
        assert N >= 1, "N must be at least 1"
        self.N = N
        self.words = words

        # Build vocabulary
        chars = sorted(list(set(''.join(words))))
        self.stoi = {s: i+1 for i, s in enumerate(chars)}
        self.stoi['.'] = 0
        self.itos = {i: s for s, i in self.stoi.items()}
        self.vocab_size = len(self.stoi)

        # Count N-grams
        shape = tuple([self.vocab_size] * N)
        self.N_counts = torch.zeros(shape, dtype=torch.int16)
        for w in words:
            chs = ['.'] * (N-1) + list(w) + ['.']
            for gram in zip(*[chs[i:] for i in range(N)]):
                idxs = [self.stoi[ch] for ch in gram]
                self.N_counts[tuple(idxs)] += 1

        # Compute probabilities
        self.P = self.N_counts.float()
        # Normalize along the last axis
        self.P /= self.P.sum(dim=-1, keepdim=True)
        # Replace NaNs (from division by zero) with zeros
        self.P = torch.nan_to_num(self.P, nan=0.0)
        self.P = self.P.to(torch.float16)

    def sample(self, num_samples=1, seed=None):
        g = None
        if seed is not None:
            g = torch.Generator()
            g.manual_seed(seed)
        results = []
        for _ in range(num_samples):
            out = []
            context = [0] * (self.N - 1)
            while True:
                p = self.P[tuple(context)]
                if g is not None:
                    ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
                else:
                    ix = torch.multinomial(p, num_samples=1, replacement=True).item()
                out.append(self.itos[ix])
                if ix == 0:
                    break
                context = context[1:] + [ix] if self.N > 1 else []
            results.append(''.join(out))
        return results

In [13]:
# Example usage:
words = open('names 2.txt', 'r').read().splitlines()
model = NGramModel(N=6, words=words)

In [24]:
samples = model.sample(num_samples=20)
for s in samples:
    print(s)

gimena.
mea.
glynis.
aubree.
binny.
bobbe.
rayna.
kuswara.
sydel.
reggie.
loki.
patchit.
young.
viviana.
vincente.
bienvenido.
blythe.
devan.
ramin.
kourtnay.
