Bring in the RLBench BC data from the saved hdf5 data files for each task. The task files were created with the instructRL/data/collect_data.py script. The task files include an array of data samples.  Each sample is a dict that includes the following keys:

* image - data for 4 camera positions: front_rgb, left_shoulder_rgb, right_shoulder_rgb, wrist_rgb
* instruct - the text instruction
* action - the action vector: [p;q;g] where for RLBench the gripper state is a single scalar open or closed [1,0]

In [None]:
from getdata import RLBenchDataset
import torch


# Read in the RLBench training data.  The RLBenchDataset class is a subclass of the PyTorch Dataset class.
batch_size = 2

# Load the training dataset and create a PyTorch DataLoader object.
train_dataset = RLBenchDataset(
    update=None,
    dataset_name="reach_target",
    start_offset_ratio=None,
    split="train",
)

train_loader = torch.utils.data.DataLoader(    
    dataset=train_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=0)

# Load the validation dataset and create a PyTorch DataLoader object.
val_dataset = RLBenchDataset(
    update=None,
    dataset_name="reach_target",
    start_offset_ratio=None,
    split="val",
)

val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=0)

In [None]:
print(f"Training dataset batch info:")
for i, (data, target) in enumerate(train_loader):
    print("data:", data['action'].shape)
    print("target:", target.shape)
    break
print("num batches: " + str(len(train_loader)))

print(f"Val dataset batch info:")
for i, (data, target) in enumerate(val_loader):
    print("data:", data['action'].shape)
    print("target:", target.shape)
    break
print("num batches: " + str(len(val_loader)))


In [None]:
from torch import nn, Tensor
import math

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [None]:

from torch.nn import TransformerDecoder, TransformerDecoderLayer

class ActionDecoderModel(nn.Module):

    def __init__(self, action_dim: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5, action_seq_len: int = 4,
                 mem_seq_len: int = 10):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=action_seq_len)
        decoder_layer = TransformerDecoderLayer(d_model, nhead, d_hid, dropout, batch_first=True)
        self.transformer_decoder = TransformerDecoder(decoder_layer, nlayers)
        self.d_model = d_model
        self.linear_action_in = nn.Linear(action_dim, d_model)
        self.linear_action_out = nn.Linear(action_seq_len*d_model, action_dim)
        self.tgt_mask = nn.Transformer.generate_square_subsequent_mask(action_seq_len)
        self.mem_mask = nn.Transformer.generate_square_subsequent_mask(mem_seq_len)
   

    def forward(self, actions: Tensor, memory: Tensor) -> Tensor:
        """
        
        """
        actions = self.linear_action_in(actions)
        actions = self.pos_encoder(actions)
        output = self.transformer_decoder(
            tgt=actions, 
            memory=memory, 
            tgt_is_causal=True, 
            memory_is_causal=True,
            tgt_mask=self.tgt_mask,
            memory_mask=self.mem_mask
            )
        output = output.reshape(output.shape[0], -1)
        output = self.linear_action_out(output)
        return output

In [None]:

seq_len = 4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
action_dim = 8 # length of the action vector 
mm_dim = 768  # embedding dimension
d_hid = 768  # dimension of the feedforward network model in ``nn.TransformerDecoder``
nlayers = 2  # number of ``nn.TransformerEncoderLayer`` in ``nn.TransformerDecoder``
nhead = 2  # number of heads in ``nn.MultiheadAttention``
dropout = 0.2  # dropout probability
model = ActionDecoderModel(
    action_dim=action_dim, 
    d_model = mm_dim, 
    nhead = nhead, 
    d_hid = d_hid, 
    nlayers = nlayers, 
    dropout = dropout, 
    action_seq_len = seq_len-1, 
    mem_seq_len = seq_len
    ).cuda()

In [None]:
# print whether model is on GPU or CPU
print(f"Model is on {next(model.parameters()).device}")

# Display the model architecture and number of trainable parameters
print(model)

In [None]:
import time
import numpy as np

loss_fn = nn.MSELoss(
    reduction='mean'
)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.01,
    betas=(0.9, 0.999),
)
scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer, 
    gamma=0.9,
    verbose=True
)

# Create a 2D array of zeros to store the training loss for each epoch.
train_loss_buf = np.zeros((5, len(train_loader)))

def train(model: nn.Module) -> None:
    model.train()
    log_interval_loss = 0.
    start_time = time.time()

    for i, (batch, targets) in enumerate(train_loader):
        
        encoder_embeddings = batch['mm_embeddings'].cuda()
        action_inputs = batch['action'].cuda()
        optimizer.zero_grad()
        output = model(actions=action_inputs, memory=encoder_embeddings)
        targets = torch.squeeze(targets).cuda()
        batch_train_loss = loss_fn(output, targets)
        batch_train_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        log_interval_loss += batch_train_loss.item()
        train_loss_buf[epoch, i] = batch_train_loss.item()
        log_interval = 1
        if i % log_interval == 0 and i >= 0:
            cur_loss = log_interval_loss / log_interval
            elapsed = time.time() - start_time
            print(f'| epoch: {epoch:3d} | {i+1:5d}/{len(train_loader):5d} batches | '
                  f'lr: {scheduler.get_last_lr()[0]:02.3f} | ms/batch: {elapsed * 1000 / log_interval:5.2f} | '
                  f'log int loss: {log_interval_loss:5.2f} | ')
            log_interval_loss = 0
            start_time = time.time()
        
        

In [None]:


def evaluate(model: nn.Module, val_loader: iter) -> float:
    model.eval()  # turn on evaluation mode
    total_val_loss = 0.
    with torch.no_grad():
        for i, (batch, targets) in enumerate(train_loader):
            encoder_embeddings = batch['mm_embeddings'].cuda()
            action_inputs = batch['action'].cuda()
            targets = torch.squeeze(targets).cuda()
            output = model(actions=action_inputs, memory=encoder_embeddings)
            batch_val_loss = loss_fn(output, targets)
            total_val_loss += batch_val_loss.item()
            
    return np.mean(total_val_loss, dtype=np.float32)

In [None]:
epochs = 5 # The number of epochs

for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        train(model)
       # val_loss = evaluate(model, val_loader)
        scheduler.step()
        elapsed = time.time() - epoch_start_time
        print('-' * 89)
        print(f'| end of epoch: {epoch:3d} | epoc time: {elapsed:5.2f}s | ')
        print('-' * 89)
        

In [None]:
#    print(f'| end of epoch: {epoch:3d} | epoc time: {elapsed:5.2f}s | '
     #         f'valid loss: {val_loss:5.2f} | valid ppl: {np.exp(np.min([val_loss,10])):8.2f}')

In [None]:
from datetime import date
import matplotlib.pyplot as plt

# plot the training loss for each epoch
plt.plot(np.log(np.transpose(train_loss_buf)))
plt.xlabel('Batch')
plt.ylabel('Log Training loss')

# add a legend
plt.legend(['Epoch 1', 'Epoch 2', 'Epoch 3', 'Epoch 4', 'Epoch 5'])
plt.show()

# save the image to disk with the current date in the name of the file to the /home/levi directory
today = date.today()
fig1 = plt.gcf()
fig1.savefig(f'/home/levi/transformer_training_loss_{today}.jpg')
