### set seed 

In [10]:
import random
import numpy as np
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [11]:
from src.model import *
from src.tokenizer import *
from src.utils import *

In [12]:
@torch.no_grad()
def generate_text(model, config, start_text, encoder, max_new_tokens=100, temperature=1.0, top_k=None):
    seed = random.randint(0, 10_000)
    set_seed(seed)
    model.eval()
    device = next(model.parameters()).device

    # Encode the input text
    encoded = encoder.encoder(start_text)
    x = torch.tensor(encoded, dtype=torch.long)[None, :].to(device)

    for _ in range(max_new_tokens):
        x_cond = x if x.size(1) <= config['block_size'] else x[:, -config['block_size']:]

        logits, _ = model(x_cond)
        logits = logits[:, -1, :] / temperature

        if top_k is not None:
            top_logits, top_indices = torch.topk(logits, top_k)
            logits = torch.full_like(logits, float('-inf'))
            logits.scatter_(1, top_indices, top_logits)

        probs = torch.nn.functional.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        x = torch.cat((x, next_token), dim=1)

    # Decode tokens to text
    out = x[0].tolist()
    return encoder.decoder(out)

def clean_repetition(text):
    # Clean up repeated "Nobel Prize in <field>"
    text = re.sub(r'(the nobel prize in \w+)( \1)+', r'\1', text, flags=re.IGNORECASE)

    # Clean other repetitive structures, e.g., "the the", "was was"
    text = re.sub(r'\b(\w+)\s+\1\b', r'\1', text)

    # Remove excessive repetition of "in <year>" in close succession
    text = re.sub(r'(\d{4},\s*)\1+', r'\1', text)

    # Clean up the format if there are multiple "Nobel Prize" mentions in different sections
    text = re.sub(r'(the nobel prize in \w+ ){2,}', r'\1', text, flags=re.IGNORECASE)

    return text


In [13]:
config = {
 'n_layer': 8,
 'n_head': 16,
 'n_embd': 512,
 'vocab_size': 50257,
 'block_size': 128,
 'embd_pdrop': 0.1,
 'resid_pdrop': 0.1,
 'attn_pdrop': 0.1,
 'device': 'cpu',
 'num_workers': 3,
 'max_iters': None,
 'batch_size': 64,
 'learning_rate': 0.0003,
 'betas': (0.9, 0.95),
 'weight_decay': 0.1,
 'grad_norm_clip': 1.0
}


In [14]:
model = GPT(config)

number of parameters: 51.02M


In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('./saved_models/model_shakespeare_new_v5_latest.pth', map_location=device))

<All keys matched successfully>

In [16]:
model.to('cpu')
print("model loaded")

model loaded


In [17]:
import pickle

# Load the pickle file
with open('./saved_models/encoder_shakespeare_v5.pkl', 'rb') as f:
    bpe = pickle.load(f)

### Basic Prompts

In [18]:
def print_ot_for_prompt(input_):
    output = generate_text(model, config, input_, bpe, max_new_tokens=100, temperature=0.7, top_k=50)

    output = clean_repetition(output)

    print("input: ", input_)
    print("output: ", output+"\n")
    print("----------------------------------------------------------------------------------")

prompt_lis = [
   "first citizen",
    "hermione:",
    "menenius"
]

for prompt in prompt_lis:
    print_ot_for_prompt(prompt)

input:  first citizen
output:  first citizen:
he cannot help the joy of others, proud disdain,
unless the loving welshmen can clear
ne'gainst the strong suspicion.

second murderer:
no, first let's reason with him.

clarence:
where art thou, keeper? give me a cup of wine.

second murderer:
you shall have wine enough, my lord, anon.

clarence:
in god's name, what art thou?

----------------------------------------------------------------------------------
input:  hermione:
output:  hermione:
nay, but you will?

polixenes:
i may not, verily.

hermione:
verily!
you put me off with limber vows; but i,
though you would seek to unsphere the
stars with oaths,
should yet say 'sir, no going.' verily,
you shall not go: a lady's 'verily' 's
as potent as a lord's. will you

----------------------------------------------------------------------------------
input:  menenius
output:  menenius:
why, what of that?

first citizen:
the former agents, if they did complain,
what could the belly answer?

me