In [2]:
import time, sys
sys.path.append('/media/carson/New Volume/Chloe/chloebot/env36/lib/python3.6/site-packages')
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

import import_ipynb
from MoveData import Options, json2datatools, num_batches, nopeak_mask, create_masks
from EncoderDecoder import Encoder, Decoder

In [3]:
class CosineWithRestarts(torch.optim.lr_scheduler._LRScheduler):
    """
    Cosine annealing with restarts.

    Parameters
    ----------
    optimizer : torch.optim.Optimizer

    T_max : int
        The maximum number of iterations within the first cycle.

    eta_min : float, optional (default: 0)
        The minimum learning rate.

    last_epoch : int, optional (default: -1)
        The index of the last epoch.

    """

    def __init__(self,
                 optimizer: torch.optim.Optimizer,
                 T_max: int,
                 eta_min: float = 0.,
                 last_epoch: int = -1,
                 factor: float = 1.) -> None:
        # pylint: disable=invalid-name
        self.T_max = T_max
        self.eta_min = eta_min
        self.factor = factor
        self._last_restart: int = 0
        self._cycle_counter: int = 0
        self._cycle_factor: float = 1.
        self._updated_cycle_len: int = T_max
        self._initialized: bool = False
        super(CosineWithRestarts, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        """Get updated learning rate."""
        # HACK: We need to check if this is the first time get_lr() was called, since
        # we want to start with step = 0, but _LRScheduler calls get_lr with
        # last_epoch + 1 when initialized.
        if not self._initialized:
            self._initialized = True
            return self.base_lrs

        step = self.last_epoch + 1
        self._cycle_counter = step - self._last_restart

        lrs = [
            (
                self.eta_min + ((lr - self.eta_min) / 2) *
                (
                    np.cos(
                        np.pi *
                        ((self._cycle_counter) % self._updated_cycle_len) /
                        self._updated_cycle_len
                    ) + 1
                )
            ) for lr in self.base_lrs
        ]

        if self._cycle_counter % self._updated_cycle_len == 0:
            # Adjust the cycle length.
            self._cycle_factor *= self.factor
            self._cycle_counter = 0
            self._updated_cycle_len = int(self._cycle_factor * self.T_max)
            self._last_restart = step

        return lrs

In [None]:
def talk_to_chloe(input_str, chloe, opt, infield, outfield):
    input_sequence = string2tensor(input_str, infield)
    input_mask = (input_sequence != infield.vocab.stoi['<pad>']).unsqueeze(-2)
    chloe.eval()
    encoding = chloe.encoder(input_sequence, input_mask)
    init_tok = outfield.vocab.stoi['<sos>'] 
    decoder_input = torch.LongTensor([[init_tok]])
    logprobs = torch.Tensor([[]])
    for pos in range(opt.max_len):
        decoder_input_mask = nopeak_mask(size=pos+1, opt=opt)
        out = chloe.out(chloe.decoder(decoder_input, encoding, input_mask, decoder_input_mask))
        softout = F.softmax(out, dim=-1)
        distr = Categorical(probs=softout)
        action = distr.sample()[:,-1].unsqueeze(0)
        logprob = -distr.log_prob(action)[:,-1].unsqueeze(0)
        decoder_input = torch.cat((decoder_input, action), dim=1)
        logprobs = torch.cat((logprobs, logprob), dim=1)
        if outfield.vocab.itos[action] == '<eos>':
            de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0]])
            return decoder_input, de_str, logprobs
        
    de_str = ' '.join([outfield.vocab.itos[tok] for tok in decoder_input[0]])
    return decoder_input, de_str, logprobs