# Generation Notebook 

In [None]:
# ==== Import ====
import sys, os, json, torch, tiktoken, textwrap

# Add Repo Root
sys.path.append(os.path.abspath(".."))

from models.soloGPT_v1_model import SoloGPT_v1

In [2]:
# ==== Load Config ====
with open("../config/soloGPT_v1_config.json", "r") as f:
    config = json.load(f)

In [3]:
# ==== Set Device ====
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [4]:
# ==== Load tokenizer ====
tokenizer = tiktoken.get_encoding("gpt2")

In [5]:
# ==== Load Model ====
model = SoloGPT_v1(config).to(device)
model.load_state_dict(torch.load('../outputs/pytorch_model.bin', map_location=device))
model.eval()

SoloGPT_v1(
  (input_embed): Embedding(50257, 1024)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.15, inplace=False)
  )
  (decoder): Custom_Decoder(
    (layers): ModuleList(
      (0-7): 8 x Custom_DecoderLayer(
        (mha): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
        )
        (dropout1): Dropout(p=0.15, inplace=False)
        (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (ff): FeedForward(
          (linear_layer1): Linear(in_features=1024, out_features=4096, bias=True)
          (linear_layer2): Linear(in_features=4096, out_features=1024, bias=True)
          (dropout): Dropout(p=0.15, inplace=False)
        )
        (dropout2): Dropout(p=0.15, inplace=False)
        (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (output_layer): Linear(in_features=10

In [6]:
def generate(prompt, max_new_tokens=100, temperature=1.0, top_k=40):
    input_ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(device)
    generated = input_ids

    for _ in range(max_new_tokens):
        logits = model(generated)
        logits = logits[:, -1, :]  # Only the last token's logits

        # Top-k sampling
        top_k_logits, top_k_indices = torch.topk(logits, k=top_k, dim=-1)
        probs = torch.softmax(top_k_logits / temperature, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        next_token = top_k_indices.gather(-1, next_token)
        generated = torch.cat([generated, next_token], dim=1)

    return tokenizer.decode(generated[0].tolist())

In [7]:
prompt = "In a future world,"

In [None]:
# ==== Generate ====
output = generate(prompt=prompt, max_new_tokens=100, temperature=1.0, top_k=40)
print(textwrap.fill(output, width=80))  # wraps at 80 characters per line

In a future world, he'll be able to use the "gauntlet of cards" – that is, one
with a more powerful ability to cast spells that are more powerful than cards
you might ever want to use. The first way is to get rid of the "slimpse" that's
just coming around. In the process, someone might be able to build a monster
that can summon one of those minions in another.  The first way to get rid of it
is to get rid of all the things you
