In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
data_path = "/content/drive/MyDrive/NLP/spelling_dictionary.json"

In [15]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertModel
import json

class SpellingDataset(Dataset):
    def __init__(self, spelling_dict, tokenizer, max_length=32):  # Reduced max_length
        self.pairs = []
        for correct, misspellings in spelling_dict.items():
            for misspelling in misspellings:
                self.pairs.append((misspelling, correct))
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __getitem__(self, idx):
        misspelling, correct = self.pairs[idx]

        # Tokenize without special tokens for input
        misspelling_encoding = self.tokenizer(
            misspelling,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=False,  # No special tokens
            return_tensors='pt'
        )

        correct_encoding = self.tokenizer(
            correct,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=False,  # No special tokens
            return_tensors='pt'
        )

        return {
            'input_ids': misspelling_encoding['input_ids'].squeeze(),
            'attention_mask': misspelling_encoding['attention_mask'].squeeze(),
            'labels': correct_encoding['input_ids'].squeeze()
        }

    def __len__(self):
        return len(self.pairs)

class SpellingCorrector(nn.Module):
    def __init__(self, vocab_size, hidden_size=768):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Freeze BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False

        # Add character-level embedding
        self.char_embedding = nn.Embedding(vocab_size, 128)

        # Bidirectional LSTM for sequence modeling
        lstm_hidden_size = hidden_size  # This will be the size per direction
        self.lstm = nn.LSTM(
            input_size=hidden_size + 128,  # Combined BERT and char embeddings
            hidden_size=lstm_hidden_size,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
            dropout=0.2
        )

        # The LSTM output will be 2*hidden_size due to bidirectional
        lstm_output_size = lstm_hidden_size * 2

        # Linear layer before batch norm
        self.pre_norm = nn.Linear(lstm_output_size, hidden_size)

        # Batch norm on the hidden dimension
        self.batch_norm = nn.BatchNorm1d(hidden_size)

        # Output layer
        self.output = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size, vocab_size)
        )

    def forward(self, input_ids, attention_mask):
        # Get BERT embeddings
        bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]

        # Get character embeddings
        char_embeds = self.char_embedding(input_ids)

        # Combine embeddings
        combined = torch.cat([bert_output, char_embeds], dim=-1)

        # LSTM processing
        lstm_out, _ = self.lstm(combined)

        # Pre-norm linear layer
        hidden = self.pre_norm(lstm_out)

        # Reshape for batch norm
        batch_size, seq_len, hidden_dim = hidden.shape
        hidden = hidden.permute(0, 2, 1)  # [batch, hidden, seq_len]
        hidden = self.batch_norm(hidden)
        hidden = hidden.permute(0, 2, 1)  # [batch, seq_len, hidden]

        # Generate output probabilities
        logits = self.output(hidden)

        return logits

def train_and_evaluate(num_epochs=10, data_path='spelling_dictionary.json'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    with open(data_path, 'r') as f:
        spelling_dict = json.load(f)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = SpellingCorrector(tokenizer.vocab_size).to(device)

    # Create dataset with smaller max_length
    dataset = SpellingDataset(spelling_dict, tokenizer, max_length=32)

    # Split dataset
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32)

    # Training setup
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
    best_accuracy = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad()
            output = model(batch['input_ids'], batch['attention_mask'])

            # Reshape for loss calculation
            loss = criterion(
                output.view(-1, tokenizer.vocab_size),
                batch['labels'].view(-1)
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        # Evaluate
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in test_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(batch['input_ids'], batch['attention_mask'])

                predictions = outputs.argmax(dim=-1)
                correct += (predictions == batch['labels']).sum().item()
                total += batch['labels'].numel()

                # Print some examples
                if epoch % 2 == 0:
                    for pred, actual in zip(predictions[:2], batch['labels'][:2]):
                        pred_text = tokenizer.decode(pred, skip_special_tokens=True)
                        actual_text = tokenizer.decode(actual, skip_special_tokens=True)
                        print(f"Pred: {pred_text} | Actual: {actual_text}")

        accuracy = correct / total
        print(f"Accuracy: {accuracy:.4f}")

        scheduler.step(avg_loss)

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
            }, 'best_spelling_corrector.pth')

    return model, tokenizer




In [16]:
model, tokenizer = train_and_evaluate(
    num_epochs=10,  # You can adjust this
    data_path= data_path  # Path to your JSON file
)

Using device: cuda
Epoch 1, Average Loss: 7.4443
Pred: conscientiousddd'''''''''''''''''''''''inc inc | Actual: conscience
Pred: biscuitsesies sdd s'''''''''''''''''''''' approximately approximately approximately | Actual: whistling
Pred: acknowledgingmentmentmentdddd beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful approximately approximately subterranean | Actual: acknowledging
Pred: sauceateiesquadd beautiful beautiful beautiful'''''beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful beautiful'''' approximately approximately approximately approximately | Actual: anxious
Pred: guaranteeesesnnd guarantee guarantee guarantee'''' guarantee guarantee guarantee guarantee guarantee guarantee guarantee guarantee guarantee guarantee guarantee guarantee guarantee or or or or or or | Actual: purple

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertModel
import json

class SpellingDataset(Dataset):
    def __init__(self, spelling_dict, tokenizer, max_length=32):
        self.pairs = []
        for correct, misspellings in spelling_dict.items():
            for misspelling in misspellings:
                self.pairs.append((misspelling, correct))
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __getitem__(self, idx):
        misspelling, correct = self.pairs[idx]

        # Tokenize without special tokens for input
        misspelling_encoding = self.tokenizer(
            misspelling,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt'
        )

        correct_encoding = self.tokenizer(
            correct,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt'
        )

        return {
            'input_ids': misspelling_encoding['input_ids'].squeeze(),
            'attention_mask': misspelling_encoding['attention_mask'].squeeze(),
            'labels': correct_encoding['input_ids'].squeeze()
        }

    def __len__(self):
        return len(self.pairs)

class SpellingCorrector(nn.Module):
    def __init__(self, vocab_size, hidden_size=768):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Freeze BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False

        # Reduce embedding size for characters
        self.char_embedding = nn.Embedding(vocab_size, 64)  # Reduced from 128

        # Add position embeddings
        self.pos_embedding = nn.Parameter(torch.randn(1, 32, 64))  # For max_length=32

        # Transformer encoder layer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size + 64,  # Combined size
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)

        # Output layers with residual connections
        self.fc1 = nn.Linear(hidden_size + 64, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
        self.fc3 = nn.Linear(hidden_size // 2, vocab_size)

        self.layer_norm1 = nn.LayerNorm(hidden_size + 64)
        self.layer_norm2 = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask):
        # Get BERT embeddings
        bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]

        # Get character embeddings and add positional encoding
        char_embeds = self.char_embedding(input_ids)
        char_embeds = char_embeds + self.pos_embedding[:, :char_embeds.size(1), :]

        # Combine embeddings
        combined = torch.cat([bert_output, char_embeds], dim=-1)
        combined = self.layer_norm1(combined)

        # Transform through transformer
        transformed = self.transformer(combined, src_key_padding_mask=~attention_mask.bool())

        # Process through output layers with residual connections
        out = self.fc1(transformed)
        out = self.layer_norm2(out)
        out = self.dropout(F.relu(out))

        residual = out
        out = self.fc2(out)
        out = F.relu(out)
        out = self.dropout(out)

        out = self.fc3(out)

        return out

def train_and_evaluate(num_epochs=10, data_path='spelling_dictionary.json'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    with open(data_path, 'r') as f:
        spelling_dict = json.load(f)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = SpellingCorrector(tokenizer.vocab_size).to(device)

    # Create dataset
    dataset = SpellingDataset(spelling_dict, tokenizer, max_length=32)

    # Split dataset with better ratio
    train_size = int(0.9 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64)

    # Training setup
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=5, T_mult=2, eta_min=1e-6
    )

    best_accuracy = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad()
            output = model(batch['input_ids'], batch['attention_mask'])

            # Reshape for loss calculation
            loss = criterion(
                output.view(-1, tokenizer.vocab_size),
                batch['labels'].view(-1)
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        # Evaluate
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in test_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(batch['input_ids'], batch['attention_mask'])

                predictions = outputs.argmax(dim=-1)
                correct += (predictions == batch['labels']).sum().item()
                total += batch['labels'].numel()

                # Print some examples
                if epoch % 2 == 0:
                    for pred, actual in zip(predictions[:2], batch['labels'][:2]):
                        pred_text = tokenizer.decode(pred, skip_special_tokens=True)
                        actual_text = tokenizer.decode(actual, skip_special_tokens=True)
                        print(f"Pred: {pred_text} | Actual: {actual_text}")

        accuracy = correct / total
        print(f"Accuracy: {accuracy:.4f}")

        scheduler.step()

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
            }, 'best_spelling_corrector.pth')

    return model, tokenizer

# To train and evaluate:
model, tokenizer = train_and_evaluate(num_epochs=10, data_path= data_path)

Using device: cuda
Epoch 1, Average Loss: 8.3573
Pred: re'''''''''''''''''''''''''''''' ' | Actual: women
Pred: enthusiastic'''''''''''''''''''''''''''''' ' | Actual: analysis


  output = torch._nested_tensor_from_mask(


Pred: opposite'''''''''''''''''''''''''''''' ' | Actual: secondary
Pred: enthusiastic''ious'''''''''''''''''''''''''''' | Actual: principles
Pred: re'''''''''''''''''''''''''''''' ' | Actual: pilot's
Pred: con''ious'''''''''''''''''''''''''''' | Actual: accommodations
Pred: enthusiastic'''''''''''''''''''''''''''''' ' | Actual: card
Pred: con'veniousven'''''''''''''''''''''''''' ' | Actual: specifically
Pred: inc'''''''''''''''''''''''''''''' ' | Actual: analysis
Pred: inc''ious'''''''''''''''''''''''''''' | Actual: technical
Pred: enthusiastic'''''''''''''''''''''''''''''' ' | Actual: accustomed
Pred: answer'''''''''''''''''''''''''''''' ' | Actual: phase
Pred: inc''ious'''''''''''''''''''''''''''' | Actual: memorandum
Pred: answer'''''''''''''''''''''''''''''' ' | Actual: knife
Pred: con''ious'''''''''''''''''''''''''''' | Actual: manoeuvrable
Pred: material'''''''''''''''''''''''''''''' ' | Actual: quiz
Pred: inc''ious'''''''''''''''''''''''''''' | Actual: rheumatism
Pred: answer'''

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertModel
import json

class SpellingDataset(Dataset):
    def __init__(self, spelling_dict, tokenizer, max_length=32):
        self.pairs = []
        for correct, misspellings in spelling_dict.items():
            for misspelling in misspellings:
                self.pairs.append((misspelling, correct))
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __getitem__(self, idx):
        misspelling, correct = self.pairs[idx]

        misspelling_encoding = self.tokenizer(
            misspelling,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt'
        )

        correct_encoding = self.tokenizer(
            correct,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt'
        )

        return {
            'input_ids': misspelling_encoding['input_ids'].squeeze(),
            'attention_mask': misspelling_encoding['attention_mask'].squeeze(),
            'labels': correct_encoding['input_ids'].squeeze()
        }

    def __len__(self):
        return len(self.pairs)

class SpellingCorrector(nn.Module):
    def __init__(self, vocab_size, hidden_size=768):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Freeze only some BERT layers
        for param in list(self.bert.parameters())[:-2]:  # Keep last 2 layers trainable
            param.requires_grad = False

        self.char_embedding = nn.Embedding(vocab_size, 64)

        self.encoder = nn.GRU(
            input_size=hidden_size + 64,
            hidden_size=hidden_size,
            num_layers=2,
            dropout=0.1,
            bidirectional=True,
            batch_first=True
        )

        self.output = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, vocab_size)
        )

    def forward(self, input_ids, attention_mask):
        # Get BERT embeddings
        bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]

        # Get character embeddings
        char_embeds = self.char_embedding(input_ids)

        # Combine embeddings
        combined = torch.cat([bert_output, char_embeds], dim=-1)

        # Encode sequence
        encoded, _ = self.encoder(combined)

        # Generate output probabilities
        logits = self.output(encoded)

        return logits

def train_and_evaluate(num_epochs=10, data_path='spelling_dictionary.json'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    with open(data_path, 'r') as f:
        spelling_dict = json.load(f)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = SpellingCorrector(tokenizer.vocab_size).to(device)

    dataset = SpellingDataset(spelling_dict, tokenizer, max_length=32)

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32)

    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    optimizer = torch.optim.AdamW([
        {'params': model.bert.parameters(), 'lr': 1e-5},
        {'params': model.char_embedding.parameters(), 'lr': 1e-3},
        {'params': model.encoder.parameters(), 'lr': 1e-3},
        {'params': model.output.parameters(), 'lr': 1e-3}
    ])

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=[1e-5, 1e-3, 1e-3, 1e-3],
        epochs=num_epochs,
        steps_per_epoch=len(train_loader)
    )

    best_accuracy = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad()
            output = model(batch['input_ids'], batch['attention_mask'])

            loss = criterion(
                output.view(-1, tokenizer.vocab_size),
                batch['labels'].view(-1)
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in test_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(batch['input_ids'], batch['attention_mask'])

                predictions = outputs.argmax(dim=-1)
                correct += (predictions == batch['labels']).sum().item()
                total += batch['labels'].numel()

                if epoch % 2 == 0:
                    for pred, actual in zip(predictions[:2], batch['labels'][:2]):
                        pred_text = tokenizer.decode(pred, skip_special_tokens=True)
                        actual_text = tokenizer.decode(actual, skip_special_tokens=True)
                        print(f"Pred: {pred_text} | Actual: {actual_text}")

        accuracy = correct / total
        print(f"Accuracy: {accuracy:.4f}")

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
            }, 'best_spelling_corrector.pth')

    return model, tokenizer
model, tokenizer = train_and_evaluate(num_epochs=10, data_path= data_path)

Using device: cuda
Epoch 1, Average Loss: 7.8209
Pred: thin'sddddddddddddddddddddddddddddd | Actual: next
Pred: anniversaryonveniencedddddddddddddddddddddddddddd | Actual: overshadowed
Pred: fraternallyddddddddddddddddddddddddddddd | Actual: frightened
Pred: shaggy''ddddddddddddddddddddddddddddd | Actual: judge
Pred: vaenanceddddddddddddddddddddddddddddd | Actual: chapter
Pred: sympatheticlynceddddddddddddddddddddddddddddd | Actual: sympathetic
Pred: shaggy 'bonddddddddddddddddddddddddddddd | Actual: opposite
Pred: magnificentheientiousdddddddddddddddddddddddddddd | Actual: documents
Pred: exaggerateddddddddddddddddddddddddddddd | Actual: existence
Pred: guaranteelyciesddddddddddddddddddddddddddddd | Actual: dangerous
Pred: sound'siousdddddddddddddddddddddddddddd | Actual: class
Pred: enthusiasmesnceiousdddddddddddddddddddddddddddd | Actual: memories
Pred: shaggy''ddddddddddddddddddddddddddddd | Actual: married
Pred: recently 'bonddddddddddddddddddddddddddddd | Actual: standard
Pred: e

In [24]:
from transformers import get_linear_schedule_with_warmup
class SpellingCorrector(nn.Module):
    def __init__(self, vocab_size, hidden_size=768, max_length=32):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.max_length = max_length

        # Freeze BERT
        for param in self.bert.parameters():
            param.requires_grad = False

        self.char_embedding = nn.Embedding(vocab_size, 32)

        self.encoder = nn.GRU(
            input_size=hidden_size + 32,
            hidden_size=256,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )

        # Remove the attention mechanism to simplify
        self.output = nn.Sequential(
            nn.Linear(512, vocab_size),
            nn.LogSoftmax(dim=-1)
        )

    def forward(self, input_ids, attention_mask):
        # Get BERT embeddings
        with torch.no_grad():
            bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]

        # Get character embeddings
        char_embeds = self.char_embedding(input_ids)

        # Combine embeddings
        combined = torch.cat([bert_output, char_embeds], dim=-1)

        # Encode sequence
        encoded, _ = self.encoder(combined)

        # Generate output probabilities
        logits = self.output(encoded)

        return logits

def train_and_evaluate(num_epochs=10, data_path='spelling_dictionary.json'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    with open(data_path, 'r') as f:
        spelling_dict = json.load(f)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = SpellingCorrector(tokenizer.vocab_size).to(device)

    dataset = SpellingDataset(spelling_dict, tokenizer, max_length=32)

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=16)

    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    num_training_steps = num_epochs * len(train_loader)
    num_warmup_steps = num_training_steps // 10

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    best_accuracy = 0
    patience = 3
    no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad()
            output = model(batch['input_ids'], batch['attention_mask'])

            # Reshape for loss calculation to match dimensions
            loss = criterion(
                output.contiguous().view(-1, tokenizer.vocab_size),
                batch['labels'].contiguous().view(-1)
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in test_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(batch['input_ids'], batch['attention_mask'])

                # Get predictions for each position in the sequence
                predictions = outputs.argmax(dim=-1)

                # Mask out padding tokens
                mask = batch['labels'] != tokenizer.pad_token_id
                correct += ((predictions == batch['labels']) & mask).sum().item()
                total += mask.sum().item()

                if epoch % 2 == 0:
                    for pred, actual in zip(predictions[:2], batch['labels'][:2]):
                        pred_text = tokenizer.decode(pred, skip_special_tokens=True)
                        actual_text = tokenizer.decode(actual, skip_special_tokens=True)
                        print(f"Pred: {pred_text} | Actual: {actual_text}")

        accuracy = correct / total if total > 0 else 0
        print(f"Accuracy: {accuracy:.4f}")

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            no_improve = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
            }, 'best_spelling_corrector.pth')
        else:
            no_improve += 1
            if no_improve >= patience:
                print("Early stopping triggered")
                break

    return model, tokenizer

model, tokenizer = train_and_evaluate(num_epochs=10, data_path= data_path)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Pred: court 'iousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousious | Actual: capacity
Pred: psychology'siousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousious | Actual: occurrence
Pred: court 'iousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousious | Actual: involving
Pred: magnificent 'iousiousiousdiousiousiousiousiousddddddiousiousiousiousiousdddddddddd | Actual: refrigerator
Pred: miscellaneous'siousdddddddddddddddddddddddddddd | Actual: initial
Pred: con 'iousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousdd | Actual: possibility
Pred: ''siousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousiousious | Actual: allo

In [25]:
class SpellingCorrector(nn.Module):
    def __init__(self, vocab_size, hidden_size=768):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Freeze BERT
        for param in self.bert.parameters():
            param.requires_grad = False

        self.char_embedding = nn.Embedding(vocab_size, 32)

        # Use a single GRU layer
        self.gru = nn.GRU(
            input_size=hidden_size + 32,
            hidden_size=256,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )

        # Global average pooling followed by output layers
        self.output = nn.Sequential(
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, vocab_size)
        )

    def forward(self, input_ids, attention_mask):
        # Get BERT embeddings
        with torch.no_grad():
            bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]

        # Get character embeddings
        char_embeds = self.char_embedding(input_ids)

        # Combine embeddings
        combined = torch.cat([bert_output, char_embeds], dim=-1)

        # Process through GRU
        gru_out, _ = self.gru(combined)

        # Global average pooling
        pooled = torch.mean(gru_out * attention_mask.unsqueeze(-1), dim=1)

        # Generate single word prediction
        logits = self.output(pooled)

        return logits

def train_and_evaluate(num_epochs=10, data_path='spelling_dictionary.json'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    with open(data_path, 'r') as f:
        spelling_dict = json.load(f)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = SpellingCorrector(tokenizer.vocab_size).to(device)

    # Modify the dataset to use only first token of correct word
    class ModifiedSpellingDataset(Dataset):
        def __init__(self, spelling_dict, tokenizer, max_length=32):
            self.pairs = []
            for correct, misspellings in spelling_dict.items():
                # Get the first token ID of the correct word
                correct_id = tokenizer(
                    correct,
                    padding='max_length',
                    max_length=1,
                    truncation=True,
                    add_special_tokens=False,
                    return_tensors='pt'
                )['input_ids'][0][0]

                for misspelling in misspellings:
                    self.pairs.append((misspelling, correct_id))
            self.tokenizer = tokenizer
            self.max_length = max_length

        def __getitem__(self, idx):
            misspelling, correct_id = self.pairs[idx]

            misspelling_encoding = self.tokenizer(
                misspelling,
                padding='max_length',
                max_length=self.max_length,
                truncation=True,
                add_special_tokens=False,
                return_tensors='pt'
            )

            return {
                'input_ids': misspelling_encoding['input_ids'].squeeze(),
                'attention_mask': misspelling_encoding['attention_mask'].squeeze(),
                'labels': correct_id
            }

        def __len__(self):
            return len(self.pairs)

    dataset = ModifiedSpellingDataset(spelling_dict, tokenizer, max_length=32)

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    num_training_steps = num_epochs * len(train_loader)
    num_warmup_steps = num_training_steps // 10

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    best_accuracy = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad()
            output = model(batch['input_ids'], batch['attention_mask'])

            loss = criterion(output, batch['labels'])

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in test_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(batch['input_ids'], batch['attention_mask'])

                predictions = outputs.argmax(dim=-1)
                correct += (predictions == batch['labels']).sum().item()
                total += len(predictions)

                if epoch % 2 == 0:
                    for pred, actual in zip(predictions[:2], batch['labels'][:2]):
                        pred_text = tokenizer.decode([pred], skip_special_tokens=True)
                        actual_text = tokenizer.decode([actual], skip_special_tokens=True)
                        print(f"Pred: {pred_text} | Actual: {actual_text}")

        accuracy = correct / total
        print(f"Accuracy: {accuracy:.4f}")

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
            }, 'best_spelling_corrector.pth')

    return model, tokenizer
model, tokenizer = train_and_evaluate(num_epochs=10, data_path= data_path)

Using device: cuda
Epoch 1, Average Loss: 9.0999
Pred: especially | Actual: base
Pred: miscellaneous | Actual: suspicion
Pred: miscellaneous | Actual: chapter
Pred: miscellaneous | Actual: peculiar
Pred: miscellaneous | Actual: indefinite
Pred: enthusiasm | Actual: avoid
Pred: miscellaneous | Actual: tell
Pred: miscellaneous | Actual: trial
Pred: miscellaneous | Actual: technically
Pred: miscellaneous | Actual: general
Pred: miscellaneous | Actual: smoking
Pred: miscellaneous | Actual: jam
Pred: miscellaneous | Actual: region
Pred: miscellaneous | Actual: gradually
Pred: miscellaneous | Actual: annoyance
Pred: miscellaneous | Actual: street
Pred: miscellaneous | Actual: fundamental
Pred: miscellaneous | Actual: magnificent
Pred: miscellaneous | Actual: recipe
Pred: miscellaneous | Actual: con
Pred: miscellaneous | Actual: above
Pred: miscellaneous | Actual: juice
Pred: miscellaneous | Actual: execution
Pred: miscellaneous | Actual: initiation
Pred: miscellaneous | Actual: un
Pred: misc

In [28]:
!pip install python-Levenshtein

Collecting python-Levenshtein
  Downloading python_Levenshtein-0.26.1-py3-none-any.whl.metadata (3.7 kB)
Collecting Levenshtein==0.26.1 (from python-Levenshtein)
  Downloading levenshtein-0.26.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.2 kB)
Collecting rapidfuzz<4.0.0,>=3.9.0 (from Levenshtein==0.26.1->python-Levenshtein)
  Downloading rapidfuzz-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading python_Levenshtein-0.26.1-py3-none-any.whl (9.4 kB)
Downloading levenshtein-0.26.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.7/162.7 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m97.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages

In [29]:
class SpellingCorrector(nn.Module):
    def __init__(self, vocab_size, hidden_size=768):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Freeze BERT
        for param in self.bert.parameters():
            param.requires_grad = False

        # Character and n-gram embeddings
        self.char_embedding = nn.Embedding(vocab_size, 32)
        self.ngram_embedding = nn.Embedding(vocab_size, 32)
        self.edit_dist_embedding = nn.Embedding(10, 32)

        # Feature combination
        self.feature_combine = nn.Sequential(
            nn.Linear(hidden_size + 96, hidden_size),  # 96 = 32 * 3 (char + ngram + edit)
            nn.LayerNorm(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1)
        )

        # BiGRU encoder
        self.encoder = nn.GRU(
            input_size=hidden_size,
            hidden_size=256,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
            dropout=0.1
        )

        # Output layers
        self.output = nn.Sequential(
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, vocab_size)
        )

    def forward(self, input_ids, attention_mask, ngram_ids=None, edit_distances=None):
        # Get BERT embeddings
        with torch.no_grad():
            bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]

        # Get character embeddings
        char_embeds = self.char_embedding(input_ids)
        char_embeds = torch.mean(char_embeds, dim=1)

        # Get n-gram embeddings
        if ngram_ids is not None:
            ngram_embeds = self.ngram_embedding(ngram_ids)
            ngram_embeds = torch.mean(ngram_embeds, dim=1)
        else:
            ngram_embeds = torch.zeros(input_ids.size(0), 32, device=input_ids.device)

        # Get edit distance embeddings
        if edit_distances is not None:
            edit_embeds = self.edit_dist_embedding(edit_distances)
        else:
            edit_embeds = torch.zeros(input_ids.size(0), 32, device=input_ids.device)

        # Get BERT sequence representation
        bert_pooled = torch.mean(bert_output * attention_mask.unsqueeze(-1), dim=1)

        # Combine all features
        combined = torch.cat([
            bert_pooled,
            char_embeds,
            ngram_embeds,
            edit_embeds
        ], dim=-1)

        # Process through feature combiner
        features = self.feature_combine(combined)
        features = features.unsqueeze(1).expand(-1, input_ids.size(1), -1)

        # Encode sequence
        encoded, _ = self.encoder(features)

        # Global average pooling
        pooled = torch.mean(encoded * attention_mask.unsqueeze(-1), dim=1)

        # Generate logits
        logits = self.output(pooled)

        return logits



import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
import json
from Levenshtein import distance  # You might need to install this: pip install python-Levenshtein

def get_char_ngrams(word, n=3):
    if len(word) < n:
        return [word]
    return [word[i:i+n] for i in range(len(word)-n+1)]

class SpellingDataset(Dataset):
    def __init__(self, spelling_dict, tokenizer, max_length=32):
        self.pairs = []
        for correct, misspellings in spelling_dict.items():
            # Get target token ID
            correct_id = tokenizer(
                correct,
                add_special_tokens=False,
                return_tensors='pt'
            )['input_ids'][0][0]

            for misspelling in misspellings:
                # Get character n-grams
                ngrams = get_char_ngrams(misspelling)
                # Convert to IDs, handling unknown tokens
                ngram_ids = []
                for ngram in ngrams:
                    try:
                        ngram_id = tokenizer.convert_tokens_to_ids(ngram)
                        ngram_ids.append(ngram_id)
                    except:
                        ngram_ids.append(0)

                # Pad ngram_ids
                ngram_ids = ngram_ids[:max_length]
                ngram_ids.extend([0] * (max_length - len(ngram_ids)))

                # Calculate edit distance
                edit_dist = min(distance(misspelling, correct), 9)

                self.pairs.append((
                    misspelling,
                    correct_id,
                    torch.tensor(ngram_ids),
                    edit_dist
                ))

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __getitem__(self, idx):
        misspelling, correct_id, ngram_ids, edit_dist = self.pairs[idx]

        encoding = self.tokenizer(
            misspelling,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(correct_id),
            'ngram_ids': ngram_ids,
            'edit_distances': torch.tensor(edit_dist)
        }

    def __len__(self):
        return len(self.pairs)



def train_and_evaluate(num_epochs=10, data_path='spelling_dictionary.json'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    with open(data_path, 'r') as f:
        spelling_dict = json.load(f)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = SpellingCorrector(tokenizer.vocab_size).to(device)

    dataset = SpellingDataset(spelling_dict, tokenizer, max_length=32)

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

    num_training_steps = num_epochs * len(train_loader)
    num_warmup_steps = num_training_steps // 10

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    best_accuracy = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad()
            output = model(
                batch['input_ids'],
                batch['attention_mask'],
                batch['ngram_ids'],
                batch['edit_distances']
            )

            loss = criterion(output, batch['labels'])

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in test_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(
                    batch['input_ids'],
                    batch['attention_mask'],
                    batch['ngram_ids'],
                    batch['edit_distances']
                )

                predictions = outputs.argmax(dim=-1)
                correct += (predictions == batch['labels']).sum().item()
                total += len(predictions)

                if epoch % 2 == 0:
                    for pred, actual in zip(predictions[:2], batch['labels'][:2]):
                        pred_text = tokenizer.decode([pred], skip_special_tokens=True)
                        actual_text = tokenizer.decode([actual], skip_special_tokens=True)
                        print(f"Pred: {pred_text} | Actual: {actual_text}")

        accuracy = correct / total
        print(f"Accuracy: {accuracy:.4f}")

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
            }, 'best_spelling_corrector.pth')

    return model, tokenizer
model, tokenizer = train_and_evaluate(num_epochs=10, data_path= data_path)


Using device: cuda


  'labels': torch.tensor(correct_id),


Epoch 1, Average Loss: 9.9641
Pred: facilities | Actual: comparatively
Pred: con | Actual: schemes
Pred: con | Actual: del
Pred: facilities | Actual: tournament
Pred: con | Actual: mysterious
Pred: con | Actual: ruler
Pred: con | Actual: violence
Pred: facilities | Actual: ru
Pred: con | Actual: memory
Pred: facilities | Actual: un
Pred: con | Actual: phase
Pred: facilities | Actual: examination
Pred: facilities | Actual: accommodations
Pred: facilities | Actual: del
Pred: con | Actual: rope
Pred: con | Actual: city
Pred: con | Actual: enjoy
Pred: facilities | Actual: strategy
Pred: con | Actual: altitude
Pred: con | Actual: climb
Pred: facilities | Actual: extraordinary
Pred: facilities | Actual: expenditure
Pred: facilities | Actual: criticism
Pred: facilities | Actual: folk
Pred: con | Actual: admitted
Pred: con | Actual: who
Pred: facilities | Actual: opposite
Pred: facilities | Actual: acquaintance
Pred: con | Actual: initiation
Pred: con | Actual: material
Pred: facilities | Actu

In [30]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import BertTokenizer, BertModel
import json

class SpellingDataset(Dataset):
    def __init__(self, spelling_dict, tokenizer, max_length=32):
        self.pairs = []
        for correct, misspellings in spelling_dict.items():
            # Convert correct word to token IDs
            correct_encoding = tokenizer(
                correct,
                padding='max_length',
                max_length=max_length,
                truncation=True,
                add_special_tokens=False,
                return_tensors='pt'
            )

            for misspelling in misspellings:
                self.pairs.append((misspelling, correct))

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __getitem__(self, idx):
        misspelling, correct = self.pairs[idx]

        # Encode misspelling
        misspelling_encoding = self.tokenizer(
            misspelling,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt'
        )

        # Encode correct word
        correct_encoding = self.tokenizer(
            correct,
            padding='max_length',
            max_length=self.max_length,
            truncation=True,
            add_special_tokens=False,
            return_tensors='pt'
        )

        return {
            'input_ids': misspelling_encoding['input_ids'].squeeze(),
            'attention_mask': misspelling_encoding['attention_mask'].squeeze(),
            'labels': correct_encoding['input_ids'].squeeze()
        }

    def __len__(self):
        return len(self.pairs)

class SpellingCorrector(nn.Module):
    def __init__(self, vocab_size, hidden_size=768):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

        # Freeze BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False

        # Simple GRU decoder
        self.decoder = nn.GRU(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True
        )

        # Output projection
        self.output = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_ids, attention_mask):
        # Get BERT embeddings
        bert_output = self.bert(input_ids, attention_mask=attention_mask)[0]

        # Decode
        decoded, _ = self.decoder(bert_output)

        # Project to vocabulary
        logits = self.output(decoded)

        return logits

def train_and_evaluate(num_epochs=10, data_path='spelling_dictionary.json'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    with open(data_path, 'r') as f:
        spelling_dict = json.load(f)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = SpellingCorrector(tokenizer.vocab_size).to(device)

    # Create dataset
    dataset = SpellingDataset(spelling_dict, tokenizer)

    # Split dataset
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=16)

    # Training setup
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_accuracy = 0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad()
            output = model(batch['input_ids'], batch['attention_mask'])

            # Reshape for loss calculation
            loss = criterion(
                output.view(-1, tokenizer.vocab_size),
                batch['labels'].view(-1)
            )

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Average Loss: {avg_loss:.4f}")

        # Evaluate
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch in test_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(batch['input_ids'], batch['attention_mask'])

                # Only compare first token for accuracy
                predictions = outputs[:, 0].argmax(dim=-1)
                correct += (predictions == batch['labels'][:, 0]).sum().item()
                total += len(predictions)

                # Print examples
                if epoch % 2 == 0:
                    for pred, actual in zip(predictions[:2], batch['labels'][:2]):
                        pred_text = tokenizer.decode([pred], skip_special_tokens=True)
                        actual_text = tokenizer.decode([actual[0]], skip_special_tokens=True)
                        print(f"Pred: {pred_text} | Actual: {actual_text}")

        accuracy = correct / total
        print(f"Accuracy: {accuracy:.4f}")

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'accuracy': accuracy,
            }, 'best_spelling_corrector.pth')

    return model, tokenizer

# Example usage:
# model, tokenizer = train_and_evaluate(num_epochs=10, data_path='your_spelling_dictionary.json')

model, tokenizer = train_and_evaluate(num_epochs=10, data_path= data_path)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Pred: destination | Actual: temporarily
Pred: magnificent | Actual: rainy
Pred: magnificent | Actual: phase
Pred: inc | Actual: inc
Pred: inc | Actual: opportunities
Pred: guarantee | Actual: happy
Pred: magnificent | Actual: punishment
Pred: inc | Actual: attorneys
Pred: magnificent | Actual: anything
Pred: miscellaneous | Actual: thin
Pred: miscellaneous | Actual: extraordinary
Pred: magnificent | Actual: pilot
Pred: inc | Actual: sneaking
Pred: enthusiasm | Actual: phase
Pred: inc | Actual: as
Pred: magnificent | Actual: genius
Pred: guarantee | Actual: referring
Pred: con | Actual: cancelled
Pred: enthusiasm | Actual: una
Pred: magnificent | Actual: driven
Pred: magnificent | Actual: special
Pred: con | Actual: exhibition
Pred: enthusiasm | Actual: sometimes
Pred: enthusiasm | Actual: thousand
Pred: magnificent | Actual: organization
Pred: con | Actual: playground
Pred: enthusiasm | Actual: accompaniment
Pred: unanimo

In [4]:
import shutil

# Define the source and destination paths
source_path = 'spelling_corrector_model.pth'  # Path to the saved model
destination_path = '/content/drive/MyDrive/NLP/spelling_corrector_model.pth'  # Destination path in your Drive

# Copy the file
shutil.copy(source_path, destination_path)

print(f"Model copied to: {destination_path}")

FileNotFoundError: [Errno 2] No such file or directory: 'spelling_corrector_model.pth'