In [None]:
import torch
import random
from transformers import BertTokenizer, BertForMaskedLM

# Load pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForMaskedLM.from_pretrained(model_name)

# Ensure the model is in evaluation mode
model.eval()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [None]:
model.to("cpu")
model = torch.jit.trace(model.forward, tokenizer.encode('random', return_tensors='pt').long(), strict=False)
model_save_path = '/content/drive/MyDrive/Colab Notebooks/bert_base_uncased_2.pt'
torch.save(model.state_dict(), model_save_path)



In [None]:
tokens = tokenizer.encode('random', return_tensors='pt')
print(tokens.dtype)

torch.int64


In [None]:
torch.cuda.empty_cache()

In [None]:
# Function to perform top-p sampling
def top_p_sampling(logits, top_p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)

    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    # Set the logits of tokens to be removed to a very low value
    logits[sorted_indices[sorted_indices_to_remove]] = -float('Inf')

    # Sample from the filtered distribution
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    next_token_id = torch.multinomial(probabilities, 1).item()
    return next_token_id

In [None]:
def constraint_score(token):
    if token[0] in {'p', 'P'}: return -3
    elif token[0] in {' ', ',', '.', '!', '?', ';', ':', '-'}: return -1
    else: return 2

def minimal_energy_sampling(logits):
    probabilities = torch.nn.functional.softmax(logits, dim=-1);
    energy = -torch.log(probabilities)

    lambda1 = 3
    epsilon1 = 0
    checked_ids = set()
    for _ in range(999):
        token_id = torch.multinomial(probabilities, 1).item()
        if token_id in checked_ids: continue
        checked_ids.add(token_id)
        token = tokenizer.decode([token_id])
        score = constraint_score(token)
        energy[token_id] -= lambda1 * (epsilon1 - score)

    probabilities = torch.nn.functional.softmax(-energy, dim=-1)
    next_token_id = torch.multinomial(probabilities, 1).item()
    return next_token_id

In [None]:
# Function to generate text using autoregressive approach with BERT
def generate_autoregressive_text(prompt, max_length=20, iterations=20, top_p=0.9):
    # Tokenize the input prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    # Prepare the masked input
    masked_input_ids = input_ids.clone()
    number_of_masks = max_length - len(input_ids[0])
    tensor = torch.tensor([[tokenizer.mask_token_id] * number_of_masks])
    masked_input_ids = torch.cat([masked_input_ids, tensor], dim=1)

    # Generate tokens autoregressively
    for i in range(iterations):
        with torch.no_grad():
            outputs = model(masked_input_ids)
            predictions = outputs.logits

            # Get the index of the first [MASK] token
            mask_token_index = (masked_input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)

            if len(mask_token_index[1]) == 0:
                break

            first_mask_token_index = mask_token_index[1][0].item()

            # Get the logits for the [MASK] token
            next_token_logits = predictions[0, first_mask_token_index]

            # Sample from the top-p filtered distribution
            # predicted_token_id = top_p_sampling(next_token_logits, top_p)
            predicted_token_id = minimal_energy_sampling(next_token_logits)
            masked_input_ids[0, first_mask_token_index] = predicted_token_id

            # Decode the current output
            generated_text = tokenizer.decode(masked_input_ids[0], skip_special_tokens=False)
            generated_text = generated_text.replace('[CLS] ', '').replace(' [SEP]', '')

            # Print the intermediate result
            print(f"Iteration {i + 1}: {generated_text}")

    return generated_text

# Example usage
prompt = "in the midst of"
final_generated_text = generate_autoregressive_text(prompt)
print(f"\nFinal Generated Text: {final_generated_text}")

Iteration 1: in the midst of pure [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
Iteration 2: in the midst of pure, [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
Iteration 3: in the midst of pure, practically [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
Iteration 4: in the midst of pure, practically perfectly [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
Iteration 5: in the midst of pure, practically perfectly pleasant [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
Iteration 6: in the midst of pure, practically perfectly pleasant pain [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
Iteration 7: in the midst of pure, practically perfectly pleasant pain, [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
Iteration 8: in the midst of pure, practically perfectly pleasant pain, placed [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]
Iterati

In [None]:
# Function to generate text using non-autoregressive approach with BERT
def generate_non_sequential_text(prompt, max_length=20, iterations=20, top_p=0.9):
    # Tokenize the input prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    # Prepare the masked input
    masked_input_ids = input_ids.clone()
    masked_input_ids = torch.cat([masked_input_ids,torch.tensor([[tokenizer.mask_token_id] * (max_length - len(input_ids[0]))])], dim=1)

    # Generate tokens non-autoregressively
    for i in range(iterations):
        with torch.no_grad():
            outputs = model(masked_input_ids)
            predictions = outputs.logits

            # Get the indices of all [MASK] tokens
            mask_token_indices = (masked_input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)

            if len(mask_token_indices[1]) == 0:
                break

            # Randomly select one [MASK] token to replace
            random_mask_token_index = random.choice(mask_token_indices[1]).item()

            # Get the logits for the [MASK] token
            next_token_logits = predictions[0, random_mask_token_index]

            # Sample from the top-p filtered distribution
            # predicted_token_id = top_p_sampling(next_token_logits, top_p)
            predicted_token_id = minimal_energy_sampling(next_token_logits)
            masked_input_ids[0, random_mask_token_index] = predicted_token_id

            # Decode the current output, including [MASK] tokens
            generated_text = tokenizer.decode(masked_input_ids[0], skip_special_tokens=False)
            generated_text = generated_text.replace('[CLS] ', '').replace(' [SEP]', '')

            # Print the intermediate result
            print(f"Iteration {i + 1}: {generated_text}")

    # Final output without special tokens
    final_generated_text = tokenizer.decode(masked_input_ids[0], skip_special_tokens=True)
    return final_generated_text

# Example usage
prompt = "in the midst of"
non_autoregressive = generate_non_sequential_text(prompt)
print(f"Final Generated Text: {non_autoregressive}\n")

Iteration 1: in the midst of [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK], [MASK] [MASK] [MASK] [MASK]
Iteration 2: in the midst of [MASK] [MASK] [MASK] [MASK], [MASK] [MASK] [MASK] [MASK], [MASK] [MASK] [MASK] [MASK]
Iteration 3: in the midst of [MASK] [MASK] [MASK] [MASK], [MASK] [MASK] [MASK] paul, [MASK] [MASK] [MASK] [MASK]
Iteration 4: in the midst of [MASK] [MASK] [MASK] [MASK], [MASK] [MASK] [MASK] paul, [MASK] [MASK] [MASK] procession
Iteration 5: in the midst of [MASK] [MASK] [MASK] [MASK], pastor [MASK] [MASK] paul, [MASK] [MASK] [MASK] procession
Iteration 6: in the midst of [MASK] pope [MASK] [MASK], pastor [MASK] [MASK] paul, [MASK] [MASK] [MASK] procession
Iteration 7: in the midst of [MASK] pope [MASK] [MASK], pastor [MASK] [MASK] paul, planned [MASK] [MASK] procession
Iteration 8: in the midst of [MASK] pope [MASK] [MASK], pastor [MASK] [MASK] paul, planned personal [MASK] procession
Iteration 9: in the midst of [MASK] pope [MASK] [MASK], pastor peter