In [1]:
import os
import torch

from collections import defaultdict, Counter, OrderedDict
import pandas as pd

In [2]:
from transformers import RobertaTokenizer, RobertaForMaskedLM

In [3]:
def load_vocabulary(file_path):
  
    with open(file_path, 'r') as file:
        # Read all lines, strip whitespace, and filter out empty lines
        vocabulary = [line.strip() for line in file if line.strip()]
    return vocabulary 



def load_parquet(file_path="temporary_words.parquet"):
   
    try:
        df = pd.read_parquet(file_path)
        if "words" in df.columns:
            return df["words"].tolist()
        else:
            raise ValueError("The Parquet file does not contain a 'words' column.")
    except Exception as e:
        print(f"Error loading Parquet file: {e}")
        return None



In [4]:
def filter_out_words_with_chars_parquet(vocab, chars, debug=True, output_file="temporary_words.parquet"):

    filtered_vocab = [word for word in vocab if not any(char in word for char in chars)]
    df = pd.DataFrame(filtered_vocab, columns=["words"])
    df.to_parquet(output_file, index=False)
    if debug:
        print(f"Filtered vocabulary saved to {output_file}")


def filter_in_words_with_chars_parquet(vocab, chars, debug=True, output_file="temporary_words.parquet"):
  
    filtered_vocab = [word for word in vocab if any(char in word for char in chars)]
    df = pd.DataFrame(filtered_vocab, columns=["words"])
    df.to_parquet(output_file, index=False)
    if debug:
        print(f"Filtered vocabulary saved to {output_file}")


def updateGivenWord(given_word, answer, guess):
    filled_dashes=0
    
    for str_idx, char in enumerate(answer):
        if guess == char:
            given_word[str_idx] = guess
            filled_dashes+=1

    return given_word, filled_dashes


In [5]:
def character_frequency_by_length(vocabulary):
   
    # Group words by their lengths
    length_groups = defaultdict(list)
    for word in vocabulary:
        length_groups[len(word)].append(word)
    
    # Compute character frequencies for each length group
    freq_by_length = {}
    for length, words in length_groups.items():
        # Flatten all characters from words of this length
        all_chars = ''.join(words)
        # Compute frequency of each character
        char_counts = Counter(all_chars)
        # Sort the frequencies by count (descending), then by character (alphabetical) for ties
        sorted_counts = OrderedDict(sorted(char_counts.items(), key=lambda x: (-x[1], x[0])))
        # Store the sorted frequencies
        freq_by_length[length] = sorted_counts
    
    return freq_by_length



def printFreqCounts(freq_counts):
    for length, freqs in sorted(freq_counts.items()):
        print(f"Length {length}: {dict(freqs)}")

In [6]:
def character_word_coverage_by_length(vocabulary):
    
    # Group words by their lengths
    length_groups = defaultdict(list)
    for word in vocabulary:
        length_groups[len(word)].append(word)
    
    # Compute word coverage for each character by length group
    coverage_by_length = {}
    for length, words in length_groups.items():
        # Create a set of characters per word and compute coverage
        char_coverage = defaultdict(int)
        for word in words:
            for char in set(word):  # Use `set` to avoid double-counting characters in the same word
                char_coverage[char] += 1
        
        # Sort the coverage dictionary by frequency (descending), then alphabetically
        sorted_coverage = OrderedDict(
            sorted(char_coverage.items(), key=lambda x: (-x[1], x[0]))
        )
        coverage_by_length[length] = sorted_coverage
    
    return coverage_by_length



def printWordCoverage(coverage_counts):
    for length, coverage in sorted(coverage_counts.items()):
        print(f"Length {length}: {dict(coverage)}")

In [7]:
def getTopChoice(freq, coverage, guessed_words, TOP_K=5):

   
    filtered_coverage = [char for char in coverage if char not in guessed_words]
    filtered_freq = [char for char in freq if char not in guessed_words]
    
    if len(filtered_coverage)==0 and len(filtered_freq)==0:
        return None
    
    elif len(filtered_coverage)>0 and len(filtered_freq)>0:
        top_coverage_chars = filtered_coverage[:min(TOP_K, len(freq))]
        top_freq_chars = filtered_freq[:min(TOP_K, len(coverage))]

        # Find the intersection while maintaining the order of `top_coverage_chars`
        common_chars = [char for char in top_coverage_chars if char in top_freq_chars]
        
        # If there's an intersection, choose the top coverage character from it
        if common_chars:
            return common_chars[0]  # The first common character in `top_coverage_chars`

        # If no intersection, return the most covered character
        return top_coverage_chars[0]
    
    elif len(filtered_coverage)>0:
        return filtered_coverage[0]
    else: 
        return filtered_freq[0]

In [8]:
# Initialize tokenizer 
class CharLevelTokenizer:
    def __init__(self, vocab):
        self.char_vocab = vocab
        self.char_to_id = {char: idx for idx, char in enumerate(vocab)}
        self.id_to_char = {idx: char for idx, char in enumerate(vocab)}

    def encode(self, text):
        return [self.char_to_id[char] if char in self.char_to_id else self.char_to_id["_"] for char in text]

    def decode(self, token_ids):
        return "".join([self.id_to_char[token_id] for token_id in token_ids])


def predict_masked_characters(input_sequence, tokenizer, model, mask_token_id):
    # Convert the input sequence to token IDs using the tokenizer
    input_ids = tokenizer.encode(input_sequence)

    # Convert input_ids to tensor and move it to the right device
    input_tensor = torch.tensor([input_ids])

    # Run the model to predict masked token positions
    with torch.no_grad():
        outputs = model(input_tensor)
        logits = outputs.logits

    # Extract the predicted token IDs for each masked position
    predicted_ids = torch.argmax(logits, dim=-1).squeeze().cpu().numpy()

    # Convert predicted IDs to characters using the tokenizer
    predicted_sequence = tokenizer.decode(predicted_ids)
    return predicted_sequence


def loadFinetunedModel(model_path = None):
      
    if model_path: 
        model = RobertaForMaskedLM.from_pretrained(model_path)

        # Set the model to evaluation mode for inference
        model.eval()
        
        return model

    else: 
        print("ENTER VALID MODEL PATH!!")
        return None



In [None]:
# Custom tokenizer to tokenize by lowercase characters only
char_vocab = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "[MASK]", "[PAD]", "_"]


# Initialize the custom lowercase character-level tokenizer
char_tokenizer = CharLevelTokenizer(char_vocab)
mask_token_id = char_tokenizer.char_to_id["[MASK]"]
pad_token_id = char_tokenizer.char_to_id["[PAD]"]


## Load finetuned-trained RoBERTa model for Masked Language Modeling (MLM)
# MODEL_PATH = './model/fine_tuned_roberta_char_level_uncased'
# model = loadFinetunedModel(MODEL_PATH)


MODEL_PATH = "facebook/bart-large" 
# MODEL_PATH="roberta-base"
# model = RobertaForMaskedLM.from_pretrained(MODEL_PATH )


# Resize the model's token embeddings to match the character-level vocab size
model.resize_token_embeddings(len(char_tokenizer.char_vocab))  # Resize for lowercase char-level tokens

config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

You are using a model of type bart to instantiate a model of type roberta. This is not supported for all configurations of models and can yield errors.


pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [None]:
# vocabulary = ["a", "an", "ant", "bat", "cat", "me", "do", "see", "tree", "hangman", "meith"]

CHECKPOINT = 0
vocabulary = load_vocabulary("words_250000_train.txt")
vocabulary = vocabulary[CHECKPOINT:]

In [None]:
freq_counts = character_frequency_by_length(vocabulary)
printFreqCounts(freq_counts)

Length 1: {'c': 1, 'd': 1, 'e': 1, 'f': 1, 'g': 1, 'h': 1, 'i': 1, 'j': 1, 'k': 1, 'm': 1, 'o': 1, 'p': 1, 'r': 1, 's': 1, 'v': 1, 'w': 1, 'y': 1}
Length 2: {'c': 26, 'w': 26, 'd': 25, 'h': 25, 'r': 25, 'e': 24, 'i': 24, 'a': 22, 'o': 22, 'u': 22, 'f': 21, 'l': 21, 'n': 21, 's': 20, 't': 20, 'x': 20, 'p': 19, 'v': 19, 'b': 17, 'k': 17, 'm': 17, 'q': 16, 'g': 15, 'j': 15, 'y': 15, 'z': 14}
Length 3: {'a': 533, 's': 448, 'c': 397, 'e': 383, 'o': 364, 'i': 351, 't': 349, 'm': 347, 'd': 337, 'p': 333, 'r': 318, 'l': 287, 'b': 268, 'n': 247, 'u': 244, 'f': 223, 'g': 220, 'h': 187, 'v': 146, 'w': 146, 'y': 132, 'k': 118, 'x': 73, 'j': 66, 'q': 46, 'z': 40}
Length 4: {'a': 2188, 'e': 1860, 'o': 1519, 's': 1491, 'i': 1395, 'r': 1202, 'l': 1109, 't': 1017, 'n': 1000, 'u': 867, 'd': 849, 'm': 809, 'c': 771, 'p': 700, 'b': 609, 'h': 587, 'g': 548, 'y': 542, 'k': 535, 'f': 451, 'w': 356, 'v': 263, 'z': 180, 'j': 163, 'x': 101, 'q': 36}
Length 5: {'a': 6291, 'e': 5578, 's': 4178, 'o': 3775, 'r': 37

In [None]:
coverage_by_length = character_word_coverage_by_length(vocabulary)
printWordCoverage(coverage_by_length)

Length 1: {'c': 1, 'd': 1, 'e': 1, 'f': 1, 'g': 1, 'h': 1, 'i': 1, 'j': 1, 'k': 1, 'm': 1, 'o': 1, 'p': 1, 'r': 1, 's': 1, 'v': 1, 'w': 1, 'y': 1}
Length 2: {'c': 25, 'd': 25, 'w': 25, 'h': 24, 'i': 24, 'r': 24, 'e': 23, 'a': 22, 'f': 21, 'l': 21, 'n': 21, 'o': 21, 'u': 21, 's': 20, 't': 20, 'p': 19, 'v': 19, 'x': 19, 'k': 17, 'b': 16, 'm': 16, 'q': 16, 'g': 14, 'j': 14, 'y': 14, 'z': 13}
Length 3: {'a': 508, 's': 423, 'c': 372, 'e': 366, 'i': 342, 'o': 341, 't': 337, 'm': 328, 'd': 321, 'p': 315, 'r': 311, 'l': 274, 'b': 261, 'n': 239, 'u': 237, 'f': 216, 'g': 215, 'h': 182, 'v': 145, 'w': 143, 'y': 131, 'k': 116, 'x': 69, 'j': 66, 'q': 44, 'z': 38}
Length 4: {'a': 1985, 'e': 1672, 's': 1367, 'o': 1363, 'i': 1323, 'r': 1162, 'l': 1028, 'n': 957, 't': 942, 'u': 836, 'd': 808, 'm': 767, 'c': 711, 'p': 654, 'b': 579, 'h': 569, 'y': 535, 'g': 525, 'k': 515, 'f': 415, 'w': 351, 'v': 255, 'z': 167, 'j': 162, 'x': 98, 'q': 36}
Length 5: {'a': 5359, 'e': 4800, 's': 3782, 'r': 3497, 'i': 3435,

## UPDATE AFTER FILTERING

In [None]:
# # filter out `a`

# char=['a']
# filter_out_words_with_chars_parquet(vocabulary, char, output_file="temporary_words.parquet")
# vocabulary = load_parquet("temporary_words.parquet")

# freq_counts = character_frequency_by_length(vocabulary)
# printFreqCounts(freq_counts)

## LOGIC

In [None]:
def hangman(given_word, answer, NUM_TRAILS=6, temp_file_path="temporary_words.parquet", debug=True):
    
    # 1. Load vocab
    vocabulary = load_vocabulary("words_250000_train.txt")
    N = len(given_word)
    NUM_DASHES = N
    guessed_words=set()

    if debug:
            print(f"\n\n******\n=> GIVEN WORD : {given_word}\nTOTAL ATTEMPTS: {NUM_TRAILS}")
    
    while NUM_TRAILS>2:
        
        if debug:
            print(f"\n\n******")
        #     print(f"\n\n******\n=> GIVEN WORD : {given_word}\nATTEMPTS LEFT: {NUM_TRAILS}")

        # Generate meta-data: make list of length wise freq of words (for starting character)
        freq_counts = character_frequency_by_length(vocabulary)
        if debug:
            print(f"\nFREQ COUNTS:{freq_counts[N]}")
            # printFreqCounts(freq_counts)
        
        coverage_by_length = character_word_coverage_by_length(vocabulary)
        if debug:
            print(f"\nCOVERAGE COUNTS: {coverage_by_length[N]}")
            # printFreqCounts(coverage_by_length)


        # make the top selection list for that length: 
        freq_counts = list(freq_counts[N])
        coverage_by_length = list(coverage_by_length[N])
        
        if len(freq_counts)>0 or len(coverage_by_length)>0: 
            top_choice = getTopChoice(freq_counts , coverage_by_length, guessed_words)
           
        else: 
            break ## need to stop topChoice if theres nothing left for that length##

        
        if debug:
            print(f"\n\n=> CURRENT GUESS: {top_choice}")
        

        if top_choice in answer:
            # if hit then only consider all words with that char and re-compute freq and coverage
            given_word, filled_dashes = updateGivenWord(given_word, answer, top_choice)
            filter_in_words_with_chars_parquet(vocabulary, [top_choice], debug)
            NUM_DASHES -= filled_dashes

            if debug:
                print(f"HIT!!\nUPDATED WORD: {given_word}\nATTEMPTS LEFT: {NUM_TRAILS}") 

            if NUM_DASHES == 0:
                print(f"\n\n*** CRACKED WORD ***\nGUESSED WORD : {''.join(given_word)} \nANSWER: {answer}")
                # return ''.join(given_word)
                return 1
        
     
        else: 
            # else remove all the words with that char and re-compute the freq and coverage
            filter_out_words_with_chars_parquet(vocabulary, [top_choice],debug)
            NUM_TRAILS-=1

            if debug:
                print(f"MISS!!\nUPDATED WORD: {given_word}\nATTEMPTS LEFT: {NUM_TRAILS}") 

        guessed_words.add(top_choice)
        vocabulary = load_parquet(temp_file_path)

    # once we get the 1-2 chars then feed into the model 
    for _ in range(NUM_TRAILS):
        
        # need to update it so it doesnot predict same word everytime.  #
        
        pred_word = predict_masked_characters(given_word, char_tokenizer, model, mask_token_id)
        NUM_TRAILS-=1
        
        if pred_word == answer:
            print(f"\n\n***CRACKED WORD***\nGUESSED WORD : {''.join(given_word)} \nANSWER: {answer}")
            # return ''.join(given_word)
            return 1 
        
        elif debug:
            print(f"MISS!!\nUPDATED WORD: {pred_word}\nATTEMPTS LEFT: {NUM_TRAILS}") 
            # update the list of guessed chars

    if debug: 
        print("\n\nTRAILS EXHUASTED!!")
    return 0
    



In [None]:
answer = "hagaman"
NUM_TRAILS = 6
given_word = ['_']*len(answer)
debug=True
temp_file_path="temporary_words.parquet"


hangman(given_word, answer, NUM_TRAILS, temp_file_path, debug)




******
=> GIVEN WORD : ['_', '_', '_', '_', '_', '_', '_']
TOTAL ATTEMPTS: 6


******

FREQ COUNTS:OrderedDict([('e', 20564), ('a', 17455), ('i', 13867), ('r', 13151), ('s', 13029), ('n', 11601), ('o', 11335), ('l', 10641), ('t', 10211), ('d', 6959), ('u', 6944), ('c', 6665), ('m', 5480), ('h', 4991), ('p', 4961), ('g', 4960), ('b', 4173), ('y', 3338), ('k', 2691), ('f', 2551), ('w', 2191), ('v', 1586), ('z', 829), ('j', 602), ('x', 545), ('q', 316)])

COVERAGE COUNTS: OrderedDict([('e', 15976), ('a', 13702), ('i', 12066), ('r', 11624), ('s', 10972), ('n', 10223), ('o', 9526), ('l', 9193), ('t', 8879), ('u', 6450), ('d', 6232), ('c', 6107), ('m', 5069), ('h', 4769), ('p', 4487), ('g', 4485), ('b', 3826), ('y', 3245), ('k', 2565), ('f', 2250), ('w', 2120), ('v', 1542), ('z', 741), ('j', 594), ('x', 542), ('q', 315)])


=> CURRENT GUESS: e
Filtered vocabulary saved to temporary_words.parquet
MISS!!
UPDATED WORD: ['_', '_', '_', '_', '_', '_', '_']
ATTEMPTS LEFT: 5


******

FREQ COUNTS

0

## TEST

In [None]:
test_words = ['Perfectly', 'Efficiency', 'Literature', 'Calendar', 'Familiar', 'Confidence', 'Influence', 'Intuition', 'Adventure', 'Excellent', 'Motivation', 'Victory', 'Victory', 'Reasonable', 'Wanderer', 'Hospital', 'Tranquil', 'Delight', 'Dominating', 'Integration', 'Opportunity', 'Strategy', 'Overwhelmed', 'Community', 'Discovery', 'Wonderful', 'Resilient', 'Universe', 'Underestimate', 'Observation', 'Wisdom', 'Classroom', 'Learning', 'Beautiful', 'Satisfaction', 'Invisible', 'Beneficial', 'Calculate', 'Background', 'Marvelous', 'Mysterious', 'Overcome', 'Familiar', 'Hilarious', 'Reflection', 'Simplicity', 'Capacity', 'Wealthy', 'Courage', 'Victory', 'Contribute', 'Frontier', 'Essential', 'Ultimate', 'Impressive', 'Dominance', 'Recognize', 'Amazing', 'Defeated', 'Participate', 'Happiness', 'Elevation', 'Gateway', 'Financial', 'Blanket', 'Tolerance', 'Security', 'Knowledge', 'Freedom', 'Lantern', 'Journey', 'Journeying', 'Courageous', 'Reputable', 'Tolerable', 'Adventure', 'Practice', 'Quality', 'Amplified', 'Influence', 'Fragrance', 'Wanderlust', 'Fabulous', 'Wonderment', 'Supportive', 'Generosity', 'Resourceful', 'Imagination', 'Sufficient', 'Sunshine', 'Potential', 'Luminous', 'Motivation', 'Masterpiece', 'Forgetful', 'Generate', 'Satisfying', 'Nervously', 'Encounter', 'Harmony']


In [None]:
debug=True
correct=0


for answer in test_words[:10]:
    print("\n", answer)
    given_word = ['_']*len(answer)
    correct += hangman(given_word, answer, NUM_TRAILS, temp_file_path, debug)


acc = correct / len(test_words)
print(f"ACCURACY : {acc:.3}")


 Perfectly


******
=> GIVEN WORD : ['_', '_', '_', '_', '_', '_', '_', '_', '_']
TOTAL ATTEMPTS: 6


******

FREQ COUNTS:OrderedDict([('e', 31906), ('i', 23523), ('a', 23496), ('s', 20724), ('r', 20017), ('o', 18839), ('n', 18664), ('t', 17643), ('l', 15807), ('c', 11190), ('d', 10623), ('u', 9868), ('m', 8286), ('p', 8047), ('h', 7535), ('g', 6842), ('b', 5448), ('y', 4675), ('f', 3728), ('k', 2990), ('w', 2822), ('v', 2597), ('z', 1120), ('x', 837), ('q', 469), ('j', 458)])

COVERAGE COUNTS: OrderedDict([('e', 22034), ('i', 18588), ('a', 17906), ('r', 16738), ('s', 15810), ('n', 15309), ('o', 14789), ('t', 14709), ('l', 12879), ('c', 9813), ('d', 9415), ('u', 8848), ('m', 7490), ('p', 7226), ('h', 6994), ('g', 6250), ('b', 5038), ('y', 4451), ('f', 3340), ('k', 2812), ('w', 2649), ('v', 2517), ('z', 1051), ('x', 835), ('q', 466), ('j', 451)])


=> CURRENT GUESS: e
Filtered vocabulary saved to temporary_words.parquet
HIT!!
UPDATED WORD: ['_', 'e', '_', '_', 'e', '_', '_', '_', '_']
