In [3]:
import torch
import torch.nn.functional as F

In [4]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
loadstr = '/home/bstadt/root/tlm/models/tlm-2025-08-05_16-42-11/checkpoint-10500/'
model = AutoModelForMaskedLM.from_pretrained(loadstr)
tokenizer = AutoTokenizer.from_pretrained(loadstr)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def get_top_fills(phrase, model, tokenizer, top_k=5):
    """
    Get the top k most likely fills for a phrase with mask tokens.
    
    Args:
        phrase: String with [MASK] tokens to fill
        model: The language model
        tokenizer: The tokenizer
        top_k: Number of top fills to return (default 5)
    
    Returns:
        List of tuples (fill_tokens, probability) for top k fills
    """
    model.eval()
    
    # Tokenize and find mask positions
    inputs = tokenizer.encode_plus(phrase, add_special_tokens=False, return_tensors='pt')
    mask_locs = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
    n_masks = len(mask_locs)
    
    if n_masks == 0:
        raise ValueError("No mask tokens found in phrase")
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0]  # Remove batch dimension
        
        # Get logits for each mask position
        mask_logits = logits[mask_locs]  # Shape: (n_masks, vocab_size)
        
        # Convert to probabilities
        mask_probs = F.softmax(mask_logits, dim=-1)
        
        # For multiple masks, we need to consider combinations
        if n_masks == 1:
            # Single mask case
            top_probs, top_indices = torch.topk(mask_probs[0], top_k)
            results = []
            for prob, idx in zip(top_probs, top_indices):
                token = tokenizer.decode([idx.item()])
                results.append((token, prob.item()))
            return results
        
        else:
            # Multiple masks - get top tokens for each position and combine
            # This is a simplified approach - for exact top-k we'd need beam search
            top_tokens_per_mask = []
            for i in range(n_masks):
                top_probs_i, top_indices_i = torch.topk(mask_probs[i], top_k)
                top_tokens_per_mask.append([(tokenizer.decode([idx.item()]), prob.item()) 
                                          for prob, idx in zip(top_probs_i, top_indices_i)])
            
            # Generate combinations and compute joint probabilities
            from itertools import product
            combinations = list(product(*top_tokens_per_mask))
            
            results = []
            for combo in combinations:
                tokens = [token for token, _ in combo]
                joint_prob = np.prod([prob for _, prob in combo])
                fill_text = ' '.join(tokens)
                results.append((fill_text, joint_prob))
            
            # Sort by probability and return top k
            results.sort(key=lambda x: x[1], reverse=True)
            return results[:top_k]


In [None]:
years = list(range(1990, 2020))
year_fills = ['[YEAR:{}]'.format(year) for year in years]
year_fill_token_ids = [tokenizer.encode(e)[1] for e in year_fills]
device = model.device
def lyear(phrase, model, tokenizer):
    year_template = '[MASK] ' + phrase
    input_ids = tokenizer.encode(year_template, add_special_tokens=False, return_tensors='pt')
    with torch.no_grad():
        outputs = model(input_ids=input_ids.to(device))
        logits = outputs.logits[0][0]
        year_sublogits = logits[year_fill_token_ids]
        year_subprobs = F.softmax(year_sublogits, dim=0)

    return years, year_sublogits, year_subprobs

In [115]:
import urllib.request
import matplotlib.font_manager as fm
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt

# Download and set the Junicode font
font_url = "http://calcifercomputing.com/fonts/junicode/TTF/Junicode-Regular.ttf"
font_path = "/tmp/Junicode-Regular.ttf"
urllib.request.urlretrieve(font_url, font_path)
junicode_prop = fm.FontProperties(fname=font_path)


def visuzlize_process(target, top_k=10, plot_max=5, nucleus=None, saveas=None):
    if nucleus is None:
        results = get_top_fills(target, model, tokenizer, top_k=top_k)
        nucleus = [result[0] for result in results]
    print('building posterior for nucleus: ', nucleus)

    _, _, template_year_subprobs = lyear(target, model, tokenizer)
    bayes_by_fill = []
    for fill in tqdm(nucleus):
        years, _, fill_probs = lyear(fill, model, tokenizer)
        bayes_factors = fill_probs/template_year_subprobs
        bayes_by_fill.append(bayes_factors.detach().cpu().numpy())

    posteriors = np.stack(bayes_by_fill).T/np.sum(bayes_by_fill, axis=0, keepdims=True).T

    significant_traces = []
    significant_labels = []
    significant_p_values = []

    for i, word in enumerate(nucleus):
        time_series = posteriors[:, i]
        # Test if the time series has a slope of 0 (is stationary)
        from scipy import stats
        slope, intercept, r_value, p_value, std_err = stats.linregress(years, time_series)
        is_stationary = abs(slope) < 1e-6 or p_value > 0.01  # No significant slope

        if not is_stationary:
            significant_traces.append(posteriors[:, i])
            significant_labels.append(word)
            significant_p_values.append(p_value)

    # Sort by p-value and take only the top 5 most significant if there are more than 5
    if len(significant_traces) > 5:
        # Sort by p-value (ascending - smaller p-values are more significant)
        sorted_indices = sorted(range(len(significant_p_values)), key=lambda i: significant_p_values[i])
        significant_traces = [significant_traces[i] for i in sorted_indices[:5]]
        significant_labels = [significant_labels[i] for i in sorted_indices[:5]]
        significant_p_values = [significant_p_values[i] for i in sorted_indices[:5]]

    # Plot significant traces using plotly as a stacked area chart
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots

    
    good_colors = ['#EA5526', '#4462BD', '#51915B', '#8064A2', '#E5B700']

    if significant_traces:
        fig = go.Figure()
        
        # Add traces for stacked area chart
        for i, (trace, label, p_val) in enumerate(zip(significant_traces, significant_labels, significant_p_values)):
            color = good_colors[i % len(good_colors)]
            # Convert hex to rgba
            hex_color = color.lstrip('#')
            r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
            
            fig.add_trace(go.Scatter(
                x=years,
                y=trace,
                mode='lines',
                stackgroup='one',
                name=f'{label}',
                line=dict(width=0.5, color=color),
                fillcolor=f'rgba({r}, {g}, {b}, 1)',
                hovertemplate=f'<b>{label}</b>  %{{y:.3f}}<extra></extra>'
            ))
        
        fig.update_layout(
            title=dict(
                text=f'Significant Temporal Trends for "{target}"',
                font=dict(family="Junicode", size=24, color="black"),
                x=0.5
            ),
            xaxis=dict(
                title=dict(text='Year', font=dict(family="Junicode", size=20, color="black")),
                tickfont=dict(family="Geist Mono", size=12, color="black"),
                linecolor='black',
                linewidth=1,
                showgrid=False,
                zeroline=False,
                #spikecolor='black',
                #spikethickness=1,
                #spikedash='dot'
            ),
            yaxis=dict(
                title=dict(text='Posterior Probability', font=dict(family="Junicode", size=20, color="black")),
                tickfont=dict(family="Geist Mono", size=12, color="black"),
                linecolor='black',
                linewidth=1,
                showgrid=False,
                zeroline=False,
                tickvals=[tick for tick in fig.layout.yaxis.tickvals if tick != 0] if fig.layout.yaxis.tickvals else None
            ),
            legend=dict(
                font=dict(family="Junicode", size=16, color="black")
            ),
            hoverlabel=dict(
                font=dict(family="Geist Mono", size=12)
            ),
            hovermode='x',
            width=1000,
            height=500,
            plot_bgcolor='rgba(0,0,0,0)',
            paper_bgcolor='rgba(0,0,0,0)',
            showlegend=True
        )
        config={'displayModeBar': False}
        fig.show(config=config)
        if saveas:
            fig.write_html(saveas, config=config)
    else:
        print("No significant temporal trends found.")

In [116]:
visuzlize_process('The United States is a [MASK] country.', top_k=25, saveas='us_country.html')

building posterior for nucleus:  [' big', ' great', ' good', ' large', ' major', ' wonderful', ' nice', ' powerful', ' huge', ' famous', ' free', ' small', ' strong', ' foreign', ' popular', ' new', ' rich', ' beautiful', ' federal', ' young', ' democratic', ' wealthy', ' separate', ' western', ' British']


100%|██████████| 25/25 [00:03<00:00,  7.60it/s]


In [111]:
visuzlize_process('I am generally [MASK] about the future.', top_k=25, saveas='future.html')

building posterior for nucleus:  [' optimistic', ' concerned', ' hopeful', ' uncertain', ' positive', ' confident', ' enthusiastic', ' excited', ' worried', ' anxious', ' cautious', ' unsure', ' knowledgeable', ' skeptical', ' curious', ' thinking', ' cynical', ' happy', ' optimism', ' nervous', ' negative', ' informed', ' good', ' comfortable', ' certain']


100%|██████████| 25/25 [00:03<00:00,  7.91it/s]


In [112]:
visuzlize_process('I am generally [MASK] about my future.', top_k=25, saveas='future.html')


building posterior for nucleus:  [' optimistic', ' concerned', ' positive', ' confident', ' hopeful', ' enthusiastic', ' uncertain', ' excited', ' worried', ' unsure', ' anxious', 'istic', ' happy', ' cynical', ' satisfied', ' good', ' pleased', ' encouraged', ' cautious', ' certain', ' secure', ' cheerful', ' skeptical', ' thinking', ' indifferent']


100%|██████████| 25/25 [00:03<00:00,  6.96it/s]
