In [1]:
import logging
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

In [2]:
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)

logger.info('Test log message')

In [3]:
def preprocess_data(txt_data_path: str) -> list[int]:
    with open(txt_data_path, 'r') as f:
        txt_data = f.read()

    chars = list(set(txt_data))

    num_chars = len(chars)
    txt_data_size = (len(txt_data))

    print(f'Input dataset length: {txt_data_size} \t Unique characters: {num_chars}')
    # logger.info(f'Input dataset length: {txt_data_size} \t Unique characters: {num_chars}')

    char_to_int = dict((c, i ) for i, c in enumerate(chars))
    int_to_char = dict((v,k) for k,v in char_to_int.items())

    #TODO logging

    txt_data_encoded = [char_to_int[i] for i in txt_data]
    return txt_data_encoded, txt_data_size, chars, num_chars

In [28]:
data, data_size, chars, chars_size = preprocess_data("data/shakespeare_input.txt")

Input dataset length: 4573338 	 Unique characters: 67


In [5]:
class RNN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, num_layers):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(input_size, input_size)
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
        self.decoder = nn.Linear(hidden_size, output_size)

    def forward(self, input_seq, hidden_state):
        embed = self.embedding(input_seq)
        output, hidden_state = self.rnn(embed, hidden_state)
        output = self.decoder(output)
        return output, (hidden_state[0].detach(), hidden_state[1].detach())

In [6]:
hidden_size = 512 #config.hidden_size
seq_len = 100 #config.seq_len
num_layers = 3 #config.num_layers
lr = 0.002 #config.lr
epochs = 100 #config.epochs
eval_sample_length = 200 #config.eval_sample_length
load_chk = False #config.load_chk
save_path = ".pretrained/test.pth" #config.save_path
data_path = "data/shakespeare_input.txt" #config.data_path
device = torch.device("mps" if torch.backends.mps.is_available() and torch.backends.mps.is_built() else "cpu") 
print(f'Device found: {device}')

Device found: mps


In [7]:
data, data_size, chars, chars_size = preprocess_data(data_path)

# data tensor on device
data = torch.tensor(data).to(device)
data = torch.unsqueeze(data, dim=1)

Input dataset length: 4573338 	 Unique characters: 67


In [8]:
rnn = RNN(chars_size, chars_size, hidden_size, num_layers).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)

In [10]:
start_idx = np.random.randint(100)
n = 0
running_loss = 0
hidden_state = None

input_seq = data[start_idx : start_idx + seq_len]
target_seq = data[start_idx + 1 : start_idx + seq_len + 1]

In [11]:
output, hidden_state = rnn(input_seq, hidden_state)

In [22]:
torch.squeeze(output[0])

tensor([-1.0571,  0.7263, -6.7040, -3.4512, -0.8408, -5.9098, -5.8185, -0.0410,
         0.7861, -0.0268, -1.4102,  1.5347, -0.2554,  0.3503, -1.3698, -4.4935,
        -0.1480, -5.5501,  0.5013, -1.7296, -2.6169,  0.0665, -1.7384, -0.3351,
         0.7180,  1.3147, -6.5140, -2.1672, -6.7022,  1.2267, -1.0990, -3.7923,
         0.8452, -2.4252, -0.1474, -0.5068, -6.6027,  0.7670, -6.3746, -3.4414,
        -3.9210, -0.7353, -6.5947,  0.6785, -0.1872,  0.7637, -2.0007,  1.2866,
        -3.7340, -4.4850, -2.9054, -7.1683, -3.6637, -1.8749, -4.0536, -4.0645,
        -0.9814, -1.4169,  2.0271, -0.4198, -5.0288,  1.8017, -0.8093, -2.4910,
        -1.9575, -1.9671, -1.8819], device='mps:0', grad_fn=<SqueezeBackward0>)

In [23]:
torch.squeeze(target_seq[0])

tensor(56, device='mps:0')

In [13]:
output.shape

torch.Size([100, 1, 67])

In [15]:
torch.squeeze(output).shape

torch.Size([100, 67])

In [18]:
torch.squeeze(target_seq).shape

torch.Size([100])

In [19]:
loss_fn(torch.squeeze(output), torch.squeeze(target_seq))

tensor(3.6798, device='mps:0', grad_fn=<NllLossBackward0>)

In [16]:
for i in range(1, epochs+1):

        start_idx = np.random.randint(100)
        n = 0
        running_loss = 0
        hidden_state = None

        while True:
            input_seq = data[start_idx : start_idx + seq_len]
            target_seq = data[start_idx + 1 : start_idx + seq_len + 1]


            output, hidden_state = rnn(input_seq, hidden_state)

            loss = loss_fn(torch.squeeze(output), torch.squeeze(target_seq))
            running_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            start_idx += seq_len
            n += 1

            if start_idx + seq_len +1 > data_size:
                break

KeyboardInterrupt: 

In [25]:
def vectorized_stride(array, start, sub_window_size, stride_size):
    time_steps = len(array) - start
    max_time = time_steps - time_steps % stride_size
    
    sub_windows = (
        start + 
        np.expand_dims(np.arange(sub_window_size), 0) +
        np.expand_dims(np.arange(max_time), 0).T
    )
    
    # Fancy indexing to select every V rows.
    return array[sub_windows[::stride_size]]

In [26]:
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, data, seq_length): 
        'Strides a window of seq_length across the text data from a random starting point'
        self.start_idx = np.random.randint(100)
        self.X = vectorized_stride(data, self.start_idx, seq_length, seq_length) 
        self.y = vectorized_stride(data, self.start_idx + 1, seq_length, seq_length)
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.X)
    
    def __getitem__(self, idx):
        'Generates one sample of data'

        return torch.tensor(self.X[idx]), torch.tensor(self.y[idx])
    

In [75]:
seq_length = 100
dataset = TextDataset(np.array(data), seq_length)

In [76]:
from torch.utils.data import DataLoader
batch_size = 32
loader = DataLoader(dataset,
                    batch_size=batch_size,
                    shuffle=True,
                    pin_memory=True)

In [50]:
dataiter = iter(loader)
X0, y0 = next(dataiter)
X0.shape, y0.shape

(torch.Size([1, 100]), torch.Size([1, 100]))

In [51]:
X0 = X0.to(device)
y0 = y0.to(device)

In [52]:
output, hidden_state = rnn(X0, None)


In [53]:
torch.squeeze(output).shape, torch.squeeze(y0).shape

(torch.Size([100, 67]), torch.Size([100]))

In [71]:
loss_fn(output.reshape(67,100), y0)

SyntaxError: invalid syntax (662699743.py, line 1)

In [77]:
for i in range(1, epochs+1):

    running_loss = 0
    hidden_state = None

    for i, (X, y) in enumerate(loader):
        X, y = X.to(device), y.to(device)
        output, hidden_state = rnn(X, hidden_state)

        loss = loss_fn(output.reshape(batch_size,chars_size,seq_length), y) 
        running_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 1000 == 0:
            print(f'{int(i/1000)}000 sequences processed')

        if start_idx + seq_len +1 > data_size:
            break

    print(f'Epoch: {i} \t Loss: {running_loss/n:.8f}')
    # logger.info(f'Epoch: {i} \t Loss: {running_loss/n:.8f}')
    # torch.save(rnn.state_dict(), save_path)

0k sequences processed
0k sequences processed


RuntimeError: shape '[32, 67, 100]' is invalid for input of size 33500