In [None]:
import numpy as np

import torch
from torch.autograd import Variable

"""
This implementation defines the model as a custom Module subclass. Whenever you
want a model more complex than a simple sequence of existing Modules you will
need to define your model this way.
""";

In [None]:
# d is size of input
batch_size, d = 256, 5

x_np = np.random.randint( d, size=(batch_size,1) )

t_long = torch.LongTensor
t_float = torch.FloatTensor
if torch.cuda.is_available():
    t_long = torch.cuda.LongTensor
    t_float = torch.cuda.FloatTensor

# Create random Tensors to hold inputs and outputs, and wrap them in Variables
#x = Variable(torch.randn(batch_size, D_in))
#y = Variable(torch.randn(batch_size, D_out), requires_grad=False)

x = torch.from_numpy(x_np).type( t_long )
y = torch.from_numpy(x_np).type( t_long )

In [None]:
class SelectorNet(torch.nn.Module):
  def __init__(self, D):
    """
    In the constructor we instantiate nn modules and assign them as member variables.
    """
    super(SelectorNet, self).__init__()
    
    self.input_space = torch.FloatTensor(batch_size, D).type(t_float)
    self.linear = torch.nn.Linear(D, D)  # This should become the identity

  def forward(self, i):
    """
    In the forward function we accept a Variable of input data and we must return
    a Variable of output data. We can use Modules defined in the constructor as
    well as arbitrary operators on Variables.
    """
    
    # Convert the input 'i' into a one-hot vector
    self.input_space.zero_()
    self.input_space.scatter_(1, i, 1)
    
    x = self.input_space
    logits = self.linear(x)
    
    #.clamp(min=0)
    #y_pred = self.linear2(h_relu)
    
    action1 = torch.nn.SoftMax(logits)
    
    y_probs, y_idx = torch.max(action1)
    return y_idx


In [None]:
# Construct our model by instantiating the class defined above
model = SelectorNet(d)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
  # Forward pass: Compute predicted y by passing x to the model
  y_pred = model(x)

  # Compute and print loss
  loss = criterion(y_pred, y)
  print(t, loss.data[0])

  # Zero gradients, perform a backward pass, and update the weights.
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
