In [None]:
import numpy as np
import torch
import tqdm
from torch.utils.data import Dataset
from matplotlib import pyplot as plt
from widis_lstm_tools_master.widis_lstm_tools.nn import LSTMLayer

# Prepare some random generators for later
rnd_gen = np.random.RandomState(seed=123)
_ = torch.manual_seed(123)

In [None]:
class Environment(Dataset):
    def __init__(self, n_samples: int, max_timestep: int, n_positions: int, rnd_gen: np.random.RandomState):
        """Our simple 1D environment as PyTorch Dataset"""
        super(Environment, self).__init__()
        n_actions = 2
        zero_position = int(np.ceil(n_positions / 2.))
        coin_position = zero_position + 2
        
        # Generate random action sequences
        actions = np.asarray(rnd_gen.randint(low=0, high=2, size=(n_samples, max_timestep)), dtype=np.int)
        actions_onehot = np.identity(n_actions, dtype=np.float32)[actions]
        
        # Generate observations from action sequences
        actions[:] = (actions * 2) - 1
        observations = np.full(fill_value=zero_position, shape=(n_samples, max_timestep), dtype=np.int)
        for t in range(max_timestep-1):
            action = actions[:, t]
            observations[:, t+1] = np.clip(observations[:, t] + action, 0, n_positions-1)
        observations_onehot = np.identity(n_positions, dtype=np.float32)[observations]
        
        # Calculate rewards (sum over coin position for all timesteps)
        rewards = np.zeros(shape=(n_samples, max_timestep), dtype=np.float32)
        rewards[:, -1] = observations_onehot[:, :, coin_position].sum(axis=1)
        
        self.actions = actions_onehot
        self.observations = observations_onehot
        self.rewards = rewards
        
    def __len__(self):
        return self.rewards.shape[0]
    
    def __getitem__(self, idx):
        return self.observations[idx], self.actions[idx], self.rewards[idx]


n_positions = 13
env = Environment(n_samples=1000, max_timestep=50, n_positions=13, rnd_gen=rnd_gen)
env_loader = torch.utils.data.DataLoader(env, batch_size=8, num_workers=4)

In [None]:

obs0, a0, r0 = env.__getitem__(3)
obs1, a1, r1 = env.__getitem__(25)
fig, axes = plt.subplots(3, 2, figsize=(8, 4.5), dpi=100)
axes[0, 0].plot(obs0.argmax(-1) - 6)
axes[0, 1].plot(obs1.argmax(-1) - 6)
axes[0, 0].set_ylim(-6, 6)
axes[0, 1].set_ylim(-6, 6)
axes[0, 0].axhline(2, linestyle='--', color='r')
axes[0, 1].axhline(2, linestyle='--', color='r')
axes[0, 0].xaxis.grid(True)
axes[0, 1].xaxis.grid(True)
axes[0, 0].set_title('observations (sample 1)')
axes[0, 1].set_title('observations (sample 2)')
axes[0, 0].set_xlabel('time (environment steps)')
axes[0, 1].set_xlabel('time (environment steps)')

axes[1, 0].plot(a0.argmax(-1))
axes[1, 1].plot(a1.argmax(-1))
axes[1, 0].xaxis.grid(True)
axes[1, 1].xaxis.grid(True)
axes[1, 0].set_title('actions (sample 1)')
axes[1, 1].set_title('actions (sample 2)')
axes[1, 0].set_xlabel('time (environment steps)')
axes[1, 1].set_xlabel('time (environment steps)')

axes[2, 0].plot(r0)
axes[2, 1].plot(r1)
axes[2, 0].xaxis.grid(True)
axes[2, 1].xaxis.grid(True)
axes[2, 0].set_title('original rewards (sample 1)')
axes[2, 1].set_title('original rewards (sample 2)')
axes[2, 0].set_xlabel('time (environment steps)')
axes[2, 1].set_xlabel('time (environment steps)')

fig.tight_layout()

In [None]:
class Net(torch.nn.Module):
    def __init__(self, n_positions, n_actions, n_lstm):
        super(Net, self).__init__()
        
        # This will create an LSTM layer where we will feed the concatenate
        self.lstm1 = LSTMLayer(
            in_features=n_positions+n_actions, out_features=n_lstm, inputformat='NLC',
            # cell input: initialize weights to forward inputs with xavier, disable connections to recurrent inputs
            w_ci=(torch.nn.init.xavier_normal_, False),
            # input gate: disable connections to forward inputs, initialize weights to recurrent inputs with xavier
            w_ig=(False, torch.nn.init.xavier_normal_),
            # output gate: disable all connection (=no forget gate) and disable bias
            w_og=False, b_og=False,
            # forget gate: disable all connection (=no forget gate) and disable bias
            w_fg=False, b_fg=False,
            # LSTM output activation is set to identity function
            a_out=lambda x: x
        )
        
        # After the LSTM layer, we add a fully connected output layer
        self.fc_out = torch.nn.Linear(n_lstm, 1)
    
    def forward(self, observations, actions):
        # Process input sequence by LSTM
        lstm_out, *_ = self.lstm1(torch.cat([observations, actions], dim=-1),
                                  return_all_seq_pos=True  # return predictions for all sequence positions
                                  )
        net_out = self.fc_out(lstm_out)
        return net_out


# Create Network
device = 'cpu'
net = Net(n_positions=n_positions, n_actions=2, n_lstm=16)
_ = net.to(device)

In [None]:
def lossfunction(predictions, rewards):
    returns = rewards.sum(dim=1)
    # Main task: predicting return at last timestep
    main_loss = torch.mean(predictions[:, -1] - returns) ** 2
    # Auxiliary task: predicting final return at every timestep ([..., None] is for correct broadcasting)
    aux_loss = torch.mean(predictions[:, :] - returns[..., None]) ** 2
    # Combine losses
    loss = main_loss + aux_loss * 0.5
    return loss

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5)

update = 0
n_updates = 10000
running_loss = 100.
progressbar = tqdm.tqdm(total=n_updates)
while update < n_updates:
    for data in env_loader:
        # Get samples
        observations, actions, rewards = data
        observations, actions, rewards = observations.to(device), actions.to(device), rewards.to(device)
        
        # Reset gradients
        optimizer.zero_grad()
        
        # Get outputs for network
        outputs = net(observations=observations, actions=actions)
        
        # Calculate loss, do backward pass, and update
        loss = lossfunction(outputs[..., 0], rewards)
        loss.backward()
        running_loss = running_loss*0.99 + loss*0.01
        optimizer.step()
        update += 1
        progressbar.set_description(f"Loss: {running_loss:8.4f}")
        progressbar.update(1)

progressbar.close()