In [1]:
import torch
from torch.distributions.categorical import Categorical
from torch import einsum
from typing import Optional
from einops import  reduce
import torch.nn as nn
import torch.nn.functional as F

class CategoricalMasked(Categorical):
    def __init__(self, logits: torch.Tensor, mask: Optional[torch.Tensor] = None):
        self.mask = mask
        self.batch, self.nb_action = logits.size()
        if mask is None:
            super(CategoricalMasked, self).__init__(logits=logits)
        else:
            self.mask_value = torch.tensor(
                torch.finfo(logits.dtype).min, dtype=logits.dtype
            )
            logits = torch.where(self.mask, logits, self.mask_value)
            super(CategoricalMasked, self).__init__(logits=logits)

    def entropy(self):
        if self.mask is None:
            return super().entropy()
        # Elementwise multiplication
        p_log_p = einsum("ij,ij->ij", self.logits, self.probs)
        # Compute the entropy with possible action only
        p_log_p = torch.where(
            self.mask,
            p_log_p,
            torch.tensor(0, dtype=p_log_p.dtype, device=p_log_p.device),
        )
        return -reduce(p_log_p, "b a -> b", "sum", b=self.batch, a=self.nb_action)


  from .autonotebook import tqdm as notebook_tqdm


In [13]:
logits = torch.randn((1, 3), requires_grad=True) # batch size, nb action
print(logits) 
mask = torch.zeros((1, 3), dtype=torch.bool) # batch size, nb action

mask[0, 0] = True
mask[0, 1] = True
print(mask)
action_dist = CategoricalMasked(logits=logits)
print(action_dist.probs,action_dist.entropy())
action_dist_masked = CategoricalMasked(logits=logits, mask=mask)
print(action_dist_masked.probs,action_dist_masked.entropy())


tensor([[ 0.9309, -1.1033,  0.3168]], requires_grad=True)
tensor([[ True,  True, False]])
tensor([[0.5981, 0.0782, 0.3237]], grad_fn=<SoftmaxBackward0>) tensor([0.8719], grad_fn=<NegBackward0>)
tensor([[0.8843, 0.1157, 0.0000]], grad_fn=<SoftmaxBackward0>) tensor([0.3582], grad_fn=<NegBackward0>)


In [None]:
class Policy(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        
        return x
    
    def get_action(self, x, action, mask):
        action_dist = self.forward(x)
        action_dist_masked = CategoricalMasked(logits=action_dist, mask=mask)
        return action_dist_masked.probs, action_dist_masked.entropy()

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_size, use_tanh=False, C=10, use_cuda=USE_CUDA):
        super(Attention, self).__init__()
        
        self.use_tanh = use_tanh
        self.W_query = nn.Linear(hidden_size, hidden_size)
        self.W_ref   = nn.Conv1d(hidden_size, hidden_size, 1, 1)
        self.C = C
        
        V = torch.FloatTensor(hidden_size)
        if use_cuda:
            V = V.cuda()  
        self.V = nn.Parameter(V)
        self.V.data.uniform_(-(1. / math.sqrt(hidden_size)) , 1. / math.sqrt(hidden_size))
        
    def forward(self, query, ref):
        """
        Args: 
            query: [batch_size x hidden_size]
            ref:   ]batch_size x seq_len x hidden_size]
        """
        
        batch_size = ref.size(0)
        seq_len    = ref.size(1)

        ref = ref.permute(0, 2, 1)
        query = self.W_query(query).unsqueeze(2)  # [batch_size x hidden_size x 1]
        ref   = self.W_ref(ref)  # [batch_size x hidden_size x seq_len] 

        expanded_query = query.repeat(1, 1, seq_len) # [batch_size x hidden_size x seq_len]
        V = self.V.unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1) # [batch_size x 1 x hidden_size]

        logits = torch.bmm(V, F.tanh(expanded_query + ref)).squeeze(1)
        
        if self.use_tanh:
            logits = self.C * F.tanh(logits)
        else:
            logits = logits  
        return ref, logits

In [None]:
class PointerNet(nn.Module):
    def __init__(self, 
            embedding_size,
            hidden_size,
            seq_len,
            n_glimpses,
            tanh_exploration,
            use_tanh,
            use_cuda=USE_CUDA):
        super(PointerNet, self).__init__()
        
        self.embedding_size = embedding_size
        self.hidden_size    = hidden_size
        self.n_glimpses     = n_glimpses
        self.seq_len        = seq_len
        self.use_cuda       = use_cuda
        
        
        self.embedding = nn.Embedding(seq_len, embedding_size)
        self.encoder = nn.LSTM(embedding_size, hidden_size, batch_first=True)
        self.decoder = nn.LSTM(embedding_size, hidden_size, batch_first=True)
        self.pointer = Attention(hidden_size, use_tanh=use_tanh, C=tanh_exploration, use_cuda=use_cuda)
        self.glimpse = Attention(hidden_size, use_tanh=False, use_cuda=use_cuda)
        
        self.decoder_start_input = nn.Parameter(torch.FloatTensor(embedding_size))
        self.decoder_start_input.data.uniform_(-(1. / math.sqrt(embedding_size)), 1. / math.sqrt(embedding_size))
        
        self.criterion = nn.CrossEntropyLoss()
        
    def apply_mask_to_logits(self, logits, mask, idxs): 
        batch_size = logits.size(0)
        clone_mask = mask.clone()

        if idxs is not None:
            clone_mask[[i for i in range(batch_size)], idxs.data] = 1
            logits[clone_mask] = -np.inf
        return logits, clone_mask
            
    def forward(self, inputs, target):
        """
        Args: 
            inputs: [batch_size x sourceL]
        """
        batch_size = inputs.size(0)
        seq_len    = inputs.size(1)
        assert seq_len == self.seq_len
        
        embedded = self.embedding(inputs)
        target_embedded = self.embedding(target)
        encoder_outputs, (hidden, context) = self.encoder(embedded)
        
        mask = torch.zeros(batch_size, seq_len).byte()
        if self.use_cuda:
            mask = mask.cuda()
            
        idxs = None
       
        decoder_input = self.decoder_start_input.unsqueeze(0).repeat(batch_size, 1)
        
        loss = 0
        
        for i in range(seq_len):
            
            
            _, (hidden, context) = self.decoder(decoder_input.unsqueeze(1), (hidden, context))
            
            query = hidden.squeeze(0)
            for i in range(self.n_glimpses):
                ref, logits = self.glimpse(query, encoder_outputs)
                logits, mask = self.apply_mask_to_logits(logits, mask, idxs)
                query = torch.bmm(ref, F.softmax(logits).unsqueeze(2)).squeeze(2) 
                
                
            _, logits = self.pointer(query, encoder_outputs)
            logits, mask = self.apply_mask_to_logits(logits, mask, idxs)
            
            decoder_input = target_embedded[:,i,:]
            
            loss += self.criterion(logits, target[:,i])
            
            
        return loss / seq_len

In [None]:
pointer = PointerNet(embedding_size=32, hidden_size=32, seq_len=10, n_glimpses=1, tanh_exploration=10, use_tanh=True)
adam = optim.Adam(pointer.parameters(), lr=1e-4)

In [None]:
n_epochs = 1
train_loss = []
val_loss   = []

for epoch in range(n_epochs):
    for batch_id, sample_batch in enumerate(train_loader):

        inputs = Variable(sample_batch)
        target = Variable(torch.sort(sample_batch)[0])
        if USE_CUDA:
            inputs = inputs.cuda()
            target = target.cuda()

        loss = pointer(inputs, target)

        adam.zero_grad()
        loss.backward()
        adam.step()
        
        train_loss.append(loss.data[0])

        if batch_id % 10 == 0:

            clear_output(True)
            plt.figure(figsize=(20,5))
            plt.subplot(131)
            plt.title('train epoch %s loss %s' % (epoch, train_loss[-1] if len(train_loss) else 'collecting'))
            plt.plot(train_loss)
            plt.grid()
            plt.subplot(132)
            plt.title('val epoch %s loss %s' % (epoch, val_loss[-1] if len(val_loss) else 'collecting'))
            plt.plot(val_loss)
            plt.grid()
            plt.show()
        
        if batch_id % 100 == 0:
            pointer.eval()
            for val_batch in val_loader:
                inputs = Variable(val_batch)
                target = Variable(torch.sort(val_batch)[0])
                if USE_CUDA:
                    inputs = inputs.cuda()
                    target = target.cuda()

                loss = pointer(inputs, target)
                val_loss.append(loss.data[0])

In [None]:
import math

def attention(q, k, v, d_k, mask=None):
    scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)
    
    # if mask is not None:
    #     mask = mask.unsqueeze(1)
    #     scores = scores.masked_fill(mask == 0, -1e9)
    
    scores = F.softmax(scores, dim=-1)
    output = torch.matmul(scores, v)
    return output