# Fast API Test

This is a test for my transformer API. It is a simple test to check if the API is working correctly. The training process will be **very** slow, but just a proof of concept.

In [1]:
import requests
import json
import re
from collections import Counter
SEQUENCE_LENGTH = 200
API_URL = "http://localhost:8080"

In [2]:
def load_corpus(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()

def tokenize(text):
    return list(text)  

def build_vocab(tokens, max_size=None):
    freq = Counter(tokens)
    most_common = freq.most_common(max_size)
    vocab = {tok: i+1 for i, (tok, _) in enumerate(most_common)}
    vocab['<unk>'] = 0
    inv_vocab = {i: tok for tok, i in vocab.items()}
    return vocab, inv_vocab

def encode(tokens, vocab):
    return [vocab.get(tok, vocab['<unk>']) for tok in tokens]


def make_sequences(ids, seq_len=SEQUENCE_LENGTH):
    inputs, targets = [], []
    for i in range(len(ids) - seq_len):
        inputs.append(ids[i:i+seq_len])
        targets.append(ids[i+1:i+1+seq_len])
    return inputs, targets

In [3]:
def create_model(vocab_size, d_model=128, num_heads=4, d_ff=512, num_layers=1, max_len=5000):
    payload = {
        "vocab_size": vocab_size,
        "d_model": d_model,
        "num_heads": num_heads,
        "d_ff": d_ff,
        "num_layers": num_layers,
        "max_len": max_len
    }
    resp = requests.post(f"{API_URL}/create", json=payload)
    resp.raise_for_status()
    print(resp.json())

def train_model(inputs, targets, epochs=1, lr=1e-3):
    batch_size = 64
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        for i in range(0, len(inputs), batch_size):
            batch_in = inputs[i:i+batch_size]
            batch_tg = targets[i:i+batch_size]
            payload = {
                "token_sequences": batch_in,
                "target_sequences": batch_tg,
                "epochs": 1,
                "learning_rate": lr
            }
            resp = requests.post(f"{API_URL}/train", json=payload)
            resp.raise_for_status()
        print(resp.json())


def autocomplete(prefix, vocab, inv_vocab, max_tokens=50):
    # encode prefix
    tokens = tokenize(prefix)
    ids = encode(tokens, vocab)

    for _ in range(max_tokens):
        resp = requests.post(f"{API_URL}/predict", json={"input_tokens": ids[-SEQUENCE_LENGTH:]})
        resp.raise_for_status()
        data = resp.json()
        next_id = data['predicted_token']
        next_tok = inv_vocab.get(next_id, '<unk>')
        tokens.append(next_tok)
        ids.append(next_id)
        if next_tok in ['.', '!', '?']:
            break
    return ' '.join(tokens)

In [None]:
# 1) Load and preprocess
text = load_corpus('input_short.txt')
toks = tokenize(text)
vocab, inv_vocab = build_vocab(toks)
print(f"Vocabulary size: {len(vocab)}")

# 2) Build sequences
ids = encode(toks, vocab)
inputs, targets = make_sequences(ids)
print(f"Total sequences: {len(inputs)}")

# 3) Initialize model
create_model(vocab_size=len(vocab), d_model=256, num_heads=8, d_ff=1024, num_layers=1)

# 4) Train
train_model(inputs, targets, epochs=1, lr=0.01)

# 5) Test autocomplete
prompt = "once upon a time"
completion = autocomplete(prompt, vocab, inv_vocab)
print("\nCompletion:", completion)
