# 1D Convolutional Neural Network 
In this notebook we implement a first attempt at next token prediction using a neural network. To align with previous weeks in this course we choose a 'simple' 1D convolutional neural network. The motivation behind this is that by composing convolutions we can create a receptive field on our sequence context that is more efficient in its parameter use than our previous n-gram markov model, thereby allowing longer sequence context to be used.

In [None]:
# All dependencies for the entire notebook
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import RandomSampler

from dataclasses import dataclass, field
from tqdm.auto import tqdm, trange
from matplotlib import pyplot as plt

DEVICE = torch.device('mps')

## Data

In [None]:
# Download the tinyshakespeare dataset
!wget -nc https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

We use a character level tokenizer and a dataset class to select (and one-hot encode) a context of a given size and the next token to predict.

In [None]:
@dataclass
class CharacterTokenizer:
    decode_dict: dict[int, str] = field(default_factory=dict)
    encode_dict: dict[str, int] = field(default_factory=dict)

    def get_vocab(self):
        """Character to int mapping"""
        return self.encode_dict

    def train(self, input_str: str) -> None:
        """Determine what character will be mapped to which int using lexicograpical order"""
        chars = sorted(set(input_str))
        self.decode_dict: dict[int, str] = dict(enumerate(chars))
        self.encode_dict: dict[str, int] = {v:k for k,v in self.decode_dict.items()}

    def encode(self, input: str) -> list[int]:
        """Turn a string into a list of ints using a pretrained lookup table"""
        return [self.encode_dict[char] for char in input]

    def decode(self, tokens: list[int]) -> str:
        """Turn a list of ints into a string using a reverse lookup table"""
        return ''.join(self.decode_dict[token] for token in tokens)

class CharacterDataset:
    def __init__(self, data: str, tokenizer: CharacterTokenizer, context_size: int=256):
        self.data = data
        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer.get_vocab())
        self.context_size = context_size

    def __repr__(self):
        n_chars = len(self.data)
        vocab_size = self.vocab_size
        block_size = self.context_size
        return f'CharacterDataset({n_chars=}, {context_size=}, {block_size=})'

    @classmethod
    def from_textfile(cls, filename: str, context_size: int=256) -> 'CharacterDataset':
        tokenizer = CharacterTokenizer()
        with open(filename, 'r') as fh:
            data = fh.read()
            tokenizer.train(data)
            return cls(data, tokenizer, context_size=context_size)

    def train_test_split(self, train_percentage: float=0.8) -> tuple['CharacterDataset','CharacterDataset']:
        n_train_chars = int(train_percentage * len(self.data))

        train_data = self.data[:n_train_chars]
        train_dataset = CharacterDataset(train_data, self.tokenizer, self.context_size)

        test_data = self.data[n_train_chars:]
        test_dataset = CharacterDataset(test_data, self.tokenizer, self.context_size)

        return train_dataset, test_dataset

    def __len__(self) -> int:
        return len(self.data) - self.context_size

    def __getitem__(self, idx: int) -> torch.tensor:
        # grab a chunk of context_size + 1 characters from the data
        chunk = self.data[idx:idx + self.context_size + 1]
        # encode every character to an integer
        tokens = self.tokenizer.encode(chunk)
        # convert to tensor
        tokens = torch.tensor(tokens, dtype=torch.long)
        # Onehot encode, transpose because Conv1D takes (batch, channels, length) as input dims: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
        x = F.one_hot(tokens[:-1], num_classes=self.vocab_size).type(torch.float32).T
        # Use last character as target
        y = tokens[-1:]
        return x,y

dataset = CharacterDataset.from_textfile('./input.txt', context_size=32)
x,y = dataset[0]
x.shape,y.shape

## Model
Below we implement our 1D convolutional neural network for next character prediction. 

## Exercise 1
Add the 1D CNN to the provided implementation below. Create two convolution blocks of the specified number of channels that use [Conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html), [ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html), and [MaxPool1d](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool1d.html), followed by a [linear projection](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) to the output dimensionality (vocab_size). Use a convolution kernel size of 3, with padding to keep the output the same size. Use max-pooling with a kernel size of 2 and a stride to keep the same output size. Train your model using the training codeblock. What training and test loss does your 'simple' 1D convolution net achieve?

In [None]:
class CNN1D(nn.Module):
    """1D Convolutional Neural Network for next token prediction"""
    def __init__(self, vocab_size: int, context_size: int, conv_channels: int=128, use_bias: bool=False):
        super().__init__()
        assert context_size % 2 == 0, f'Invalid block_size, {block_size} is not an even number'
        self.vocab_size = vocab_size
        self.context_size = context_size
        self.cnn = # IMPLEMENT ME!
    
    def forward(self, X: torch.tensor, targets: torch.tensor=None) -> tuple[torch.tensor, torch.tensor]:
        """Predict logits of next character conditioned on context_size previous characters"""
        logits = self.cnn(X)
        loss = None if targets is None else F.cross_entropy(logits, targets.view(-1), ignore_index=-1)
        return logits,loss

    def generate(self, sample_length: int=256) -> list[int]:
        """Generate sample of tokens"""
        device = next(self.parameters()).device
        # Start generating with n=context_size newline tokens, these will later be omitted
        idx = torch.zeros((1, self.context_size), dtype=torch.long, device=device)
        
        for _ in trange(sample_length, desc='Sampling'):
            # onehot encode the last context_size tokens
            context_tokens = idx[:, -self.context_size:]
            input = F.one_hot(context_tokens, self.vocab_size).to(torch.float).transpose(1,2) # transpose because of Conv1D shape requirements
            # forward model and calculate token probabilities
            logits, _ = model(input)
            probs = F.softmax(logits, dim=-1)
            # sample next token
            idx_next = torch.multinomial(probs, num_samples=1)
            # append generated token to current sample
            idx = torch.cat([idx, idx_next], dim=1)

        # Omit first context_size tokens (these were all newlines to get the sampling started)
        return idx[0, self.context_size:].tolist()

# Create a sample with an untrained model for comparison/testing
model = CNN1D(dataset.vocab_size, context_size=dataset.context_size)
sample = model.generate()
print(dataset.tokenizer.decode(sample))

## Training

In [None]:
dataset = CharacterDataset.from_textfile('./input.txt', context_size=64)
train_dataset,test_dataset = dataset.train_test_split()

model = CNN1D(vocab_size=dataset.vocab_size, context_size=dataset.context_size)
model.to(DEVICE)

train_steps = 2000
batch_size = 64

train_dataloader = DataLoader(
    dataset=train_dataset,
    sampler=RandomSampler(train_dataset, num_samples=train_steps * batch_size),
    batch_size=batch_size,
)
test_dataloader = DataLoader(
    dataset=test_dataset,
    sampler=RandomSampler(test_dataset),
    batch_size=batch_size,
)
test_dataloader = iter(test_dataloader)

optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-3)

train_losses = []
test_losses = []

for i,(x,y) in enumerate(tqdm(train_dataloader, desc='Training')):
    x,y = x.to(DEVICE), y.to(DEVICE)
    # forward model and calculate loss
    _,loss = model(x,y)
    # save train and test loss every 20 steps
    if i % 20 == 0:
        train_losses.append(loss.item())
        test_x, test_y = next(test_dataloader)
        _,test_loss = model(test_x.to(DEVICE), test_y.to(DEVICE))
        test_losses.append(test_loss.item())
    # backprop and update the parameters
    model.zero_grad()
    loss.backward()
    optimizer.step()

plt.plot(train_losses, label='train')
plt.plot(test_losses, label='test')
plt.legend()

## Evaluation

In [None]:
sample = model.generate(sample_length=256)
print(dataset.tokenizer.decode(sample))

## Answers

### Exercise 1
Example 1D CNN implementation below, after training for 2000 iterations using batch_size=64 and context_size=64 this achieves a training and test loss of 2.0 - 2.5.

In [None]:
self.cnn = nn.Sequential(
    # conv block 1
    nn.LazyConv1d(out_channels=conv_channels, kernel_size=3, padding='same', bias=use_bias),
    nn.ReLU(),
    nn.MaxPool1d(kernel_size=2, stride=2),
    # conv block 2
    nn.LazyConv1d(out_channels=conv_channels, kernel_size=3, padding='same', bias=use_bias),
    nn.ReLU(),
    nn.MaxPool1d(kernel_size=2, stride=2),
    # output projection
    nn.Flatten(1, -1),
    nn.LazyLinear(self.vocab_size, bias=use_bias)
)