In [1]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange

In [2]:
with open('/home/iot/jupyter/root_dir/liudongdong/dataset/charprediction/train.txt', 'r') as f:
    text = f.read()

chars = tuple(set(text))
int2char = dict(enumerate(chars))
char2int = {ch: ii for ii, ch in int2char.items()}

encoded = np.array([char2int[ch] for ch in text])

In [4]:
def one_hot_encode(arr, n_labels):
    
    # Initialize the the encoded array
    one_hot = np.zeros((np.multiply(*arr.shape), n_labels), dtype=np.float32)
    
    # Fill the appropriate elements with ones
    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
    
    # Finally reshape it to get back to the original array
    one_hot = one_hot.reshape((*arr.shape, n_labels))
    
    return one_hot

In [5]:
def get_batches(arr, batch_size, seq_length):
    '''Create a generator that returns batches of size
       batch_size x seq_length from arr.
       
       Arguments
       ---------
       arr: Array you want to make batches from
       batch_size: Batch size, the number of sequences per batch
       seq_length: Number of encoded chars in a sequence
    '''
    
    ## TODO: Get the number of batches we can make
    n_batches = (len(arr))//(batch_size*seq_length)
    
    ## TODO: Keep only enough characters to make full batches
    arr = arr[:(n_batches*batch_size*seq_length)]
    
    ## TODO: Reshape into batch_size rows
    size=(batch_size,-1)
    arr = arr.reshape(size)  #(batch, columns)  后续数据直接在 columns 遍历
    
    ## TODO: Iterate over the batches using a window of size seq_length
    for n in range(0, arr.shape[1], seq_length):
        # The features
        x = arr[:, n:n+seq_length]
        # The targets, shifted by one
        y = np.zeros_like(x)
        try:
          y[:, :-1],y[:, -1]=x[:,1:], arr[:,n+seq_length]
        except IndexError:
          y[:, :-1],y[:, -1]=x[:,1:], arr[:,0]
        yield x, y

In [6]:
# check if GPU is available
train_on_gpu = torch.cuda.is_available()
if(train_on_gpu):
    print('Training on GPU!')
else: 
    print('No GPU available, training on CPU; consider making n_epochs very small.')

Training on GPU!


In [45]:
class CharRNN(nn.Module):
    
    def __init__(self, tokens, n_hidden=256, n_layers=2,
                               drop_prob=0.5, lr=0.001):
        super().__init__()
        self.drop_prob = drop_prob
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.lr = lr
        
        # creating character dictionaries
        self.chars = tokens
        self.int2char = dict(enumerate(self.chars))
        self.char2int = {ch: ii for ii, ch in self.int2char.items()}
        
        ## TODO: define the layers of the model
        self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, 
                            dropout=drop_prob, batch_first=True)   #注意这里
        
        self.dropout=nn.Dropout(drop_prob)

        self.fc = nn.Linear(n_hidden, len(self.chars))
      
    
    def forward(self, x, hidden):
        ''' Forward pass through the network. 
            These inputs are x, and the hidden/cell state `hidden`. '''
                
        ## TODO: Get the outputs and the new hidden state from the lstm
        #x=rearrange(x,'b s d-> s b d')
        r_output,hidden=self.lstm(x,hidden)
        #r_output=rearrange(r_output,'s b d-> b s d')
        out=self.dropout(r_output)
        #print("self.dropout",out.shape)
        out=out.contiguous().view(-1,self.n_hidden)
        #print("self.contiguous",out.shape)
        out=self.fc(out)
# self.dropout torch.Size([128, 100, 512])
# self.contiguous torch.Size([12800, 512])
# output, torch.Size([12800, 94])
        # return the final output and the hidden state
        return out, hidden
    
    
    def init_hidden(self, batch_size):
        ''' Initializes hidden state '''
        # Create two new tensors with sizes n_layers x batch_size x n_hidden,
        # initialized to zero, for hidden state and cell state of LSTM
        weight = next(self.parameters()).data
        
        if (train_on_gpu):
            hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),
                  weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda())
        else:
            hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),
                      weight.new(self.n_layers, batch_size, self.n_hidden).zero_())
        
        return hidden

In [46]:
## TODO: set you model hyperparameters
# define and print the net
n_hidden=512
n_layers=2

net = CharRNN(chars, n_hidden, n_layers)
print(net)


CharRNN(
  (lstm): LSTM(94, 512, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=512, out_features=94, bias=True)
)


In [96]:
def train(net, data, epochs=10, batch_size=10, seq_length=50, lr=0.001, clip=5, val_frac=0.1, print_every=10):
    ''' Training a network 
    
        Arguments
        ---------
        
        net: CharRNN network
        data: text data to train the network
        epochs: Number of epochs to train
        batch_size: Number of mini-sequences per mini-batch, aka batch size
        seq_length: Number of character steps per mini-batch
        lr: learning rate
        clip: gradient clipping
        val_frac: Fraction of data to hold out for validation
        print_every: Number of steps for printing training and validation loss
    
    '''
    net.train()
    
    opt = torch.optim.Adam(net.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
#     scheduler=ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
    # create training and validation data
    val_idx = int(len(data)*(1-val_frac))
    data, val_data = data[:val_idx], data[val_idx:]
    
    if(train_on_gpu):
        net.cuda()
    
    counter = 0
    n_chars = len(net.chars)
    for e in range(epochs):
        # initialize hidden state
        h = net.init_hidden(batch_size)
        
        for x, y in get_batches(data, batch_size, seq_length):
            counter += 1
            
            #print("inputs.shape,targets.shape",x.shape,y.shape,[int2char[ch] for ch in x[0]])
            # One-hot encode our data and make them Torch tensors
            x = one_hot_encode(x, n_chars)
            inputs, targets = torch.from_numpy(x), torch.from_numpy(y)
            #print("inputs.shape,targets.shape",inputs.shape,targets.shape)
            
            if(train_on_gpu):
                inputs, targets = inputs.cuda(), targets.cuda()
            
            # Creating new variables for the hidden state, otherwise
            # we'd backprop through the entire training history
            h = tuple([each.data for each in h])

            # zero accumulated gradients
            net.zero_grad()
            
            # get the output from the model
            output, h = net(inputs, h)
            #print("output,",output.shape)
    
            # calculate the loss and perform backprop
            loss = criterion(output, targets.view(batch_size*seq_length))
            loss.backward()
            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            nn.utils.clip_grad_norm_(net.parameters(), clip)
            opt.step()
            
            # loss stats
            if counter % print_every == 0:
                # Get validation loss
                val_h = net.init_hidden(batch_size)
                val_losses = []
                accuracy_train=0
                net.eval()
                for x, y in get_batches(val_data, batch_size, seq_length):
                    # One-hot encode our data and make them Torch tensors
                    x = one_hot_encode(x, n_chars)
                    x, y = torch.from_numpy(x), torch.from_numpy(y)
                    
                    # Creating new variables for the hidden state, otherwise
                    # we'd backprop through the entire training history
                    val_h = tuple([each.data for each in val_h])
                    
                    inputs, targets = x, y
                    if(train_on_gpu):
                        inputs, targets = inputs.cuda(), targets.cuda()

                    output, val_h = net(inputs, val_h)
                    #print(output.shape)
                    val_loss = criterion(output, targets.view(batch_size*seq_length))
                    pre_train=F.softmax(np.reshape(output.cpu().detach(),(batch_size,seq_length,-1)),dim=2)
                    #print("inputs.shape,targets.shape",pre_train[0],targets.shape)
                    pre_train=torch.argmax(pre_train,dim = 2)
                    #print("inputs.shape,targets.shape",pre_train[0],targets.shape)
        
                    
                    #print("inputs.shape,targets.shape",pre_train.shape,targets.shape)
                    targets_train=targets.cpu().detach()
                    accuracy_train = torch.sum(pre_train == targets_train)/(targets_train.shape[0] * targets_train.shape[1])
                    val_losses.append(val_loss.item())
                
                net.train() # reset to train mode after iterationg through validation data
                
                print("Epoch: {}/{}...".format(e+1, epochs),
                      "Step: {}...".format(counter),
                      "Loss: {:.4f}...".format(loss.item()),
                      "Val Loss: {:.4f} acc:{}".format(np.mean(val_losses),accuracy_train))

In [98]:
batch_size = 128
seq_length = 100
n_epochs = 100 # start small if you are just testing initial behavior

# train the model
train(net, encoded, epochs=n_epochs, batch_size=batch_size, seq_length=seq_length, lr=0.001, print_every=10)

Epoch: 1/100... Step: 10... Loss: 1.5137... Val Loss: 1.4837 acc:0.5511718988418579
Epoch: 1/100... Step: 20... Loss: 1.5372... Val Loss: 1.4763 acc:0.5548437237739563
Epoch: 1/100... Step: 30... Loss: 1.4863... Val Loss: 1.4723 acc:0.5546093583106995
Epoch: 1/100... Step: 40... Loss: 1.4652... Val Loss: 1.4681 acc:0.5569531321525574
Epoch: 1/100... Step: 50... Loss: 1.5087... Val Loss: 1.4653 acc:0.5575781464576721
Epoch: 1/100... Step: 60... Loss: 1.4948... Val Loss: 1.4610 acc:0.5596093535423279
Epoch: 1/100... Step: 70... Loss: 1.4860... Val Loss: 1.4586 acc:0.5614843964576721
Epoch: 1/100... Step: 80... Loss: 1.4970... Val Loss: 1.4597 acc:0.5596874952316284
Epoch: 1/100... Step: 90... Loss: 1.4899... Val Loss: 1.4542 acc:0.5625781416893005
Epoch: 1/100... Step: 100... Loss: 1.4754... Val Loss: 1.4546 acc:0.5637500286102295
Epoch: 1/100... Step: 110... Loss: 1.4868... Val Loss: 1.4503 acc:0.5625
Epoch: 1/100... Step: 120... Loss: 1.4911... Val Loss: 1.4478 acc:0.5608593821525574
E

Epoch: 5/100... Step: 990... Loss: 1.3812... Val Loss: 1.3289 acc:0.5948437452316284
Epoch: 5/100... Step: 1000... Loss: 1.4033... Val Loss: 1.3273 acc:0.5956249833106995
Epoch: 6/100... Step: 1010... Loss: 1.3398... Val Loss: 1.3288 acc:0.5960937738418579
Epoch: 6/100... Step: 1020... Loss: 1.3531... Val Loss: 1.3252 acc:0.5955469012260437
Epoch: 6/100... Step: 1030... Loss: 1.3147... Val Loss: 1.3259 acc:0.5941406488418579
Epoch: 6/100... Step: 1040... Loss: 1.2861... Val Loss: 1.3262 acc:0.596484363079071
Epoch: 6/100... Step: 1050... Loss: 1.3324... Val Loss: 1.3235 acc:0.5992968678474426
Epoch: 6/100... Step: 1060... Loss: 1.3257... Val Loss: 1.3219 acc:0.5950781106948853
Epoch: 6/100... Step: 1070... Loss: 1.3229... Val Loss: 1.3228 acc:0.594921886920929
Epoch: 6/100... Step: 1080... Loss: 1.3293... Val Loss: 1.3202 acc:0.599609375
Epoch: 6/100... Step: 1090... Loss: 1.3138... Val Loss: 1.3213 acc:0.5986718535423279
Epoch: 6/100... Step: 1100... Loss: 1.3092... Val Loss: 1.3214 a

Epoch: 10/100... Step: 1950... Loss: 1.2829... Val Loss: 1.2703 acc:0.6094531416893005
Epoch: 10/100... Step: 1960... Loss: 1.2591... Val Loss: 1.2714 acc:0.6100780963897705
Epoch: 10/100... Step: 1970... Loss: 1.2666... Val Loss: 1.2731 acc:0.6104687452316284
Epoch: 10/100... Step: 1980... Loss: 1.2534... Val Loss: 1.2693 acc:0.6111719012260437
Epoch: 10/100... Step: 1990... Loss: 1.2917... Val Loss: 1.2677 acc:0.610546886920929
Epoch: 10/100... Step: 2000... Loss: 1.3183... Val Loss: 1.2669 acc:0.6103125214576721
Epoch: 11/100... Step: 2010... Loss: 1.2534... Val Loss: 1.2689 acc:0.6100000143051147
Epoch: 11/100... Step: 2020... Loss: 1.2606... Val Loss: 1.2688 acc:0.6115624904632568
Epoch: 11/100... Step: 2030... Loss: 1.2401... Val Loss: 1.2652 acc:0.6121875047683716
Epoch: 11/100... Step: 2040... Loss: 1.2184... Val Loss: 1.2654 acc:0.61328125
Epoch: 11/100... Step: 2050... Loss: 1.2563... Val Loss: 1.2670 acc:0.6145312786102295
Epoch: 11/100... Step: 2060... Loss: 1.2489... Val L

Epoch: 15/100... Step: 2900... Loss: 1.1905... Val Loss: 1.2396 acc:0.622265636920929
Epoch: 15/100... Step: 2910... Loss: 1.2085... Val Loss: 1.2431 acc:0.6189844012260437
Epoch: 15/100... Step: 2920... Loss: 1.1920... Val Loss: 1.2388 acc:0.6224218606948853
Epoch: 15/100... Step: 2930... Loss: 1.1870... Val Loss: 1.2347 acc:0.6206250190734863
Epoch: 15/100... Step: 2940... Loss: 1.2240... Val Loss: 1.2388 acc:0.6184375286102295
Epoch: 15/100... Step: 2950... Loss: 1.2201... Val Loss: 1.2392 acc:0.6200781464576721
Epoch: 15/100... Step: 2960... Loss: 1.2119... Val Loss: 1.2426 acc:0.6197656393051147
Epoch: 15/100... Step: 2970... Loss: 1.2080... Val Loss: 1.2401 acc:0.6193749904632568
Epoch: 15/100... Step: 2980... Loss: 1.2075... Val Loss: 1.2413 acc:0.6211718916893005
Epoch: 15/100... Step: 2990... Loss: 1.2371... Val Loss: 1.2407 acc:0.6236718893051147
Epoch: 15/100... Step: 3000... Loss: 1.2705... Val Loss: 1.2357 acc:0.6200000047683716
Epoch: 16/100... Step: 3010... Loss: 1.2095.

Epoch: 20/100... Step: 3850... Loss: 1.1669... Val Loss: 1.2238 acc:0.6278906464576721
Epoch: 20/100... Step: 3860... Loss: 1.1775... Val Loss: 1.2195 acc:0.6271093487739563
Epoch: 20/100... Step: 3870... Loss: 1.1473... Val Loss: 1.2232 acc:0.6303125023841858
Epoch: 20/100... Step: 3880... Loss: 1.1628... Val Loss: 1.2201 acc:0.6287500262260437
Epoch: 20/100... Step: 3890... Loss: 1.1515... Val Loss: 1.2202 acc:0.6324999928474426
Epoch: 20/100... Step: 3900... Loss: 1.1283... Val Loss: 1.2210 acc:0.6288281083106995
Epoch: 20/100... Step: 3910... Loss: 1.1612... Val Loss: 1.2221 acc:0.6268749833106995
Epoch: 20/100... Step: 3920... Loss: 1.1467... Val Loss: 1.2212 acc:0.6304687261581421
Epoch: 20/100... Step: 3930... Loss: 1.1456... Val Loss: 1.2178 acc:0.6301562786102295
Epoch: 20/100... Step: 3940... Loss: 1.1811... Val Loss: 1.2152 acc:0.6287500262260437
Epoch: 20/100... Step: 3950... Loss: 1.1784... Val Loss: 1.2161 acc:0.6271093487739563
Epoch: 20/100... Step: 3960... Loss: 1.1695

Epoch: 24/100... Step: 4800... Loss: 1.2084... Val Loss: 1.2053 acc:0.6334375143051147
Epoch: 25/100... Step: 4810... Loss: 1.1435... Val Loss: 1.2099 acc:0.6310937404632568
Epoch: 25/100... Step: 4820... Loss: 1.1580... Val Loss: 1.2077 acc:0.633593738079071
Epoch: 25/100... Step: 4830... Loss: 1.1104... Val Loss: 1.2097 acc:0.6315624713897705
Epoch: 25/100... Step: 4840... Loss: 1.1081... Val Loss: 1.2092 acc:0.6365625262260437
Epoch: 25/100... Step: 4850... Loss: 1.1281... Val Loss: 1.2085 acc:0.6353906393051147
Epoch: 25/100... Step: 4860... Loss: 1.1370... Val Loss: 1.2060 acc:0.6337500214576721
Epoch: 25/100... Step: 4870... Loss: 1.1256... Val Loss: 1.2130 acc:0.6327343583106995
Epoch: 25/100... Step: 4880... Loss: 1.1356... Val Loss: 1.2084 acc:0.635546863079071
Epoch: 25/100... Step: 4890... Loss: 1.1241... Val Loss: 1.2055 acc:0.6318749785423279
Epoch: 25/100... Step: 4900... Loss: 1.1185... Val Loss: 1.2056 acc:0.6341406106948853
Epoch: 25/100... Step: 4910... Loss: 1.1365..

Epoch: 29/100... Step: 5750... Loss: 1.1321... Val Loss: 1.2077 acc:0.6331250071525574
Epoch: 29/100... Step: 5760... Loss: 1.1080... Val Loss: 1.2014 acc:0.6367968916893005
Epoch: 29/100... Step: 5770... Loss: 1.1161... Val Loss: 1.2224 acc:0.6333593726158142
Epoch: 29/100... Step: 5780... Loss: 1.1159... Val Loss: 1.2093 acc:0.6357031464576721
Epoch: 29/100... Step: 5790... Loss: 1.1379... Val Loss: 1.2070 acc:0.6327343583106995
Epoch: 29/100... Step: 5800... Loss: 1.1848... Val Loss: 1.2027 acc:0.6360156536102295
Epoch: 30/100... Step: 5810... Loss: 1.1183... Val Loss: 1.2090 acc:0.6344531178474426
Epoch: 30/100... Step: 5820... Loss: 1.1350... Val Loss: 1.2043 acc:0.6392968893051147
Epoch: 30/100... Step: 5830... Loss: 1.1016... Val Loss: 1.2074 acc:0.634765625
Epoch: 30/100... Step: 5840... Loss: 1.0835... Val Loss: 1.2091 acc:0.6354687213897705
Epoch: 30/100... Step: 5850... Loss: 1.1178... Val Loss: 1.2004 acc:0.6360937356948853
Epoch: 30/100... Step: 5860... Loss: 1.1179... Val

Epoch: 34/100... Step: 6710... Loss: 1.1036... Val Loss: 1.2043 acc:0.6392187476158142
Epoch: 34/100... Step: 6720... Loss: 1.0793... Val Loss: 1.2019 acc:0.6407812237739563
Epoch: 34/100... Step: 6730... Loss: 1.0804... Val Loss: 1.2016 acc:0.6353906393051147
Epoch: 34/100... Step: 6740... Loss: 1.1122... Val Loss: 1.2009 acc:0.6385937333106995
Epoch: 34/100... Step: 6750... Loss: 1.1155... Val Loss: 1.2105 acc:0.6382812261581421
Epoch: 34/100... Step: 6760... Loss: 1.1063... Val Loss: 1.2072 acc:0.6352343559265137
Epoch: 34/100... Step: 6770... Loss: 1.1072... Val Loss: 1.2128 acc:0.6360937356948853
Epoch: 34/100... Step: 6780... Loss: 1.0983... Val Loss: 1.2113 acc:0.6362500190734863
Epoch: 34/100... Step: 6790... Loss: 1.1183... Val Loss: 1.2021 acc:0.6391406059265137
Epoch: 34/100... Step: 6800... Loss: 1.1604... Val Loss: 1.2011 acc:0.637890636920929
Epoch: 35/100... Step: 6810... Loss: 1.0942... Val Loss: 1.2074 acc:0.6368749737739563
Epoch: 35/100... Step: 6820... Loss: 1.1107.

Epoch: 39/100... Step: 7660... Loss: 1.0740... Val Loss: 1.2015 acc:0.6396093964576721
Epoch: 39/100... Step: 7670... Loss: 1.0665... Val Loss: 1.2062 acc:0.639453113079071
Epoch: 39/100... Step: 7680... Loss: 1.0789... Val Loss: 1.2051 acc:0.639843761920929
Epoch: 39/100... Step: 7690... Loss: 1.0615... Val Loss: 1.2032 acc:0.637499988079071
Epoch: 39/100... Step: 7700... Loss: 1.0676... Val Loss: 1.2016 acc:0.6389843821525574
Epoch: 39/100... Step: 7710... Loss: 1.0749... Val Loss: 1.2051 acc:0.6392968893051147
Epoch: 39/100... Step: 7720... Loss: 1.0634... Val Loss: 1.1996 acc:0.6407031416893005
Epoch: 39/100... Step: 7730... Loss: 1.0659... Val Loss: 1.1975 acc:0.6397656202316284
Epoch: 39/100... Step: 7740... Loss: 1.0905... Val Loss: 1.2031 acc:0.638671875
Epoch: 39/100... Step: 7750... Loss: 1.0928... Val Loss: 1.2083 acc:0.6382031440734863
Epoch: 39/100... Step: 7760... Loss: 1.0841... Val Loss: 1.2049 acc:0.6381250023841858
Epoch: 39/100... Step: 7770... Loss: 1.0816... Val Lo

Epoch: 44/100... Step: 8610... Loss: 1.0727... Val Loss: 1.2122 acc:0.6379687786102295
Epoch: 44/100... Step: 8620... Loss: 1.0718... Val Loss: 1.2046 acc:0.6401562690734863
Epoch: 44/100... Step: 8630... Loss: 1.0481... Val Loss: 1.2020 acc:0.6407812237739563
Epoch: 44/100... Step: 8640... Loss: 1.0433... Val Loss: 1.2101 acc:0.6380468606948853
Epoch: 44/100... Step: 8650... Loss: 1.0658... Val Loss: 1.2067 acc:0.6396093964576721
Epoch: 44/100... Step: 8660... Loss: 1.0655... Val Loss: 1.2036 acc:0.6404687762260437
Epoch: 44/100... Step: 8670... Loss: 1.0541... Val Loss: 1.2060 acc:0.6393749713897705
Epoch: 44/100... Step: 8680... Loss: 1.0677... Val Loss: 1.2081 acc:0.6390625238418579
Epoch: 44/100... Step: 8690... Loss: 1.0430... Val Loss: 1.2042 acc:0.6384375095367432
Epoch: 44/100... Step: 8700... Loss: 1.0586... Val Loss: 1.2006 acc:0.6389843821525574
Epoch: 44/100... Step: 8710... Loss: 1.0682... Val Loss: 1.2145 acc:0.6397656202316284
Epoch: 44/100... Step: 8720... Loss: 1.0486

Epoch: 48/100... Step: 9560... Loss: 1.0575... Val Loss: 1.2042 acc:0.6381250023841858
Epoch: 48/100... Step: 9570... Loss: 1.0493... Val Loss: 1.2095 acc:0.6374218463897705
Epoch: 48/100... Step: 9580... Loss: 1.0580... Val Loss: 1.2069 acc:0.6379687786102295
Epoch: 48/100... Step: 9590... Loss: 1.0725... Val Loss: 1.2034 acc:0.6401562690734863
Epoch: 48/100... Step: 9600... Loss: 1.1065... Val Loss: 1.2083 acc:0.6382031440734863
Epoch: 49/100... Step: 9610... Loss: 1.0527... Val Loss: 1.2138 acc:0.6385156512260437
Epoch: 49/100... Step: 9620... Loss: 1.0664... Val Loss: 1.2031 acc:0.6404687762260437
Epoch: 49/100... Step: 9630... Loss: 1.0347... Val Loss: 1.2052 acc:0.6387500166893005
Epoch: 49/100... Step: 9640... Loss: 1.0334... Val Loss: 1.2108 acc:0.6374218463897705
Epoch: 49/100... Step: 9650... Loss: 1.0365... Val Loss: 1.2127 acc:0.6385937333106995
Epoch: 49/100... Step: 9660... Loss: 1.0432... Val Loss: 1.2035 acc:0.6414843797683716
Epoch: 49/100... Step: 9670... Loss: 1.0402

Epoch: 53/100... Step: 10510... Loss: 1.0399... Val Loss: 1.2161 acc:0.6380468606948853
Epoch: 53/100... Step: 10520... Loss: 1.0296... Val Loss: 1.2085 acc:0.6414843797683716
Epoch: 53/100... Step: 10530... Loss: 1.0313... Val Loss: 1.2122 acc:0.6370312571525574
Epoch: 53/100... Step: 10540... Loss: 1.0470... Val Loss: 1.2112 acc:0.6401562690734863
Epoch: 53/100... Step: 10550... Loss: 1.0509... Val Loss: 1.2168 acc:0.6392187476158142
Epoch: 53/100... Step: 10560... Loss: 1.0488... Val Loss: 1.2075 acc:0.6388280987739563
Epoch: 53/100... Step: 10570... Loss: 1.0342... Val Loss: 1.2141 acc:0.637499988079071
Epoch: 53/100... Step: 10580... Loss: 1.0397... Val Loss: 1.2111 acc:0.6385937333106995
Epoch: 53/100... Step: 10590... Loss: 1.0621... Val Loss: 1.2077 acc:0.6391406059265137
Epoch: 53/100... Step: 10600... Loss: 1.1009... Val Loss: 1.2140 acc:0.640625
Epoch: 54/100... Step: 10610... Loss: 1.0416... Val Loss: 1.2166 acc:0.6399999856948853
Epoch: 54/100... Step: 10620... Loss: 1.056

Epoch: 58/100... Step: 11450... Loss: 1.0290... Val Loss: 1.2161 acc:0.6414843797683716
Epoch: 58/100... Step: 11460... Loss: 1.0263... Val Loss: 1.2174 acc:0.6403906345367432
Epoch: 58/100... Step: 11470... Loss: 1.0173... Val Loss: 1.2199 acc:0.639843761920929
Epoch: 58/100... Step: 11480... Loss: 1.0218... Val Loss: 1.2190 acc:0.639843761920929
Epoch: 58/100... Step: 11490... Loss: 1.0135... Val Loss: 1.2145 acc:0.6417187452316284
Epoch: 58/100... Step: 11500... Loss: 1.0093... Val Loss: 1.2154 acc:0.6392187476158142
Epoch: 58/100... Step: 11510... Loss: 1.0286... Val Loss: 1.2203 acc:0.6391406059265137
Epoch: 58/100... Step: 11520... Loss: 1.0136... Val Loss: 1.2164 acc:0.6419531106948853
Epoch: 58/100... Step: 11530... Loss: 1.0164... Val Loss: 1.2182 acc:0.6382031440734863
Epoch: 58/100... Step: 11540... Loss: 1.0341... Val Loss: 1.2153 acc:0.6380468606948853
Epoch: 58/100... Step: 11550... Loss: 1.0415... Val Loss: 1.2199 acc:0.6393749713897705
Epoch: 58/100... Step: 11560... Lo

Epoch: 62/100... Step: 12390... Loss: 1.0420... Val Loss: 1.2212 acc:0.638671875
Epoch: 62/100... Step: 12400... Loss: 1.0718... Val Loss: 1.2230 acc:0.6378124952316284
Epoch: 63/100... Step: 12410... Loss: 1.0213... Val Loss: 1.2218 acc:0.639453113079071
Epoch: 63/100... Step: 12420... Loss: 1.0362... Val Loss: 1.2197 acc:0.6389062404632568
Epoch: 63/100... Step: 12430... Loss: 1.0012... Val Loss: 1.2269 acc:0.6382031440734863
Epoch: 63/100... Step: 12440... Loss: 0.9980... Val Loss: 1.2253 acc:0.6407812237739563
Epoch: 63/100... Step: 12450... Loss: 1.0108... Val Loss: 1.2224 acc:0.638671875
Epoch: 63/100... Step: 12460... Loss: 1.0180... Val Loss: 1.2236 acc:0.6380468606948853
Epoch: 63/100... Step: 12470... Loss: 1.0054... Val Loss: 1.2263 acc:0.6408593654632568
Epoch: 63/100... Step: 12480... Loss: 1.0229... Val Loss: 1.2244 acc:0.639843761920929
Epoch: 63/100... Step: 12490... Loss: 0.9990... Val Loss: 1.2199 acc:0.6418750286102295
Epoch: 63/100... Step: 12500... Loss: 1.0062... 

Epoch: 67/100... Step: 13330... Loss: 0.9985... Val Loss: 1.2286 acc:0.6369531154632568
Epoch: 67/100... Step: 13340... Loss: 1.0210... Val Loss: 1.2203 acc:0.6401562690734863
Epoch: 67/100... Step: 13350... Loss: 1.0217... Val Loss: 1.2307 acc:0.6380468606948853
Epoch: 67/100... Step: 13360... Loss: 1.0168... Val Loss: 1.2225 acc:0.6385937333106995
Epoch: 67/100... Step: 13370... Loss: 1.0052... Val Loss: 1.2289 acc:0.6378124952316284
Epoch: 67/100... Step: 13380... Loss: 1.0061... Val Loss: 1.2276 acc:0.6357812285423279
Epoch: 67/100... Step: 13390... Loss: 1.0288... Val Loss: 1.2260 acc:0.6370312571525574
Epoch: 67/100... Step: 13400... Loss: 1.0692... Val Loss: 1.2258 acc:0.6376562714576721
Epoch: 68/100... Step: 13410... Loss: 1.0112... Val Loss: 1.2248 acc:0.6389062404632568
Epoch: 68/100... Step: 13420... Loss: 1.0230... Val Loss: 1.2240 acc:0.6407031416893005
Epoch: 68/100... Step: 13430... Loss: 0.9907... Val Loss: 1.2295 acc:0.637499988079071
Epoch: 68/100... Step: 13440... L

Epoch: 72/100... Step: 14270... Loss: 0.9872... Val Loss: 1.2323 acc:0.6395312547683716
Epoch: 72/100... Step: 14280... Loss: 1.0060... Val Loss: 1.2325 acc:0.6392187476158142
Epoch: 72/100... Step: 14290... Loss: 0.9951... Val Loss: 1.2258 acc:0.6365625262260437
Epoch: 72/100... Step: 14300... Loss: 0.9886... Val Loss: 1.2419 acc:0.6378124952316284
Epoch: 72/100... Step: 14310... Loss: 1.0030... Val Loss: 1.2342 acc:0.6411718726158142
Epoch: 72/100... Step: 14320... Loss: 0.9833... Val Loss: 1.2325 acc:0.6384375095367432
Epoch: 72/100... Step: 14330... Loss: 0.9945... Val Loss: 1.2267 acc:0.6385937333106995
Epoch: 72/100... Step: 14340... Loss: 1.0186... Val Loss: 1.2269 acc:0.6403906345367432
Epoch: 72/100... Step: 14350... Loss: 1.0079... Val Loss: 1.2350 acc:0.6384375095367432
Epoch: 72/100... Step: 14360... Loss: 1.0073... Val Loss: 1.2308 acc:0.6360156536102295
Epoch: 72/100... Step: 14370... Loss: 0.9949... Val Loss: 1.2361 acc:0.6353906393051147
Epoch: 72/100... Step: 14380... 

Epoch: 77/100... Step: 15210... Loss: 1.0012... Val Loss: 1.2398 acc:0.6358593702316284
Epoch: 77/100... Step: 15220... Loss: 1.0125... Val Loss: 1.2398 acc:0.6369531154632568
Epoch: 77/100... Step: 15230... Loss: 0.9728... Val Loss: 1.2333 acc:0.639843761920929
Epoch: 77/100... Step: 15240... Loss: 0.9731... Val Loss: 1.2358 acc:0.6393749713897705
Epoch: 77/100... Step: 15250... Loss: 0.9898... Val Loss: 1.2374 acc:0.6372656226158142
Epoch: 77/100... Step: 15260... Loss: 0.9914... Val Loss: 1.2368 acc:0.6374218463897705
Epoch: 77/100... Step: 15270... Loss: 0.9817... Val Loss: 1.2392 acc:0.6373437643051147
Epoch: 77/100... Step: 15280... Loss: 0.9950... Val Loss: 1.2366 acc:0.6392187476158142
Epoch: 77/100... Step: 15290... Loss: 0.9767... Val Loss: 1.2360 acc:0.6371874809265137
Epoch: 77/100... Step: 15300... Loss: 0.9850... Val Loss: 1.2494 acc:0.6378124952316284
Epoch: 77/100... Step: 15310... Loss: 0.9947... Val Loss: 1.2374 acc:0.6371874809265137
Epoch: 77/100... Step: 15320... L

Epoch: 81/100... Step: 16150... Loss: 0.9935... Val Loss: 1.2415 acc:0.6346874833106995
Epoch: 81/100... Step: 16160... Loss: 0.9910... Val Loss: 1.2383 acc:0.6385156512260437
Epoch: 81/100... Step: 16170... Loss: 0.9792... Val Loss: 1.2447 acc:0.6374218463897705
Epoch: 81/100... Step: 16180... Loss: 0.9885... Val Loss: 1.2441 acc:0.6377343535423279
Epoch: 81/100... Step: 16190... Loss: 0.9991... Val Loss: 1.2456 acc:0.6376562714576721
Epoch: 81/100... Step: 16200... Loss: 1.0453... Val Loss: 1.2366 acc:0.6374218463897705
Epoch: 82/100... Step: 16210... Loss: 0.9943... Val Loss: 1.2442 acc:0.637890636920929
Epoch: 82/100... Step: 16220... Loss: 0.9959... Val Loss: 1.2401 acc:0.6399999856948853
Epoch: 82/100... Step: 16230... Loss: 0.9667... Val Loss: 1.2421 acc:0.639453113079071
Epoch: 82/100... Step: 16240... Loss: 0.9838... Val Loss: 1.2395 acc:0.6391406059265137
Epoch: 82/100... Step: 16250... Loss: 0.9796... Val Loss: 1.2466 acc:0.6349218487739563
Epoch: 82/100... Step: 16260... Lo

Epoch: 86/100... Step: 17090... Loss: 0.9681... Val Loss: 1.2529 acc:0.6364843845367432
Epoch: 86/100... Step: 17100... Loss: 0.9744... Val Loss: 1.2668 acc:0.6356250047683716
Epoch: 86/100... Step: 17110... Loss: 0.9965... Val Loss: 1.2435 acc:0.6381250023841858
Epoch: 86/100... Step: 17120... Loss: 0.9625... Val Loss: 1.2612 acc:0.6339062452316284
Epoch: 86/100... Step: 17130... Loss: 0.9661... Val Loss: 1.2499 acc:0.6382031440734863
Epoch: 86/100... Step: 17140... Loss: 0.9863... Val Loss: 1.2499 acc:0.6370312571525574
Epoch: 86/100... Step: 17150... Loss: 0.9940... Val Loss: 1.2441 acc:0.6371093988418579
Epoch: 86/100... Step: 17160... Loss: 0.9901... Val Loss: 1.2477 acc:0.6362500190734863
Epoch: 86/100... Step: 17170... Loss: 0.9789... Val Loss: 1.2538 acc:0.635937511920929
Epoch: 86/100... Step: 17180... Loss: 0.9717... Val Loss: 1.2501 acc:0.6357031464576721
Epoch: 86/100... Step: 17190... Loss: 0.9994... Val Loss: 1.2547 acc:0.6396093964576721
Epoch: 86/100... Step: 17200... L

Epoch: 91/100... Step: 18030... Loss: 0.9655... Val Loss: 1.2483 acc:0.6396093964576721
Epoch: 91/100... Step: 18040... Loss: 0.9675... Val Loss: 1.2492 acc:0.6389062404632568
Epoch: 91/100... Step: 18050... Loss: 0.9736... Val Loss: 1.2574 acc:0.6366406083106995
Epoch: 91/100... Step: 18060... Loss: 0.9727... Val Loss: 1.2539 acc:0.6375781297683716
Epoch: 91/100... Step: 18070... Loss: 0.9608... Val Loss: 1.2533 acc:0.6370312571525574
Epoch: 91/100... Step: 18080... Loss: 0.9768... Val Loss: 1.2509 acc:0.6368749737739563
Epoch: 91/100... Step: 18090... Loss: 0.9541... Val Loss: 1.2542 acc:0.6357812285423279
Epoch: 91/100... Step: 18100... Loss: 0.9676... Val Loss: 1.2643 acc:0.6349218487739563
Epoch: 91/100... Step: 18110... Loss: 0.9894... Val Loss: 1.2450 acc:0.6354687213897705
Epoch: 91/100... Step: 18120... Loss: 0.9615... Val Loss: 1.2645 acc:0.6379687786102295
Epoch: 91/100... Step: 18130... Loss: 0.9717... Val Loss: 1.2520 acc:0.6370312571525574
Epoch: 91/100... Step: 18140... 

Epoch: 95/100... Step: 18970... Loss: 0.9643... Val Loss: 1.2654 acc:0.633984386920929
Epoch: 95/100... Step: 18980... Loss: 0.9701... Val Loss: 1.2584 acc:0.6357812285423279
Epoch: 95/100... Step: 18990... Loss: 0.9832... Val Loss: 1.2566 acc:0.6390625238418579
Epoch: 95/100... Step: 19000... Loss: 1.0244... Val Loss: 1.2533 acc:0.6387500166893005
Epoch: 96/100... Step: 19010... Loss: 0.9702... Val Loss: 1.2652 acc:0.6371874809265137
Epoch: 96/100... Step: 19020... Loss: 0.9863... Val Loss: 1.2568 acc:0.6353124976158142
Epoch: 96/100... Step: 19030... Loss: 0.9490... Val Loss: 1.2527 acc:0.6360156536102295
Epoch: 96/100... Step: 19040... Loss: 0.9599... Val Loss: 1.2580 acc:0.633984386920929
Epoch: 96/100... Step: 19050... Loss: 0.9718... Val Loss: 1.2591 acc:0.6375781297683716
Epoch: 96/100... Step: 19060... Loss: 0.9730... Val Loss: 1.2582 acc:0.6322656273841858
Epoch: 96/100... Step: 19070... Loss: 0.9529... Val Loss: 1.2554 acc:0.6360156536102295
Epoch: 96/100... Step: 19080... Lo

Epoch: 100/100... Step: 19910... Loss: 0.9753... Val Loss: 1.2506 acc:0.6364062428474426
Epoch: 100/100... Step: 19920... Loss: 0.9431... Val Loss: 1.2673 acc:0.637890636920929
Epoch: 100/100... Step: 19930... Loss: 0.9500... Val Loss: 1.2677 acc:0.6368749737739563
Epoch: 100/100... Step: 19940... Loss: 0.9750... Val Loss: 1.2570 acc:0.6357812285423279
Epoch: 100/100... Step: 19950... Loss: 0.9730... Val Loss: 1.2585 acc:0.6350781321525574
Epoch: 100/100... Step: 19960... Loss: 0.9699... Val Loss: 1.2684 acc:0.6334375143051147
Epoch: 100/100... Step: 19970... Loss: 0.9628... Val Loss: 1.2629 acc:0.6343749761581421
Epoch: 100/100... Step: 19980... Loss: 0.9632... Val Loss: 1.2630 acc:0.6352343559265137
Epoch: 100/100... Step: 19990... Loss: 0.9842... Val Loss: 1.2679 acc:0.6343749761581421
Epoch: 100/100... Step: 20000... Loss: 1.0271... Val Loss: 1.2583 acc:0.6363281011581421


In [99]:
def predict(net, char, h=None, top_k=None):
        ''' Given a character, predict the next character.
            Returns the predicted character and the hidden state.
        '''
        
        # tensor inputs
        x = np.array([[net.char2int[char]]])
        x = one_hot_encode(x, len(net.chars))
        inputs = torch.from_numpy(x)
        
        if(train_on_gpu):
            inputs = inputs.cuda()
        
        # detach hidden state from history
        h = tuple([each.data for each in h])
        # get the output of the model
        out, h = net(inputs, h)

        # get the character probabilities
        p = F.softmax(out, dim=1).data
        if(train_on_gpu):
            p = p.cpu() # move to cpu
        
        # get top characters
        if top_k is None:
            top_ch = np.arange(len(net.chars))
        else:
            p, top_ch = p.topk(top_k)
            top_ch = top_ch.numpy().squeeze()
        
        # select the likely next character with some element of randomness
        p = p.numpy().squeeze()
        char = np.random.choice(top_ch, p=p/p.sum())
        
        # return the encoded value of the predicted char and the hidden state
        return net.int2char[char], h

In [100]:
def sample(net, size, prime='The', top_k=None):
        
    if(train_on_gpu):
        net.cuda()
    else:
        net.cpu()
    
    net.eval() # eval mode
    
    # First off, run through the prime characters
    chars = [ch for ch in prime]
    h = net.init_hidden(1)
    for ch in prime:
        char, h = predict(net, ch, h, top_k=top_k)

    chars.append(char)
    
    # Now pass in the previous character and get a new one
    for ii in range(size):
        char, h = predict(net, chars[-1], h, top_k=top_k)
        chars.append(char)

    return ''.join(chars)

In [108]:
print(sample(net, 1, prime='appl', top_k=5))

apple 


In [110]:
torch.save(net, 'output/model/twolayLSTM.pth') 

In [111]:
model=torch.load('output/model/twolayLSTM.pth') 

In [112]:
print(sample(model, 1, prime='appl', top_k=5))

apples
