In [1]:
import torch
from torch import nn
import re
import random
import tqdm
import time

In [2]:
!wget https://s3.amazonaws.com/text-datasets/nietzsche.txt

--2023-06-18 17:55:33--  https://s3.amazonaws.com/text-datasets/nietzsche.txt
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.167.184, 52.216.209.24, 52.217.160.168, ...
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.167.184|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 600901 (587K) [text/plain]
Saving to: ‘nietzsche.txt.2’


2023-06-18 17:55:34 (1.02 MB/s) - ‘nietzsche.txt.2’ saved [600901/600901]



In [3]:
with open('nietzsche.txt', encoding='utf-8') as f:
    text = f.read().lower()
print('length:', len(text))
text = re.sub('[^a-z ]', ' ', text)
text = re.sub('\s+', ' ', text)

length: 600893


In [4]:
text[:100]

'preface supposing that truth is a woman what then is there not ground for suspecting that all philos'

In [5]:
INDEX_TO_CHAR = sorted(list(set(text)))
CHAR_TO_INDEX = {c: i for i, c in enumerate(INDEX_TO_CHAR)}

In [6]:
CHAR_TO_INDEX

{' ': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26}

In [7]:
MAX_LEN = 40
STEP = 3
SENTENCES = []
NEXT_CHARS = []
for i in range(0, len(text) - MAX_LEN, STEP):
    SENTENCES.append(text[i: i + MAX_LEN])
    NEXT_CHARS.append(text[i + MAX_LEN])
print('Num sents:', len(SENTENCES))

Num sents: 193075


In [8]:
print('Vectorization...')
X = torch.zeros((len(SENTENCES), MAX_LEN), dtype=int)
Y = torch.zeros((len(SENTENCES)), dtype=int)
for i, sentence in enumerate(SENTENCES):
    for t, char in enumerate(sentence):
        X[i, t] = CHAR_TO_INDEX[char]
    Y[i] = CHAR_TO_INDEX[NEXT_CHARS[i]]

Vectorization...


In [9]:
X[0:1], Y[0]

(tensor([[16, 18,  5,  6,  1,  3,  5,  0, 19, 21, 16, 16, 15, 19,  9, 14,  7,  0,
          20,  8,  1, 20,  0, 20, 18, 21, 20,  8,  0,  9, 19,  0,  1,  0, 23, 15,
          13,  1, 14,  0]]),
 tensor(23))

In [10]:
BATCH_SIZE=512
dataset = torch.utils.data.TensorDataset(X, Y)
data = torch.utils.data.DataLoader(dataset, BATCH_SIZE, shuffle=True)


In [11]:
class NeuralNetwork(nn.Module):
    def __init__(self, rnnClass, dictionary_size, embedding_size, num_hiddens, num_classes):
        super().__init__()
        
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(dictionary_size, embedding_size)
        self.hidden = rnnClass(embedding_size, num_hiddens, batch_first=True)
        self.output = nn.Linear(num_hiddens, num_classes)
        
    def forward(self, X):
        out = self.embedding(X)
        _, state = self.hidden(out)
        predictions = self.output(state[0].squeeze())
        return predictions

In [12]:
model = NeuralNetwork(nn.LSTM, len(CHAR_TO_INDEX), 64, 128, len(CHAR_TO_INDEX))

In [14]:
model(X[0:1])

tensor([-0.0124,  0.1266, -0.1151, -0.0075, -0.0161,  0.0299, -0.0930, -0.0228,
         0.1060,  0.0493, -0.0202, -0.0397,  0.0827, -0.0721,  0.0468,  0.0895,
         0.0609,  0.0054, -0.1286,  0.0954,  0.0660,  0.2153, -0.0481, -0.0264,
        -0.0420,  0.0932, -0.0165], grad_fn=<AddBackward0>)

In [15]:
embedding = nn.Embedding(len(INDEX_TO_CHAR), 15)
rnn = nn.LSTM(15,128, batch_first=True)

In [16]:
o, s = rnn(embedding(X[0:10]))
o.shape, len(s), s[0].shape, s[1].shape

(torch.Size([10, 40, 128]),
 2,
 torch.Size([1, 10, 128]),
 torch.Size([1, 10, 128]))

In [17]:
rnn = nn.GRU(15,128, batch_first=True)
o, s = rnn(embedding(X[0:10]))
o.shape, len(s), s[0].shape

(torch.Size([10, 40, 128]), 1, torch.Size([10, 128]))

In [18]:
o, s = rnn(embedding(X[0:10]))

In [19]:
def sample(preds):
    softmaxed = torch.softmax(preds, 0)
    probas = torch.distributions.multinomial.Multinomial(1, softmaxed).sample()
    return probas.argmax()

def generate_text():
    start_index = random.randint(0, len(text) - MAX_LEN - 1)

    generated = ''
    sentence = text[start_index: start_index + MAX_LEN]
    generated += sentence

    for i in range(MAX_LEN):
        x_pred = torch.zeros((1, MAX_LEN), dtype=int)
        for t, char in enumerate(generated[-MAX_LEN:]):
            x_pred[0, t] = CHAR_TO_INDEX[char]

        preds = model(x_pred)
        next_char = INDEX_TO_CHAR[sample(preds)]
        generated = generated + next_char

    print(generated[:MAX_LEN] + '|' + generated[MAX_LEN:])

In [20]:
a = torch.Tensor([51,50,1,49,7])

In [21]:
p = []
for i in a :
  p.append(torch.exp(i)/torch.sum(torch.exp(a)))

In [22]:
p = torch.FloatTensor(p)

In [23]:
torch.distributions.multinomial.Multinomial (50,p).sample()

tensor([29., 16.,  0.,  5.,  0.])

In [24]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [25]:
for ep in range(100):
    start = time.time()
    train_loss = 0.
    train_passed = 0

    model.train()
    for X_b, y_b in data:
    
        optimizer.zero_grad()
        answers = model(X_b)
        loss = criterion(answers, y_b)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        train_passed += 1

    print("Epoch {}. Time: {:.3f}, Train loss: {:.3f}".format(ep, time.time() - start, train_loss / train_passed))
    model.eval()
    generate_text()

Epoch 0. Time: 17.880, Train loss: 2.175
o be old fashioned and grandfatherly res|enas a ans at of he worsttlive and onsel
Epoch 1. Time: 18.955, Train loss: 1.795
ble filth of all political agitation the| vires theveres of persines a onlest tha
Epoch 2. Time: 17.772, Train loss: 1.656
ganizations how the dim mole eyes of suc|klic our it in grole in to the orlies wi
Epoch 3. Time: 18.977, Train loss: 1.569
nward impulse rules them with the master| a narder arted they in which knownsh to
Epoch 4. Time: 18.136, Train loss: 1.506
uld be philosophers nowadays would be co|nernow oun of ording to the drome to may
Epoch 5. Time: 18.743, Train loss: 1.460
osite but as its refinement it is to be |philosomptent with reart and the apsaine
Epoch 6. Time: 19.571, Train loss: 1.423
ere are systems of morals which are mean|s the en arthain pectleness to the stold
Epoch 7. Time: 18.369, Train loss: 1.391
et a german who was favourably inclined |close and the world bad one consciousent
Epoch 8. Time: 1