In [1]:
!pip install -q datasets==2.18.0 tqdm

In [2]:
!pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121  # or cu118 / cpu
!pip install datasets>=2.18 transformers>=4.40

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu121


In [3]:
from datasets import load_dataset, Dataset
from tqdm.auto import tqdm

In [4]:
from datasets import load_dataset, Dataset, DatasetDict
from tqdm.auto import tqdm
import random, os, shutil

SEED          = 42          # change for a different shuffle
TEST_FRAC     = 0.3         # 10 % of threads → test
DATASET_NAME  = "rickRossie/bluemoon_roleplay_chat_data_300k_messages"

random.seed(SEED)

In [5]:
from datetime import datetime
from random import randint

# ------------------------------------------------------------------
# 1 · Stream & bucket messages per thread

In [6]:
raw = load_dataset(DATASET_NAME, split="train", streaming=True)

threads = {}
for row in tqdm(raw, desc="Grouping"):
    txt = row["message"]
    if not txt:
        continue
    threads.setdefault(row["thread_title"], []).append(
        (row["message_timestamp"], txt)
    )


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Grouping: 0it [00:00, ?it/s]

In [7]:
# 1. Basic size stats
print("Total threads:", len(threads))
lengths = [len(msgs) for msgs in threads.values()]
print(f"Min thread length: {min(lengths)}")
print(f"Max thread length: {max(lengths)}")
print(f"Avg thread length: {sum(lengths)/len(lengths):.2f}")

# 2. Sample a thread and inspect
sample_title = random.choice(list(threads.keys()))
sample_msgs = sorted(threads[sample_title], key=lambda x: x[0])

print(f"\n▶ Thread title: {sample_title}")
print(f"Number of messages: {len(sample_msgs)}\n")
for ts, msg in sample_msgs[:5]:
    print(msg[:100].replace("\n", " "), "…")

Total threads: 3637
Min thread length: 1
Max thread length: 7609
Avg thread length: 71.51

▶ Thread title: The Force Awakens: Kylo x Rey (newblood & malfrost)
Number of messages: 9

Rey licked her lips, which were becoming dry from nervousness. The dark haired woman's eyes narrowed …
Ren could feel the girl's willpower and resolve weakening as he drew her closer. He knew, that with  …
The last several hours had changed her life completely.Upon momentary contemplation, she could not p …
She was stronger than he originally guessed. He had felt the budding power of the force in her, but  …
Rey's breathing was labored, and as he approached her she let go of a breath that she hadn't been aw …


In [8]:
threads = {k: v for k, v in threads.items() if len(v) >= 4}

print("Threads after filtering:", len(threads))

Threads after filtering: 2845


In [9]:
lengths = [len(msgs) for msgs in threads.values()]
print(f"Min thread length: {min(lengths)}")
print(f"Max thread length: {max(lengths)}")
print(f"Avg thread length: {sum(lengths)/len(lengths):.2f}")


Min thread length: 4
Max thread length: 7609
Avg thread length: 90.95


# ------------------------------------------------------------------
# 2 · Train / test split at the thread level

In [10]:
titles = list(threads.keys())
random.shuffle(titles)

cut = int(len(titles) * (1 - TEST_FRAC))
train_titles, test_titles = set(titles[:cut]), set(titles[cut:])

def make_pairs(title_set):
    for tt in title_set:
        turns = sorted(threads[tt], key=lambda x: x[0])
        msgs  = [m for _, m in turns]
        for i in range(len(msgs) - 3):
            yield {
                "input_text":  "\n\n".join(msgs[i:i+3]),
                "target_text": msgs[i+3],
            }

train_pairs = Dataset.from_generator(lambda: make_pairs(train_titles))
test_pairs  = Dataset.from_generator(lambda: make_pairs(test_titles))

In [11]:
assert train_titles.isdisjoint(test_titles), "Thread titles leak between train and test"
print("✅ Train/test split is disjoint.")


✅ Train/test split is disjoint.


In [12]:
# Pick a thread from train
some_train_title = random.choice(list(train_titles))
thread = threads[some_train_title]

# Show timestamps before sorting
print(f"\n▶ Raw order of thread: {some_train_title}")
for ts, msg in thread[:5]:
    print(ts, "→", msg[:60].replace("\n", " "))

# Date format used in your thread: "Aug 10, 2016 at 5:28 AM"
fmt = "%b %d, %Y at %I:%M %p"

# Sort using parsed datetime
sorted_turns = sorted(thread, key=lambda x: datetime.strptime(x[0], fmt))

print(f"\n▶ After sorting:")
for ts, msg in sorted_turns[:5]:
    dt = datetime.strptime(ts, fmt)
    print(dt, "→", msg[:60].replace('\n',' '))



▶ Raw order of thread: No Such Thing As Freak (ThomasRHellsing X Shards~Of~Roses
Feb 4, 2013 at 6:19 PM → Sora walked through the street of the city, it always seemed
Feb 4, 2013 at 7:29 PM → Thomas smirked as he watched. A voice calling out "Not bad g
Feb 4, 2013 at 7:53 PM → Sora perked up when she heard a new voice, hearing what he s
Feb 4, 2013 at 8:19 PM → He smirked, "Meh, most capes are trained to do things non le
Feb 4, 2013 at 8:44 PM → Sora looked towards him as she listened to what he had to sa

▶ After sorting:
2013-02-04 18:19:00 → Sora walked through the street of the city, it always seemed
2013-02-04 19:29:00 → Thomas smirked as he watched. A voice calling out "Not bad g
2013-02-04 19:53:00 → Sora perked up when she heard a new voice, hearing what he s
2013-02-04 20:19:00 → He smirked, "Meh, most capes are trained to do things non le
2013-02-04 20:44:00 → Sora looked towards him as she listened to what he had to sa


In [13]:
# Sample a thread title
title = random.choice(list(train_titles))
turns = threads[title]

# Sort and extract messages only
msgs = [m for _, m in sorted(turns, key=lambda x: x[0])]

# Display one example
if len(msgs) >= 4:
    i = random.randint(0, len(msgs) - 4)
    input_text  = "\n\n".join(msgs[i:i+3])
    target_text = msgs[i+3]

    print("▶ Thread title:", title)
    print("\nInput (previous 3 messages):\n")
    print(input_text)
    print("\nTarget (next message):\n")
    print(target_text)
else:
    print("Thread too short:", title)


▶ Thread title: A New Type Of Vulcan (Me & apollo247)

Input (previous 3 messages):

Khan looked back at her and grinned, he was still confused with everything that was going on around him and yet he wanted to move things along. He nodded to her when she told him about his crew and was glad that they were still safe, he would soon find a way to bring them all to the land of the living and wake them up without anyone noticing what he truly was up to. He watched her for a moment as they entered the guest quarters, he looked around as he wondered what he should do first, though plans were already forming in his mind about how he should proceed.He should have known that she was half breed, that pathetic Vulcan mind trick always finds its way to foul his plans, however it would seem that she only got half of what he was planning and even then it wasn t enough to actually inform her on what he was doing, just a name and a feeling that he wasn t being truthful and he could work with that very

In [14]:
# Sample a row from train_pairs
i = randint(0, len(train_pairs) - 1)
sample = train_pairs[i]

print("Input text:")
print(sample["input_text"])
print("\nTarget text:")
print(sample["target_text"])

Input text:
She gently smiled, softly purring as he slid his delightful cock out of her pussy. She happily kissed him back, and she spoke sweetly,"Exactly what I wanted to hear. Maybe tomorrow after all classes and activities are done?" She wanted to do this many more times. She was rather surprised by her own behavior, actively seeking out sex like this, willingly lose her own virginity... She had never been so bold before... But now, she decided she was going to only have sex with Draco.

"Right? It's amazing, isn't it?" Draco asked with a smile as he slowly slid his cock out of her, groaning as it popped out of her and he kissed her hotly on the lips. "We can go as often as you like." He told her with a smile.

She was melting in his arm. It was pure artwork to look at the way his muscles and veins moved. She sweetly spoke,"Lets do our next round tomorrow. Won't be awake for lessons if we go at it all night... Appealing as that sounds!" The way his cock was nestled in her backside w

# ------------------------------------------------------------------
# 3 · Save to disk

In [15]:
for folder in ["bluemoon_train_ds", "bluemoon_test_ds"]:
    if os.path.exists(folder):
        shutil.rmtree(folder)

train_pairs.save_to_disk("bluemoon_train_ds")
test_pairs.save_to_disk("bluemoon_test_ds")

print("✅ wrote", len(train_pairs), "train examples and",
      len(test_pairs), "test examples.")

Saving the dataset (0/2 shards):   0%|          | 0/176828 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/73396 [00:00<?, ? examples/s]

✅ wrote 176828 train examples and 73396 test examples.


# ------------------------------------------------------------------
# 4 · Tokenization

In [16]:
# ─── Tokenization Cell ─────────────────────────────────────────────────────
from datasets import load_from_disk
from transformers import AutoTokenizer
from tqdm.auto import tqdm
import os

# 1) Load GPT-2 tokenizer and ensure it has a pad token
tok = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# 2) Define encoding function
def encode(batch):
    # Tokenize context and target separately
    enc = tok(batch["input_text"],
              truncation=True,
              padding="longest",
              max_length=2048)
    lab = tok(batch["target_text"],
              truncation=True,
              padding="longest",
              max_length=512)
    return {
        "input_ids":      enc["input_ids"],
        "attention_mask": enc["attention_mask"],
        "labels":         lab["input_ids"],
    }

# 3) Process and save both splits
for split in ["train", "test"]:
    ds = load_from_disk(f"bluemoon_{split}_ds")
    ds_tok = ds.map(
        encode,
        batched=True,
        remove_columns=ds.column_names,
        num_proc=8,
        desc=f"Tokenizing {split}"
    )
    out_dir = f"bluemoon_{split}_tok_ds"
    if os.path.exists(out_dir):
        os.makedirs(out_dir, exist_ok=True)
    ds_tok.save_to_disk(out_dir)


Tokenizing train (num_proc=8):   0%|          | 0/176828 [00:00<?, ? examples/s]

Saving the dataset (0/5 shards):   0%|          | 0/176828 [00:00<?, ? examples/s]

Tokenizing test (num_proc=8):   0%|          | 0/73396 [00:00<?, ? examples/s]

Saving the dataset (0/3 shards):   0%|          | 0/73396 [00:00<?, ? examples/s]

# ------------------------------------------------------------------
# 5 · train mini deepseek

In [62]:
import importlib, mini_deepseek
importlib.reload(mini_deepseek)
from mini_deepseek import RolePlayTransformer

In [None]:
DATA_DIR = "bluemoon_train_tok_ds"   # ← change if needed
CKPT_DIR = "mini_ds_ckpts"           # checkpoints go here
#tokenizer_name = "deepseek-ai/DeepSeek-V3"
tokenizer_name = "gpt2"

# 2 ▸ import your model class
from mini_deepseek import RolePlayTransformer
from transformers import AutoTokenizer
import torch, os, math

tok   = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
tok.pad_token = tok.eos_token
model = RolePlayTransformer(
    vocab=len(tok),
    gradient_checkpointing=True,
)          # keep defaults or tweak dims

os.makedirs(CKPT_DIR, exist_ok=True)                 # make sure folder exists

# 5 ▸ train (prints loss every 100 steps, saves each epoch)
model.train_model(
    dataset_path = DATA_DIR,         # ← tokenised Arrow dataset folder
    tokenizer_name = "gpt2",
    epochs = 5,
    micro_batch = 4,
    grad_accum   = 8,   # effective 32
    lr = 1.5e-4,
    warmup_updates = 1000,
    device = "cuda" if torch.cuda.is_available() else "cpu",
    save_path = CKPT_DIR,
)


  return fn(*args, **kwargs)


epoch 1 | update 100 | loss 19.7773
epoch 1 | update 200 | loss 11.0888
epoch 1 | update 300 | loss 11.9902
epoch 1 | update 400 | loss 14.0733
epoch 1 | update 500 | loss 4.6649
epoch 1 | update 600 | loss 15.0917
epoch 1 | update 700 | loss 5.1662
epoch 1 | update 800 | loss 2.6873
epoch 1 | update 900 | loss 6.1856
epoch 1 | update 1000 | loss 10.1230
epoch 1 | update 1100 | loss 3.7102
epoch 1 | update 1200 | loss 3.1111
epoch 1 | update 1300 | loss 2.3632
epoch 1 | update 1400 | loss 1.6231
epoch 1 | update 1500 | loss 4.7958
epoch 1 | update 1600 | loss 1.7974
epoch 1 | update 1700 | loss 2.2119
epoch 1 | update 1800 | loss 2.8110
epoch 1 | update 1900 | loss 2.7682
epoch 1 | update 2000 | loss 4.1491
epoch 1 | update 2100 | loss 4.7361
epoch 1 | update 2200 | loss 3.5256
epoch 1 | update 2300 | loss 1.4579
epoch 1 | update 2400 | loss 1.6036
epoch 1 | update 2500 | loss 1.6332
epoch 1 | update 2600 | loss 3.2005
epoch 1 | update 2700 | loss 2.9208
epoch 1 | update 2800 | loss 2.

# ------------------------------------------------------------------
# 6 · Perplexity of mini deepseek


In [63]:
import torch

from mini_deepseek import RolePlayTransformer
#from markov_baseline import MarkovChain


In [73]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [64]:
import gc

# 1. Delete large objects if they exist
for var in ['model', 'optimizer', 'loader', 'ds']:
    if var in globals():
        try:
            del globals()[var]
        except:
            pass

# 2. Run garbage collector
gc.collect()

# 3. Clear PyTorch CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [65]:
DATA_DIR = "bluemoon_test_tok_ds"
device   = "cuda" if torch.cuda.is_available() else "cpu"

model = RolePlayTransformer(vocab=len(tok), gradient_checkpointing=True).to(device)
model.load_state_dict(torch.load(f"ckpt_ep2.pt", map_location=device, weights_only=False))
model.eval()

RolePlayTransformer(
  (embed): Embedding(50257, 768)
  (layers): ModuleList(
    (0-4): 5 x TransformerBlock(
      (attn_norm): RMSNorm()
      (attn): MLA(
        (kv_proj): Linear(in_features=768, out_features=1536, bias=False)
        (q_proj): Linear(in_features=768, out_features=768, bias=False)
        (o_proj): Linear(in_features=768, out_features=768, bias=False)
        (drop): Dropout(p=0.1, inplace=False)
      )
      (ffn_norm): RMSNorm()
      (ffn): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=False)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.1, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=False)
          (4): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (5-9): 5 x TransformerBlock(
      (attn_norm): RMSNorm()
      (attn): MLA(
        (kv_proj): Linear(in_features=768, out_features=1536, bias=False)
        (q_proj): Linear(in_features=76

In [20]:
ppl_transformer = model.evaluate_perplexity(
    dataset_path=DATA_DIR,
    tokenizer=tok,
    device=device,
    batch_size=8
)
print(f"Transformer perplexity: {ppl_transformer:.2f}")


[  100/ 9175 batches] Tokens 408,800 • AvgCE 2.0739 • PPL ≈ 7.96
[  200/ 9175 batches] Tokens 817,600 • AvgCE 2.3483 • PPL ≈ 10.47
[  300/ 9175 batches] Tokens 1,226,400 • AvgCE 1.9770 • PPL ≈ 7.22
[  400/ 9175 batches] Tokens 1,635,200 • AvgCE 2.0344 • PPL ≈ 7.65
[  500/ 9175 batches] Tokens 2,044,000 • AvgCE 2.1511 • PPL ≈ 8.59
[  600/ 9175 batches] Tokens 2,452,800 • AvgCE 2.0710 • PPL ≈ 7.93
[  700/ 9175 batches] Tokens 2,861,600 • AvgCE 2.0026 • PPL ≈ 7.41
[  800/ 9175 batches] Tokens 3,270,400 • AvgCE 2.0600 • PPL ≈ 7.85
[  900/ 9175 batches] Tokens 3,679,200 • AvgCE 2.0244 • PPL ≈ 7.57
[ 1000/ 9175 batches] Tokens 4,088,000 • AvgCE 1.9910 • PPL ≈ 7.32
[ 1100/ 9175 batches] Tokens 4,496,800 • AvgCE 1.9078 • PPL ≈ 6.74
[ 1200/ 9175 batches] Tokens 4,905,600 • AvgCE 1.9754 • PPL ≈ 7.21
[ 1300/ 9175 batches] Tokens 5,314,400 • AvgCE 2.0253 • PPL ≈ 7.58
[ 1400/ 9175 batches] Tokens 5,723,200 • AvgCE 2.0726 • PPL ≈ 7.95
[ 1500/ 9175 batches] Tokens 6,132,000 • AvgCE 2.0861 • PPL ≈ 8.0

In [74]:
# 1) Simple greeting exchange
prev_msgs = [
    "Hi there!",
    "Hello! How can I help you today?",
    "Can you tell me a joke?"
]

reply = model.generate_chat(prev_msgs, max_new_tokens=30, temperature=0.8, top_k = 50)
print(reply)


Hi there!
Hello! How can I help you today?
Can you tell me a joke?insured Voyager corpses shifted onto Kuya BuchfullyMuslimschievousadminist learner creativitycription 264 rune rune ringing ringing ringing ringing ringing ringing ringing ringing ringing ringing


In [46]:
# 2) Trivia question
prev_msgs = [
    "What's the tallest mountain in the world?",
    "Mount Everest is the tallest mountain above sea level.",
    "And which mountain is the tallest overall, including base below sea level?"
]
reply = model.generate_chat(prev_msgs, max_new_tokens=50, temperature=0.7)
print(reply)


What's the tallest mountain in the world?
Mount Everest is the tallest mountain above sea level.
And which mountain is the tallest overall, including base below sea level?insuredHAHAHAHAjson Kuya BuchfullyMuslimschievousadminist learner Satuxe19921992199219921992199219921992199219921992199219921992199219921992199219921992199219921992199219921992199219921992199219921992199219921992


In [47]:
# 3) Creative story prompt
prev_msgs = [
    "User: Once upon a time, a curious cat found a mysterious key.",
    "Assistant: The cat examined the key closely—it glowed with a faint blue light.",
    "User: What happened next?"
]
reply = model.generate_chat(prev_msgs, max_new_tokens=60, temperature=1.0)
print(reply)


User: Once upon a time, a curious cat found a mysterious key.
Assistant: The cat examined the key closely—it glowed with a faint blue light.
User: What happened next?insured Voyager isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate tableMuslimschievous smileMuslimschievous


In [30]:
batch_prompts = [
    ["User: Hello", "Assistant: Hi! What's up?", "User: Tell me a fun fact."],
    ["User: Define recursion.", "Assistant: Recursion is ...", "User: Can you give me a simple example?"],
    ["User: Who won the World Cup in 2018?", "Assistant: France did.", "User: And in 2022?"]
]

for prev in batch_prompts:
    print("\n")
    print(">>> Prompt:", prev[-1])
    print(model.generate_chat(prev, max_new_tokens=30, temperature=0.8))
    print()




>>> Prompt: User: Tell me a fun fact.
User: Hello

Assistant: Hi! What's up?

User: Tell me a fun fact.insured Voyager isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate



>>> Prompt: User: Can you give me a simple example?
User: Define recursion.

Assistant: Recursion is ...

User: Can you give me a simple example?insured Voyager isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate



>>> Prompt: User: And in 2022?
User: Who won the World Cup in 2018?

Assistant: France did.

User: And in 2022?insured Voyager isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate isolate i

In [None]:
mc = MarkovChain(n=2)
mc.train_dataset(DATA_DIR, text_field="input_text")

ppl_markov = mc.perplexity_dataset(DATA_DIR, text_field="input_text")
print(f"Markov baseline perplexity: {ppl_markov:.2f}")

In [None]:
prev_msgs = [
    "User: Hi there!",
    "Assistant: Hello! How can I help you today?",
    "User: Can you tell me a joke?"
]
reply = mc.generate_chat(prev_msgs)
print(reply)

In [None]:
prev_msgs = [
    "User: Once upon a time, a curious cat found a mysterious key.",
    "Assistant: The cat examined the key closely—it glowed with a faint blue light.",
    "User: What happened next?"
]
reply = mc.generate_chat(prev_msgs)
print(reply)

In [None]:
prev_msgs = [
    "User: What's the tallest mountain in the world?",
    "Assistant: Mount Everest is the tallest mountain above sea level.",
    "User: And which mountain is the tallest overall, including base below sea level?"
]
reply = mc.generate_chat(prev_msgs)
print(reply)

In [None]:
batch_prompts = [
    ["User: Hello", "Assistant: Hi! What's up?", "User: Tell me a fun fact."],
    ["User: Define recursion.", "Assistant: Recursion is ...", "User: Can you give me a simple example?"],
    ["User: Who won the World Cup in 2018?", "Assistant: France did.", "User: And in 2022?"]
]

for prev in batch_prompts:
    print(">>> Prompt:", prev[-1])
    print(mc.generate_chat(prev_msgs))
    print()
