In [1]:
import numpy as np


# data I/O
data = open('input.txt', 'r').read() # should be simple plain text file
chars = list(set(data))
data_size, vocab_size = len(data), len(chars)
print('data has %d characters, %d unique.' % (data_size, vocab_size))
char_to_ix = { ch:i for i,ch in enumerate(chars) }
ix_to_char = { i:ch for i,ch in enumerate(chars) }

data has 1115393 characters, 65 unique.


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


class CharLSTMOneHot(nn.Module):
    """
    A LSTM that takes one-hot inputs of shape (batch, seq_length, vocab_size).
    No embedding layer is used.
    """
    def __init__(self, vocab_size, hidden_size=512):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
        # LSTM expects input_size = vocab_size (since one-hot dimension)
        # batch_first=True => input shape (batch, seq_length, vocab_size)
        self.lstm = nn.LSTM(input_size=vocab_size, hidden_size=hidden_size,
                            num_layers=2, batch_first=True)
        
        # Map hidden states to vocab-size logits
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden=None, cell=None):
        """
        x: (batch, seq_length, vocab_size) => one-hot vectors
        hidden: (1, batch, hidden_size) initial hidden state
        cell: (1, batch, hidden_size) initial cell state
        returns: (logits, hidden)
          logits shape: (batch, seq_length, vocab_size)
        """
        # Pass x directly to LSTM (no embedding)
        # output => (batch, seq_length, hidden_size)
        # hidden => (2, batch, hidden_size)
        # cell => (2, batch, hidden_size)
        output, (hidden, cell) = self.lstm(x, (hidden, cell))
        
        # Map to vocab logits
        # shape => (batch, seq_length, vocab_size)
        logits = self.fc(output)
        
        return logits, (hidden, cell)
    
    def init_hidden(self, batch_size=1):
        """ Return a fresh hidden state of shape (2, batch_size, hidden_size). """
        return torch.zeros(2, batch_size, self.hidden_size)

In [3]:
def indices_to_onehot(indices, vocab_size):
    """
    indices: shape (batch, seq_length), dtype long
    returns: float tensor of shape (batch, seq_length, vocab_size)
    """
    # F.one_hot => (batch, seq_length, vocab_size) with dtype=int
    one_hot = F.one_hot(indices, num_classes=vocab_size)
    return one_hot.float()

In [4]:
def sampleLSTM(model, hidden, cell, start_ix, length, vocab_size, ix_to_char):
    """
    model: Your CharGRUOneHot instance
    hidden: initial hidden state (shape [1, batch=1, hidden_size])
    cell: initial cell state (shape [1, batch=1, hidden_size])
    start_ix: integer index of the first character
    length: how many characters to generate
    vocab_size: total number of possible chars
    ix_to_char: mapping int -> character (for final text)
    returns: a string of generated characters
    """

    # We'll store the generated characters' indices
    generated_indices = []
    # current input char index
    char_ix = start_ix

    # We do a loop for 'length' steps
    for _ in range(length):
        # Build a one-hot input of shape (batch=1, seq_len=1, vocab_size)
        x_onehot = torch.zeros((1, 1, vocab_size))
        x_onehot[0, 0, char_ix] = 1.0

        # Forward pass (no grad)
        with torch.no_grad():
            logits, (hidden, cell) = model(x_onehot, hidden, cell)
            # logits shape => (1, 1, vocab_size)

        # Take the last time-step => shape (vocab_size,)
        probs = F.softmax(logits[0, 0], dim=-1).cpu().numpy()
        # Sample from the probability distribution
        char_ix = np.random.choice(vocab_size, p=probs)
        
        generated_indices.append(char_ix)

    # Convert all indices to characters
    generated_text = ''.join(ix_to_char[ix] for ix in generated_indices)
    return generated_text

In [None]:
# Suppose you have:
# data_indices: a long list of integers in [0, vocab_size-1]
# vocab_size: total number of unique chars
# idx_to_char, char_to_idx: optional for sampling

model = CharLSTMOneHot(vocab_size, hidden_size=512)
criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=1e-3)

seq_length = 25
pointer = 0
n = 0
max_iters = 44700 * 100 # 100 times pass the full text 
smooth_loss = 0.0

# init hidden state for single-batch
hidden = model.init_hidden(batch_size=1)
cell = model.init_hidden(batch_size=1)

while True:

    # if we near end of data, wrap around
    if pointer + seq_length + 1 >= len(data):
        pointer = 0
        hidden = model.init_hidden(batch_size=1)
        cell = model.init_hidden(batch_size=1)
    
    # Grab chunk of length seq_length
    input_seq = [char_to_ix[ch] for ch in data[pointer:pointer+seq_length]]
    target_seq = [char_to_ix[ch] for ch in data[pointer+1:pointer+seq_length + 1]]

    # Convert to tensors, shape => (1, seq_length)
    inputs_t = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
    targets_t = torch.tensor(target_seq, dtype=torch.long).unsqueeze(0)

    # Convert input_seq to one-hot => shape (1, seq_length, vocab_size)
    x_onehot = indices_to_onehot(inputs_t, vocab_size=vocab_size)
    
    optimizer.zero_grad()
    # forward pass
    logits, (hidden, cell) = model(x_onehot, hidden, cell)
    
    # shape => (1, seq_length, vocab_size)
    # Flatten so cross-entropy can be computed:
    #  => logits: (1*seq_length, vocab_size), targets: (1*seq_length)
    loss = criterion(logits.view(-1, vocab_size), targets_t.view(-1))
    
    
    loss.backward()
    optimizer.step()

    hidden = hidden.detach()
    cell = cell.detach()

     # smooth loss
    smooth_loss = 0.99 * smooth_loss + 0.01 * loss.item()
    
    if n % 100 == 0:
        print(f"Iter {n}, smooth_loss={smooth_loss:.4f}")
        # Optional: sample from model...
        print(sampleLSTM(model, hidden, cell, input_seq[0], 200, vocab_size, ix_to_char))
    
    pointer += seq_length
    n += 1
    
    if n == max_iters: 
        break

Iter 0, smooth_loss=1.0426
m,rmdF&Pg!AR;?fJ?wazXtMFnn$&KZzmdF!kj3&;BByxVtYtFtGg;T3,fF
Qi3HVrQ$ULE:eP
Buv
V$S'zwWfy'KwT!VKLejEgg?kpKHhpoGHDGEo!;xwkhQ.Dgjwuj:KawjLnvkF3'?P'jW ag;dJKiXq?geTmavW,;VQFi rEvE;x$c&W'CSVrUphuEtJ?NXHijo,
Iter 100, smooth_loss=51.8650
yeii 'h:hl, s h
zm
nrt  nh as.Wn
hesEcs  z
 :mao e snhO;npsatoa hetn tsaIL nhmn:p cTss ka,h 
t
m.d o ehvtohyon! i mhn 
bsvfW weCoue'  so h,.iy
eC! 

t eaohn ynaluvtCY'Fslt? igcano.mw e

e t r oco r 
 
Iter 200, smooth_loss=68.7032
 lthd td iuem, EtsnrSjnanru ,eiTust, l aoe iowukhi:Ht:hmbicn !ortinWbm, vioytshy ouulr ar h
aTuoe,tuyln, 
hret
uowtob
.lo

rrn Itue oli tCis nol ati
g-,
ydcIsl nhi mb,lvertsl neep
oo yit wedr gee ipir
Iter 300, smooth_loss=70.1024
s sfEspSye puwmmer Doks fosit:r.
Le
Taes
ghok thh, foame whit bho belea wimgwbes ,iri? wasg, alaoo, wacs bus triae man? anet houe ar cirsr WofrNI
 hinc toug rau' sosb bous,lo Foonns
lhlo sgobtap
chat 
Iter 400, smooth_loss=67.4858
nno nsec thiiN- ghes
Wheig bosd the rse toogtWane

In [14]:
print(sampleLSTM(model, hidden, cell, 1, 400, vocab_size, ix_to_char))

--for me? But, for the prince! we were worse o' the isle,
And have fled
The shepheals of her respecting no
Will they will undo.

VALUSIET:
Affordsing them;
And crap up the number and a noble countryment,
My cow I am when I rather stay.

MARCIUS:
We are:' gentle Mariana?

MARCIUS:
Consura, the oranoubs have none opposed
The sake is reapped with accuspt out.

Second Peter:
When 'tis a cever you hald
