In [1]:
import json
from nltk import Nonterminal
from analysis2 import to_cnf, cyk_parse
from generate_pcfg import PARSERS
import math
from collections import Counter

# Function to find non-overlapping longest valid substrings
def find_non_overlapping_longest(tokens, cnf_rules, C: Nonterminal):
    """
    Returns a list of (i, j, text) for the longest non-overlapping substrings
    tokens[i:j] valid under nonterminal C.
    """
    # Gather all valid spans
    hits = []
    N = len(tokens)
    for i in range(N):
        for j in range(i+1, N+1):
            substr = tokens[i:j]
            if cyk_parse(substr, cnf_rules, C)["valid"]:
                hits.append((i, j, " ".join(substr)))
    # Sort by length descending
    hits.sort(key=lambda x: x[1] - x[0], reverse=True)
    # Greedily select non-overlapping spans
    selected = []
    occupied = set()
    for i, j, text in hits:
        if any(pos in occupied for pos in range(i, j)):
            continue
        selected.append((i, j, text))
        for pos in range(i, j):
            occupied.add(pos)

    return selected, len(selected)

  from .autonotebook import tqdm as notebook_tqdm


In [46]:
# Parameters: adjust as needed
grammar_name = "Conditionals"    # key in PARSERS
epoch        = "epoch_20.pt"       # epoch identifier in results_log.json
symbol       = "C"                 # non-terminal to analyze

# Load CNF rules and sequences
parser = PARSERS[grammar_name]
cnf_rules, nonterminals, _ = to_cnf(parser)
C = Nonterminal(symbol)

with open("../results/results_log.json") as f:
    all_results = json.load(f)
sequences = all_results["ConditionalLoops"][epoch]["generated_sequences"]

In [None]:
# Find and display longest non-overlapping C-derived substrings
total_count = 0
valid_count = 0
for seq in sequences:
    tokens = seq.split()

    spans, count = find_non_overlapping_longest(tokens, cnf_rules, C)
    total_count += count
    #print(f"Seq: '{seq}' \n--> Found {len(spans)} non-overlapping substrings:")
    for i, j, text in spans:
        valid = False
        if tokens[i-1] not in {"cond", "not", "and"}:
            valid = True
        if j<len(tokens)-1 and tokens[j+1] in {"cond", "not", "and"}:
            valid = valid and True
        if valid:
            valid_count += 1
        elif j < len(tokens) - 1:
            print(tokens[i-1], tokens[j+1])
        
        #print(f"  - span ({i}, {j}): '{text}'")
    
    
print(f"{valid_count}/{total_count} valid non-overlapping substrings in this sequence.\n")

57/57 valid non-overlapping substrings in this sequence.



In [None]:
# Add this cell after your model loading cell
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizerFast
from model import GPT, FourLayer 

grammar_name = "ConditionalLoops"  # Name of the grammar
# Load model and tokenizer
def setup_model_and_tokenizer():
    # Load tokenizer
    tokenizer = PreTrainedTokenizerFast(
        tokenizer_file=f"../data/{grammar_name}/{grammar_name}_1000/tokenizer.json",
        bos_token="<|bos|>",
        eos_token="<|eos|>"
    )
    
    # Create vocabulary mapping
    vocab = tokenizer.get_vocab()
    id_to_token = {v: k for k, v in vocab.items()}
    
    # Load model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPT(FourLayer()).to(device)

    checkpoint_dir = f"../data/{grammar_name}/{grammar_name}_1000/FourLayer/{epoch}"
    model.load_state_dict(
        torch.load(checkpoint_dir, map_location=device)
)
    
    return model, tokenizer, vocab, id_to_token

# Get logits for tokens in a sequence
def get_sequence_token_logits(model, sequence, vocab):
    """Get logits for each token in the sequence"""
    # Convert sequence to tensor
    tokens = sequence.split()

    # Create input tensor
    input_ids = torch.tensor([[vocab.get(t, 0) for t in tokens]])

    # Get model output
    with torch.no_grad():
        logits, _ = model(input_ids)
    
    # Get log probabilities
    log_probs = F.log_softmax(logits.squeeze(1), dim=-1)
    
    # Extract logits for observed tokens
    token_logits = []
    for i in range(len(tokens)-1):  # -1 because we predict the next token
        next_token = tokens[i+1]
        next_token_id = vocab.get(next_token, 0)
        # Get probability that model assigns to the actual next token
        token_logit = log_probs[0, i, next_token_id].item()
        token_logits.append(token_logit)
    
    return token_logits

# Get logits for a specific subsequence within a full sequence
def get_subsequence_logits_in_context(model, full_sequence, start_idx, end_idx, vocab):
    """Get logits for tokens in a subsequence within the context of the full sequence"""
    # Get full sequence logits
    print("full_sequence", full_sequence)
    full_logits = get_sequence_token_logits(model, full_sequence, vocab)
    print("full_logits", full_logits)
    print("length of full_logits", len(full_logits))
    print("start_idx", start_idx, "end_idx", end_idx)
    
    # Extract logits for the subsequence (offset by 1 since we predict next token)
    subseq_logits = full_logits[start_idx:end_idx]

    print(subseq_logits)
    
    # Calculate total log probability for the subsequence
    total_log_prob = sum(subseq_logits)
    
    return {
        "token_logits": subseq_logits,
        "log_prob": total_log_prob,
        "prob": math.exp(total_log_prob)
    }

# Setup model
model, tokenizer, vocab, id_to_token = setup_model_and_tokenizer()
print("Model loaded successfully")
    


number of parameters: 0.86M
Model loaded successfully


In [49]:
import math
from collections import Counter
from nltk import ViterbiParser

log_eps = math.log(1e-12)

def seq_log_pcfg(parser: ViterbiParser, text: str) -> float:
    toks   = text.split()
    parses = list(parser.parse(toks))
    return math.log(parses[0].prob()) if parses else log_eps


# 1) Gather all longest non-overlapping C‐derived substrings
diffs =0
selected_texts = []
for seq in sequences:
    tokens = seq.split()
    spans, count = find_non_overlapping_longest(tokens, cnf_rules, C)
    # spans is a list of (i,j,text)
    total_count += count
    #print(f"Seq: '{seq}' \n--> Found {len(spans)} non-overlapping substrings:")
    for i, j, text in spans:
        valid = False
        if tokens[i-1] not in {"cond", "not", "and"}:
            valid = True
        if j<len(tokens)-1 and tokens[j+1] in {"cond", "not", "and"}:
            valid = valid and True
        if valid:
            valid_count += 1
        elif j < len(tokens) - 1:
            print(tokens[i-1], tokens[j+1])
    
        lp_pcfg = seq_log_pcfg(parser, text)
        print(lp_pcfg)
        lp_neural = get_subsequence_logits_in_context(model, seq, i, j, vocab)

        diffs += abs(lp_neural["log_prob"] - lp_pcfg)
    selected_texts.extend(text for (_,_,text) in spans)

print(diffs)    


-4.1588830833596715
full_sequence if cond and not cond then if cond then if not cond then action else t t z else z
full_logits [-1.1778427362442017, -0.42783209681510925, -1.0726779699325562, -0.9511851072311401, -0.44496414065361023, -1.574199914932251, -0.8865269422531128, -0.49690911173820496, -0.7577444314956665, -0.9320438504219055, -0.9475969672203064, -0.4323291480541229, -0.665740966796875, -0.8581465482711792, -0.008859374560415745, -0.56456059217453, -0.541461169719696, -0.8143876791000366, -0.09167063236236572, -0.9104676246643066]
length of full_logits 20
start_idx 1 end_idx 5
[-0.42783209681510925, -1.0726779699325562, -0.9511851072311401, -0.44496414065361023]
-2.0794415416798357
full_sequence if cond and not cond then if cond then if not cond then action else t t z else z
full_logits [-2.414663076400757, -0.4646325707435608, -1.0828773975372314, -1.003840446472168, -0.45671403408050537, -1.2790613174438477, -0.9539477229118347, -0.46990036964416504, -0.48317593336105347,