In [None]:
import pandas as pd
import time
import multiprocessing
import subprocess

# Configuration
input_csv_path = "input.csv"  # Update this path as needed
num_threads = multiprocessing.cpu_count()
vocab_size = 50000  # Target vocabulary size
min_frequency = 2   # Minimum frequency for pairs to be considered

print(f'CPU threads available: {num_threads}')
print(f'Target vocabulary size: {vocab_size}')
print(f'Minimum pair frequency: {min_frequency}')

# Load dataset
print('\nLoading dataset...')
start_time = time.time()
df = pd.read_csv(input_csv_path)
print(f'Dataset loaded in {time.time() - start_time:.2f}s')
print(f'Total examples: {len(df)}')

# Write corpus to text file for C++ processing
print('Writing corpus to file...')
with open('full_corpus.txt', 'w', encoding='utf-8') as f:
    for text in df['assistant']:
        # Escape newlines in the text itself
        f.write(str(text).replace('\n', '\\n') + '\n')

print(f'Wrote {len(df)} documents to full_corpus.txt')
del df  # Free memory


In [None]:
%%writefile complete_bpe.cpp
#include <bits/stdc++.h>
#include <thread>
#include <mutex>
#include <atomic>
using namespace std;

struct PairCount {
    long long pair;
    long long count;
    int first_token, second_token;
};

class CompleteBPE {
private:
    vector<vector<int>> docs;
    int num_threads;
    int vocab_size;
    int min_frequency;
    int current_vocab_id;
    mutex mtx;
    
    // Vocabulary tracking
    map<int, string> vocab; // token_id -> string representation
    map<int, pair<int, int>> merges; // token_id -> (first_token, second_token) for merged tokens
    
public:
    CompleteBPE(const string& corpus_file, int threads, int target_vocab, int min_freq) 
        : num_threads(threads), vocab_size(target_vocab), min_frequency(min_freq), current_vocab_id(256) {
        initializeVocab();
        loadCorpus(corpus_file);
    }
    
    void initializeVocab() {
        // Initialize byte-level vocabulary (0-255)
        for(int i = 0; i < 256; ++i) {
            if(i == 0) {
                vocab[i] = "<NULL>";
            } else if(i == 9) {
                vocab[i] = "<TAB>";
            } else if(i == 10) {
                vocab[i] = "<LF>";
            } else if(i == 13) {
                vocab[i] = "<CR>";
            } else if(i == 32) {
                vocab[i] = "<SPACE>";
            } else if(i >= 33 && i <= 126) {
                // Printable ASCII
                vocab[i] = string(1, (char)i);
            } else {
                // Non-printable bytes
                vocab[i] = "<BYTE_" + to_string(i) + ">";
            }
        }
    }
    
    void loadCorpus(const string& corpus_file) {
        cout << "Loading corpus..." << endl;
        auto start = chrono::high_resolution_clock::now();
        
        vector<string> lines;
        ifstream fin(corpus_file);
        string line;
        while(getline(fin, line)) {
            lines.push_back(line);
        }
        fin.close();
        
        int N = lines.size();
        docs.resize(N);
        
        // Convert to token arrays in parallel
        vector<thread> threads;
        auto convert_worker = [&](int tid) {
            int start = tid * N / num_threads;
            int end = (tid + 1) * N / num_threads;
            
            for(int i = start; i < end; ++i) {
                auto& s = lines[i];
                auto& doc = docs[i];
                doc.reserve(s.size());
                for(char c : s) {
                    doc.push_back((unsigned char)c);
                }
            }
        };
        
        for(int t = 0; t < num_threads; ++t) {
            threads.emplace_back(convert_worker, t);
        }
        for(auto& th : threads) {
            th.join();
        }
        
        auto end = chrono::high_resolution_clock::now();
        double load_time = chrono::duration<double>(end - start).count();
        cout << "Corpus loaded: " << N << " documents in " << load_time << "s" << endl;
    }
    
    unordered_map<long long, long long> countPairs() {
        unordered_map<long long, long long> global_counts;
        global_counts.reserve(1000000);
        
        vector<thread> threads;
        auto count_worker = [&](int tid) {
            int N = docs.size();
            int start = tid * N / num_threads;
            int end = (tid + 1) * N / num_threads;
            unordered_map<long long, long long> local_counts;
            
            for(int i = start; i < end; ++i) {
                auto& tokens = docs[i];
                for(int j = 0; j + 1 < tokens.size(); ++j) {
                    long long key = ((long long)tokens[j] << 32) | (unsigned int)tokens[j+1];
                    local_counts[key]++;
                }
            }
            
            lock_guard<mutex> lock(mtx);
            for(auto& kv : local_counts) {
                global_counts[kv.first] += kv.second;
            }
        };
        
        for(int t = 0; t < num_threads; ++t) {
            threads.emplace_back(count_worker, t);
        }
        for(auto& th : threads) {
            th.join();
        }
        
        return global_counts;
    }
    
    vector<PairCount> getBatchMerges(const unordered_map<long long, long long>& counts) {
        // Convert to vector and sort by frequency
        vector<PairCount> candidates;
        for(auto& kv : counts) {
            if(kv.second >= min_frequency) {
                PairCount pc;
                pc.pair = kv.first;
                pc.count = kv.second;
                pc.first_token = (int)(kv.first >> 32);
                pc.second_token = (int)(kv.first & 0xFFFFFFFF);
                candidates.push_back(pc);
            }
        }
        
        sort(candidates.begin(), candidates.end(), [](const PairCount& a, const PairCount& b) {
            return a.count > b.count;
        });
        
        // Greedily select non-conflicting pairs, limiting to 100 per batch
        vector<PairCount> batch;
        unordered_set<int> used_tokens;
        
        for(const auto& candidate : candidates) {
            // Stop if we've reached the batch limit
            if(batch.size() >= 100) break;
            
            // Check for conflicts: xy,yz or xy,zx
            bool conflict = false;
            
            for(const auto& existing : batch) {
                // Check xy,yz pattern (second token of existing = first token of candidate)
                if(existing.second_token == candidate.first_token) {
                    conflict = true;
                    break;
                }
                // Check xy,zx pattern (first token of existing = second token of candidate)
                if(existing.first_token == candidate.second_token) {
                    conflict = true;
                    break;
                }
                // Check zx,xy pattern (second token of candidate = first token of existing)
                if(candidate.second_token == existing.first_token) {
                    conflict = true;
                    break;
                }
                // Check yz,xy pattern (first token of candidate = second token of existing)
                if(candidate.first_token == existing.second_token) {
                    conflict = true;
                    break;
                }
            }
            
            if(!conflict) {
                batch.push_back(candidate);
                used_tokens.insert(candidate.first_token);
                used_tokens.insert(candidate.second_token);
            }
        }
        
        return batch;
    }
    
    void applyMerges(const vector<PairCount>& merges) {
        if(merges.empty()) return;
        
        unordered_map<long long, int> merge_map;
        for(const auto& merge : merges) {
            int new_token_id = current_vocab_id++;
            merge_map[merge.pair] = new_token_id;
            
            // Track the merge in vocabulary
            this->merges[new_token_id] = {merge.first_token, merge.second_token};
            
            // Create string representation for the merged token
            string first_str = vocab[merge.first_token];
            string second_str = vocab[merge.second_token];
            vocab[new_token_id] = first_str + second_str;
        }
        
        vector<thread> threads;
        auto merge_worker = [&](int tid) {
            int N = docs.size();
            int start = tid * N / num_threads;
            int end = (tid + 1) * N / num_threads;
            
            for(int i = start; i < end; ++i) {
                auto& tokens = docs[i];
                vector<int> new_tokens;
                new_tokens.reserve(tokens.size());
                
                for(size_t j = 0; j < tokens.size(); ) {
                    if(j + 1 < tokens.size()) {
                        long long key = ((long long)tokens[j] << 32) | (unsigned int)tokens[j+1];
                        auto it = merge_map.find(key);
                        if(it != merge_map.end()) {
                            new_tokens.push_back(it->second);
                            j += 2;
                            continue;
                        }
                    }
                    new_tokens.push_back(tokens[j]);
                    j++;
                }
                tokens.swap(new_tokens);
            }
        };
        
        for(int t = 0; t < num_threads; ++t) {
            threads.emplace_back(merge_worker, t);
        }
        for(auto& th : threads) {
            th.join();
        }
    }
    
    void runCompleteBPE() {
        cout << "\nStarting complete BPE process..." << endl;
        auto start_total = chrono::high_resolution_clock::now();
        
        int iteration = 0;
        
        while(current_vocab_id < vocab_size) {
            iteration++;
            cout << "\n--- Iteration " << iteration << " ---" << endl;
            cout << "Current vocab size: " << current_vocab_id << endl;
            
            // Count pairs
            auto start_count = chrono::high_resolution_clock::now();
            auto counts = countPairs();
            auto end_count = chrono::high_resolution_clock::now();
            double count_time = chrono::duration<double>(end_count - start_count).count();
            
            cout << "Pair counting: " << count_time << "s (" << counts.size() << " unique pairs)" << endl;
            
            // Get batch merges
            auto start_batch = chrono::high_resolution_clock::now();
            auto batch = getBatchMerges(counts);
            auto end_batch = chrono::high_resolution_clock::now();
            double batch_time = chrono::duration<double>(end_batch - start_batch).count();
            
            if(batch.empty()) {
                cout << "No more pairs to merge (min frequency: " << min_frequency << ")" << endl;
                break;
            }
            
            cout << "Batch selection: " << batch_time << "s (" << batch.size() << " merges)" << endl;
            
            // Apply merges
            auto start_merge = chrono::high_resolution_clock::now();
            applyMerges(batch);
            auto end_merge = chrono::high_resolution_clock::now();
            double merge_time = chrono::duration<double>(end_merge - start_merge).count();
            
            cout << "Merge application: " << merge_time << "s" << endl;
            
            // Show top merges
            cout << "Top merges in this batch:" << endl;
            for(int i = 0; i < min(5, (int)batch.size()); ++i) {
                cout << "  (" << batch[i].first_token << ", " << batch[i].second_token 
                     << ") -> " << (current_vocab_id - batch.size() + i) 
                     << " (count: " << batch[i].count << ")" << endl;
            }
            
            if(current_vocab_id >= vocab_size) {
                cout << "Reached target vocabulary size!" << endl;
                break;
            }
        }
        
        auto end_total = chrono::high_resolution_clock::now();
        double total_time = chrono::duration<double>(end_total - start_total).count();
        
        cout << "\n" << string(50, '=') << endl;
        cout << "BPE COMPLETE!" << endl;
        cout << "Final vocabulary size: " << current_vocab_id << endl;
        cout << "Total iterations: " << iteration << endl;
        cout << "Total time: " << total_time << "s" << endl;
        cout << "Threads used: " << num_threads << endl;
        cout << string(50, '=') << endl;
        
        // Save vocabulary to CSV
        saveVocabulary();
    }
    
    void saveVocabulary() {
        cout << "\nSaving vocabulary to vocab.csv..." << endl;
        auto start = chrono::high_resolution_clock::now();
        
        ofstream fout("vocab.csv");
        fout << "token_id,token_string,is_merged,first_token,second_token\n";
        
        for(int i = 0; i < current_vocab_id; ++i) {
            fout << i << ",\"";
            
            // Escape quotes in the token string
            string token_str = vocab[i];
            for(char c : token_str) {
                if(c == '"') {
                    fout << "\"\"";
                } else {
                    fout << c;
                }
            }
            fout << "\"";
            
            if(merges.count(i)) {
                // This is a merged token
                fout << ",true," << merges[i].first << "," << merges[i].second;
            } else {
                // This is a byte-level token
                fout << ",false,,";
            }
            fout << "\n";
        }
        
        fout.close();
        
        auto end = chrono::high_resolution_clock::now();
        double save_time = chrono::duration<double>(end - start).count();
        
        cout << "Vocabulary saved: " << current_vocab_id << " tokens in " << save_time << "s" << endl;
        cout << "Output file: vocab.csv" << endl;
    }
};

int main(int argc, char** argv) {
    if(argc < 5) {
        cerr << "Usage: ./complete_bpe corpus.txt num_threads vocab_size min_frequency" << endl;
        return 1;
    }
    
    string corpus_file = argv[1];
    int num_threads = stoi(argv[2]);
    int vocab_size = stoi(argv[3]);
    int min_frequency = stoi(argv[4]);
    
    CompleteBPE bpe(corpus_file, num_threads, vocab_size, min_frequency);
    bpe.runCompleteBPE();
    
    return 0;
}


In [None]:
# Compile the complete BPE implementation
print('Compiling complete BPE C++ code...')
start_compile = time.time()

result = subprocess.run([
    'g++', '-O3', '-march=native', '-std=c++17', '-pthread', 
    'complete_bpe.cpp', '-o', 'complete_bpe'
], capture_output=True, text=True)

compile_time = time.time() - start_compile

if result.returncode != 0:
    print('Compilation failed:')
    print(result.stderr)
    print(result.stdout)
else:
    print(f'Compilation successful! ({compile_time:.2f}s)')
    print('Optimizations: -O3 -march=native')
    print('Features: Multi-threading, batch merging, conflict detection')

start_bpe = time.time()

# Run the complete BPE process
result = subprocess.run([
    './complete_bpe', 
    'full_corpus.txt', 
    str(num_threads), 
    str(vocab_size), 
    str(min_frequency)
], capture_output=True, text=True)

total_bpe_time = time.time() - start_bpe

print(result.stdout)
if result.stderr:
    print('STDERR:', result.stderr)

print(f'\nWall clock time for complete BPE: {total_bpe_time:.2f}s')
print(f'Average time per 1000 vocab tokens: {total_bpe_time / (vocab_size/1000):.2f}s')

# Performance summary and analysis
print('\n' + '=' * 60)
print('PERFORMANCE SUMMARY')
print('=' * 60)
print(f'Implementation: Multi-threaded C++ with batch merging')
print(f'CPU threads utilized: {num_threads}')
print(f'Dataset size: {len(pd.read_csv(input_csv_path))} documents')
print(f'Target vocabulary: {vocab_size} tokens')
print(f'Minimum frequency threshold: {min_frequency}')
print(f'Total processing time: {total_bpe_time:.2f}s')
print(f'Throughput: {vocab_size/total_bpe_time:.1f} vocab tokens/second')
print('\nFeatures implemented:')
print('✓ Multi-threaded corpus loading and processing')
print('✓ Parallel pair counting across documents')
print('✓ Batch merging with conflict detection (xy,yz patterns)')
print('✓ Greedy selection by frequency')
print('✓ Parallel merge application')
print('✓ Complete BPE process until target vocabulary')
print('=' * 60)

# Run the complete BPE process
print('=' * 60)
print('RUNNING COMPLETE BPE PROCESS ON FULL DATASET')
print('=' * 60)
print(f'Dataset: {input_csv_path}')
print(f'Threads: {num_threads}')
print(f'Target vocab size: {vocab_size}')
print(f'Min frequency: {min_frequency}')
print('=' * 60)


