In [1]:
from utils import get_data, Dataset, index2Sent, collate_fn_train, collate_fn_test

## Load Data

In [2]:
DATA_PATH = "data/news_summary.csv"

In [3]:
NUM_WORDS = 10000
MAX_TEXT_LEN = 500
MAX_SUM_LEN = 100

In [4]:
data, w2i, i2w = get_data(DATA_PATH, NUM_WORDS)

Length of the data: 4514
Length of the data after dropping nan: 4396


In [5]:
dataset = Dataset(data, w2i, MAX_TEXT_LEN, MAX_SUM_LEN, isTrain=True)

# Train

In [6]:
from Seq2Seq import Seq2Seq
import torch
import torch.nn as nn

In [7]:
VOCAB_SIZE = NUM_WORDS + 4
EMBEDDING_DIM = 50
HIDDEN_DIM = 128
BATCH_SIZE = 12
DEVICE = 'cuda'

In [8]:
dataloader = torch.utils.data.DataLoader(dataset, BATCH_SIZE, num_workers=8, shuffle=True, collate_fn=collate_fn_train)

In [9]:
model = Seq2Seq(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM).to(DEVICE)
criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [10]:
for i in range(10):
    total_loss = 0.
    for batch_num, ((x, xlens), (y, ylens)) in enumerate(dataloader):

        # setup tensors
        x = x.long().to(DEVICE)
        y = y.long().to(DEVICE)
        
#         print(x.size(), y.size())
        
        # clear previous gradients
        optimizer.zero_grad()

        # generate predictions
        # output: (BATCH_SIZE, time_steps, NUM_WORDS)
        output = model(x, xlens, y)

        ### Calculate Loss
        # 1. y must be shifted by 1 for loss calc. since outputs should not not contain <sos>
        y_true = torch.cat([y[:, 1:], torch.ones((y.size(0), 1)).long().to(DEVICE) * w2i["<pad>"]], dim=-1)

        # 2. Ouput shape for loss calculation must be of the form (BATCH_SIZE, NUM_WORDS, *)
        # Refer pytorch docs for more details
        loss = criterion(output.permute(0, 2, 1), y_true)

        # 3. Mask the loss. Needed since we have padding which is not needed
        # Can avoid if using pack_padded sequence?
        num_tokens = 0
        for i, yl in enumerate(ylens):
            loss[i, yl-1:] *= 0 # yl-1 to remove <sos>
            num_tokens += yl - 1

        # 4. SUM the losses then divide by number of tokens and finally call backward
        loss = loss.sum() / num_tokens
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        # Adjust parameters
        optimizer.step()
        if (batch_num + 1) % 10 == 0:
             print("Step: {} Loss: {}".format(batch_num + 1, loss.item()))
        total_loss += loss.item()
    print("EP: {} Loss: {}".format(i + 1, total_loss / len(dataloader)))

Step: 10 Loss: 9.210471153259277
Step: 20 Loss: 9.089808464050293
Step: 30 Loss: 9.113744735717773
Step: 40 Loss: 9.088224411010742
Step: 50 Loss: 9.086456298828125
Step: 60 Loss: 9.070070266723633
Step: 70 Loss: 9.06949520111084
Step: 80 Loss: 9.092009544372559
Step: 90 Loss: 9.075565338134766
Step: 100 Loss: 9.085577011108398
Step: 110 Loss: 9.0734224319458
Step: 120 Loss: 9.111626625061035
Step: 130 Loss: 9.076574325561523
Step: 140 Loss: 9.066244125366211
Step: 150 Loss: 9.04426383972168
Step: 160 Loss: 9.093443870544434
Step: 170 Loss: 9.059752464294434
Step: 180 Loss: 9.063953399658203
Step: 190 Loss: 9.077128410339355
Step: 200 Loss: 9.085248947143555
Step: 210 Loss: 9.087504386901855
Step: 220 Loss: 9.125009536743164
Step: 230 Loss: 9.101083755493164
Step: 240 Loss: 9.06827163696289
Step: 250 Loss: 9.071728706359863
Step: 260 Loss: 9.079447746276855
Step: 270 Loss: 9.087320327758789
Step: 280 Loss: 9.104023933410645


KeyboardInterrupt: 

In [None]:
xlens[0]

In [None]:
i = 0

In [None]:
print(x[i][:xlens[i]])

In [None]:
print(y[i][:ylens[i]])

In [None]:
index2Sent(x[i, :xlens[i]].cpu().numpy(), i2w)

In [None]:
index2Sent(y[i, :ylens[i]].cpu().numpy(), i2w)

In [None]:
index2Sent(output.argmax(dim=-1)[i].cpu().numpy(), i2w)

In [None]:
output.argmax(dim=-1)[-1]

In [None]:
i2w[31]