<a href="https://colab.research.google.com/github/nick-baliesnyi/self-attention-rl/blob/main/Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### This notebook implements and trains the Self-Attention model from https://arxiv.org/abs/1907.08027

## Self-attention model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib
import math
from tqdm.notebook import tqdm

In [None]:
class PositionalEncoding(nn.Module):
    '''From: https://github.com/pytorch/examples/blob/master/word_language_model/model.py'''

    def __init__(self, d_model, max_len=250):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class SelfAttentionForRL(nn.Module):
  def __init__(
      self,
      observation_size,
      action_size,
      device,
      embedding_size=127, # chosen odd so that embedding_size+action_size is even
      dim_feedforward=128,
      pad_val=10.,
      max_len=250,
      verbose=False,
      ):
    super(SelfAttentionForRL, self).__init__()

    self.verbose = verbose
    self.device = device
    self.pad_val = pad_val
    self.observation_embedding = nn.Sequential(
        nn.Linear(observation_size, embedding_size),
        nn.ReLU()
    )
    self.pos_encoder = PositionalEncoding(embedding_size + action_size, max_len)
    self.reward_embedding = nn.Embedding(11, embedding_size + action_size) # 3 reward values, but padding is encoded as "10", so need dictionary size >= 10+1
    self.self_attention = nn.TransformerDecoderLayer(embedding_size + action_size, nhead=1, dim_feedforward=dim_feedforward, dropout=0.2)
    self.fc_out = nn.Linear(embedding_size + action_size, 3) # 3 reward classes

  def forward(self, observations, actions, rewards):
    seq_length, N, observation_size = observations.shape
    seq_length, N, action_size = actions.shape
    seq_length, N = rewards.shape

    x = self.observation_embedding(observations) # dropout?
    x = torch.cat((x, actions), dim=2) # concatenate action vector to observation vector
    x = self.pos_encoder(x) # dropout?

    y = self.reward_embedding(rewards)
    y = y.reshape(y.shape[0], y.shape[1], -1)

    padding_mask = self.make_padding_mask(rewards) # allows to skip computation for padded positions in x and y
    y_mask = self.generate_square_subsequent_mask(sz=seq_length).to(self.device)

    if self.verbose:
      print('seq_length', seq_length)
      print('N', N)
      print('observation_size', observation_size)
      print('action_size', action_size)
      print('key_padding_mask', padding_mask.shape)
      print('y_mask', y_mask.shape)
      print('x', x.shape)
      print('y', y.shape)

    out = self.self_attention(
        y, 
        x, 
        tgt_mask=y_mask, 
        memory_key_padding_mask=padding_mask,
        tgt_key_padding_mask=padding_mask
    )
    out = self.fc_out(out)
    return out

  def make_padding_mask(self, batch_of_sequences):
    # batch_of_sequences is of (seq_length, N)
    # key_padding_mask expects (N, seq_length)
    padding_mask = batch_of_sequences.transpose(0, 1) == self.pad_val

    if self.verbose:
      print('before:',padding_mask.shape)

    padding_mask = padding_mask[:, :]

    if self.verbose:
      print('after:',padding_mask.shape)

    return padding_mask.to(self.device)

  def generate_square_subsequent_mask(self, sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

## Model training

In [None]:
# support the format of the trajectories dataset
class TrajectoriesDataset(Dataset):
    def __init__(self, observations, actions, rewards):
      self.samples = []

      n_batches = observations.shape[0]

      for i_batch in range(n_batches):
        self.samples.append((observations[i_batch,:], actions[i_batch,:], rewards[i_batch,:]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

path_to_dataset = '/content/drive/MyDrive/Self-Attention/dataset-1608570873.pt'
dataset = torch.load(path_to_dataset)

In [None]:
validation_subset = 0.25
batch_size = 4 # TODO: change to much bigger (2048?)

n_total_samples = len(dataset)
n_train_samples = math.floor(n_total_samples * (1-validation_subset))
n_valid_samples = n_total_samples - n_train_samples

train_dataset, valid_dataset = random_split(
    dataset, 
    [n_train_samples, n_valid_samples], 
    generator=torch.Generator().manual_seed(42)
) # reproducible results

print('Train set size:', len(train_dataset), 'samples')
print('Validation set size:', len(valid_dataset), 'samples')

train_loader = DataLoader(train_dataset, batch_size=batch_size)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size)

print('Train set:', len(train_loader), 'batches')
print('Validation set:', len(valid_loader), 'batches')

Train set size: 24 samples
Validation set size: 8 samples
Train set: 6 batches
Validation set: 2 batches


In [None]:
def evaluate_validation(model, criterion, validation_set, verbose=False):
  losses = []
  accuracies = []

  for batch_idx, batch in enumerate(validation_set):
    observations, actions, rewards = batch

    observations = observations.reshape(observations.shape[1], observations.shape[0], -1)
    actions = actions.reshape(actions.shape[1], actions.shape[0], -1)
    rewards = rewards.reshape(rewards.shape[1], rewards.shape[0])

    output = model(observations, actions, rewards)

    output = output.reshape(output.shape[1], output.shape[2], -1)
    rewards = rewards.reshape(rewards.shape[1], -1)
    
    loss = criterion(output, rewards)
    preds = output.argmax(dim=1)
    accuracy = torch.sum(preds == rewards) / rewards.numel()

    losses.append(loss.item())
    accuracies.append(accuracy.item())

  mean_loss = sum(losses) / len(losses)
  mean_acc = sum(accuracies) / len(accuracies)

  if verbose:
    print('mean_loss:', mean_loss)
    print('mean_acc:', mean_acc)

  return mean_loss, mean_acc

In [None]:
# Input parameters
observation_size = 164
action_size = 27

# Training hyperparameters
num_epochs = 20
learning_rate = 3e-3
batch_size = 32
class_weights = torch.tensor([0.499, 0.02, 0.499]) # TODO: second weight might be a typo in the paper, consider 0.002
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model hyperparameters
max_len = 250
pad_val = 10 # value used to pad sequences to same size, will be ignored by the model

model = SelfAttentionForRL(observation_size, action_size, device, verbose=False).to(device)

# Tensorboard to get nice loss plot
writer = SummaryWriter("runs/")
step = 0

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=10, verbose=True
)

criterion = nn.CrossEntropyLoss(weight=class_weights, ignore_index=pad_val)

for epoch in range(num_epochs):
  print(f"> Epoch {epoch+1}/{num_epochs}", end=' ')

  losses = []

  for batch_idx, batch in tqdm(enumerate(train_loader)):
    observations, actions, rewards = batch

    observations = observations.reshape(observations.shape[1], observations.shape[0], -1)
    actions = actions.reshape(actions.shape[1], actions.shape[0], -1)
    rewards = rewards.reshape(rewards.shape[1], rewards.shape[0])

    output = model(observations, actions, rewards)

    # reshape output for K-dimensional CrossEntropy loss
    output = output.reshape(output.shape[1], output.shape[2], -1)
    rewards = rewards.reshape(rewards.shape[1], -1)
    
    optimizer.zero_grad()

    loss = criterion(output, rewards)
    losses.append(loss.item())

    loss.backward()
    # Clip to avoid exploding gradient issues
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

    optimizer.step()

    # plot to tensorboard
    model.eval()
    val_loss, val_acc = evaluate_validation(model, criterion, valid_loader)

    preds = output.argmax(dim=1)
    acc = torch.sum(preds == rewards) / rewards.numel()

    writer.add_scalar("Training loss", loss, global_step=step)
    writer.add_scalar("Training acc", acc, global_step=step)
    writer.add_scalar("Validation loss", val_loss, global_step=step)
    writer.add_scalar("Validation acc", val_acc, global_step=step)
    
    print(f'loss: {loss:0.5f}, acc: {acc:0.5f}, val_loss: {val_loss:0.5f}, val_acc: {val_acc:0.5f}', end='\r')

    model.train()

    step += 1

  mean_loss = sum(losses) / len(losses)
  scheduler.step(mean_loss)

> Epoch 1/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.26195, acc: 0.21961, val_loss: 1.13706, val_acc: 0.17472
> Epoch 2/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.17254, acc: 0.21133, val_loss: 1.17707, val_acc: 0.23204
> Epoch 3/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.10315, acc: 0.17127, val_loss: 1.09291, val_acc: 0.17058
> Epoch 4/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.05790, acc: 0.16436, val_loss: 1.12168, val_acc: 0.18025
> Epoch 5/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.11556, acc: 0.15470, val_loss: 1.11113, val_acc: 0.16367
> Epoch 6/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.08748, acc: 0.16436, val_loss: 1.09959, val_acc: 0.15746
> Epoch 7/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.08178, acc: 0.17680, val_loss: 1.09223, val_acc: 0.16022
> Epoch 8/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.12505, acc: 0.16575, val_loss: 1.12000, val_acc: 0.18163
> Epoch 9/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.05356, acc: 0.16298, val_loss: 1.10466, val_acc: 0.17127
> Epoch 10/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.10333, acc: 0.17956, val_loss: 1.09529, val_acc: 0.15539
> Epoch 11/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.06138, acc: 0.15608, val_loss: 1.11185, val_acc: 0.17196
> Epoch 12/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.07279, acc: 0.17127, val_loss: 1.09520, val_acc: 0.15124
> Epoch 13/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.04626, acc: 0.14365, val_loss: 1.14068, val_acc: 0.15470
> Epoch 14/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.07505, acc: 0.20028, val_loss: 1.09009, val_acc: 0.16022
> Epoch 15/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.06757, acc: 0.16160, val_loss: 1.10063, val_acc: 0.14019
> Epoch 16/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.10385, acc: 0.15193, val_loss: 1.12440, val_acc: 0.16367
> Epoch 17/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.03952, acc: 0.19199, val_loss: 1.10601, val_acc: 0.15884
> Epoch 18/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.06097, acc: 0.15746, val_loss: 1.12051, val_acc: 0.14710
> Epoch 19/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.11652, acc: 0.22238, val_loss: 1.11321, val_acc: 0.17472
> Epoch 20/20 

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

loss: 1.06975, acc: 0.15608, val_loss: 1.09201, val_acc: 0.16436


In [None]:
!tensorboard dev upload --logdir ./runs \
  --name "Self-attention test run #2" \
  --description "Training results from https://colab.research.google.com/drive/1yivILB8utI4a0ChtTrdFN57H1VFHpO0y?usp=sharing" \
  --one_shot

2020-12-23 16:19:42.663064: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1

New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/Ub0kXMzvRjKvYQEaclV1gQ/

[1m[2020-12-23T16:19:44][0m Started scanning logdir.
[1m[2020-12-23T16:19:45][0m Total uploaded: 480 scalars, 0 tensors, 0 binary objects
[1m[2020-12-23T16:19:45][0m Done scanning logdir.


Done. View your TensorBoard at https://tensorboard.dev/experiment/Ub0kXMzvRjKvYQEaclV1gQ/
