In [2]:
import random
import numpy as np
from collections import Counter

In [3]:
data_file = "./txt/input.txt"

In [4]:
with open(data_file, "r") as f:
    text = f.read()

vocab = sorted(list(set(text)))
vocab_size = len(vocab)
vocab_size

65

In [5]:
total_length = len(list(text))
total_length

1115394

In [6]:
tokenizer = {
    "token2idx": {token: idx for idx, token in enumerate(vocab)},
    "idx2token": {idx: token for idx, token in enumerate(vocab)}
}

In [7]:
token_probs = np.zeros(vocab_size)
for key, value in Counter(list(text)).items():
    token_probs[tokenizer["token2idx"][key]] = value / total_length
token_probs.sum()

1.0

In [8]:
idx = np.random.choice(np.arange(vocab_size), p=token_probs)
idx

56

In [70]:
token_probs_tensor = torch.tensor(token_probs, dtype=torch.float32)
token_probs_tensor.sum()

tensor(1.)

In [71]:
# num_samples should be batch size
idx = torch.multinomial(token_probs_tensor, num_samples=1)
idx

tensor([56])

In [9]:
text_indices = [tokenizer["token2idx"][token] for token in list(text)]

In [10]:
context_length = 128

In [11]:
full_text_array = np.array(text_indices)
full_text_array.shape

(1115394,)

In [12]:
shift = random.randint(0, context_length)
text_array = full_text_array[shift:]
shift

124

In [66]:
chunked = text_array[:(text_array.shape[0] // context_length) * context_length]

In [67]:
chunked = torch.tensor(chunked).view(-1, context_length)

In [68]:
len(chunked)

8713

In [69]:
example = "".join([tokenizer["idx2token"][idx] for idx in chunked[0, :].tolist()])
example

' to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nA'

In [16]:
import torch
from transformers import LlamaConfig, LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification

In [17]:
generator_config = LlamaConfig(
    vocab_size=vocab_size,
    hidden_size=64,
    intermediate_size=256,
    num_hidden_layers=2,
    num_attention_heads=4,
    max_position_embeddings=128,
    bos_token_id=None,
    eos_token_id=None
)

In [18]:
generator = LlamaForCausalLM(generator_config)

In [19]:
num_parameters = sum(p.numel() for p in generator.parameters())
num_parameters

139712

In [20]:
token_probs_tensor = torch.tensor(token_probs, dtype=torch.float64)

In [74]:
inputs = torch.multinomial(token_probs_tensor, num_samples=3).view(-1, 1)
print(inputs.shape)
print(inputs)

torch.Size([3, 1])
tensor([[40],
        [ 1],
        [43]])


In [75]:
fake_generation = generator.generate(inputs, max_length=context_length, temperature=1.0, top_p=0.95, do_sample=True)

In [23]:
fake_text = "".join([tokenizer["idx2token"][idx] for idx in fake_generation[0, :].detach().cpu().tolist()])
fake_text

"gnPWo lOKU,NZOxWYiMX?zYGy&q ;edtAMZe& E'MJZ-ZTMVezK,YPUXM:kTvRW3OEW-LNRkn:UHapOnV.HH!GzxIA3tMGOgS?E'hg\ndbZ'N&rz HdR.tVDB\nJ'z'n,w"

In [32]:
discriminator_config = LlamaConfig(
    vocab_size=vocab_size,
    hidden_size=64,
    intermediate_size=256,
    num_hidden_layers=2,
    num_attention_heads=4,
    max_position_embeddings=128,
    bos_token_id=None,
    eos_token_id=None
)

In [33]:
discriminator = LlamaModel(discriminator_config)

In [34]:
num_parameters = sum(p.numel() for p in discriminator.parameters())
num_parameters

135552

In [42]:
classification = discriminator(
    input_ids=fake_generation
)

In [48]:
cin = classification.last_hidden_state
cin.size()

torch.Size([1, 128, 64])

In [45]:
out_layer = torch.nn.Linear(64, 1)
sigmoid_layer = torch.nn.Sigmoid()

In [52]:
c = out_layer(cin)
c = sigmoid_layer(c)
c.size()

torch.Size([1, 128, 1])

In [51]:
loss = torch.nn.functional.binary_cross_entropy(c.squeeze(), torch.zeros(context_length))
loss

tensor(0.6818, grad_fn=<BinaryCrossEntropyBackward0>)

In [None]:
# real loss for real generation from batch and torch.ones(1)
# RL loss for generator
# discriminator should have a prediction for each token, not just at the end?
## use base llama model and put a sigmoid head on each output?
## try both ways and compare

In [None]:
def train_step(batch):
    input_real = batch