In [431]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import re
import numpy as np
from typing import List, Tuple, Dict
import random

In [432]:
class DataGenerator:
    def __init__(self):
        self.names = ['john', 'jane', 'bob', 'alice', 'charlie', 'diana', 'jamy12354']
        self.domains = ['gmail.com', 'yahoo.com', 'outlook.com', 'company.com']
        
    def generate_phone(self):
        area = f"{random.randint(100,999)}"
        exchange = f"{random.randint(100,999)}"
        number = f"{random.randint(1000,9999)}"
        
        formats = [
            f"({area}) {exchange}-{number}",
            f"{area}-{exchange}-{number}",
            f"{area}.{exchange}.{number}",
            f"{area} {exchange} {number}"
        ]
        return random.choice(formats)
    
    def generate_email(self):
        name = random.choice(self.names)
        num = random.randint(1, 99)
        domain = random.choice(self.domains)
        return f"{name}{num}@{domain}"
    
    def generate_sample(self):
        phone_templates = [
            "call me at {entity}",
            "my number is {entity}",
            "contact me at my {entity} phone for info",
            "reach out to {entity}",
            "you can call {entity}",
            "{entity} is my contact",
            "here is my contact info: {entity}",            
        ]

        email_templates = [
            "contact my by sending an email at {entity} for info",
            "email me at {entity} email address",
            "email {entity} address",
            "reach out to {entity}",
            "send email to {entity}",
            "{entity} is my contact",
            "here is my contact info: {entity}",            
        ]
        
        if random.random() < 0.5:
            entity = self.generate_phone()
            entity_type = "PHONE"
            templates = phone_templates
        else:
            entity = self.generate_email()
            entity_type = "EMAIL"
            templates = email_templates
            
        template = random.choice(templates)
        text = template.format(entity=entity)
        
        return text, entity, entity_type

In [433]:
class SimpleTokenizer:
    def __init__(self, max_length=64):
        self.word_to_idx = {'<PAD>': 0, '<UNK>': 1, '<NUMBER>': 2, '<LETTER>': 3, '<SPECIAL>': 4}
        self.idx_to_word = {0: '<PAD>', 1: '<UNK>', 2: '<NUMBER>', 3: '<LETTER>', 4: '<SPECIAL>'}
        self.vocab_size = 5
        self.max_length = max_length
        
        # Pre-add common characters and numbers to avoid UNK
        all_chars = list('abcdefghijklmnopqrstuvwxyz0123456789.-@():/+')
        
        self.special_chars = list('.-@():/+')
        self.letter_chars = list('abcdefghijklmnopqrstuvwxyz')
        self.number_chars = list('0123456789')
        
        for char in all_chars:            
            if char not in self.word_to_idx:
                self.word_to_idx[char] = self.vocab_size
                self.idx_to_word[self.vocab_size] = char
                self.vocab_size += 1
        
    def build_vocab(self, texts):
        words = set()
        for text in texts:
            tokens = self.tokenize(text)
            words.update(tokens)
        
        # Add all unique tokens to vocabulary
        for word in sorted(words):
            if word not in self.word_to_idx:
                self.word_to_idx[word] = self.vocab_size
                self.idx_to_word[self.vocab_size] = word
                self.vocab_size += 1
                
        print(f"Built vocabulary with {self.vocab_size} tokens")
        print(f"Sample vocab: {list(self.word_to_idx.keys())}")
    
    def tokenize(self, text):
        # Character-level tokenization for better handling of numbers and punctuation
        text = text.lower().strip()
        tokens = []
        
        # Split into words first, then handle each word
        words = text.split()

        for i, word in enumerate(words):
            if i > 0:  # Add space token between words
                tokens.append(' ')
            
            # For each word, split into meaningful chunks
            current_token = ""
            for char in word:
                if char in self.number_chars:
                    current_token += '<NUMBER>'
                if char in self.letter_chars:
                    current_token += '<LETTER>'
                if char in self.special_chars:
                    current_token += char
                else:
                    # Punctuation - add current token and the punctuation
                    if current_token:
                        tokens.append(current_token)
                        current_token = ""
                    tokens.append(char)
            
            # Add any remaining token
            if current_token:
                tokens.append(current_token)
        
        return tokens
    
    def encode(self, text):
        tokens = self.tokenize(text)
        indices = [self.word_to_idx.get(token, 1) for token in tokens]
        
        # Pad or truncate
        if len(indices) < self.max_length:
            indices.extend([0] * (self.max_length - len(indices)))
        else:
            indices = indices[:self.max_length]
            
        return indices, len([idx for idx in indices if idx != 0])  # Return actual length
    
    def decode(self, indices):
        return [self.idx_to_word.get(idx, '<UNK>') for idx in indices if idx != 0]


In [434]:
def create_labels_for_text(text, entity, entity_type, tokenizer):
    """Create BIO labels for a text given the entity and its type"""
    tokens = tokenizer.tokenize(text)
    entity_tokens = tokenizer.tokenize(entity)
    
    labels = [0] * len(tokens)  # Start with all O (outside)
    
    # More robust entity matching - look for subsequences
    def find_entity_in_tokens(tokens, entity_tokens):
        for i in range(len(tokens) - len(entity_tokens) + 1):
            # Check if entity tokens match (ignoring spaces)
            entity_idx = 0
            match_positions = []
            
            for j in range(i, len(tokens)):
                if entity_idx >= len(entity_tokens):
                    break
                    
                if tokens[j] == entity_tokens[entity_idx]:
                    match_positions.append(j)
                    entity_idx += 1
                elif tokens[j] == ' ':
                    continue  # Skip spaces
                else:
                    break
            
            if entity_idx == len(entity_tokens):
                return match_positions
        
        return []
    
    # Find entity positions
    entity_positions = find_entity_in_tokens(tokens, entity_tokens)
    
    if entity_positions:
        # Mark the entity tokens
        if entity_type == "PHONE":
            labels[entity_positions[0]] = 1  # B-PHONE
            for pos in entity_positions[1:]:
                labels[pos] = 2  # I-PHONE
        else:  # EMAIL
            labels[entity_positions[0]] = 3  # B-EMAIL
            for pos in entity_positions[1:]:
                labels[pos] = 4  # I-EMAIL
    
    # Pad labels to match tokenizer max_length
    while len(labels) < tokenizer.max_length:
        labels.append(0)
    labels = labels[:tokenizer.max_length]
    
    return labels


In [435]:
def extract_entities(model, tokenizer, text):
    """Extract phone numbers and emails from text"""
    model.eval()
    
    tokens, actual_length = tokenizer.encode(text)
    tokens_tensor = torch.tensor(tokens).unsqueeze(0)
    attention_mask = torch.tensor([1] * actual_length + [0] * (len(tokens) - actual_length)).unsqueeze(0)
    
    with torch.no_grad():
        outputs = model(tokens_tensor, attention_mask)
        predictions = torch.argmax(outputs, dim=-1).squeeze().tolist()

    # Decode predictions
    label_names = ['O', 'B-PHONE', 'I-PHONE', 'B-EMAIL', 'I-EMAIL']
    predicted_labels = [label_names[pred] for pred in predictions[:actual_length]]
    
    # Get actual tokens
    actual_tokens = tokenizer.decode(tokens)[:actual_length]
    
    print(f"Actual - Tokens: {actual_tokens}")
    print(f"Predicted - Labels: {predicted_labels}")
    
    # Extract entities with better reconstruction
    entities = []
    current_entity = []
    current_type = None
    
    for token, label in zip(actual_tokens, predicted_labels):
        if label.startswith('B-'):
            # Save previous entity if exists
            if current_entity:
                entity_text = ''.join(current_entity).replace(' ', '')
                entities.append((current_type, entity_text))
            
            # Start new entity
            current_entity = [token] if token != ' ' else []
            current_type = label[2:]  # Remove 'B-'
            
        elif label.startswith('I-') and current_type and current_type == label[2:]:
            if token != ' ':  # Don't add spaces to entity
                current_entity.append(token)
        else:
            # End current entity
            if current_entity:
                entity_text = ''.join(current_entity).replace(' ', '')
                entities.append((current_type, entity_text))
            current_entity = []
            current_type = None
    
    # Don't forget the last entity
    if current_entity:
        entity_text = ''.join(current_entity).replace(' ', '')
        entities.append((current_type, entity_text))
    
    return entities, list(zip(actual_tokens, predicted_labels))

In [436]:
class SimpleTransformerExtractor(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=8, num_layers=2, num_classes=5, max_seq_len=64):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=256,
            dropout=0.1,
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Add a final layer norm and dropout
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(d_model, num_classes)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, 0, 0.1)
                if module.padding_idx is not None:
                    module.weight.data[module.padding_idx].zero_()
    
    def forward(self, x, attention_mask=None):
        # x shape: (batch_size, seq_len)
        batch_size, seq_len = x.shape
        
        # Create padding mask for transformer
        if attention_mask is None:
            padding_mask = (x == 0)  # True for padding tokens
        else:
            padding_mask = ~attention_mask.bool()
        
        # Embedding with proper scaling
        x = self.embedding(x) * (self.d_model ** 0.5)
        x = self.pos_encoding(x)
        
        # Apply transformer
        x = self.transformer(x, src_key_padding_mask=padding_mask)
        
        # Final processing
        x = self.layer_norm(x)
        x = self.dropout(x)
        x = self.classifier(x)
        
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           (-np.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class ExtractionDataset(Dataset):
    def __init__(self, samples, tokenizer):
        self.samples = samples
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        text, entity, entity_type = self.samples[idx]
        
        tokens, actual_length = self.tokenizer.encode(text)
        labels = create_labels_for_text(text, entity, entity_type, self.tokenizer)
        
        return {
            'tokens': torch.tensor(tokens, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long),
            'attention_mask': torch.tensor([1] * actual_length + [0] * (len(tokens) - actual_length), dtype=torch.bool)
        }

def collate_fn(batch):
    tokens = torch.stack([item['tokens'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    
    return tokens, labels, attention_mask

def train_model(num_samples=2000, epochs=15):
    print("Generating training data...")
    
    # Generate data
    generator = DataGenerator()
    samples = [generator.generate_sample() for _ in range(num_samples)]
    
    # Build tokenizer
    tokenizer = SimpleTokenizer(max_length=64)
    texts = [sample[0] for sample in samples]
    tokenizer.build_vocab(texts)
    
    # Split data
    split_idx = int(0.8 * len(samples))
    train_samples = samples[:split_idx]
    val_samples = samples[split_idx:]
    
    print(f"Training samples: {len(train_samples)}, Validation samples: {len(val_samples)}")

    # Create datasets
    train_dataset = ExtractionDataset(train_samples, tokenizer)
    val_dataset = ExtractionDataset(val_samples, tokenizer)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
    
    # Initialize model
    model = SimpleTransformerExtractor(
        vocab_size=tokenizer.vocab_size,
        d_model=128,
        nhead=8,
        num_layers=2,
        num_classes=5,
        max_seq_len=64
    )
    
    # Use class weights to handle imbalanced data
    class_weights = torch.tensor([0.1, 2.0, 2.0, 2.0, 2.0])  # Lower weight for O, higher for entities
    criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-100)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.8)
    
    label_names = ['O', 'B-PHONE', 'I-PHONE', 'B-EMAIL', 'I-EMAIL']
    
    print("Starting training...")
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        # Training
        model.train()
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0
        
        for batch_idx, (tokens, labels, attention_mask) in enumerate(train_loader):
            optimizer.zero_grad()
            
            outputs = model(tokens, attention_mask)
            
            # Only compute loss on non-padded tokens
            active_loss = attention_mask.view(-1) == 1
            active_logits = outputs.view(-1, 5)[active_loss]
            active_labels = labels.view(-1)[active_loss]
            
            loss = criterion(active_logits, active_labels)
            
            if torch.isnan(loss):
                print(f"NaN loss detected at batch {batch_idx}")
                continue
                
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            total_loss += loss.item()
            
            # Calculate accuracy
            predictions = torch.argmax(active_logits, dim=-1)
            correct_predictions += (predictions == active_labels).sum().item()
            total_predictions += active_labels.size(0)
        
        scheduler.step()
        
        avg_train_loss = total_loss / len(train_loader)
        train_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for tokens, labels, attention_mask in val_loader:
                outputs = model(tokens, attention_mask)
                
                active_loss = attention_mask.view(-1) == 1
                active_logits = outputs.view(-1, 5)[active_loss]
                active_labels = labels.view(-1)[active_loss]
                
                loss = criterion(active_logits, active_labels)
                val_loss += loss.item()
                
                predictions = torch.argmax(active_logits, dim=-1)
                val_correct += (predictions == active_labels).sum().item()
                val_total += active_labels.size(0)
        
        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = val_correct / val_total if val_total > 0 else 0
        
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}")
        print(f"  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"  New best validation loss!")
    
    return model, tokenizer




In [437]:
# Train the transformer
print("Training transformer model for phone/email extraction...")

# Train the model
model, tokenizer = train_model(num_samples=3000, epochs=5)

Training transformer model for phone/email extraction...
Generating training data...
Built vocabulary with 55 tokens
Sample vocab: ['<PAD>', '<UNK>', '<NUMBER>', '<LETTER>', '<SPECIAL>', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '@', '(', ')', ':', '/', '+', ' ', '(<NUMBER>', '-<NUMBER>', '.<LETTER>', '.<NUMBER>', '@<LETTER>']
Training samples: 2400, Validation samples: 600
Starting training...
Epoch 1/5
  Train Loss: 0.2534, Train Acc: 0.7292
  Val Loss: 0.0692, Val Acc: 0.9036
  New best validation loss!
Epoch 2/5
  Train Loss: 0.0503, Train Acc: 0.9160
  Val Loss: 0.0172, Val Acc: 0.9422
  New best validation loss!
Epoch 3/5
  Train Loss: 0.0153, Train Acc: 0.9738
  Val Loss: 0.0023, Val Acc: 0.9966
  New best validation loss!
Epoch 4/5
  Train Loss: 0.0073, Train Acc: 0.9893
  Val Loss: 0.0033, Val Acc: 0.9983
Epoch 5/5
  Train Loss: 0.

In [438]:
# Test on examples
test_texts = [
    "call me at (555) 123-4567",
    "email me at john@gmail.com",
    "here is my cell (343) 232 554",
    "email me at j123@gmail.com",    
    "cell number 222 343-234",
    "contact info: (121) 334-5456",
    "contact info: mike@text.ai"
]

print("\n" + "="*50)
print("Testing extraction:")
print("="*50)

for text in test_texts:
    print(f"\nOriginal text: {text}")
    
    # Show tokenization
    tokens = tokenizer.tokenize(text)
    print(f"Tokenized: {tokens}")
    
    # Check vocabulary coverage
    unknown_tokens = [t for t in tokens if tokenizer.word_to_idx.get(t) == 1]
    if unknown_tokens:
        print(f"Unknown tokens: {unknown_tokens}")
    
    entities, token_labels = extract_entities(model, tokenizer, text)
    extracted = dict()
    for entity in entities:
        name, value = entity
        value = str(value)
        for symbol in ['<LETTER>', '<NUMBER>']:
            value = value.replace(symbol, '')            
        extracted[name] = value
    print(f"Extracted entities: {extracted}")
    print("-" * 30)


Testing extraction:

Original text: call me at (555) 123-4567
Tokenized: ['<LETTER>', 'c', '<LETTER>', 'a', '<LETTER>', 'l', '<LETTER>', 'l', ' ', '<LETTER>', 'm', '<LETTER>', 'e', ' ', '<LETTER>', 'a', '<LETTER>', 't', ' ', '(<NUMBER>', '5', '<NUMBER>', '5', '<NUMBER>', '5', ')', ' ', '<NUMBER>', '1', '<NUMBER>', '2', '<NUMBER>', '3', '-<NUMBER>', '4', '<NUMBER>', '5', '<NUMBER>', '6', '<NUMBER>', '7']
Actual - Tokens: ['<LETTER>', 'c', '<LETTER>', 'a', '<LETTER>', 'l', '<LETTER>', 'l', ' ', '<LETTER>', 'm', '<LETTER>', 'e', ' ', '<LETTER>', 'a', '<LETTER>', 't', ' ', '(<NUMBER>', '5', '<NUMBER>', '5', '<NUMBER>', '5', ')', ' ', '<NUMBER>', '1', '<NUMBER>', '2', '<NUMBER>', '3', '-<NUMBER>', '4', '<NUMBER>', '5', '<NUMBER>', '6', '<NUMBER>', '7']
Predicted - Labels: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-PHONE', 'I-PHONE', 'I-PHONE', 'I-PHONE', 'I-PHONE', 'I-PHONE', 'I-PHONE', 'I-PHONE', 'I-PHONE', 'I-PHONE', 'I-PHONE', 'I-P