In [8]:
import torch
from generate_sequences import generate_bigram_sequences_using_table
from tables import *

In [9]:
def sample_bigram_seqs_to_convergence(
    bigram_table: torch.Tensor,
    max_length: int = 128,
    batch_size: int = 256
):
    device = bigram_table.device  # Ensure everything stays on the same device
    counts = torch.zeros(len(bigram_table), device=device, dtype=torch.float32)
    n = bigram_table.shape[0]
    p = torch.full((n,), 1 / n, device=device, dtype=torch.float32)

    for i in range(1000 * n):
        counts_sum = counts.sum().clamp(min=1e-8)  # Avoid division by zero
        p_next = counts / counts_sum

        if i % 50 == 0:  # Compute norm less frequently for speedup
            norm = torch.norm(p_next - p)
            print(f"Iteration {i}, Norm: {norm.item():.6f}")
            if norm < 1e-8:
                break
        p = p_next

        # Generate bigram sequences in larger batches to better utilize GPU
        seqs = generate_bigram_sequences_using_table(batch_size, max_length, bigram_table)

        # Ensure seqs is on the same device before calling `.unique`
        seqs = seqs.to(device)
        uc, uc_counts = seqs.unique(return_counts=True)

        # GPU-optimized scatter_add_
        counts.scatter_add_(0, uc, uc_counts.to(counts.dtype))

    return p


In [10]:
bigram_table = create_normal_bigram_table(10000).to('cuda')

In [11]:
sample_bigram_seqs_to_convergence(bigram_table)

Iteration 0, Norm: 0.010000


KeyboardInterrupt: 

In [None]:
sample_bigram_seqs_to_convergence(bigram_table)

nan
nan
0.00711576733738184
0.005210084840655327
0.00290660560131073
0.0020947977900505066
0.0020090974867343903
0.0012562324991449714
0.001393466955050826
0.0012117131846025586
0.001045257318764925
0.0008443673723377287
0.0010063119698315859
0.0006478465511463583
0.0005971979117020965
0.0010532723972573876
0.0007574147311970592
0.00043739587999880314
0.0005724529619328678
0.0008251378312706947
0.0005973120569251478
0.0004623265704140067
0.0005956720560789108
0.0004878173931501806
0.00043426259071566164
0.0005230871611274779
0.0004728665226139128
0.00047515128972008824
0.00050107337301597
0.00032228868803940713
0.00045693537686020136
0.0005689171375706792
0.0004072487645316869
0.00027504703029990196
0.0002496831875760108
0.0002855987404473126
0.00023584833252243698
0.00035541149554774165
0.00018394945072941482
0.00034586648689582944
0.00016097835032269359
0.0003276806965004653
0.00017654357361607254
0.00022352708037942648
0.0002697829040698707
0.0002921372069977224
0.000151061918586492

tensor([0.0842, 0.0990, 0.1115, 0.1037, 0.0893, 0.1059, 0.0983, 0.1054, 0.1096,
        0.0931])

tensor([0.0842, 0.0990, 0.1115, 0.1037, 0.0893, 0.1059, 0.0983, 0.1054, 0.1096,
        0.0931])
tensor([0.0839, 0.0991, 0.1112, 0.1040, 0.0891, 0.1056, 0.0985, 0.1060, 0.1097,
        0.0929])

In [6]:
def get_stationary_distribution(bigram_table: torch.Tensor) -> torch.Tensor:
    """
    Get the stationary distribution of a bigram table.
    
    Args:
        bigram_table: `torch.Tensor` - the bigram table.
        
    Returns:
        `torch.Tensor` - the stationary distribution.
    """
    
    n = bigram_table.shape[0]
    p = torch.ones(n) / n
    for _ in range(10 * n):
        p_next = p @ bigram_table
        if torch.norm(p_next - p) < 1e-8:
            break
        p = p_next
    return p

In [7]:
def calculate_entropy_bigram(bigram_table: torch.Tensor) -> float:
    """
    Calculate entropy of a bigram table.
    
    Args:
        bigram_table: `torch.Tensor` - the bigram table.
        
    Returns:
        `float` - the entropy of the bigram table.
    """
    
    p = get_stationary_distribution(bigram_table)
    
    joint_probs = p[:, None] * bigram_table
    
    try:
        t = -1 * (joint_probs * joint_probs.log()).sum()
    except:
        t = -1 * (joint_probs * (joint_probs + 1e-10).log()).sum()
        
    return t.item()

In [8]:
calculate_entropy_bigram(bigram_table)

4.556020736694336