## Deep Learning for text with PyTorch

https://campus.datacamp.com/courses/deep-learning-for-text-with-pytorch/advanced-topics-in-deep-learning-for-text-with-pytorch?ex=7

In [9]:
import torch
from torch import nn

In [183]:
train_data = ["the cat sat on the mat","the cat wears a hat","a cat eats cat food","the cat paints a picture"]

#vocab = list(set(' '.join(data).split()))
#vocab = {index: word for index, word in enumerate(vocab, start=1)}
#print(vocab)

vocab = set(' '.join(train_data).split())
word_to_ix = {word: i for i, word in enumerate(vocab)}
ix_to_word = {i: word for word, i in word_to_ix.items()}
pairs = [sentence.split() for sentence in train_data]
#print(pairs)

#input_data = []
input_data_index = []
#target_data = []
target_data_index = []

for pair in pairs:
    paircounter = 1
    #print(f"pair: {pair}")
    #print(paircounter)
    for word in pair:
        words = pair[0:paircounter]
        #print(f"words: {words}")       
        input_data.append(words)
        input_data_index.append(torch.tensor([word_to_ix[word] for word in words], dtype=torch.long))        
        target_data.append([pair[paircounter]])
        target_data_index.append(torch.tensor([word_to_ix[pair[paircounter]]], dtype=torch.long))
        paircounter += 1    
        if paircounter >= len(pair):
            break        
#print(input_data)  
#print(input_data_index)  
#print(target_data)  
#print(target_data_index)  
        
input_data = [[word_to_ix[word] for word in sentence[:-1]] for sentence in pairs]
target_data = [word_to_ix[sentence[-1]] for sentence in pairs]
inputs = [torch.tensor(seq, dtype=torch.long) for seq in input_data]
targets = torch.tensor(target_data, dtype=torch.long)
vocab_size = len(vocab)
#print(word_to_ix)
print(inputs) 
print(targets)
#print(vocab_size)

[tensor([ 3,  1, 10,  4,  3]), tensor([3, 1, 0, 6]), tensor([6, 1, 8, 1]), tensor([3, 1, 9, 6])]
tensor([ 5,  7, 11,  2])


In [184]:
embedding_dimension = 10
hidden_dimension = 16

class RNNWithAttentionModel(nn.Module):
    def __init__(self):
        super(RNNWithAttentionModel, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dimension)
        self.rnn = nn.RNN(embedding_dimension, hidden_dimension, batch_first=True)
        self.attention = nn.Linear(hidden_dimension, 1)
        self.fc = nn.Linear(hidden_dimension, vocab_size)

    def forward(self, x):
        x = self.embeddings(x)
        out, _ = self.rnn(x)
        attention_weights = torch.nn.functional.softmax(self.attention(out).squeeze(2), dim=1)
        context = torch.sum(attention_weights.unsqueeze(2) * out, dim=1)
        out = self.fc(context)
        return out

def pad_sequences(batch):
    max_length = max([len(seq) for seq in batch])
    return torch.stack([torch.cat([seq, torch.zeros(max_length - len(seq)).long()]) for seq in batch])    

In [185]:
criterion = nn.CrossEntropyLoss()
attention_model = RNNWithAttentionModel()
optimizer = torch.optim.Adam(attention_model.parameters(), lr=0.01)

epochs = 2000

for epoch in range(epochs):
    attention_model.train()
    optimizer.zero_grad()
    #for input in input_data_index:
    #    print(input)    
    padding_inputs = pad_sequences(inputs)
    #print(f"\nInput: {' '.join([ix_to_word[int(ix)] for ix in inputs])}")    
    #print(f"Target: {ix_to_word[int(target)]}")    
    #print(padding_inputs)
    #print(target_data_index)
    #print(padding_inputs)
    outputs = attention_model(padding_inputs)        
    #print(outputs)
    #print(targets)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        print(f"epoch: {epoch}, loss: {loss.item()}")

epoch: 0, loss: 2.360353946685791
epoch: 100, loss: 0.007466976065188646
epoch: 200, loss: 0.0025251884944736958
epoch: 300, loss: 0.0013571387389674783
epoch: 400, loss: 0.000878841383382678
epoch: 500, loss: 0.0006261293892748654
epoch: 600, loss: 0.0004728807834908366
epoch: 700, loss: 0.0003714765189215541
epoch: 800, loss: 0.00030036226962693036
epoch: 900, loss: 0.0002483417047187686
epoch: 1000, loss: 0.00020901163225062191
epoch: 1100, loss: 0.00017844037211034447
epoch: 1200, loss: 0.00015418532711919397
epoch: 1300, loss: 0.0001344590273220092
epoch: 1400, loss: 0.00011830820585601032
epoch: 1500, loss: 0.00010489866690477356
epoch: 1600, loss: 9.351530752610415e-05
epoch: 1700, loss: 8.377082122024149e-05
epoch: 1800, loss: 7.551623275503516e-05
epoch: 1900, loss: 6.827478500781581e-05


In [182]:
print(word_to_ix)
print(input_data)
print(target_data)
#print(targets)
#target_data = [0]
#print(target_data)
#print(type(target_data))

for input_seq, target in zip(input_data, targets):
    input_test = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0)
    #print(input_test)
    #print(target)
    attention_model.eval()
    attention_output = attention_model(input_test)
    
    attention_prediction = ix_to_word[torch.argmax(attention_output).item()]
    
    print(f"\nInput: {' '.join([ix_to_word[int(ix)] for ix in input_seq])}")    
    print(f"Target: {ix_to_word[int(target)]}")
    print(f"RNN with Attention prediction: {attention_prediction}")
    print(attention_output)

{'wears': 0, 'cat': 1, 'picture': 2, 'the': 3, 'on': 4, 'mat': 5, 'a': 6, 'hat': 7, 'eats': 8, 'paints': 9, 'sat': 10, 'food': 11}
[[3, 1, 10, 4, 3], [3, 1, 0, 6], [6, 1, 8, 1], [3, 1, 9, 6]]
[5, 7, 11, 2]

Input: the cat sat on the
Target: mat
RNN with Attention prediction: mat
tensor([[-2.0674, -2.0125, -1.9536, -2.2417, -2.8000, 10.3503, -2.0787, -3.5007,
         -2.3193, -1.8649, -1.6632, -0.3335]], grad_fn=<AddmmBackward0>)

Input: the cat wears a
Target: hat
RNN with Attention prediction: hat
tensor([[-2.1528, -2.0714,  0.8729, -2.5259, -2.7319, -0.1561, -3.5221,  7.7299,
         -2.7464, -2.7942, -2.2278,  1.1865]], grad_fn=<AddmmBackward0>)

Input: a cat eats cat
Target: food
RNN with Attention prediction: food
tensor([[-2.5573, -3.0988, -3.4908, -2.7104, -2.6637, -0.0671, -2.9777,  0.7935,
         -3.4106, -2.0855, -2.8723, 11.0869]], grad_fn=<AddmmBackward0>)

Input: the cat paints a
Target: picture
RNN with Attention prediction: picture
tensor([[-1.1960, -1.3891, 10.8273,

In [192]:
test_data = ["the cat sat on the mat","the cat wears a hat","a cat eats cat food","the cat paints a picture",
            "the cat", "the cat wears", "cat eats", "the cat paints"]
test_pairs = [sentence.split() for sentence in test_data]
test_input_data = [[word_to_ix[word] for word in sentence[:-1]] for sentence in test_pairs]
test_target_data = [word_to_ix[sentence[-1]] for sentence in test_pairs]
test_inputs = [torch.tensor(seq, dtype=torch.long) for seq in test_input_data]
test_targets = torch.tensor(test_target_data, dtype=torch.long)

for test_input_seq, test_target in zip(test_input_data, test_targets):
    print(test_input_seq)
    input_test = torch.tensor(test_input_seq, dtype=torch.long).unsqueeze(0)    
    attention_model.eval()
    attention_output = attention_model(input_test)
    attention_prediction = ix_to_word[torch.argmax(attention_output).item()]
    print(f"\nInput: {' '.join([ix_to_word[int(ix)] for ix in test_input_seq])}")    
    print(f"Target: {ix_to_word[int(test_target)]}")
    print(f"RNN with Attention prediction: {attention_prediction}")
    print(attention_output)

[3, 1, 10, 4, 3]

Input: the cat sat on the
Target: mat
RNN with Attention prediction: mat
tensor([[-1.4444, -1.9390,  0.1107, -1.5883, -1.9620, 10.8326, -2.4085, -0.3996,
         -1.5323, -1.9698, -1.6586, -2.7604]], grad_fn=<AddmmBackward0>)
[3, 1, 0, 6]

Input: the cat wears a
Target: hat
RNN with Attention prediction: hat
tensor([[-2.4892, -2.6596, -0.3764, -2.4425, -2.5558,  0.5581, -2.1837, 10.5335,
         -1.7379, -2.1922, -2.4021, -0.6459]], grad_fn=<AddmmBackward0>)
[6, 1, 8, 1]

Input: a cat eats cat
Target: food
RNN with Attention prediction: food
tensor([[-1.7456, -2.2132, -0.0474, -1.8779, -2.0949, -4.2556, -2.5420, -0.3577,
         -1.9348, -2.2835, -1.8511, 10.6666]], grad_fn=<AddmmBackward0>)
[3, 1, 9, 6]

Input: the cat paints a
Target: picture
RNN with Attention prediction: picture
tensor([[-3.2474e+00, -3.5137e+00,  1.0952e+01, -2.5388e+00, -2.9565e+00,
         -1.7396e-01, -3.3336e+00, -2.9747e-01, -2.4403e+00, -3.9164e+00,
         -2.5756e+00, -1.0504e-02]], 