In [1]:

!git clone https://github.com/companys1234/Jumping_LLM_Flash.git

Cloning into 'Jumping_LLM_Flash'...
remote: Enumerating objects: 304, done.[K
remote: Counting objects: 100% (103/103), done.[K
remote: Compressing objects: 100% (100/100), done.[K
remote: Total 304 (delta 52), reused 3 (delta 3), pack-reused 201 (from 1)[K
Receiving objects: 100% (304/304), 120.02 KiB | 12.00 MiB/s, done.
Resolving deltas: 100% (122/122), done.


In [2]:

from Jumping_LLM_Flash.scr.preprocessing import BPE
from Jumping_LLM_Flash.scr.architecture import *
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
import math

In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2026-01-28 09:57:43--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2026-01-28 09:57:44 (120 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()


In [5]:
words = []
for tex in text:
    words.extend(tex.split())

In [6]:
tokenizer = BPE(100)
tokenizer.fit(words)


In [7]:
def get_batch(data, block_size, batch_size, device):
    """
    data: torch.Tensor (1D)
    block_size: int — длина контекста
    """
    ix = torch.randint(0, len(data) - block_size - 1, (batch_size,), device=device)

    x = torch.stack([data[i:i + block_size] for i in ix])  # (B, block_size)
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])  # (B, block_size)

    return x, y


device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
class Grok1(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_experts = 8
        self.embedding = nn.Embedding(260, 100)

        self.expert = nn.ModuleList([nn.Sequential(
            GQA(d_model=100, num_heads=4, num_kv_heads=None, pos_enc=RoPE),
            GQA(d_model=100, num_heads=4, num_kv_heads=None, pos_enc=RoPE),
            GQA(d_model=100, num_heads=4, num_kv_heads=None, pos_enc=RoPE)
        ) for _ in range(self.num_experts)])


        self.moe = MoE(self.expert, 100 * 32, 8)

        self.output = nn.Linear(100, 260)

    def forward(self, x):
        x_emb = self.embedding(x)


        out = self.moe(x_emb)

        logits = self.output(out)
        return logits, None
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, block_size=32):

        for _ in range(max_new_tokens):

            idx_cond = idx[:, -block_size:] if idx.size(1) > block_size else idx


            logits, _ = self(idx_cond)


            logits = logits[:, -1, :]


            probs = F.softmax(logits, dim=-1)


            next_token = torch.multinomial(probs, num_samples=1)

            idx = torch.cat([idx, next_token], dim=1)

        return idx

In [9]:

def train(model, data, optimizer, criterion, device, epochs=3, batch_size=32, block_size=32):
    model.train()
    model.to(device)

    for epoch in range(epochs):
        total_loss = 0
        num_batches = 0

        for i in range(0, len(data) - block_size - 1, batch_size):

            indices = torch.randint(0, len(data) - block_size - 1, (batch_size,))


            x = torch.stack([data[idx:idx+block_size] for idx in indices])
            y = torch.stack([data[idx+1:idx+block_size+1] for idx in indices])  # (B, block_size)

            x = x.to(device)
            y = y.to(device)


            optimizer.zero_grad()
            logits, _ = model(x)

            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(num_batches, 1)
        print(f"Epoch {epoch+1}: loss={avg_loss:.4f}")

In [10]:
text.lower()

tokens = torch.tensor(tokenizer.encode(text[:10000], return_ids=True), dtype=torch.long)

In [11]:
model = Grok1()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()
train(
    model=model,
    data=tokens,
    optimizer=optimizer,
    criterion=criterion,
    device = 'cpu'
)

Epoch 1: loss=3.0092
Epoch 2: loss=1.5306
Epoch 3: loss=0.9081


In [12]:

def chat(model, tokenizer, prompt, device, max_new_tokens):
    model.eval()
    input_ids = torch.tensor(tokenizer.encode(prompt, return_ids=True), dtype=torch.long, device=device)
    input_ids = input_ids.unsqueeze(0)

    with torch.no_grad():
        output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens)

    output_text = tokenizer.decode(output_ids[0].tolist(), from_ids=True)
    return output_text

In [15]:
vocab = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,!?:;'\n-&3$#@э!?*")

vocab_size = 69

prompt = "We are accounted poor citizens, the patricians good. What authority surfeits on would relieve us: if they would yield us but the superfluity, while it were wholesome, we might guess they relieved us humanely;"

response = chat(model, tokenizer, prompt, device='cpu', max_new_tokens=60)
print('response:',response)

response: We are accounted poor citizens, the patricians good. What authority surfeits on would relieve us: if they would yield us but the superfluity, while it were wholesome, we might guess they relieved us humanely;E,,?,,gi''cc::::::::i uuuuuu iMeUUee tllllecuu:dddb:bdd:d:d
