In [None]:
import torch
from torch import nn
import numpy as np
from timeit import default_timer as timer
from time import perf_counter

In [None]:
no_of_chars = 142000

with open('text2', 'r') as fd:
    full_text = fd.read().lower()
full_text = full_text[0:no_of_chars]

vocab = set(full_text)
int2char = dict(enumerate(vocab))
char2int = {char: ind for ind, char in int2char.items()}
vocab_size = len(char2int)
print("Vocabulary size:", vocab_size)
print("Text lenght:", len(full_text))

In [None]:
is_cuda = torch.cuda.is_available()
device = torch.device("cuda")

In [None]:
class ModelLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers):
        super(ModelLSTM, self).__init__()
        output_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.lstm = nn.LSTM(input_size, hidden_size, n_layers, batch_first=True)   
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x, full_hidden):
        out, full_hidden = self.lstm(x, full_hidden)
        out = out.contiguous().view(-1, self.hidden_size)
        out = self.fc(out)
        return out, full_hidden
    
    def init_full_hidden(self, batch_size):
        hidden = torch.randn(self.n_layers, batch_size, self.hidden_size).to(device)
        cell_state = torch.randn(self.n_layers, batch_size, self.hidden_size).to(device)
        return (hidden, cell_state)

In [None]:
def split_eq(text, no):
    cnt = int(len(text) / no)
    examples = [text[i:i+cnt] for i in range(0, len(text), cnt)]
    if (no*cnt == len(text)):
        return examples
    else:
        return examples[:-1]

def produce_targets(examples):
    targets = [ex[1:] for ex in examples]
    inputs = [ex[:-1] for ex in examples]
    return inputs, targets

def translate_to_int(examples):
    translated = [list(map(lambda ch: char2int[ch], ex)) for ex in examples]
    return translated

def translate_to_char(examples):
    translated = [''.join(list(map(lambda i: int2char[i], ex))) for ex in examples]
    return translated

def one_hot_encode(examples):
    features = np.zeros((len(examples), len(examples[0]), len(char2int)), dtype=np.float32)
    
    for i, example in enumerate(examples):
        for pos in range(len(examples[i]) - 1):
            features[i, pos, examples[i][pos]] = 1
    return features

def to_model_format(inputs):
    if isinstance(inputs, str):
        inputs = [inputs]
    trans_inputs = translate_to_int(inputs)
    encoded = one_hot_encode(trans_inputs)
    encoded_tensor = torch.from_numpy(encoded)
    return encoded_tensor

In [None]:
# configuration
no_of_examples = 64
batch_size = examples_per_batch = 32
lr = 0.0048

no_of_batches = int(no_of_examples / examples_per_batch)

examples = split_eq(full_text, no_of_examples)
chars_per_example = len(examples[0])
inputs, targets = produce_targets(examples)
trans_inputs = translate_to_int(inputs)
trans_targets = translate_to_int(targets)

batches = []

for i in range(no_of_batches):
    input_seq = one_hot_encode(trans_inputs[i*examples_per_batch:(i+1)*examples_per_batch])
    target_seq = torch.Tensor(trans_targets[i*examples_per_batch:(i+1)*examples_per_batch])
    batches.append((torch.from_numpy(input_seq), target_seq))

print("No of examples/No of data parts:", no_of_examples)
print("No of batches:", no_of_batches)
print("Examples per batch:", examples_per_batch)
print("Chars per example:", chars_per_example)

In [None]:
dict_size = len(char2int)
model = ModelLSTM(input_size=dict_size, hidden_size=36, n_layers=3)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
model.to(device)

In [None]:
epochs = 1700
counter = 0
print_every = 50

t_start = perf_counter()
model.train()
for i in range(epochs):
    counter += 1
    for batch in batches:
        h = model.init_full_hidden(batch_size)
        model.zero_grad()
        inp, target = batch
        inp, target = inp.to(device), target.to(device)
        output, h = model(inp, h)
        loss = criterion(output, target.view(-1).long())
        loss.backward()
        optimizer.step()
        
    if counter%print_every == 0:
        print("Epoch: {}/{}...".format(i+1, epochs),
              "Loss: {:.6f}...".format(loss.item()))
        t_stop = perf_counter()
        print("Time elasped:", t_stop - t_start)

In [None]:
def predict_next(device, model, full_hidden, input_string):
    encoded_input = to_model_format(input_string)
    out, hidden = model(encoded_input.to(device), full_hidden)

    # choosing one with highest probability
    prob = nn.functional.softmax(out[-1], dim=0).data
    char_ind = torch.max(prob, dim=0)[1].item()
    return int2char[char_ind], hidden


def run_model(device, model, starting_seq, size=50):
    model.eval()
    seq = starting_seq.lower()
    h = model.init_full_hidden(1)
    for _ in range(size):
        char, h = predict_next(device, model, h, seq)
        seq += char
    return ''.join(seq)

In [None]:
res = run_model(device, model, 'A great and advanced society has ', 250)
print(res)

In [None]:
torch.save(model.state_dict(), "./lstm_gpu")

Some results:

Epoch: 2000/2000... Loss: 1.196766... 36/3 12800/32/16
"a great and advanced society has eopr doass moaddyr oolot  oorloo ,awn ooloa sahr aots aots loild ooloa sahr aots ooloa sahr aots loisd  oorsoa  oor,oart- eogr ooloo ,awn ooloa sahr aots loisd  oorsoa  oor,oart- eogr ooloo ,awn ooloa sahr aots loisd  oorsoa  oor,oart- eogr ooloo ,aw"

Epoch 1500/1500 Loss: 1.50 36/3 13200/16/16
a great and advanced society has euft rult  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld  uedetiig tueld 

Epoch 2250/2250 Loss: 1.350470...13200/16/16
a great and advanced society has euft ruotd tuutcoeddt  ueteel,s iutt rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rult rul

Epoch: 3000/3000... Loss: 1.262259...13200/16/16
a great and advanced society has eeftetg noubd tuet iutt rutttttutl lundetde rult rult rult rult nuuts iuttrrltt tuet eugttnitd eugtrnttrrln., ,uev tuet iuttrrtttttuudtt nuuts iuttrrltt tuet eugttnitd eugtrnttrrln., ,uev tuet iuttrrtttttuudtt nuuts iuttrrltt tuet eugttnitd eugtrnttr

Epoch: 4000/4000... Loss: 1.187211...13200/16/16
a great and advanced society has eeftetc  updu mm  feudrpling tuek tuetging tuet iugttnitg tuek tuelgn tftet  eenrton  ueielc  umdiilt  eentingtnrtnmtnnttrmm iftet  uede duoicg  umiill ,uifct mettoyt ruttttmunt rult nuutt  uede,drtmrr lfnitd eufc mf tuutdamdiagdtn tftet  eefrritdtyn

Epoch 4500
a great and advanced society has eeftetc  umdli,g lengyn.,irdnt ruots tuet iugttn tutkoetle  eerroitdeudttl npulrnn,c luyttnuuld ,uef eugrrstautd efteldn,trowttn tuetging tuet lundetde suo gupdepgisg,r relltruusd ymuls iitdouddtrrtlunl ,peodet iutt ruttttmunt rill oeg  ueielc  pmolc

Epoch 5000
a great and advanced society has eeftetc  umdiilt  iedtmnreli ruld tpetcoiddtt nuuts ,peodlttretct tuek ,feedaggttrrn  peotlo tiedd nuuusd nfurr npuftu teedtm, audlist  imdieitg tuuklynn,, ,uefette,,n,n nerttrnct nuuts ipt relct uudrend mmt iettott uid tmenc  peotdetti nfwwo raoldr 

Epoch 5500
a great and advanced society has eeftetci teed nmumv  uedrtlunn yuu rfldttrttmt tftetl  ppict,i mmtt relct yiult reltnrnnitd eufc efaed-neniesc ,uef tuet iugttnitg tuekg,n mettoyt roasr  uebi ddotttn tfetri gpetteidg,,  ueditg rill oeg  ueielcs niutdoiddttmnnytn afdetiyg.m guedct nu

Epoch 6200
a great and advanced society has eeftetci tteutdoild yuu rfldntnrt uutr lpnte riolet ,pe dmuitggt wuttteugt rill oet iutt rotlln tuutlatd,y ,ee gmer teesm  feidtawd nuutr upd nfuwt uudring iutr upv ncuct..iiddtm mutt yfutdeer.. tuutllm siuddtt,tawiells yuu rfodd  uudrsna rdlsndns rf

Epoch 6200 + 1700
a great and advanced society has eufr tuek nfurtt tuet huo sdiit  feudtprpmesrlct eettert  cewdternttrrnmawd  etr iuttrrmtwue- iudrnlett neutd iotsetl  mp gopdluande reotgitgtmln spond ietcielg  imgleem,," 2uea"fy nmutu nuutt rult ,uef tuet iu tvendttr ncu foedm,l,r reokgtnwt neudlt
