In [1]:
import torch
from torch import nn
import numpy as np

In [2]:
data_path = 'DATA/rus.txt'

In [3]:
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()

with open(data_path, 'r', encoding='utf-8') as f:
    lines = f.read().split('\n')
for line in lines[:-1]:
    input_text, target_text = line.split('\t')
    # We use "tab" as the "start sequence" character
    # for the targets, and "\n" as "end sequence" character.
    target_text = '\t' + target_text + '\n'
    input_texts.append(input_text)
    target_texts.append(target_text)
    for char in input_text:
        if char not in input_characters:
            input_characters.add(char)
    for char in target_text:
        if char not in target_characters:
            target_characters.add(char)

input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)

In [4]:
print('Number of samples:', len(input_texts))
print('Number of unique input tokens:', num_encoder_tokens)
print('Number of unique output tokens:', num_decoder_tokens)
print('Max sequence length for inputs:', max([len(txt) for txt in input_texts]))
print('Max sequence length for outputs:', max([len(txt) for txt in target_texts]))
print('Median sequence length for inputs:', np.median([len(txt) for txt in input_texts]))
print('Median sequence length for outputs:', np.median([len(txt) for txt in target_texts]))

Number of samples: 304513
Number of unique input tokens: 93
Number of unique output tokens: 158
Max sequence length for inputs: 239
Max sequence length for outputs: 267
Median sequence length for inputs: 27.0
Median sequence length for outputs: 29.0


In [5]:
max_encoder_seq_length = 45
max_decoder_seq_length = 45

In [6]:
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])

encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length),dtype='float32')
decoder_input_data = np.zeros((len(input_texts), max_decoder_seq_length),dtype='float32')
decoder_target_data = np.zeros((len(input_texts), max_decoder_seq_length),dtype='float32')

for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
    # Zip stops after it hits the max seq len
    for t, char in zip(range(max_encoder_seq_length), input_text):
        encoder_input_data[i, t] = input_token_index[char]
    for t, char in zip(range(0, max_decoder_seq_length), target_text):
        # decoder_target_data is ahead of decoder_input_data by one timestep
        decoder_input_data[i, t] = target_token_index[char]
        if t > 0:
            # decoder_target_data will be ahead by one timestep
            # and will not include the start character.
            decoder_target_data[i, t - 1] = target_token_index[char]

print(decoder_target_data.shape)
print(encoder_input_data.shape)
print(decoder_target_data.shape)

(304513, 45)
(304513, 45)
(304513, 45)


In [7]:
def create_emb(vecs, itos, em_sz):
    # Make regular embedding with vocab sz, em_sz, and pad
    emb = nn.Embedding(len(itos), em_sz, padding_idx=1)
    # Get embedding weights
    wgts = emb.weight.data
    
    miss = []
    for i,w in enumerate(itos):
        # Idk triple it or something
        try: wgts[i] = torch.from_numpy(vecs[w]*3)
        # If you missed some, append to w without tripling 
        except: miss.append(w)
    print(len(miss),miss[5:10])
    return emb

In [8]:
class AttentionRNN(nn.Module):
    def __init__(vecs_enc, vecs_dec, em_sz, n_h, n_l, inp_sz, out_sz):
        super().__init__()
        self.em_sz, self.n_h, self.n_l, self.inp_sz, self.out_sz = em_sz, n_h, n_l, inp_sz, out_sz
        # Encoder
        self.enc_em = create_emb(self.inp_sz, itos_enc, self.em_sz)
        self.em_drp = nn.Dropout(0.15)
        self.enc_gru = nn.GRU(self.em_sz, self.h_sz, num_layers=self.n_l, dropout=0.2)
        self.enc_drp = nn.Dropout(0.3)
        self.dec_out = nn.Linear(self.h_sz, self.em_sz, bias=False)
        # Decoder
        self.dec_em = create_emb(self.h_sz, itos_dec, self.em_sz)
        self.dec_gru = nn.GRU(self.em_sz, self.h_sz, num_layers=self.n_l, dropout=0.2)
        self.dec_drp = nn.Dropout(0.3)
        self.dec_out = nn.Linear(self.em_sz, self.out_sz)
        self.out.weight.data = self.enc_em.weight.data
        
    def forward(self, inp):
        sl, bs = inp.shape
        h = self.initHidden(bs)
        
        x = self.em_drp(self.em_sz(inp))
        enc_out, h = self.enc_drp(self.enc_em(x, h))
        h = self.out_enc(h)
        
        dec_inp = torch.zeros(bs).long()
        result = []
        # What is this?
        for i in range(self.out_sz):
            emb = self.dec_emb(dec_inp).unsqueeze(0)
            outp, h = self.dec_gru(emb, h)
            outp = self.out(self.dec_drp(outp[0]))
            result.append(outp)
            dec_inp = outp.data.max(1)[1]
            if (dec_inp==1).all(): break
        return torch.stack(res)
    
     def initHidden(self, bs): 
        # Num_layers, batch size, num hidden
        return torch.zeros(self.nl, bs, self.nh)

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 38)

In [None]:
n_h = 128
n_l = 2
inp_sz = num_encoder_tokens max_encoder_seq_length
out_sz = num_decoder_tokens
model = AttentionRNN()

# TODO

In [9]:

###RUN THROUGH DL2 TRANSLATE NOTEBOOK AND ANSWER THESE QUESTIONS
#########################################
## WHAT IS 

#vecs_enc - 
    # Dict of words, with embedding vectors values
    # https://i.imgur.com/nIELpdY.png
    # https://i.imgur.com/RgBnu4O.png
#itos_enc - 
    # Index to string
    # List of strings, of which the list index is pointing to a word
    # https://i.imgur.com/oq6Kcv1.png
    # https://i.imgur.com/dXSh60V.png
#vecs_dec - Same as vecs_enc, but for dec
#itos_dec - Same as itos_enc but for dec
#########################################
##WHAT DOES create_emb DO
    # Makes an embedding with wiki vectors weights tripled 
## WHAT IS THE sl,bs IN inp.size()
    # bs is batch size
    # sl is seq_len https://i.imgur.com/icfxqv9.png

##WHAT DOES THE FOR LOOP IN FORWARD DO

##WHY DO YOU TAKE WEIGHT DATA OF OUTPUT EMBEDDINGS (IS THAT RELATED TO THE RETURN?)


##########################################
###FIGURE OUT THE LOSS FUNCTION

## WHY DO YOU PAD THE INPUT LIKE THAT

## WHY DO YOU SLICE THE INPUT
##########################################

In [19]:
# TRY TO TRAIN IT, WRITE OWN TRAIN LOOP

In [17]:
# CONVERT IT TO BIDIR

In [15]:
# TEACHER FORCING

In [16]:
# ATTENTION

In [18]:
# ALL

In [20]:
# CONVERT MODEL TO PRDOCTION