In [1]:
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F

from encoder import *
from AttnDecoder import * 
from seq2seq import *

from build_dataset import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load data 
train_file_path = {
    'source': f"data/processed/src-train.txt",
    'target': f"data/processed/tgt-train.txt"
}

test_file_path = {
    'source': f"data/processed/src-test.txt",
    'target': f"data/processed/tgt-test.txt"
}

train_dataset = QAPair(train_file_path)
test_dataset = QAPair(test_file_path)

train_dataloader = DataLoader(
            train_dataset, 
            batch_size=1024, 
            shuffle=True, 
            collate_fn=partial(pad_collate_fn, pad_token=train_dataset.pad_idx)
        )

test_dataloader = DataLoader(
            test_dataset, 
            batch_size=1024, 
            shuffle=True, 
            collate_fn=partial(pad_collate_fn, pad_token=test_dataset.pad_idx)
        )

In [3]:
pretrained_vectors = {}
pretrained_vectors['enc'] = torch.load('embeddings/encoder_emb.pt')
pretrained_vectors['dec'] = torch.load('embeddings/decoder_emb.pt')

input_size = len(train_dataset.answer_vocab)
output_size = len(train_dataset.question_vocab)

seq2seq = Seq2Seq(pretrained_vectors, output_size, input_size)

In [None]:
def train_step(batch, model, optimizer, criterion, device):
    input, xs_len, target = batch.input_vecs.to(device), batch.input_lens.to(device), batch.target_vecs.to(device)

    model.enc.train()
    model.dec.train()

    s2s_output, s2s_hidden = model(input, xs_len)

    scores = s2s_output.view(-1, s2s_output.size(-1))
    
    optimizer.zero_grad()
    loss = criterion(scores, target.view(-1))
    loss.backward()
    optimizer.step()

    return loss.item()
    