In [1]:
import string
import random
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

#### Prepare for Dataset

In [2]:
all_chars       = string.printable
n_chars         = len(all_chars)
file            = open('./shakespeare.txt').read()
file_len        = len(file)

print('Length of file: {}'.format(file_len))
print('All possible characters: {}'.format(all_chars))
print('Number of all possible characters: {}'.format(n_chars))

Length of file: 1115394
All possible characters: 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ 	

Number of all possible characters: 100


In [3]:
# Get a random sequence of the Shakespeare dataset.
def get_random_seq():
    seq_len     = 128  # The length of an input sequence.
    start_index = random.randint(0, file_len - seq_len)
    end_index   = start_index + seq_len + 1
    return file[start_index:end_index]

# Convert the sequence to one-hot tensor.
def seq_to_onehot(seq):
    tensor = torch.zeros(len(seq), 1, n_chars) 
    # Shape of the tensor:
    #     (sequence length, batch size, classes)
    # Here we use batch size = 1 and classes = number of unique characters.
    for t, char in enumerate(seq):
        index = all_chars.index(char)
        tensor[t][0][index] = 1
    return tensor

# Convert the sequence to index tensor.
def seq_to_index(seq):
    tensor = torch.zeros(len(seq), 1)
    # Shape of the tensor: 
    #     (sequence length, batch size).
    # Here we use batch size = 1.
    for t, char in enumerate(seq):
        tensor[t] = all_chars.index(char)
    return tensor

# Sample a mini-batch including input tensor and target tensor.
def get_input_and_target():
    seq    = get_random_seq()
    input  = seq_to_onehot(seq[:-1])      # Input is represented in one-hot.
    target = seq_to_index(seq[1:]).long() # Target is represented in index.
    return input, target

#### Choose a Device

In [4]:
# If there are GPUs, choose the first one for computing. Otherwise use CPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)  
# If 'cuda:0' is printed, it means GPU is available.

cpu


#### Network Definition

In [5]:
class Net(nn.Module):
    def __init__(self):
        # Initialization.
        super(Net, self).__init__()
        self.input_size  = n_chars   # Input size: Number of unique chars.
        self.hidden_size = 100       # Hidden size: 100.
        self.output_size = n_chars   # Output size: Number of unique chars.
        
        self.encoder = nn.GRUCell(self.input_size, self.hidden_size)
        self.decoder = nn.Linear(self.hidden_size, self.output_size)
    
    def forward(self, input, hidden):
        """ Forward function.
              input:  One-hot input. It refers to the x_t in homework write-up.
              hidden: Previous hidden state. It refers to the h_{t-1}.
            Returns (output, hidden) where output refers to y_t and 
                     hidden refers to h_t.
        """
        # Forward function.
        hidden = self.encoder(input, hidden) ###### To be filled ######
        output = self.decoder(hidden) ###### To be filled ######

        return output, hidden

    def init_hidden(self):
        # Initial hidden state.
        # 1 means batch size = 1.
        return torch.zeros(1, self.hidden_size).to(device) 
    
net = Net()     # Create the network instance.
net.to(device)  # Move the network parameters to the specified device.

Net(
  (encoder): GRUCell(100, 100)
  (decoder): Linear(in_features=100, out_features=100, bias=True)
)

#### Training Step and Evaluation Step

In [6]:
# Training step function.
def train_step(net, opt, input, target):
    """ Training step.
        net:    The network instance.
        opt:    The optimizer instance.
        input:  Input tensor.  Shape: [seq_len, 1, n_chars].
        target: Target tensor. Shape: [seq_len, 1].
    """
    seq_len = input.shape[0]    # Get the sequence length of current input.
    hidden = net.init_hidden()  # Initial hidden state.
    net.zero_grad()             # Clear the gradient.
    loss = 0                    # Initial loss.

    for t in range(seq_len):    # For each one in the input sequence.
        output, hidden = net(input[t], hidden)
        loss += loss_func(output, target[t])

    loss.backward()             # Backward. 
    opt.step()                  # Update the weights.

    return loss / seq_len       # Return the average loss w.r.t sequence length.

In [7]:
# Evaluation step function.
def eval_step(net, init_seq='W', predicted_len=100):
    # Initialize the hidden state, input and the predicted sequence.
    hidden        = net.init_hidden()
    init_input    = seq_to_onehot(init_seq).to(device)
    predicted_seq = init_seq

    # Use initial string to "build up" hidden state.
    for t in range(len(init_seq) - 1):
        output, hidden = net(init_input[t], hidden)
        
    # Set current input as the last character of the initial string.
    input = init_input[-1]
    
    # Predict more characters after the initial string.
    for t in range(predicted_len):
        # Get the current output and hidden state.
        output, hidden = net(input, hidden)
        
        # Sample from the output as a multinomial distribution.
        predicted_index = torch.multinomial(output.view(-1).exp(), 1)[0]
        
        # Add predicted character to the sequence and use it as next input.
        predicted_char  = all_chars[predicted_index]
        predicted_seq  += predicted_char
        
        # Use the predicted character to generate the input of next round.
        input = seq_to_onehot(predicted_char)[0].to(device)

    return predicted_seq

#### Training Procedure

In [8]:
# Number of iterations.
iters       = 15000  # Number of training iterations.
print_iters = 100    # Number of iterations for each log printing.

# The loss variables.
all_losses = []
loss_sum   = 0

# Initialize the optimizer and the loss function.
opt       = torch.optim.Adam(net.parameters(), lr=0.005)
loss_func = nn.CrossEntropyLoss()

# Training procedure.
for i in range(iters):
    input, target = get_input_and_target()            # Fetch input and target.
    input, target = input.to(device), target.to(device) # Move to GPU memory.
    loss      = train_step(net, opt, input, target)   # Calculate the loss.
    loss_sum += loss                                  # Accumulate the loss.

    # Print the log.
    if i % print_iters == print_iters - 1:
        print('iter:{}/{} loss:{}'.format(i, iters, loss_sum / print_iters))
        print('generated sequence: {}\n'.format(eval_step(net)))
              
        # Track the loss.
        all_losses.append(loss_sum / print_iters)
        loss_sum = 0

iter:99/20000 loss:3.1595852375030518
generated sequence: W}a, mePc whae lere y >olg of
cf lasiRve:Nur me core, h fonof.
Ds sy cer i moss imi soiatir;
Rup:
IN'

iter:199/20000 loss:2.5630786418914795
generated sequence: Wea t am srad wo ciiwise noud, sm sa stK
I Io, math,
WO,
Whave! sthif,
WIm cote, Soo ou tho aople wat

iter:299/20000 loss:2.405449151992798
generated sequence: Whd un, youle;

urlirr chas heereicl io mens
Gim fir,
Cacrove mel,
Whaot,. o fat hagisurs:
Hnoumes hu

iter:399/20000 loss:2.3453826904296875
generated sequence: Wachat piyis.
T ans mys,
to-r to mithize,
Aly mall matly thu thacken bu.
AFret,
Ant:
Ahy o donds.
And

iter:499/20000 loss:2.300445079803467
generated sequence: WI someno rimigh in ave rother' non,
He hawh! 'r
Bewes pour brsime, be the hise's,

SLANUTENHFEKOLE:

iter:599/20000 loss:2.21688175201416
generated sequence: Walt seand unscennt auke lrrotaduss mister, not serarver ofurgust ap weak!
Ceay sheten act ben, efotn

iter:699/20000 loss:2.17476463317

iter:5199/20000 loss:1.7071537971496582
generated sequence: Wf destreaning me:
Shembere' now the service.

VELNST:
Then slain'd holy iganings for to come.
And ho

iter:5299/20000 loss:1.6862231492996216
generated sequence: Wallad wither, a tongue to York;
Wher! that rach, praint is alfold by Claul solder med;
Lorb coustch 

iter:5399/20000 loss:1.7111458778381348
generated sequence: We varing there,
Bhonoun and lie hourd's good,
With what forwistion peccedy:
What in muke yoy Towords

iter:5499/20000 loss:1.6861637830734253
generated sequence: Weat athy'd his lady and fear
Meadrrenow'd ever'd, sir,
If in thy may supp speit: demanting gan:
a la

iter:5599/20000 loss:1.66019868850708
generated sequence: Wast, and her his to dot--

Second time
The must? IG art thy Panisepr'd,
Good; me
As door not: to liv

iter:5699/20000 loss:1.7239387035369873
generated sequence: Which him.

GLOUCESTER:
Alked the wein so lients all me bide's carr tage.

BRUTUS:
Nor, his give good

iter:5799/20000 loss:1.7

iter:10299/20000 loss:1.6378488540649414
generated sequence: What atherdolas:
Is ventuse as not of lonb, I am as flow in joy's gob-stagy
This be thou spees, is he

iter:10399/20000 loss:1.6267616748809814
generated sequence: Wats like of concen beingman
For socre the bedod! Seelt the ourles, Reday that lookands we deserve,
T

iter:10499/20000 loss:1.6425352096557617
generated sequence: Which magy you thy deat,
And if the dur not shall death, furrhtors
We well.

LARY ANUE:
No hold, none

iter:10599/20000 loss:1.657002568244934
generated sequence: Was shepters other fford.
If be or my lords handle of pland but, sec Verose
Wethincy? hath speak, and

iter:10699/20000 loss:1.6465703248977661
generated sequence: Were this wherewe some,
No may worth me, for that formess.
Nom think would still, do would, do soill'

iter:10799/20000 loss:1.612097144126892
generated sequence: Whack ot first
she of his be justy touch wall art that go dusbory.

BRAKENBURY:
Ha moster's your moaf

iter:10899/20000 l

iter:15399/20000 loss:1.6024255752563477
generated sequence: WARWICI:
I'll be grat,
The revenge love were lender poor
Genttlat Marcim here's ass
For and wish have

iter:15499/20000 loss:1.5839675664901733
generated sequence: Whence: sir.

DUKE VINCENTIO:
The mind marry, have is Trying that
There
The title good I do worthmen.

iter:15599/20000 loss:1.6167601346969604
generated sequence: WBignas makes fie: I, I what to do Loo', as,
Re banding graced a half, far one than friendsliting for

iter:15699/20000 loss:1.6264317035675049
generated sequence: Weathy
Since yerval to graves on,
And set that monest were tood,--O, Alces;
Orshat disserve arrow of 

iter:15799/20000 loss:1.5929278135299683
generated sequence: WARLIO:
?e that stands, but 'tis, am a bolm old we.

KING RICHARD II:
All are-light and re nothing sk

iter:15899/20000 loss:1.5670872926712036
generated sequence: WARTHA:
Ay, that they mount follow back, and the not: sit
Ofking asticia's cause it worthy busiastick

iter:15999/20000

#### Training Loss Curve

In [None]:
plt.xlabel('iters')
plt.ylabel('loss')
plt.plot(all_losses)
plt.show()

#### Evaluation: A Sample of Generated Sequence

In [None]:
print(eval_step(net, predicted_len=600))