In [None]:
import torch
import torch.nn as nn
import pickle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, dropout=0.3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=2,
                           batch_first=True, bidirectional=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.bridge_h = nn.Linear(hidden_dim * 2, hidden_dim)
        self.bridge_c = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, x, lengths=None):
        embedded = self.dropout(self.embedding(x))
        if lengths is not None:
            packed = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
            )
            outputs, (hidden, cell) = self.lstm(packed)
            outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        else:
            outputs, (hidden, cell) = self.lstm(embedded)

        hidden_fwd = hidden[-2]
        hidden_bwd = hidden[-1]
        hidden_combined = torch.cat([hidden_fwd, hidden_bwd], dim=1)
        cell_fwd = cell[-2]
        cell_bwd = cell[-1]
        cell_combined = torch.cat([cell_fwd, cell_bwd], dim=1)

        hidden_dec = torch.tanh(self.bridge_h(hidden_combined)).unsqueeze(0)
        cell_dec = torch.tanh(self.bridge_c(cell_combined)).unsqueeze(0)
        return outputs, (hidden_dec, cell_dec)


class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, dropout=0.3, encoder_embedding=None):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        if encoder_embedding is not None:
            self.embedding = encoder_embedding
        else:
            self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=1, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden, cell):
        embedded = self.dropout(self.embedding(x))
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden, cell


class Seq2SeqSpeller(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, dropout=0.3):
        super().__init__()
        self.encoder = Encoder(vocab_size, embed_dim, hidden_dim, dropout)
        self.decoder = Decoder(vocab_size, embed_dim, hidden_dim, dropout,
                               encoder_embedding=self.encoder.embedding)
        self.vocab_size = vocab_size

    def forward(self, src, tgt, src_lens, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        max_len = tgt.size(1)
        outputs = torch.zeros(batch_size, max_len, self.vocab_size).to(src.device)
        _, (hidden, cell) = self.encoder(src, src_lens)
        input_token = tgt[:, 0:1]
        for t in range(1, max_len):
            output, hidden, cell = self.decoder(input_token, hidden, cell)
            outputs[:, t] = output
            input_token = output.argmax(1, keepdim=True)
        return outputs

In [None]:
class Vocabulary:
    def __init__(self):
        self.PAD = '<pad>'
        self.SOS = '<sos>'
        self.EOS = '<eos>'
        self.special_tokens = [self.PAD, self.SOS, self.EOS]
        self.chars = list('აბგდევზთიკლმნოპჟრსტუფქღყშჩცძწჭხჯჰ')
        self.token_to_idx = {}
        self.idx_to_token = {}
        all_tokens = self.special_tokens + self.chars
        for idx, token in enumerate(all_tokens):
            self.token_to_idx[token] = idx
            self.idx_to_token[idx] = token
        self.pad_idx = self.token_to_idx[self.PAD]
        self.sos_idx = self.token_to_idx[self.SOS]
        self.eos_idx = self.token_to_idx[self.EOS]
        self.vocab_size = len(all_tokens)

    def encode(self, word, add_sos=False, add_eos=False):
        indices = []
        if add_sos:
            indices.append(self.sos_idx)
        for char in word:
            if char in self.token_to_idx:
                indices.append(self.token_to_idx[char])
        if add_eos:
            indices.append(self.eos_idx)
        return indices

    def decode(self, indices):
        chars = []
        for idx in indices:
            if idx == self.eos_idx:
                break
            if idx not in [self.pad_idx, self.sos_idx]:
                chars.append(self.idx_to_token.get(idx, ''))
        return ''.join(chars)

In [None]:
with open('spellchecker_artifacts.pkl', 'rb') as f:
    artifacts = pickle.load(f)

vocab = artifacts['vocab']
config = artifacts['config']

print(f"Vocabulary size: {vocab.vocab_size}")
print(f"Config: {config}")
if 'training_info' in artifacts:
    print(f"Training info: {artifacts['training_info']}")

In [None]:
# Cell 5 - Load Model
def load_model(model_path, config, vocab):
    model = Seq2SeqSpeller(
        vocab.vocab_size,
        config['embed_dim'],
        config['hidden_dim'],
        config['dropout']
    )
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    return model

model = load_model('spellchecker_model.pt', config, vocab)
print("Model loaded successfully!")

In [None]:
def beam_search_decode(model, word, vocab, beam_width=5, max_len=None):
    if max_len is None:
        max_len = len(word) + 5

    model.eval()

    with torch.no_grad():
        src = torch.tensor(vocab.encode(word)).unsqueeze(0).to(device)
        src_len = torch.tensor([len(word)])

        _, (hidden, cell) = model.encoder(src, src_len)

        beams = [(0.0, [vocab.sos_idx], hidden, cell)]
        completed = []

        for _ in range(max_len):
            candidates = []

            for score, seq, h, c in beams:
                if seq[-1] == vocab.eos_idx:
                    completed.append((score / len(seq), seq))
                    continue

                inp = torch.tensor([[seq[-1]]]).to(device)
                output, new_h, new_c = model.decoder(inp, h, c)
                log_probs = torch.log_softmax(output, dim=-1)

                topk = log_probs.topk(beam_width)
                for prob, idx in zip(topk.values[0], topk.indices[0]):
                    candidates.append((
                        score + prob.item(),
                        seq + [idx.item()],
                        new_h, new_c
                    ))

            if not candidates:
                break

            candidates.sort(key=lambda x: x[0], reverse=True)
            beams = candidates[:beam_width]

        completed.extend([(s/len(seq), seq) for s, seq, _, _ in beams])

        if completed:
            best = max(completed, key=lambda x: x[0])
            return vocab.decode(best[1])

        return word

In [None]:
_cache = {'model': None, 'vocab': None}

def correct_word(word: str, model_path: str) -> str:
    """
    Takes a potentially misspelled Georgian word and returns the corrected version.
    """
    global _cache

    # Load model if not cached
    if _cache['model'] is None:
        with open('spellchecker_artifacts.pkl', 'rb') as f:
            artifacts = pickle.load(f)

        vocab = artifacts['vocab']
        config = artifacts['config']

        model = Seq2SeqSpeller(
            vocab.vocab_size,
            config['embed_dim'],
            config['hidden_dim'],
            config['dropout']
        )
        model.load_state_dict(torch.load(model_path, map_location=device))
        model = model.to(device)
        model.eval()

        _cache['model'] = model
        _cache['vocab'] = vocab

    model = _cache['model']
    vocab = _cache['vocab']

    # Handle edge cases
    if not word or len(word) < 1:
        return word

    # Filter non-Georgian characters
    georgian_chars = set('აბგდევზთიკლმნოპჟრსტუფქღყშჩცძწჭხჯჰ')
    if not all(c in georgian_chars for c in word):
        return word

    # Beam search decode
    result = beam_search_decode(model, word, vocab, beam_width=5)

    return result if result else word