In [1]:
import torch
import pickle
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm


%run 'Dataset_and_preprocessing.ipynb'

%run 'model.ipynb'

print("Source Vocabulary Size:", len(src_vocab))
print("Target Vocabulary Size:", len(tgt_vocab))


# Set random seed for PyTorch CPU operations
encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE).to(device)

# Initialize the attention layer
attention_layer = BahdanauAttention(units)

# Initialize the decoder

decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE).to(device)

# Optimizer
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)

# Custom Loss Function
def loss_function(real, pred):
    # Mask for non-zero tokens in the target
    mask = real.ne(0)
    loss = F.cross_entropy(pred, real, reduction='none')
    loss = loss * mask
    return loss.mean()

# Training Step Function
def train_step(inp, targ, enc_hidden):
    # Move data to the device
    inp, targ, enc_hidden = inp.to(device), targ.to(device), enc_hidden.to(device)

    loss = 0
    optimizer.zero_grad()

    current_batch_size = inp.size(0)
    enc_hidden = enc_hidden[:, :current_batch_size, :]

    enc_output, enc_hidden = encoder(inp, enc_hidden)
    dec_hidden = enc_hidden

    sos_token_index = tgt_vocab['<sos>']
    dec_input = torch.full((current_batch_size, 1), sos_token_index, dtype=torch.long, device=inp.device)

    for t in range(1, targ.size(1)):
        predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
        loss += loss_function(targ[:, t], predictions.squeeze(1))
        dec_input = targ[:, t].unsqueeze(1)

    batch_loss = loss / int(targ.size(1))
    batch_loss.backward()
    optimizer.step()

    return batch_loss.item()

# Training Loop
EPOCHS = 40
for epoch in tqdm(range(EPOCHS)):
    total_loss = 0

    for inp, targ in train_loader:
        current_batch_size = inp.size(0)
        # Initialize hidden state with the correct current batch size
        enc_hidden = encoder.initialize_hidden_state(current_batch_size).to(device)

        
        batch_loss = train_step(inp, targ, enc_hidden)
        total_loss += batch_loss

    if epoch % 4 == 0:
        print(f'Epoch {epoch + 1}/{EPOCHS}, Loss: {total_loss / len(train_loader)}')

        
        
#save the model 
torch.save(encoder.state_dict(), 'encoder.pth')
torch.save(decoder.state_dict(), 'decoder.pth')

Batch shapes: torch.Size([64, 20]) torch.Size([64, 15])
2107
2131
Device: cpu
2107
2131
Source Vocabulary Size: 2107
Target Vocabulary Size: 2131


  2%|█                                           | 1/40 [01:17<50:41, 78.00s/it]

Epoch 1/40, Loss: 2.2938504954601857


 12%|█████▌                                      | 5/40 [05:51<40:25, 69.31s/it]

Epoch 5/40, Loss: 1.1056782730082249


 22%|█████████▉                                  | 9/40 [10:34<36:39, 70.96s/it]

Epoch 9/40, Loss: 0.16371341208194165


 32%|█████████████▉                             | 13/40 [15:31<32:49, 72.94s/it]

Epoch 13/40, Loss: 0.02625366375642888


 42%|██████████████████▎                        | 17/40 [20:23<28:02, 73.13s/it]

Epoch 17/40, Loss: 0.01864302252478739


 52%|██████████████████████▌                    | 21/40 [25:56<25:36, 80.89s/it]

Epoch 21/40, Loss: 0.015422908995459055


 62%|██████████████████████████▉                | 25/40 [31:37<20:54, 83.63s/it]

Epoch 25/40, Loss: 0.01460023790082716


 72%|███████████████████████████████▏           | 29/40 [37:21<15:57, 87.08s/it]

Epoch 29/40, Loss: 0.020350097658786367


 82%|███████████████████████████████████▍       | 33/40 [43:17<10:09, 87.03s/it]

Epoch 33/40, Loss: 0.01705871930623308


 92%|███████████████████████████████████████▊   | 37/40 [48:42<04:06, 82.17s/it]

Epoch 37/40, Loss: 0.016087393593439397


100%|███████████████████████████████████████████| 40/40 [52:44<00:00, 79.11s/it]


In [None]:
max_length_targ = max(len(t.split()) for t in train_data['answer'])

def evaluate(sentence):
    sentence = clean_text(sentence)

    inputs = [src_vocab[token] for token in sentence.split(' ')]
    inputs = torch.tensor([inputs]).to(device)

    result = ''

    # Initialize the hidden state with zeros
    hidden = torch.zeros((1, 1, units)).to(device)  # Modify the shape according to your GRU layer
    enc_out, enc_hidden = encoder(inputs, hidden)

    dec_hidden = enc_hidden
    dec_input = torch.tensor([[tgt_vocab['<sos>']]], dtype=torch.long).to(device)

    for t in range(max_length_targ):
        predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_out)

        predicted_id = torch.argmax(predictions[0]).item()

        # Reverse lookup function
        def index_to_word(vocab, index):
            return vocab.get_itos()[index]

        if index_to_word(tgt_vocab, predicted_id) == '<eos>':
            break

        result += index_to_word(tgt_vocab, predicted_id) + ' '

        # The predicted ID is fed back into the model
        dec_input = torch.tensor([[predicted_id]], dtype=torch.long).to(device)

    return result, sentence


# Example usage
def ask(sentence):
    result, sentence = evaluate(sentence)

    print('Question: %s' % (sentence))
    print('Predicted answer: {}'.format(result))

# Load questions and answers from a file
questions = []
answers = []
with open("./dialogs.txt", 'r') as f:
    for line in f:
        line = line.split('\t')
        questions.append(line[0])
        answers.append(line[1])

print(len(questions) == len(answers))



In [None]:
# Example usage with a specific question
print(ask(questions[15]))
print(answers[15])

In [None]:
# Function to interactively ask questions and get answers
def interact_with_model():
    while True:
        # Get user input
        user_input = input("Type your question (or 'exit' to quit): ")

        # Check if the user wants to exit
        if user_input.lower() == 'exit':
            break

        # Get the model's answer
        answer = evaluate(user_input)

        # Display the model's answer
        print("Model's answer:", answer)
        print("\n")

# Start the interactive loop
interact_with_model()
