## Tokenizer

### Byte Pair Encoding

In [1]:
import re
import collections
import pickle
from typing import Dict, List, Tuple, Set, Optional
import numpy as np

class BPETokenizer:
    """
    A Byte Pair Encoding tokenizer implementation from scratch.
    Suitable for code-mixed text like Tamil-English.
    Enhanced with padding and model preparation capabilities.
    """
    def __init__(self, vocab_size: int = 10000, max_length: int = 128):
        """
        Initialize the BPE tokenizer.

        Args:
            vocab_size: Target vocabulary size (number of merge operations + initial characters)
            max_length: Default maximum sequence length for padding
        """
        self.vocab_size = vocab_size
        self.max_length = max_length
        self.word_freqs = collections.defaultdict(int)
        self.vocab = {}
        self.merges = []
        self.special_tokens = {
            "<PAD>": 0,
            "<UNK>": 1,
            "<BOS>": 2,
            "<EOS>": 3,
            "<SEP>": 4
        }
        # Initialize vocab with special tokens
        self.vocab = {token: idx for token, idx in self.special_tokens.items()}

    def _preprocess_text(self, text: str) -> List[str]:
        """
        Preprocesses the text:
        1. Lowercase the text
        2. Tokenize into words
        3. Add spaces before and after each word

        Args:
            text: Input text

        Returns:
            List of preprocessed words
        """
        # You may need to adjust this for Tamil-specific preprocessing
        text = text.lower()

        # Split by whitespace and punctuation
        # This regex preserves punctuation as separate tokens
        words = re.findall(r'\b\w+\b|[^\w\s]', text)

        # Add word boundary markers
        words = [f"▁{word}" for word in words]

        return words

    def _get_character_level_tokens(self, word: str) -> List[str]:
        """
        Split a word into character-level tokens.

        Args:
            word: Input word

        Returns:
            List of characters
        """
        return list(word)

    def _get_stats(self, words: List[List[str]]) -> Dict[Tuple[str, str], int]:
        """
        Count frequency of token pairs across all words.

        Args:
            words: List of words, where each word is a list of tokens

        Returns:
            Dictionary mapping token pairs to their frequencies
        """
        pairs = collections.defaultdict(int)
        for word in words:
            for i in range(len(word) - 1):
                pair = (word[i], word[i + 1])
                pairs[pair] += 1
        return pairs

    def _merge_pair(self, words: List[List[str]], pair: Tuple[str, str]) -> List[List[str]]:
        """
        Merge all occurrences of a token pair in the vocabulary.

        Args:
            words: List of words, where each word is a list of tokens
            pair: The pair of tokens to merge

        Returns:
            Updated list of words with the specified pair merged
        """
        first, second = pair
        new_words = []

        for word in words:
            i = 0
            new_word = []
            while i < len(word):
                if i < len(word) - 1 and word[i] == first and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_words.append(new_word)

        return new_words

    def train(self, corpus: List[str]):
        """
        Train the BPE tokenizer on a corpus.

        Args:
            corpus: List of text examples
        """
        # Preprocess the corpus
        all_words = []
        for text in corpus:
            all_words.extend(self._preprocess_text(text))

        # Count word frequencies
        for word in all_words:
            self.word_freqs[word] += 1

        # Split words into characters
        words = [self._get_character_level_tokens(word) for word in self.word_freqs.keys()]

        # Add all characters to vocabulary
        unique_chars = set()
        for word in words:
            unique_chars.update(word)

        # Add characters to vocabulary
        for char in sorted(unique_chars):
            if char not in self.vocab:
                self.vocab[char] = len(self.vocab)

        # Start with the base vocabulary size (special tokens + unique characters)
        base_vocab_size = len(self.vocab)
        num_merges = self.vocab_size - base_vocab_size
        # print(words)
        # Perform merge operations
        for i in range(num_merges):
            # Get pair statistics
            pairs = self._get_stats(words)
            if not pairs:
                break

            # Find the most frequent pair
            best_pair = max(pairs, key=pairs.get)

            # Merge the pair in all words
            words = self._merge_pair(words, best_pair)

            # Add the merged token to the vocabulary
            merged_token = best_pair[0] + best_pair[1]
            if merged_token not in self.vocab:
                self.vocab[merged_token] = len(self.vocab)

            # Add the merge operation to the list
            self.merges.append(best_pair)

            # Print progress
            if (i + 1) % 100 == 0:
                print(f"Merge operation {i+1}/{num_merges}, vocab size: {len(self.vocab)}")

        # print("Merges")
        # print(self.merges)
        # print()
        # Finalize vocabulary: create reverse mapping
        self.id_to_token = {idx: token for token, idx in self.vocab.items()}

        print(f"BPE training complete. Final vocabulary size: {len(self.vocab)}")

    def tokenize(self, text: str) -> List[int]:
        """
        Tokenize a text string into token IDs.

        Args:
            text: Input text

        Returns:
            List of token IDs
        """
        # Preprocess the text
        words = self._preprocess_text(text)
        token_ids = []

        for word in words:
            # Start with characters
            current_tokens = self._get_character_level_tokens(word)

            # Apply merges in the same order as during training
            for pair in self.merges:
                current_tokens = self._merge_token_list(current_tokens, pair)
            # print(f"current tokens: {word}:: {current_tokens}")
            # Convert tokens to IDs
            for token in current_tokens:
                if token in self.vocab:
                    token_ids.append(self.vocab[token])
                else:
                    # Handle unknown tokens
                    token_ids.append(self.special_tokens["<UNK>"])

        return token_ids

    def _merge_token_list(self, tokens: List[str], pair: Tuple[str, str]) -> List[str]:
        """
        Apply a single merge operation to a list of tokens.

        Args:
            tokens: List of tokens
            pair: The pair of tokens to merge

        Returns:
            Updated list of tokens
        """
        first, second = pair
        i = 0
        result = []

        while i < len(tokens):
            if i < len(tokens) - 1 and tokens[i] == first and tokens[i + 1] == second:
                result.append(first + second)
                i += 2
            else:
                result.append(tokens[i])
                i += 1

        return result

    def decode(self, token_ids: List[int]) -> str:
        """
        Convert token IDs back to text.

        Args:
            token_ids: List of token IDs

        Returns:
            Decoded text
        """
        # Filter out padding tokens
        token_ids = [idx for idx in token_ids if idx != self.special_tokens["<PAD>"]]

        # Remove special tokens from beginning and end if present
        if token_ids and token_ids[0] == self.special_tokens["<BOS>"]:
            token_ids = token_ids[1:]
        if token_ids and token_ids[-1] == self.special_tokens["<EOS>"]:
            token_ids = token_ids[:-1]

        tokens = [self.id_to_token.get(idx, "<UNK>") for idx in token_ids]

        # Join tokens and remove whitespace marker
        text = ''.join(tokens)
        text = text.replace('▁', ' ').strip()

        return text

    def pad_sequences(self, token_ids_list: List[List[int]], max_length: int = None,
                      padding: str = 'post', truncating: str = 'post') -> List[List[int]]:
        """
        Pad sequences to the same length.

        Args:
            token_ids_list: List of token ID sequences
            max_length: Maximum length to pad to (default: self.max_length or longest sequence)
            padding: 'pre' or 'post' (where to add padding)
            truncating: 'pre' or 'post' (where to truncate if needed)

        Returns:
            List of padded sequences
        """
        # Find max length if not specified
        if max_length is None:
            max_length = self.max_length or max(len(seq) for seq in token_ids_list)

        padded_sequences = []
        for seq in token_ids_list:
            # Truncate if necessary
            if len(seq) > max_length:
                if truncating == 'pre':
                    seq = seq[-max_length:]
                else:  # truncating == 'post'
                    seq = seq[:max_length]

            # Calculate padding
            pad_length = max_length - len(seq)

            # Add padding
            if padding == 'pre':
                padded_seq = [self.special_tokens['<PAD>']] * pad_length + seq
            else:  # padding == 'post'
                padded_seq = seq + [self.special_tokens['<PAD>']] * pad_length

            padded_sequences.append(padded_seq)

        return padded_sequences

    def create_attention_mask(self, padded_sequences: List[List[int]]) -> List[List[int]]:
        """
        Create attention masks for padded sequences (1 for real tokens, 0 for padding).

        Args:
            padded_sequences: List of padded token ID sequences

        Returns:
            List of attention masks
        """
        masks = []
        for seq in padded_sequences:
            mask = [1 if token_id != self.special_tokens['<PAD>'] else 0 for token_id in seq]
            masks.append(mask)
        return masks

    def encode_for_model(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """
        Tokenize text and add special tokens for model input.

        Args:
            text: Input text
            add_special_tokens: Whether to add <BOS> and <EOS> tokens

        Returns:
            List of token IDs ready for model input
        """
        token_ids = self.tokenize(text)

        if add_special_tokens:
            token_ids = [self.special_tokens['<BOS>']] + token_ids + [self.special_tokens['<EOS>']]

        return token_ids

    def prepare_model_inputs(self, texts: List[str], add_special_tokens: bool = True,
                            max_length: int = None, return_attention_mask: bool = True):
        """
        Prepare inputs ready for model training or inference.

        Args:
            texts: List of input texts
            add_special_tokens: Whether to add <BOS> and <EOS> tokens
            max_length: Maximum sequence length (will pad/truncate to this length)
            return_attention_mask: Whether to return attention masks

        Returns:
            Dictionary of model inputs
        """
        # Tokenize all texts
        all_token_ids = [self.encode_for_model(text, add_special_tokens) for text in texts]

        # Use default max_length if not specified
        if max_length is None:
            max_length = self.max_length

        # Pad sequences
        padded_sequences = self.pad_sequences(all_token_ids, max_length=max_length)

        # Prepare outputs
        model_inputs = {
            'input_ids': padded_sequences
        }

        if return_attention_mask:
            attention_masks = self.create_attention_mask(padded_sequences)
            model_inputs['attention_mask'] = attention_masks

        return model_inputs

    def batch_encode(self, texts: List[str], batch_size: int = 32, add_special_tokens: bool = True,
                    max_length: int = None, return_attention_mask: bool = True):
        """
        Encode texts and group them into batches ready for model input.

        Args:
            texts: List of input texts
            batch_size: Size of each batch
            add_special_tokens: Whether to add <BOS> and <EOS> tokens
            max_length: Maximum sequence length
            return_attention_mask: Whether to return attention masks

        Returns:
            List of batched inputs (each a dictionary with 'input_ids' and optionally 'attention_mask')
        """
        # Prepare all inputs
        all_inputs = self.prepare_model_inputs(
            texts,
            add_special_tokens=add_special_tokens,
            max_length=max_length,
            return_attention_mask=return_attention_mask
        )

        # Create batches
        batches = []
        num_samples = len(texts)

        for i in range(0, num_samples, batch_size):
            batch_indices = slice(i, min(i + batch_size, num_samples))
            batch = {
                'input_ids': all_inputs['input_ids'][batch_indices]
            }

            if return_attention_mask:
                batch['attention_mask'] = all_inputs['attention_mask'][batch_indices]

            batches.append(batch)

        return batches

    def save(self, path: str):
        """
        Save the tokenizer to a file.

        Args:
            path: Path to save the tokenizer
        """
        with open(path, 'wb') as f:
            pickle.dump({
                'vocab': self.vocab,
                'merges': self.merges,
                'word_freqs': self.word_freqs,
                'special_tokens': self.special_tokens,
                'vocab_size': self.vocab_size,
                'max_length': self.max_length,
                'id_to_token': self.id_to_token if hasattr(self, 'id_to_token') else None
            }, f)

    @classmethod
    def load(cls, path: str):
        """
        Load a tokenizer from a file.

        Args:
            path: Path to the saved tokenizer

        Returns:
            Loaded BPETokenizer instance
        """
        with open(path, 'rb') as f:
            data = pickle.load(f)

        tokenizer = cls(vocab_size=data['vocab_size'], max_length=data.get('max_length', 128))
        tokenizer.vocab = data['vocab']
        tokenizer.merges = data['merges']
        tokenizer.word_freqs = data['word_freqs']
        tokenizer.special_tokens = data['special_tokens']
        if data.get('id_to_token'):
            tokenizer.id_to_token = data['id_to_token']
        else:
            tokenizer.id_to_token = {idx: token for token, idx in tokenizer.vocab.items()}

        return tokenizer



### Wordpiece Tokenizer

In [2]:
import re
import collections
import pickle
from typing import Dict, List, Tuple, Set, Optional
import numpy as np

class WordPieceTokenizer:
    def __init__(self, vocab_size: int = 10000, max_length: int = 128):
        self.vocab_size = vocab_size
        self.max_length = max_length
        self.word_freqs = collections.defaultdict(int)
        self.vocab = {}
        self.merges = []
        self.special_tokens = {
            "<PAD>": 0,
            "<UNK>": 1,
            "<BOS>": 2,
            "<EOS>": 3,
            "<SEP>": 4,
            "<MASK>": 5
        }

        # Initialize vocab with special tokens
        self.vocab = {token: idx for token, idx in self.special_tokens.items()}
        for i in range(33, 65):
            char = chr(i)
            if char not in self.vocab:
                self.vocab[char] = len(self.vocab)

        for i in range(91, 97):
            char = chr(i)
            if char not in self.vocab:
                self.vocab[char] = len(self.vocab)

    def _preprocess_text(self, text: str) -> List[str]:
        text = text.lower()
        words = re.findall(r'\b\w+\b|[^\w\s]', text)
        words = [f"▁{word}" for word in words]
        return words

    def train(self, corpus: List[str]):
        """
        Train the WordPiece tokenizer on a corpus with strict vocabulary size control.
        """
        # Preprocess the corpus
        all_words = []
        for text in corpus:
            all_words.extend(self._preprocess_text(text))

        # Count word frequencies
        for word in all_words:
            self.word_freqs[word] += 1

        # Generate initial token candidates
        token_candidates = {}

        # Start with single characters and most frequent whole words
        for word in self.word_freqs.keys():
            # Add characters
            for char in set(word):
                if char not in token_candidates:
                    token_candidates[char] = self.word_freqs.get(word, 0)

            # Generate subwords up to 5 characters long
            for i in range(len(word)):
                for j in range(i+1, min(i+6, len(word)+1)):
                    subword = word[i:j]
                    if len(subword) > 1:
                        token_candidates[subword] = token_candidates.get(subword, 0) + self.word_freqs.get(word, 0)

        # Sort token candidates by frequency
        sorted_candidates = sorted(token_candidates.items(), key=lambda x: x[1], reverse=True)

        # Limit vocabulary size, accounting for special tokens
        max_vocab_tokens = self.vocab_size - len(self.special_tokens)
        # print(sorted_candidates)
        # Add top candidates to vocabulary
        for token, freq in sorted_candidates[:max_vocab_tokens]:
            if token not in self.vocab:
                self.vocab[token] = len(self.vocab)
            if len(self.vocab) == self.vocab_size:
                break

        # Finalize vocabulary: create reverse mapping
        self.id_to_token = {idx: token for token, idx in self.vocab.items()}

        print(f"WordPiece training complete. Final vocabulary size: {len(self.vocab)}")

    def tokenize(self, text: str) -> List[int]:
        """
        Tokenize a text string into token IDs using WordPiece algorithm.
        """
        # Preprocess the text
        words = self._preprocess_text(text)
        token_ids = []

        for word in words:
            # WordPiece tokenization
            final_tokens = []

            # Start with the full word
            while word:
                # Find the longest possible subword
                found_match = False
                for length in range(len(word), 0, -1):
                    subword = word[:length]
                    if subword in self.vocab:
                        final_tokens.append(subword)
                        word = word[length:]
                        found_match = True
                        break

                # If no match found, use the first character as UNK
                if not found_match:
                    final_tokens.append('<UNK>')
                    word = word[1:]

            # Convert tokens to IDs
            for token in final_tokens:
                if token in self.vocab:
                    token_ids.append(self.vocab[token])
                else:
                    # Handle unknown tokens
                    token_ids.append(self.special_tokens["<UNK>"])

        return token_ids


    def decode(self, token_ids: List[int]) -> str:
        """
        Convert token IDs back to text.
        Args:
            token_ids: List of token IDs
        Returns:
            Decoded text
        """
        # Filter out padding tokens
        token_ids = [idx for idx in token_ids if idx != self.special_tokens["<PAD>"]]

        # Remove special tokens from beginning and end if present
        if token_ids and token_ids[0] == self.special_tokens["<BOS>"]:
            token_ids = token_ids[1:]
        if token_ids and token_ids[-1] == self.special_tokens["<EOS>"]:
            token_ids = token_ids[:-1]

        tokens = [self.id_to_token.get(idx, "<UNK>") for idx in token_ids]

        # Join tokens and remove whitespace marker
        text = ''.join(tokens)
        text = text.replace('▁', '').strip()
        return text

    def pad_sequences(self, token_ids_list: List[List[int]], max_length: int = None,
                      padding: str = 'post', truncating: str = 'post') -> List[List[int]]:
        """
        Pad sequences to the same length.

        Args:
            token_ids_list: List of token ID sequences
            max_length: Maximum length to pad to (default: self.max_length or longest sequence)
            padding: 'pre' or 'post' (where to add padding)
            truncating: 'pre' or 'post' (where to truncate if needed)

        Returns:
            List of padded sequences
        """
        # Find max length if not specified
        if max_length is None:
            max_length = self.max_length or max(len(seq) for seq in token_ids_list)

        padded_sequences = []
        for seq in token_ids_list:
            # Truncate if necessary
            if len(seq) > max_length:
                if truncating == 'pre':
                    seq = seq[-max_length:]
                else:  # truncating == 'post'
                    seq = seq[:max_length]

            # Calculate padding
            pad_length = max_length - len(seq)

            # Add padding
            if padding == 'pre':
                padded_seq = [self.special_tokens['<PAD>']] * pad_length + seq
            else:  # padding == 'post'
                padded_seq = seq + [self.special_tokens['<PAD>']] * pad_length

            padded_sequences.append(padded_seq)

        return padded_sequences

    def create_attention_mask(self, padded_sequences: List[List[int]]) -> List[List[int]]:
        """
        Create attention masks for padded sequences (1 for real tokens, 0 for padding).

        Args:
            padded_sequences: List of padded token ID sequences

        Returns:
            List of attention masks
        """
        masks = []
        for seq in padded_sequences:
            mask = [1 if token_id != self.special_tokens['<PAD>'] else 0 for token_id in seq]
            masks.append(mask)
        return masks

    def encode_for_model(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """
        Tokenize text and add special tokens for model input.

        Args:
            text: Input text
            add_special_tokens: Whether to add <BOS> and <EOS> tokens

        Returns:
            List of token IDs ready for model input
        """
        token_ids = self.tokenize(text)

        if add_special_tokens:
            token_ids = [self.special_tokens['<BOS>']] + token_ids + [self.special_tokens['<EOS>']]

        return token_ids

    def prepare_model_inputs(self, texts: List[str], add_special_tokens: bool = True,
                            max_length: int = None, return_attention_mask: bool = True):
        """
        Prepare inputs ready for model training or inference.

        Args:
            texts: List of input texts
            add_special_tokens: Whether to add <BOS> and <EOS> tokens
            max_length: Maximum sequence length (will pad/truncate to this length)
            return_attention_mask: Whether to return attention masks

        Returns:
            Dictionary of model inputs
        """
        # Tokenize all texts
        all_token_ids = [self.encode_for_model(text, add_special_tokens) for text in texts]

        # Use default max_length if not specified
        if max_length is None:
            max_length = self.max_length

        # Pad sequences
        padded_sequences = self.pad_sequences(all_token_ids, max_length=max_length)

        # Prepare outputs
        model_inputs = {
            'input_ids': padded_sequences
        }

        if return_attention_mask:
            attention_masks = self.create_attention_mask(padded_sequences)
            model_inputs['attention_mask'] = attention_masks

        return model_inputs

    def save(self, path: str):
        """
        Save the tokenizer to a file.
        Args:
            path: Path to save the tokenizer
        """
        with open(path, 'wb') as f:
            pickle.dump({
                'vocab': self.vocab,
                'word_freqs': self.word_freqs,
                'special_tokens': self.special_tokens,
                'vocab_size': self.vocab_size,
                'max_length': self.max_length,
                'id_to_token': self.id_to_token
            }, f)

    @classmethod
    def load(cls, path: str):
        """
        Load a tokenizer from a file.
        Args:
            path: Path to the saved tokenizer
        Returns:
            Loaded WordPieceTokenizer instance
        """
        with open(path, 'rb') as f:
            data = pickle.load(f)

        tokenizer = cls(vocab_size=data['vocab_size'], max_length=data.get('max_length', 128))
        tokenizer.vocab = data['vocab']
        tokenizer.word_freqs = data['word_freqs']
        tokenizer.special_tokens = data['special_tokens']
        tokenizer.id_to_token = data['id_to_token']

        return tokenizer

### Sentence piece tokenizer

In [3]:
import re
import collections
import pickle
from typing import Dict, List, Tuple, Set, Optional
import math

class SentencePieceTokenizer:
    """
    A SentencePiece tokenizer implementation from scratch.
    Supports unigram language model tokenization with subword units.
    """
    def __init__(self, vocab_size: int = 10000, max_length: int = 128,
                 max_piece_length: int = 16, alpha: float = 0.1):
        """
        Initialize the SentencePiece tokenizer.

        Args:
            vocab_size: Target vocabulary size
            max_length: Maximum sequence length for padding
            max_piece_length: Maximum length of a subword piece
            alpha: Smoothing parameter for unigram language model
        """
        self.vocab_size = vocab_size
        self.max_length = max_length
        self.max_piece_length = max_piece_length
        self.alpha = alpha

        # Special tokens
        self.special_tokens = {
            "<PAD>": 0,
            "<UNK>": 1,
            "<BOS>": 2,
            "<EOS>": 3,
            "<SEP>": 4
        }

        # Additional special token mappings for printable ASCII
        for i in range(33, 65):
            self.special_tokens[chr(i)] = i - 27
        for i in range(91, 127):
            self.special_tokens[chr(i)] = i - 53

        self.vocab = {token: idx for token, idx in self.special_tokens.items()}
        self.token_frequencies = {}
        self.vocab_scores = {}

    def _preprocess_text(self, text: str) -> str:
        """
        Preprocess input text.

        Args:
            text: Input text
        Returns:
            Preprocessed text
        """
        # Lowercase and normalize
        text = text.lower()
        # Add word boundary markers
        text = '▁' + text.replace(' ', ' ▁')
        return text

    def _enumerate_patterns(self, text: str) -> List[str]:
        """
        Generate all possible subword pieces of a text.

        Args:
            text: Input text
        Returns:
            List of all possible subword pieces
        """
        pieces = []
        for length in range(1, min(len(text) + 1, self.max_piece_length + 1)):
            for start in range(len(text) - length + 1):
                pieces.append(text[start:start+length])
        return pieces

    def _compute_unigram_loss(self, vocab: Dict[str, int]) -> float:
        """
        Compute the unigram language model loss.

        Args:
            vocab: Current vocabulary
        Returns:
            Total loss of the unigram model
        """
        total_loss = 0
        for text, freq in self.token_frequencies.items():
            piece_loss = float('inf')
            for piece_length in range(1, len(text) + 1):
                current_loss = 0
                for start in range(0, len(text), piece_length):
                    end = start + piece_length
                    if end > len(text):
                        break
                    piece = text[start:end]
                    if piece in vocab:
                        current_loss -= math.log(self.vocab_scores.get(piece, 1))
                    else:
                        # Unknown piece penalty
                        current_loss += 10
                piece_loss = min(piece_loss, current_loss)
            total_loss += freq * piece_loss
        return total_loss

    def train(self, corpus: List[str]):
        """
        Train the SentencePiece tokenizer.

        Args:
            corpus: List of text examples
        """
        # Preprocess and count frequencies
        preprocessed_corpus = [self._preprocess_text(text) for text in corpus]

        # Generate initial set of candidates
        candidates = set()
        for text in preprocessed_corpus:
            candidates.update(self._enumerate_patterns(text))

        # Initial frequency estimation
        token_freqs = collections.defaultdict(int)
        for text in preprocessed_corpus:
            for candidate in candidates:
                token_freqs[candidate] += text.count(candidate)

        # Sort candidates by frequency
        sorted_candidates = sorted(token_freqs.items(), key=lambda x: x[1], reverse=True)

        # Initialize vocabulary with most frequent tokens
        vocab = {token: idx for idx, (token, _) in enumerate(sorted_candidates[:self.vocab_size],
                                                             start=len(self.special_tokens))}

        # Iterative refinement with unigram loss minimization
        iterations = 10
        for _ in range(iterations):
            # Update token frequencies and scores
            self.token_frequencies = {k: v for k, v in token_freqs.items() if k in vocab}

            # Compute token scores (log probability)
            total_tokens = sum(self.token_frequencies.values())
            self.vocab_scores = {
                token: math.log((freq + self.alpha) / (total_tokens + self.alpha * len(vocab)))
                for token, freq in self.token_frequencies.items()
            }

            # Prune vocabulary based on loss
            new_vocab = {}
            candidates = sorted(
                [(token, score) for token, score in self.vocab_scores.items()],
                key=lambda x: x[1]
            )

            for token, idx in self.special_tokens.items():
                new_vocab[token] = idx

            # Keep special tokens and top vocabulary
            for token, _ in candidates[:self.vocab_size]:
                if token not in new_vocab:
                    new_vocab[token] = len(new_vocab)
                if len(new_vocab) == self.vocab_size:
                    break

            # Merge special tokens with new vocabulary
            

            vocab = new_vocab

        # Finalize vocabulary
        self.vocab = vocab
        self.id_to_token = {idx: token for token, idx in self.vocab.items()}

        print(f"SentencePiece training complete. Final vocabulary size: {len(self.vocab)}")

    def tokenize(self, text: str) -> List[int]:
        """
        Tokenize text into token IDs.

        Args:
            text: Input text
        Returns:
            List of token IDs
        """
        # Preprocess text
        text = self._preprocess_text(text)

        # Greedy longest match tokenization
        token_ids = []
        while text:
            best_piece = None
            for length in range(min(len(text), self.max_piece_length), 0, -1):
                piece = text[:length]
                if piece in self.vocab:
                    best_piece = piece
                    break

            if best_piece is None:
                # Unknown token
                token_ids.append(self.special_tokens["<UNK>"])
                text = text[1:]
            else:
                token_ids.append(self.vocab[best_piece])
                text = text[len(best_piece):]

        return token_ids

    def decode(self, token_ids: List[int]) -> str:
        """
        Convert token IDs back to text.

        Args:
            token_ids: List of token IDs
        Returns:
            Decoded text
        """
        # Filter out padding tokens
        token_ids = [idx for idx in token_ids if idx != self.special_tokens["<PAD>"]]

        # Remove special tokens from beginning and end if present
        if token_ids and token_ids[0] == self.special_tokens["<BOS>"]:
            token_ids = token_ids[1:]
        if token_ids and token_ids[-1] == self.special_tokens["<EOS>"]:
            token_ids = token_ids[:-1]

        # Convert to tokens
        tokens = [self.id_to_token.get(idx, "<UNK>") for idx in token_ids]

        # Reconstruct text with proper handling of word boundary markers
        text = ''.join(tokens)
        text = text.replace('▁', '').strip()

        return text


    def pad_sequences(self, token_ids_list: List[List[int]], max_length: int = None,
                     padding: str = 'post', truncating: str = 'post') -> List[List[int]]:
        """
        Pad sequences to the same length.

        Args:
            token_ids_list: List of token ID sequences
            max_length: Maximum length to pad to
            padding: 'pre' or 'post'
            truncating: 'pre' or 'post'
        Returns:
            List of padded sequences
        """
        # Find max length if not specified
        if max_length is None:
            max_length = self.max_length or max(len(seq) for seq in token_ids_list)

        padded_sequences = []
        for seq in token_ids_list:
            # Truncate if necessary
            if len(seq) > max_length:
                seq = seq[:max_length] if truncating == 'post' else seq[-max_length:]

            # Pad if necessary
            pad_length = max_length - len(seq)
            if padding == 'pre':
                padded_seq = [self.special_tokens['<PAD>']] * pad_length + seq
            else:  # padding == 'post'
                padded_seq = seq + [self.special_tokens['<PAD>']] * pad_length

            padded_sequences.append(padded_seq)

        return padded_sequences

    def create_attention_mask(self, padded_sequences: List[List[int]]) -> List[List[int]]:
        """
        Create attention masks for padded sequences.

        Args:
            padded_sequences: List of padded token ID sequences
        Returns:
            List of attention masks
        """
        masks = []
        for seq in padded_sequences:
            mask = [1 if token_id != self.special_tokens['<PAD>'] else 0 for token_id in seq]
            masks.append(mask)
        return masks

    def encode_for_model(self, text: str, add_special_tokens: bool = True) -> List[int]:
        """
        Tokenize text and add special tokens for model input.

        Args:
            text: Input text
            add_special_tokens: Whether to add <BOS> and <EOS> tokens
        Returns:
            List of token IDs ready for model input
        """
        token_ids = self.tokenize(text)
        if add_special_tokens:
            token_ids = [self.special_tokens['<BOS>']] + token_ids + [self.special_tokens['<EOS>']]
        return token_ids

    def prepare_model_inputs(self, texts: List[str], add_special_tokens: bool = True,
                            max_length: int = None, return_attention_mask: bool = True):
        """
        Prepare inputs ready for model training or inference.

        Args:
            texts: List of input texts
            add_special_tokens: Whether to add <BOS> and <EOS> tokens
            max_length: Maximum sequence length (will pad/truncate to this length)
            return_attention_mask: Whether to return attention masks

        Returns:
            Dictionary of model inputs
        """
        # Tokenize all texts
        all_token_ids = [self.encode_for_model(text, add_special_tokens) for text in texts]

        # Use default max_length if not specified
        if max_length is None:
            max_length = self.max_length

        # Pad sequences
        padded_sequences = self.pad_sequences(all_token_ids, max_length=max_length)

        # Prepare outputs
        model_inputs = {
            'input_ids': padded_sequences
        }

        if return_attention_mask:
            attention_masks = self.create_attention_mask(padded_sequences)
            model_inputs['attention_mask'] = attention_masks

        return model_inputs

    def save(self, path: str):
        """
        Save the tokenizer to a file.

        Args:
            path: Path to save the tokenizer
        """
        with open(path, 'wb') as f:
            pickle.dump({
                'vocab': self.vocab,
                'token_frequencies': self.token_frequencies,
                'vocab_scores': self.vocab_scores,
                'special_tokens': self.special_tokens,
                'vocab_size': self.vocab_size,
                'max_length': self.max_length,
                'id_to_token': self.id_to_token
            }, f)

    @classmethod
    def load(cls, path: str):
        """
        Load a tokenizer from a file.

        Args:
            path: Path to the saved tokenizer
        Returns:
            Loaded SentencePieceTokenizer instance
        """
        with open(path, 'rb') as f:
            data = pickle.load(f)

        tokenizer = cls(vocab_size=data['vocab_size'], max_length=data.get('max_length', 128))
        tokenizer.vocab = data['vocab']
        tokenizer.token_frequencies = data['token_frequencies']
        tokenizer.vocab_scores = data['vocab_scores']
        tokenizer.special_tokens = data['special_tokens']
        tokenizer.id_to_token = data['id_to_token']

        return tokenizer


## Co-BERT

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler
# from transformers import BertTokenizer
# from datasets import load_dataset
import math
import random
import numpy as np
from tqdm import tqdm


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# # Load a smaller dataset - just a portion of Wikipedia
# print("Loading dataset...")
# wiki_dataset = load_dataset("wikipedia", "20220301.en", split="train[:5%]")
# wiki_dataset = wiki_dataset.remove_columns([col for col in wiki_dataset.column_names if col != "text"])
# print(f"Dataset size: {len(wiki_dataset)} documents")


# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


class BertConfig:
    def __init__(self,
                vocab_size=30522,
                hidden_size=768,
                num_hidden_layers=6,  # Reduced
                num_attention_heads=12,
                intermediate_size=3072,
                hidden_dropout_prob=0.1,
                attention_probs_dropout_prob=0.1,
                max_position_embeddings=512,
                type_vocab_size=2,
                initializer_range=0.02,
                layer_norm_eps=1e-12):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.type_vocab_size = type_vocab_size
        self.initializer_range = initializer_range
        self.layer_norm_eps = layer_norm_eps

# Multi-Head Attention
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                f"Hidden size ({config.hidden_size}) not divisible by number of attention heads ({config.num_attention_heads})")

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Create query, key, value projections
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        """Reshape to separate multiple heads"""
        batch_size, seq_length, _ = x.size()
        x = x.view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size)
        return x.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_length, head_size)

    def forward(self, hidden_states, attention_mask=None):

        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Calculate attention scores
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # Apply attention mask if provided
        if attention_mask is not None:
            # Mask has shape [batch_size, 1, 1, seq_length]
            attention_scores = attention_scores + attention_mask

        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        batch_size, seq_length, _, _ = context_layer.size()
        context_layer = context_layer.view(batch_size, seq_length, self.all_head_size)

        return context_layer

# Output projection after self-attention
class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

# Complete attention block
class BertAttention(nn.Module):
    def __init__(self, config):
        super(BertAttention, self).__init__()
        self.self_attention = BertSelfAttention(config)
        self.output = BertSelfOutput(config)

    def forward(self, hidden_states, attention_mask=None):
        self_outputs = self.self_attention(hidden_states, attention_mask)
        attention_output = self.output(self_outputs, hidden_states)
        return attention_output

# Feed-forward network
class BertIntermediate(nn.Module):
    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.gelu = nn.GELU()

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.gelu(hidden_states)
        return hidden_states

# Output layer after feed-forward
class BertOutput(nn.Module):
    def __init__(self, config):
        super(BertOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

# Complete encoder layer
class BertLayer(nn.Module):
    def __init__(self, config):
        super(BertLayer, self).__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(self, hidden_states, attention_mask=None):
        attention_output = self.attention(hidden_states, attention_mask)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

# Stack of encoder layers
class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask=None):
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask)
        return hidden_states

# Embeddings (token, position, segment)
class BertEmbeddings(nn.Module):
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # Position IDs (0, 1, 2, ..., max_len)
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1))
        )

    def forward(self, input_ids, token_type_ids=None, position_ids=None):
        seq_length = input_ids.size(1)

        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length]

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

# Masked Language Model head
class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super(BertLMPredictionHead, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.GELU()
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.activation(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        hidden_states = self.decoder(hidden_states) + self.bias
        return hidden_states

# Next Sentence Prediction head
class BertNextSentenceHead(nn.Module):
    def __init__(self, config):
        super(BertNextSentenceHead, self).__init__()
        self.dense = nn.Linear(config.hidden_size, 2)  # Binary classification: IsNext or NotNext

    def forward(self, pooled_output):
        return self.dense(pooled_output)

# Pooler layer for sentence-level tasks
class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # Take hidden state of [CLS] token
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

# Complete BERT model
class BertModel(nn.Module):
    def __init__(self, config):
        super(BertModel, self).__init__()
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        # Reshape attention mask for broadcast to attention heads
        # [batch_size, seq_length] -> [batch_size, 1, 1, seq_length]
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Convert mask values: 0 -> -10000, 1 -> 0
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Get embeddings
        embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)

        # Pass through encoder layers
        sequence_output = self.encoder(embedding_output, extended_attention_mask)

        # Apply pooling for sentence representation
        pooled_output = self.pooler(sequence_output)

        return sequence_output, pooled_output

# BERT for pre-training (combines MLM and NSP tasks)
class BertForPreTraining(nn.Module):
    def __init__(self, config):
        super(BertForPreTraining, self).__init__()
        self.bert = BertModel(config)
        self.cls = nn.ModuleDict({
            'predictions': BertLMPredictionHead(config),
            'seq_relationship': BertNextSentenceHead(config)
        })

        # Initialize weights
        self.apply(self._init_weights)

        # Tie input and output embeddings
        self.cls['predictions'].decoder.weight = self.bert.embeddings.word_embeddings.weight

    def _init_weights(self, module):
        """Initialize weights for linear layers and embeddings"""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
        sequence_output, pooled_output = self.bert(
            input_ids,
            token_type_ids,
            attention_mask
        )

        # MLM prediction scores
        prediction_scores = self.cls['predictions'](sequence_output)

        # NSP prediction scores
        seq_relationship_score = self.cls['seq_relationship'](pooled_output)

        # Calculate loss if labels are provided
        outputs = (prediction_scores, seq_relationship_score)

        if masked_lm_labels is not None and next_sentence_label is not None:
            # MLM loss (CrossEntropyLoss)
            mlm_loss = F.cross_entropy(
                prediction_scores.view(-1, prediction_scores.size(-1)),
                masked_lm_labels.view(-1),
                ignore_index=-100  # Ignore padding tokens
            )

            # NSP loss (CrossEntropyLoss)
            nsp_loss = F.cross_entropy(
                seq_relationship_score.view(-1, 2),
                next_sentence_label.view(-1)
            )

            total_loss = mlm_loss + nsp_loss
            outputs = (total_loss,) + outputs

        return outputs

# Dataset class for BERT pre-training
import torch
from torch.utils.data import Dataset
import random

class BertPretrainingDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128, mlm_probability=0.15):
        """
        Initialize the BERT pretraining dataset with a custom BPE tokenizer.

        Args:
            texts (List[str]): List of input texts
            tokenizer (BPETokenizer): Custom BPE tokenizer
            max_length (int): Maximum sequence length
            mlm_probability (float): Probability of masking tokens
        """
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.mlm_probability = mlm_probability

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]

        # Split text into sentences
        sentences = text.split('. ')
        if len(sentences) < 2:
            # print("YES")
            # If there's only one sentence, duplicate it
            sentences.append(sentences[0])

        # Create sentence pairs for NSP (50% are actual next sentences)
        is_next = random.choice([True, False])
        if is_next:
            # Actual next sentence
            first_idx = random.randint(0, len(sentences) - 2)
            second_idx = first_idx + 1
        else:
            # Random sentence (not next)
            first_idx = random.randint(0, len(sentences) - 1)
            second_idx = random.randint(0, len(sentences) - 1)
            while second_idx == first_idx + 1:  # Ensure it's not actually the next sentence
                second_idx = random.randint(0, len(sentences) - 1)

        # Get the sentences
        sentence_a = sentences[first_idx]
        sentence_b = sentences[second_idx]

        # Combine sentences similar to BERT's [CLS] sentence_a [SEP] sentence_b [SEP]
        combined_text = f"{sentence_a} {self.tokenizer.id_to_token[self.tokenizer.special_tokens['<SEP>']]} {sentence_b}"

        # Prepare model inputs using custom BPE tokenizer
        model_inputs = self.tokenizer.prepare_model_inputs(
            [combined_text],
            add_special_tokens=True,
            max_length=self.max_length,
            return_attention_mask=True
        )

        # Convert to PyTorch tensors
        input_ids = torch.tensor(model_inputs['input_ids'][0], dtype=torch.long)
        attention_mask = torch.tensor(model_inputs['attention_mask'][0], dtype=torch.long)

        # Create token type ids (0 for first sentence, 1 for second sentence)
        token_type_ids = torch.zeros_like(input_ids, dtype=torch.long)
        sep_token_id = self.tokenizer.special_tokens['<SEP>']
        sep_indices = torch.where(input_ids == sep_token_id)[0]
        if len(sep_indices) > 0:
            token_type_ids[sep_indices[0]+1:] = 1

        # Create masked LM labels
        mlm_labels = input_ids.clone()

        # Create a mask for special tokens
        special_tokens_mask = torch.zeros_like(input_ids, dtype=torch.bool)
        special_token_ids = set([
            self.tokenizer.special_tokens['<PAD>'],
            self.tokenizer.special_tokens['<UNK>'],
            self.tokenizer.special_tokens['<BOS>'],
            self.tokenizer.special_tokens['<EOS>'],
            self.tokenizer.special_tokens.get('<SEP>', -1)
        ])
        for special_id in special_token_ids:
            special_tokens_mask |= (input_ids == special_id)

        # Get probability mask for tokens to predict (15% of non-special tokens)
        probability_matrix = torch.full(input_ids.shape, self.mlm_probability)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        # Get indices of tokens to mask
        masked_indices = torch.bernoulli(probability_matrix).bool()
        mlm_labels[~masked_indices] = -100  # -100 index will be ignored in loss

        # Mask tokens (80% of selected tokens)
        indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
        input_ids[indices_replaced] = self.tokenizer.vocab["<UNK>"]

        # Replace with random words (10% of selected tokens)
        indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer.vocab), input_ids.shape, dtype=torch.long)
        input_ids[indices_random] = random_words[indices_random]

        # NSP label: 1 for IsNext, 0 for NotNext
        nsp_label = 1 if is_next else 0

        return {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask,
            'masked_lm_labels': mlm_labels,
            'next_sentence_label': torch.tensor(nsp_label, dtype=torch.long)
        }


# Preprocess the dataset
def preprocess_dataset(dataset):
    # Filter out very short documents
    filtered_dataset = [doc for doc in dataset if len(doc.split()) > 50]
    return filtered_dataset

# Training loop
def train():
    model.train()
    epoch_loss = 0

    progress_bar = tqdm(train_dataloader, desc="Training")
    for batch in progress_bar:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(
            input_ids=batch['input_ids'],
            token_type_ids=batch['token_type_ids'],
            attention_mask=batch['attention_mask'],
            masked_lm_labels=batch['masked_lm_labels'],
            next_sentence_label=batch['next_sentence_label']
        )

        loss = outputs[0]

        # Backward pass
        loss.backward()

        # Clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Update weights
        optimizer.step()
        scheduler.step()

        # Update progress bar
        progress_bar.set_postfix({'loss': loss.item()})
        epoch_loss += loss.item()

    return epoch_loss / len(train_dataloader)


Using device: cuda


## Dataset

In [5]:
import re

def clean_sentences(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        text = file.read()

    text = text.replace('\ufeff', '')
    text = text.replace('\n', ' ')
    sentences = re.split(r'(?<=\.)\s+', text)
    sentences = [sentence.strip() for sentence in sentences if sentence.strip()]

    return sentences

def count_unique_words(sentences):
    unique_words = set()

    for sentence in sentences:
        words = re.findall(r'\b\w+\b', sentence.lower())
        unique_words.update(words)

    return len(unique_words)


In [6]:
tanglish_path = '/kaggle/input/code-mix-dataset-for-pretraining/tanglish.txt'
tanglish = clean_sentences(tanglish_path)
print(f"Number of unique words in the corpus: {count_unique_words(tanglish)}")
ct = len(tanglish)
print(f"Total number of sentences: {ct}")
tamil_corpus = tanglish[:int(ct*0.8)]

Number of unique words in the corpus: 5995
Total number of sentences: 1581


In [7]:
kanglish_path = '/kaggle/input/code-mix-dataset-for-pretraining/Kanglish.txt'
kanglish = clean_sentences(kanglish_path)
print(f"Number of unique words in the corpus: {count_unique_words(kanglish)}")
ct = len(kanglish)
print(f"Total number of sentences: {ct}")
kannada_corpus = kanglish[:int(ct*0.8)]

Number of unique words in the corpus: 5867
Total number of sentences: 1722


In [8]:
hinglish_path = '/kaggle/input/code-mix-dataset-for-pretraining/english based hinglish.txt'
hinglish = clean_sentences(hinglish_path)
print(f"Number of unique words in the corpus: {count_unique_words(hinglish)}")
ct = len(hinglish)
print(f"Total number of sentences: {ct}")
hindi_corpus = hinglish[:int(ct*0.8)]

Number of unique words in the corpus: 5846
Total number of sentences: 1736


## BPE instances

In [9]:
# Initialize the custom BPE tokenizer
bpe_tam_tokenizer = BPETokenizer(vocab_size=2500, max_length=128)
bpe_tam_tokenizer.train(tamil_corpus)


Merge operation 100/2448, vocab size: 152
Merge operation 200/2448, vocab size: 252
Merge operation 300/2448, vocab size: 352
Merge operation 400/2448, vocab size: 452
Merge operation 500/2448, vocab size: 552
Merge operation 600/2448, vocab size: 652
Merge operation 700/2448, vocab size: 752
Merge operation 800/2448, vocab size: 852
Merge operation 900/2448, vocab size: 952
Merge operation 1000/2448, vocab size: 1052
Merge operation 1100/2448, vocab size: 1152
Merge operation 1200/2448, vocab size: 1252
Merge operation 1300/2448, vocab size: 1352
Merge operation 1400/2448, vocab size: 1452
Merge operation 1500/2448, vocab size: 1552
Merge operation 1600/2448, vocab size: 1652
Merge operation 1700/2448, vocab size: 1752
Merge operation 1800/2448, vocab size: 1852
Merge operation 1900/2448, vocab size: 1952
Merge operation 2000/2448, vocab size: 2052
Merge operation 2100/2448, vocab size: 2152
Merge operation 2200/2448, vocab size: 2252
Merge operation 2300/2448, vocab size: 2352
Merge 

In [28]:
bert_tam_dataset = BertPretrainingDataset(
    texts=tanglish,
    tokenizer=bpe_tam_tokenizer,
    max_length=128,
    mlm_probability=0.15
)

train_sampler = RandomSampler(bert_tam_dataset)
train_dataloader = DataLoader(
    bert_tam_dataset,
    sampler=train_sampler,
    batch_size=16,  # Reduced batch size for Colab
    num_workers=2
)

# Initialize the BERT model
config = BertConfig()
model = BertForPreTraining(config)
model = model.to(device)

# Print model architecture
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Training parameters
num_epochs = 3   # For demonstration, increase for better results
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)

# Learning rate scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Main training function
print("Starting training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    avg_loss = train()
    print(f"Average loss: {avg_loss:.4f}")

# Save the model
# model_save_path = "bert_pretrained_small.pt"
# torch.save(model.state_dict(), model_save_path)
# print(f"Model saved to {model_save_path}")

print("Training complete!")

Model Parameters: 67579196
Starting training...
Epoch 1/3


Training: 100%|██████████| 99/99 [00:46<00:00,  2.11it/s, loss=8.03]


Average loss: 9.1090
Epoch 2/3


Training: 100%|██████████| 99/99 [00:48<00:00,  2.06it/s, loss=7.55]


Average loss: 7.4071
Epoch 3/3


Training: 100%|██████████| 99/99 [00:46<00:00,  2.12it/s, loss=7.38]

Average loss: 7.2053
Training complete!





In [11]:
bert_tam_dataset[0]

{'input_ids': tensor([   2,  128,   80,   42, 1626,  920,  152,  185, 1627,  126,  318,    1,
         1628, 1181,  770, 1182,  107, 2204, 1183,  597,  240, 1629, 1630, 1631,
          771, 1632,    1,  112,   51,   10,   51,    1,  197,   38,   51,    1,
          128,   80,   42, 1626,  920,  152,    1, 1627,    1,    1, 1180, 1628,
         1181,  770, 1182,  107,  433, 1183,  597,    1, 1629, 1630,    1, 1351,
         1632,    1,  112,   51,   10,    3,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0

In [12]:
# Initialize the custom BPE tokenizer
bpe_kan_tokenizer = BPETokenizer(vocab_size=2500, max_length=128)
bpe_kan_tokenizer.train(kannada_corpus)


Merge operation 100/2447, vocab size: 153
Merge operation 200/2447, vocab size: 253
Merge operation 300/2447, vocab size: 353
Merge operation 400/2447, vocab size: 453
Merge operation 500/2447, vocab size: 553
Merge operation 600/2447, vocab size: 653
Merge operation 700/2447, vocab size: 753
Merge operation 800/2447, vocab size: 853
Merge operation 900/2447, vocab size: 953
Merge operation 1000/2447, vocab size: 1053
Merge operation 1100/2447, vocab size: 1153
Merge operation 1200/2447, vocab size: 1253
Merge operation 1300/2447, vocab size: 1353
Merge operation 1400/2447, vocab size: 1453
Merge operation 1500/2447, vocab size: 1553
Merge operation 1600/2447, vocab size: 1653
Merge operation 1700/2447, vocab size: 1753
Merge operation 1800/2447, vocab size: 1853
Merge operation 1900/2447, vocab size: 1953
Merge operation 2000/2447, vocab size: 2053
Merge operation 2100/2447, vocab size: 2153
Merge operation 2200/2447, vocab size: 2253
Merge operation 2300/2447, vocab size: 2353
Merge 

In [29]:
bert_kan_dataset = BertPretrainingDataset(
    texts=kanglish,
    tokenizer=bpe_kan_tokenizer,
    max_length=128,
    mlm_probability=0.15
)

train_sampler = RandomSampler(bert_kan_dataset)
train_dataloader = DataLoader(
    bert_kan_dataset,
    sampler=train_sampler,
    batch_size=16,  # Reduced batch size for Colab
    num_workers=2
)

# Initialize the BERT model
config = BertConfig()
model = BertForPreTraining(config)
model = model.to(device)

# Print model architecture
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Training parameters
num_epochs = 3  # For demonstration, increase for better results
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)

# Learning rate scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Main training function
print("Starting training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    avg_loss = train()
    print(f"Average loss: {avg_loss:.4f}")

print("Training complete!")

Model Parameters: 67579196
Starting training...
Epoch 1/3


Training: 100%|██████████| 108/108 [00:56<00:00,  1.90it/s, loss=7.91]


Average loss: 9.3310
Epoch 2/3


Training: 100%|██████████| 108/108 [00:55<00:00,  1.94it/s, loss=6.82]


Average loss: 7.5609
Epoch 3/3


Training: 100%|██████████| 108/108 [00:54<00:00,  1.97it/s, loss=7.17]

Average loss: 7.3230
Training complete!





In [14]:
# Initialize the custom BPE tokenizer
bpe_hin_tokenizer = BPETokenizer(vocab_size=2500, max_length=128)
bpe_hin_tokenizer.train(hindi_corpus)


Merge operation 100/2391, vocab size: 209
Merge operation 200/2391, vocab size: 309
Merge operation 300/2391, vocab size: 409
Merge operation 400/2391, vocab size: 509
Merge operation 500/2391, vocab size: 609
Merge operation 600/2391, vocab size: 709
Merge operation 700/2391, vocab size: 809
Merge operation 800/2391, vocab size: 909
Merge operation 900/2391, vocab size: 1009
Merge operation 1000/2391, vocab size: 1109
Merge operation 1100/2391, vocab size: 1209
Merge operation 1200/2391, vocab size: 1309
Merge operation 1300/2391, vocab size: 1409
Merge operation 1400/2391, vocab size: 1509
Merge operation 1500/2391, vocab size: 1609
Merge operation 1600/2391, vocab size: 1709
Merge operation 1700/2391, vocab size: 1809
Merge operation 1800/2391, vocab size: 1909
Merge operation 1900/2391, vocab size: 2009
Merge operation 2000/2391, vocab size: 2109
Merge operation 2100/2391, vocab size: 2209
Merge operation 2200/2391, vocab size: 2309
Merge operation 2300/2391, vocab size: 2409
BPE t

In [30]:
bert_hin_dataset = BertPretrainingDataset(
    texts=hinglish,
    tokenizer=bpe_hin_tokenizer,
    max_length=128,
    mlm_probability=0.15
)

train_sampler = RandomSampler(bert_hin_dataset)
train_dataloader = DataLoader(
    bert_hin_dataset,
    sampler=train_sampler,
    batch_size=16,  # Reduced batch size for Colab
    num_workers=2
)

# Initialize the BERT model
config = BertConfig()
model = BertForPreTraining(config)
model = model.to(device)

# Print model architecture
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Training parameters
num_epochs = 3  # For demonstration, increase for better results
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)

# Learning rate scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Main training function
print("Starting training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    avg_loss = train()
    print(f"Average loss: {avg_loss:.4f}")

print("Training complete!")

Model Parameters: 67579196
Starting training...
Epoch 1/3


Training: 100%|██████████| 109/109 [01:33<00:00,  1.16it/s, loss=7.92]


Average loss: 9.1658
Epoch 2/3


Training: 100%|██████████| 109/109 [01:32<00:00,  1.17it/s, loss=7.22]


Average loss: 7.3683
Epoch 3/3


Training: 100%|██████████| 109/109 [01:36<00:00,  1.13it/s, loss=7.63]

Average loss: 7.1859
Training complete!





## Wordpiece instances

In [16]:
# Initialize the custom word piece tokenizer
wp_tam_tokenizer = WordPieceTokenizer(vocab_size=2500, max_length=128)
wp_tam_tokenizer.train(tamil_corpus)


WordPiece training complete. Final vocabulary size: 2500


In [31]:
bert_tam_dataset = BertPretrainingDataset(
    texts=tanglish,
    tokenizer=wp_tam_tokenizer,
    max_length=128,
    mlm_probability=0.15
)

train_sampler = RandomSampler(bert_tam_dataset)
train_dataloader = DataLoader(
    bert_tam_dataset,
    sampler=train_sampler,
    batch_size=16,  # Reduced batch size for Colab
    num_workers=2
)

# Initialize the BERT model
config = BertConfig()
model = BertForPreTraining(config)
model = model.to(device)

# Print model architecture
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Training parameters
num_epochs = 3  # For demonstration, increase for better results
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)

# Learning rate scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Main training function
print("Starting training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    avg_loss = train()
    print(f"Average loss: {avg_loss:.4f}")

print("Training complete!")

Model Parameters: 67579196
Starting training...
Epoch 1/3


Training: 100%|██████████| 99/99 [00:28<00:00,  3.48it/s, loss=7.97]


Average loss: 9.4637
Epoch 2/3


Training: 100%|██████████| 99/99 [00:27<00:00,  3.54it/s, loss=7.82]


Average loss: 7.6975
Epoch 3/3


Training: 100%|██████████| 99/99 [00:28<00:00,  3.53it/s, loss=7.55]

Average loss: 7.4368
Training complete!





In [18]:
# Initialize the custom word piece tokenizer
wp_kan_tokenizer = WordPieceTokenizer(vocab_size=2500, max_length=128)
wp_kan_tokenizer.train(kannada_corpus)


WordPiece training complete. Final vocabulary size: 2500


In [32]:
bert_kan_dataset = BertPretrainingDataset(
    texts=kanglish,
    tokenizer=wp_kan_tokenizer,
    max_length=128,
    mlm_probability=0.15
)

train_sampler = RandomSampler(bert_kan_dataset)
train_dataloader = DataLoader(
    bert_kan_dataset,
    sampler=train_sampler,
    batch_size=16,  # Reduced batch size for Colab
    num_workers=2
)

# Initialize the BERT model
config = BertConfig()
model = BertForPreTraining(config)
model = model.to(device)

# Print model architecture
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Training parameters
num_epochs = 3  # For demonstration, increase for better results
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)

# Learning rate scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Main training function
print("Starting training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    avg_loss = train()
    print(f"Average loss: {avg_loss:.4f}")

print("Training complete!")

Model Parameters: 67579196
Starting training...
Epoch 1/3


Training: 100%|██████████| 108/108 [00:30<00:00,  3.51it/s, loss=8.21]


Average loss: 9.4098
Epoch 2/3


Training: 100%|██████████| 108/108 [00:30<00:00,  3.53it/s, loss=7.34]


Average loss: 7.6294
Epoch 3/3


Training: 100%|██████████| 108/108 [00:30<00:00,  3.53it/s, loss=6.85]

Average loss: 7.4229
Training complete!





In [20]:
# Initialize the custom word piece tokenizer
wp_hin_tokenizer = WordPieceTokenizer(vocab_size=2500, max_length=128)
wp_hin_tokenizer.train(hindi_corpus)


WordPiece training complete. Final vocabulary size: 2500


In [33]:
bert_hin_dataset = BertPretrainingDataset(
    texts=hinglish,
    tokenizer=wp_hin_tokenizer,
    max_length=128,
    mlm_probability=0.15
)

train_sampler = RandomSampler(bert_hin_dataset)
train_dataloader = DataLoader(
    bert_hin_dataset,
    sampler=train_sampler,
    batch_size=16,  # Reduced batch size for Colab
    num_workers=2
)

# Initialize the BERT model
config = BertConfig()
model = BertForPreTraining(config)
model = model.to(device)

# Print model architecture
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Training parameters
num_epochs = 3  # For demonstration, increase for better results
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)

# Learning rate scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Main training function
print("Starting training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    avg_loss = train()
    print(f"Average loss: {avg_loss:.4f}")

print("Training complete!")

Model Parameters: 67579196
Starting training...
Epoch 1/3


Training: 100%|██████████| 109/109 [00:31<00:00,  3.49it/s, loss=7.5] 


Average loss: 9.1235
Epoch 2/3


Training: 100%|██████████| 109/109 [00:31<00:00,  3.48it/s, loss=7.16]


Average loss: 7.2237
Epoch 3/3


Training: 100%|██████████| 109/109 [00:31<00:00,  3.49it/s, loss=6.64]

Average loss: 7.0627
Training complete!





## Sentence Piece instances

In [22]:
# Initialize the custom word piece tokenizer
sp_tam_tokenizer = SentencePieceTokenizer(vocab_size=2500, max_length=128)
sp_tam_tokenizer.train(tamil_corpus)


SentencePiece training complete. Final vocabulary size: 2500


In [34]:
bert_tam_dataset = BertPretrainingDataset(
    texts=tanglish,
    tokenizer=sp_tam_tokenizer,
    max_length=128,
    mlm_probability=0.15
)

train_sampler = RandomSampler(bert_tam_dataset)
train_dataloader = DataLoader(
    bert_tam_dataset,
    sampler=train_sampler,
    batch_size=16,  # Reduced batch size for Colab
    num_workers=2
)

# Initialize the BERT model
config = BertConfig()
model = BertForPreTraining(config)
model = model.to(device)

# Print model architecture
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Training parameters
num_epochs = 3  # For demonstration, increase for better results
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)

# Learning rate scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Main training function
print("Starting training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    avg_loss = train()
    print(f"Average loss: {avg_loss:.4f}")

print("Training complete!")

Model Parameters: 67579196
Starting training...
Epoch 1/3


Training: 100%|██████████| 99/99 [00:28<00:00,  3.53it/s, loss=8.41]


Average loss: 9.6572
Epoch 2/3


Training: 100%|██████████| 99/99 [00:28<00:00,  3.53it/s, loss=7.74]


Average loss: 7.9721
Epoch 3/3


Training: 100%|██████████| 99/99 [00:28<00:00,  3.53it/s, loss=7.72]

Average loss: 7.7397
Training complete!





In [24]:
# Initialize the custom word piece tokenizer
sp_kan_tokenizer = SentencePieceTokenizer(vocab_size=2500, max_length=128)
sp_kan_tokenizer.train(kannada_corpus)


SentencePiece training complete. Final vocabulary size: 2500


In [35]:
bert_kan_dataset = BertPretrainingDataset(
    texts=kanglish,
    tokenizer=sp_kan_tokenizer,
    max_length=128,
    mlm_probability=0.15
)

train_sampler = RandomSampler(bert_kan_dataset)
train_dataloader = DataLoader(
    bert_kan_dataset,
    sampler=train_sampler,
    batch_size=16,  # Reduced batch size for Colab
    num_workers=2
)

# Initialize the BERT model
config = BertConfig()
model = BertForPreTraining(config)
model = model.to(device)

# Print model architecture
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Training parameters
num_epochs = 3  # For demonstration, increase for better results
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)

# Learning rate scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Main training function
print("Starting training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    avg_loss = train()
    print(f"Average loss: {avg_loss:.4f}")

print("Training complete!")

Model Parameters: 67579196
Starting training...
Epoch 1/3


Training: 100%|██████████| 108/108 [00:30<00:00,  3.53it/s, loss=8.13]


Average loss: 9.5295
Epoch 2/3


Training: 100%|██████████| 108/108 [00:30<00:00,  3.53it/s, loss=7.57]


Average loss: 7.9325
Epoch 3/3


Training: 100%|██████████| 108/108 [00:30<00:00,  3.54it/s, loss=7.6] 

Average loss: 7.7356
Training complete!





In [26]:
# Initialize the custom word piece tokenizer
sp_hin_tokenizer = SentencePieceTokenizer(vocab_size=2500, max_length=128)
sp_hin_tokenizer.train(hindi_corpus)


SentencePiece training complete. Final vocabulary size: 2500


In [36]:
bert_hin_dataset = BertPretrainingDataset(
    texts=hinglish,
    tokenizer=sp_hin_tokenizer,
    max_length=128,
    mlm_probability=0.15
)

train_sampler = RandomSampler(bert_hin_dataset)
train_dataloader = DataLoader(
    bert_hin_dataset,
    sampler=train_sampler,
    batch_size=16,  # Reduced batch size for Colab
    num_workers=2
)

# Initialize the BERT model
config = BertConfig()
model = BertForPreTraining(config)
model = model.to(device)

# Print model architecture
print(f"Model Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01)

# Training parameters
num_epochs = 3  # For demonstration, increase for better results
total_steps = len(train_dataloader) * num_epochs
warmup_steps = int(0.1 * total_steps)

# Learning rate scheduler
from transformers import get_linear_schedule_with_warmup
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Main training function
print("Starting training...")
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    avg_loss = train()
    print(f"Average loss: {avg_loss:.4f}")

print("Training complete!")

Model Parameters: 67579196
Starting training...
Epoch 1/3


Training: 100%|██████████| 109/109 [00:31<00:00,  3.49it/s, loss=7.89]


Average loss: 9.5183
Epoch 2/3


Training: 100%|██████████| 109/109 [00:31<00:00,  3.48it/s, loss=7.49]


Average loss: 7.7196
Epoch 3/3


Training: 100%|██████████| 109/109 [00:31<00:00,  3.48it/s, loss=7.76]

Average loss: 7.5617
Training complete!



