In [17]:
import torch

In [1]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [2]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [3]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print("vocab size: ", vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size:  65


In [15]:
stoi = {char: i for i, char in enumerate(chars)}
itos = {i: char for i, char in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] # takes a string, returns list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # takes list of integers, returns a string

In [16]:
print(encode("i am navdeep"))
print(decode([47, 1, 39, 51, 1, 52, 39, 60, 42, 43, 43, 54]))
assert decode(encode("i am navdeep")) == "i am navdeep"

[47, 1, 39, 51, 1, 52, 39, 60, 42, 43, 43, 54]
i am navdeep


In [18]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)

torch.Size([1115394])


In [19]:
# split dataset
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [29]:
torch.manual_seed(1337)

batch_size = 4 # independent sequences that will be processed in parallel
block_size = 8 # maximum context length

def get_batch(split):
    """
    returns a batch (x, y) of batch_size = 4 arrays
    arrays are of size block_size = 8
    """
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) # get batch_size=4 random numbers from len(data) - block_size. minus block_size is to prevent overflow
    x = torch.stack([data[i:i+block_size] for i in ix]) # stack batch_size=4 arrays cotaining block_size=8 elements from data starting from index in ix
    y = torch.stack([data[i+1:i+1+block_size] for i in ix]) # stack batch_size=4 arrays cotaining block_size=8 elements from data starting from index+1 in ix

    return x, y

In [36]:
xb, yb = get_batch("train")
print(xb)
print(yb)

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b][:t+1]
        target = yb[b][t]
        print(f'{context.tolist()} --> {target}')

tensor([[52, 42,  8,  0,  0, 23, 21, 26],
        [45, 53, 42, 57,  0, 23, 43, 43],
        [52,  1, 61, 39, 57,  1, 51, 53],
        [39, 49, 12,  1, 27,  1, 58, 56]])
tensor([[42,  8,  0,  0, 23, 21, 26, 19],
        [53, 42, 57,  0, 23, 43, 43, 54],
        [ 1, 61, 39, 57,  1, 51, 53, 56],
        [49, 12,  1, 27,  1, 58, 56, 39]])
[52] --> 42
[52, 42] --> 8
[52, 42, 8] --> 0
[52, 42, 8, 0] --> 0
[52, 42, 8, 0, 0] --> 23
[52, 42, 8, 0, 0, 23] --> 21
[52, 42, 8, 0, 0, 23, 21] --> 26
[52, 42, 8, 0, 0, 23, 21, 26] --> 19
[45] --> 53
[45, 53] --> 42
[45, 53, 42] --> 57
[45, 53, 42, 57] --> 0
[45, 53, 42, 57, 0] --> 23
[45, 53, 42, 57, 0, 23] --> 43
[45, 53, 42, 57, 0, 23, 43] --> 43
[45, 53, 42, 57, 0, 23, 43, 43] --> 54
[52] --> 1
[52, 1] --> 61
[52, 1, 61] --> 39
[52, 1, 61, 39] --> 57
[52, 1, 61, 39, 57] --> 1
[52, 1, 61, 39, 57, 1] --> 51
[52, 1, 61, 39, 57, 1, 51] --> 53
[52, 1, 61, 39, 57, 1, 51, 53] --> 56
[39] --> 49
[39, 49] --> 12
[39, 49, 12] --> 1
[39, 49, 12, 1] --> 27
[39