In [1]:
from collections import Counter, defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence as pad
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack
import random as rd
from pathlib import Path

In [2]:
with open('train.txt') as f:
    train_sents = [sent.strip().split() for sent in f]
    train_sents = [sent for sent in train_sents if len(sent) <= 80]
with open('valid.txt') as f:
    valid_sents = [sent.strip().split() for sent in f]
    
word_count = defaultdict(int)
for sent in train_sents:
    for word in sent:
        word_count[word] += 1
word_count = [(key, value) for key, value in word_count.items()]
word_count.sort(key = lambda x: -x[1])

word_list = [word for word, freq in word_count if word != '<unk>']
word_list = ['<pad>', '<eos>', '<unk>'] + word_list

In [3]:
class Vocab:

    def __init__(self, tokens):
        self.tokens = tokens
        self.token_dict = {token: index for index, token in enumerate(tokens)}
        self.pad = self.token_dict['<pad>']
        self.eos = self.token_dict['<eos>']
        self.unk = self.token_dict['<unk>']

    def __contain__(self, x):
        return x in self.token_dict

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, x):
        return self.tokens[x]

    def __call__(self, x):
        if x in self:
            return self.token_dict[x]
        return self.unk

    
class Dataset(torch.utils.data.Dataset):

    def __init__(self, sents):
        self.sents = sents
        self.lengths = torch.tensor([len(sent) for sent in sents])

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, index):
        x = self.sents[index]
        return x
    

class Batch:

    def __init__(self, inputs, outputs = None, lengths = None):
        self.inputs = inputs
        self.outputs = outputs
        self.lengths = lengths

    def __len__(self):
        return self.inputs.shape[1]

    def get_num_tokens(self):
        return sum(self.lengths)

    def cuda(self):
        self.inputs = self.inputs.cuda()

        if self.outputs is not None:
            self.outputs = self.outputs.cuda()
            
        return self

    
class Collator:

    def __init__(self, vocab):
        self.vocab = vocab

    def make_tensors(self, batch):
        inputs  = [torch.tensor([self.vocab.eos] + sent) for sent in batch]
        outputs = [torch.tensor(sent + [self.vocab.eos]) for sent in batch]
        lengths = [len(sent) + 1 for sent in batch]

        inputs = pad(inputs, padding_value = self.vocab.pad)
        outputs = pad(outputs, padding_value = -100)
        return inputs, outputs, lengths

    def __call__(self, batch):
        inputs, outputs, lengths = self.make_tensors(batch)
        return Batch(inputs, outputs, lengths)
    
    
class Sampler(torch.utils.data.Sampler):

    def __init__(
            self,
            dataset,
            max_tokens):

        self.dataset = dataset
        self.max_tokens = max_tokens
        self.batches = None

    def generate_batches(self):
        indices = self.get_indices()
        batches = []
        batch = []
        acc = 0
        max_len = 0
        for index in indices:
            acc += 1
            this_len = self.dataset.lengths[index]
            max_len = max(max_len, this_len)
            if (acc * max_len) > self.max_tokens:
                batches.append(batch)
                batch = [index]
                acc = 1
                max_len = this_len
            else:
                batch.append(index)
        if batch:
            batches.append(batch)
        rd.shuffle(batches)
        return batches

    def init_batches(self):
        if self.batches is None:
            self.batches = self.generate_batches()

    def __len__(self):
        self.init_batches()
        return len(self.batches)

    def __iter__(self):
        self.init_batches()
        for batch in self.batches:
            yield batch
        self.terminate_batches()


class FixedSampler(Sampler):

    def get_indices(self):
        if not hasattr(self, 'indices'):
            indices = torch.arange(len(self.dataset))
            indices = indices[self.dataset.lengths[indices].argsort(descending = True)]
            self.indices = indices
        return self.indices

    def terminate_batches(self):
        pass


class RandomSampler(Sampler):

    def get_indices(self):
        indices = torch.randperm(len(self.dataset))
        indices = indices[self.dataset.lengths[indices].argsort(descending = True)]
        return indices

    def terminate_batches(self):
        self.batches = None

In [4]:
max_tokens = 4000
vocab = Vocab(word_list)
train_data = [[vocab(word) for word in sent] for sent in train_sents]
valid_data = [[vocab(word) for word in sent] for sent in valid_sents]
train_dataset = Dataset(train_data)
valid_dataset = Dataset(valid_data)
train_sampler = RandomSampler(train_dataset, max_tokens)
valid_sampler = FixedSampler(valid_dataset, max_tokens)
collator = Collator(vocab)
train_loader = DataLoader(
        train_dataset,
        batch_sampler = train_sampler,
        collate_fn = collator)
valid_loader = DataLoader(
        valid_dataset,
        batch_sampler = valid_sampler,
        collate_fn = collator)

In [5]:
class RNNLM(nn.Module):
    
    def __init__(self, vocab_size, hidden_size, dropout=0.2):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size)
        self.dropout = nn.Dropout(p=dropout)
        self.proj = nn.Linear(hidden_size, vocab_size)

    def __call__(self, batch):
        x = self.embeddings(batch.inputs)
        packed = pack(x, batch.lengths, enforce_sorted=False)
        output, _ = self.rnn(packed)
        x, _ = unpack(output)
        x = self.dropout(x)
        x = self.proj(x)
        return x

In [6]:
class Accumulator:

    def __init__(self, name):
        self.name = name

        self.clear_epoch()
        self.clear_tmp()

    def update(self, batch, loss, lr):
        self.tmp_loss_list.append(loss)
        self.tmp_wpb_list.append(batch.get_num_tokens())
        self.tmp_spb_list.append(len(batch))
        self.tmp_lr_list.append(lr)

    def step_tmp(self):
        loss = sum(self.tmp_loss_list) / len(self.tmp_loss_list)
        wpb = sum(self.tmp_wpb_list)
        spb = sum(self.tmp_spb_list)
        lr = sum(self.tmp_lr_list) / len(self.tmp_lr_list)

        self.loss_list.append(loss)
        self.wpb_list.append(wpb)
        self.spb_list.append(spb)
        self.lr_list.append(lr)
        return loss, wpb, spb, lr

    def clear_tmp(self):
        self.tmp_loss_list = []
        self.tmp_wpb_list = []
        self.tmp_spb_list = []
        self.tmp_lr_list = []

    def clear_epoch(self):
        self.loss_list = []
        self.wpb_list = []
        self.spb_list = []
        self.lr_list = []

    def step_log(self, epoch, num_steps, grad = None):
        loss, wpb, spb, lr = self.step_tmp()
        self.clear_tmp()

        line = '| {}-inner'.format(self.name)
        line += ' | epoch {}, {}/{}'.format(
                epoch,
                len(self.spb_list),
                num_steps)
        line += ' | loss {:.4f}'.format(loss)
        line += ' | lr {:.8f}'.format(lr)
        if grad:
            line += ' | grad {:.4f}'.format(grad)
        line += ' | w/b {}'.format(wpb)
        line += ' | s/b {}'.format(spb)

        return line

    def avg(self, lst):
        num_examples = sum(self.spb_list)
        return sum([n * x for n, x in zip(self.spb_list, lst)]) / num_examples

    def epoch_log(self, epoch, num_steps = None):
        line = '| {}'.format(self.name)
        line += ' | epoch {}'.format(epoch)
        line += ' | loss {:.4f}'.format(self.avg(self.loss_list))
        line += ' | lr {:.8f}'.format(self.avg(self.lr_list))
        line += ' | w/b {:.1f}'.format(self.avg(self.wpb_list))
        line += ' | s/b {:.1f}'.format(self.avg(self.spb_list))
        if num_steps is not None:
            line += ' | steps {}'.format(num_steps)
        self.clear_epoch()
        return line

In [7]:
class LinExpScheduler(optim.lr_scheduler.LambdaLR):

    def __init__(
            self,
            optimizer,
            warmup_steps,
            last_epoch=-1):

        def lr_lambda(step):
            r = max(1e-8, step / warmup_steps)
            return min(r, r ** -0.5)

        super().__init__(
                optimizer,
                lr_lambda,
                last_epoch = last_epoch)


class Opter:

    def __init__(
            self,
            model,
            lr,
            max_grad_norm,
            scheduler = 'constant',
            warmup_steps = 0,
            start_factor = 1.0/3,
            weight_decay = 0.01):

        self.model = model
        self.scaler = torch.cuda.amp.GradScaler()
        self.max_grad_norm = max_grad_norm
        self.total_grad_norm = None

        self.optimizer = optim.AdamW(
                model.parameters(),
                lr = lr,
                weight_decay = weight_decay)

        if scheduler == 'constant':
            self.scheduler = optim.lr_scheduler.ConstantLR(
                    self.optimizer,
                    total_iters = warmup_steps,
                    factor = start_factor)
        elif scheduler == 'linear':
            self.scheduler = optim.lr_scheduler.LinearLR(
                    self.optimizer,
                    total_iters = warmup_steps,
                    start_factor = start_factor)
        elif scheduler == 'linexp':
            self.scheduler = LinExpScheduler(
                    self.optimizer,
                    warmup_steps)
        else:
            assert False

    def get_lr(self):
        return self.scheduler.get_last_lr()[0]

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        self.scaler.unscale_(self.optimizer)
        self.total_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.max_grad_norm)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.scheduler.step()

In [8]:
class LossCalc:

    def __init__(self, label_smoothing = 0.0):
        self.train_criterion = nn.CrossEntropyLoss(
            ignore_index = -100,
            label_smoothing = label_smoothing)
        self.valid_criterion = nn.CrossEntropyLoss(
            ignore_index = -100)
    
    def set_trainer(self, trainer):
        self.trainer = trainer

    def get_pred_and_target(self, batch):
        batch.cuda()
        pred = self.trainer.model(batch)
        pred = pred.view(-1, pred.size(-1))
        target = batch.outputs.view(-1)
        return pred, target

    def for_train(self, batch):
        pred, target = self.get_pred_and_target(batch)
        loss = self.train_criterion(pred, target)
        return loss

    def for_valid(self, batch):
        pred, target = self.get_pred_and_target(batch)
        loss = self.valid_criterion(pred, target)
        return loss

In [9]:
class Trainer:

    def __init__(
            self,
            train_loader,
            valid_loader,
            model,
            opter,
            losscalc,
            max_epochs,
            step_interval,
            save_interval):

        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.model = model
        self.opter = opter
        self.losscalc = losscalc
        self.losscalc.set_trainer(self)

        self.max_epochs = max_epochs
        self.epoch = 0
        self.step = 0
        self.step_interval = step_interval
        self.num_accum = 0
        self.grad_to_init = True
        self.save_interval = save_interval
        self.train_accum = Accumulator('train')

    def train(self):
        self.model.train()
        num_steps = (self.num_accum + len(self.train_loader)) // self.step_interval
        this_step = 0
        for index, batch in enumerate(self.train_loader):

            if self.grad_to_init:
                self.opter.zero_grad()
                self.grad_to_init = False

            with torch.cuda.amp.autocast():
                loss = self.losscalc.for_train(batch)
                loss_val = loss.item()
                loss = loss / self.step_interval
            self.opter.scaler.scale(loss).backward()
            self.num_accum += 1
            self.train_accum.update(batch, loss_val, self.opter.get_lr())

            if self.num_accum == self.step_interval:
                this_step += 1
                self.step += 1
                self.num_accum = 0
                self.opter.step()
                self.grad_to_init = True
                self.train_accum.step_log(
                    self.epoch,
                    num_steps,
                    grad = self.opter.total_grad_norm)

        print(self.train_accum.epoch_log(self.epoch, num_steps))

    def valid(self):
        self.model.eval()

        accum = Accumulator('valid')
        for index, batch in enumerate(self.valid_loader):
            with torch.no_grad():
                loss = self.losscalc.for_valid(batch)
            accum.update(batch, loss.item(), self.opter.get_lr())
        accum.step_log(self.epoch, len(self.valid_loader))
        print(accum.epoch_log(self.epoch))

    def save_checkpoint(self):
        Path('checkpoints').mkdir(parents = True, exist_ok = True)
        path = 'checkpoints/{}.pt'.format(self.epoch)
        torch.save(self.model.state_dict(), path)
        logger.info('| checkpoint | saved to {}'.format(path))

    def run(self):
        for _ in range(self.max_epochs):
            self.epoch += 1
            self.train()
            self.valid()
            if self.epoch % self.save_interval == 0:
                self.save_checkpoint()

In [10]:
lr = 0.001
max_grad_norm = 1.0
scheduler = 'constant'
warmup_steps = 0
start_factor = 1.0
weight_decay = 0.01
epochs = 50
step_interval = 1
save_interval = 100

model = RNNLM(len(vocab), 256, dropout = 0.2)
print(sum(p.numel() for p in model.parameters()))
model.cuda()
opter = Opter(model, lr, max_grad_norm, scheduler, warmup_steps, start_factor, weight_decay)
losscalc = LossCalc()
trainer = Trainer(train_loader, valid_loader, model, opter, losscalc, epochs, step_interval, save_interval)
trainer.run()

211099
| train | epoch 1 | loss 2.8736 | lr 0.00100000 | w/b 4377.4 | s/b 420.5 | steps 146
| valid | epoch 1 | loss 2.6767 | lr 0.00100000 | w/b 3490.0 | s/b 400.0
| train | epoch 2 | loss 2.5067 | lr 0.00100000 | w/b 4377.4 | s/b 420.5 | steps 146
| valid | epoch 2 | loss 2.6614 | lr 0.00100000 | w/b 3490.0 | s/b 400.0
| train | epoch 3 | loss 2.4416 | lr 0.00100000 | w/b 4377.4 | s/b 420.5 | steps 146
| valid | epoch 3 | loss 2.6031 | lr 0.00100000 | w/b 3490.0 | s/b 400.0
| train | epoch 4 | loss 2.3937 | lr 0.00100000 | w/b 4377.4 | s/b 420.5 | steps 146
| valid | epoch 4 | loss 2.5777 | lr 0.00100000 | w/b 3490.0 | s/b 400.0
| train | epoch 5 | loss 2.3680 | lr 0.00100000 | w/b 4377.4 | s/b 420.5 | steps 146
| valid | epoch 5 | loss 2.5344 | lr 0.00100000 | w/b 3490.0 | s/b 400.0
| train | epoch 6 | loss 2.3470 | lr 0.00100000 | w/b 4377.4 | s/b 420.5 | steps 146
| valid | epoch 6 | loss 2.5509 | lr 0.00100000 | w/b 3490.0 | s/b 400.0
| train | epoch 7 | loss 2.3296 | lr 0.001000

In [11]:
lst = []

for sent in valid_data:
    sent = [vocab.eos] + sent
    inputs = pad([torch.tensor(sent)])
    batch = Batch(inputs, lengths = [len(sent)])
    batch = batch.cuda()
    with torch.no_grad():
        outputs = model(batch)
    outputs = outputs.cpu()
    probs = torch.softmax(outputs, dim = -1)
    values, indices = torch.max(probs, dim = -1)
    for prob in values:
        lst.append(prob)

print(2 ** (-torch.log2(torch.tensor(lst))).mean())

tensor(2.8266)
