In [1]:
import torch
from torch import nn
from jlib.classifier import Classifier
from jlib.get_shakespeare_loaders import get_shakespeare_loaders

text = ""
with open('data/sequence.txt', 'r') as f:
    text = f.read()

class ShakespeareRNN(Classifier):
    def __init__(self, alphabet_size, hidden_size, rnn=nn.RNN, linear_network=[]):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(alphabet_size, hidden_size)
        self.rnn = rnn(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Sequential()
        linear_in = hidden_size
        linear_out = alphabet_size
        for i, layer_size in enumerate(linear_network):
            self.fc.add_module(f'linear_{i}', nn.Linear(linear_in, layer_size))
            self.fc.add_module(f'relu_{i}', nn.ReLU())
            linear_in = layer_size
        self.fc.add_module('final_linear', nn.Linear(linear_in, linear_out))
        
    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.rnn(x)
        x = self.fc(x[:, -1, :])  # Get the output of the last RNN cell
        return x

def train_and_plot(data, model: ShakespeareRNN, name, *training_args, **training_kwargs):
    print(f"Training {name}")
    param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {param_count} parameters")
    model.train_model(
        *training_args,
        **training_kwargs,
        train_loader=data['train_loader'],
        val_loader=data['val_loader'],
    )
    fig = model.plot_training(f"{name} Training")
    fig.savefig(f"images/{name}_training_new.png")


  from .autonotebook import tqdm as notebook_tqdm


# LSTM 20

In [2]:
data_20 = get_shakespeare_loaders(
    train_batch_size=32,
    val_batch_size=512,
    redownload=False,
    sequence_length=20
)

Train GPU Prefetch: 819200.0
Train CPU Prefetch: 81920.0
Val GPU Prefetch: 25600.0
Val CPU Prefetch: 5120.0
Train Loader
Begin init data loader
Batch Size: 0.0048828125 MiB
Data Loader init time: 2.510459 s
Begin init fetcher
Fetcher init time: 3.874673 s
Val Loader
Begin init data loader
Batch Size: 0.078125 MiB
Data Loader init time: 1.123367 s
Begin init fetcher
Fetcher init time: 1.301090 s


In [3]:
lstm20 = ShakespeareRNN(
    alphabet_size=26,
    hidden_size=128,
    rnn=nn.LSTM,
).to('cuda')
train_and_plot(
    data_20,
    lstm20,
    "LSTM-20",
    epochs=100,
    optimizer = torch.optim.Adam,
    optimizer_kwargs = {'lr': 0.001},
    min_accuracy = 0.7
)

Training LSTM-20
Model has 138778 parameters
Training ShakespeareRNN

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
training
Max: 63 Min: 0


/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [249,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [249,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [249,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [249,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [249,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1422: indexSelectLargeIndex: block: [249,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
