# Behavorial Transfomer Implementation Walkthrough
Adapted from the original BeT code (https://github.com/notmahi/bet) and paper (https://arxiv.org/pdf/2206.11251)

**Libraries Required**

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from gpt import GPT, GPTConfig
from sklearn.cluster import KMeans

# 1. K-Means Based (De/En)coder: $ a:= A_{\lfloor a \rfloor}$ + $ \langle a \rangle$
The first component of a Behavorial Transformer (BeT) is a K-means based encoder and decoder learned prior to the task training and testing.

<img width="600px" src="images/K-Mean.png">


**Encoder:** Given an action $a$, the ecnoder component would discretize it into an action bin **$\lfloor a \rfloor$** and a residual action **$ \langle a \rangle $**, such that $ a := A_{\lfloor a \rfloor} + \langle a \rangle $.

**Decoder:** Given an action bin **$\lfloor a \rfloor$** and the residual action **$ \langle a \rangle $**, it would output **a** given by $ a := A_{\lfloor a \rfloor} + \langle a \rangle $.

In [274]:
class EncoderDecoder:

    def __init__(self, action_dim, num_bins, actions):
        self.action_dim = action_dim
        self.num_bins = num_bins
        # Note: Using sklearn KMeans iteration for simplicity
        self.kmeans = KMeans(n_clusters=self.num_bins) 
        self.kmeans.fit(actions)

    def encode(self, action):

        batch_size, sequence_length, action_dim = action.shape

        # Reshaping to exclude batch size        
        action = action.reshape((batch_size*sequence_length,action_dim))
        
        action_bin = self.kmeans.predict(action)
        action_center = self.get_center(action_bin)
        action_residual = action - action_center

        # Reshaping to include batch size
        action_residual = action_residual.reshape((batch_size, sequence_length, action_dim))
        action_bin = action_bin.reshape((batch_size, sequence_length, 1))

        # Returning the Action Bin and the Action Residual
        return torch.tensor(action_bin,dtype=torch.int64), action_residual.to(torch.float)
        
    def decode(self, action_bin, action_residual):
        action_center = self.get_center(action_bin)
        action = action_residual + action_center
        # Returning the associated action
        return action
    
    def get_center(self, action_bin):
        return torch.tensor(self.kmeans.cluster_centers_[action_bin].squeeze())


**Tesing the Encoder/Decoder Component**

In [275]:
num_bins = 2
batch_size = 2
total_action_batch_size = 5
sequence_length = 10
action_dim = 3
total_actions = 100

actions = torch.rand((total_actions,action_dim))

encoder_decoder = EncoderDecoder(action_dim,num_bins,actions)

test_input = torch.rand((batch_size,sequence_length,action_dim))
action_bin, action_residual = encoder_decoder.encode(test_input) # Expected Output is a [1.0, 2.0] residual
print("Encoded Action Bin Shape:\t", action_bin.shape)
print("Encoded Action Residual Shape:\t", action_residual.shape)
decoded = encoder_decoder.decode(action_bin=action_bin,action_residual=action_residual)
print("Original:\t", test_input.numpy().tolist())
print("Decoded Action:\t", decoded.numpy().tolist())

Encoded Action Bin Shape:	 torch.Size([2, 10, 1])
Encoded Action Residual Shape:	 torch.Size([2, 10, 3])
Original:	 [[[0.24479037523269653, 0.6477630734443665, 0.2133672833442688], [0.18716996908187866, 0.28942757844924927, 0.03588855266571045], [0.15689361095428467, 0.668989896774292, 0.7439367175102234], [0.4677286744117737, 0.9012507200241089, 0.17209213972091675], [0.7616887092590332, 0.636254608631134, 0.5112188458442688], [0.27330106496810913, 0.29068005084991455, 0.1763436198234558], [0.001971602439880371, 0.0765460729598999, 0.9579184055328369], [0.3830993175506592, 0.2187047004699707, 0.33932608366012573], [0.5947545766830444, 0.18389487266540527, 0.8536553978919983], [0.30301356315612793, 0.06573772430419922, 0.2929864525794983]], [[0.20370203256607056, 0.05925190448760986, 0.619742214679718], [0.9325045347213745, 0.657662034034729, 0.5698668956756592], [0.7619799971580505, 0.9337874054908752, 0.590020477771759], [0.9497239589691162, 0.6005738973617554, 0.8144691586494446], [

# 2. nanoGPT for initial Action Bins $\lfloor a \rfloor$
The second component of BeT is a transformer that takes the history of the last $L$ observations $[o_{i-L}, ... o_{i-1}, o_i]$ as an input, and outputs a predicted probability sequence of the fixed possible action bins of the next predicted action to get $\lfloor a_{i+1} \rfloor$.

<img width="700px" src="images/gpt.png">

In [276]:
def make_gpt(observation_dim, num_bins, sequence_length):
    gpt = GPT(GPTConfig(
        block_size=sequence_length,
        input_dim=observation_dim,
        output_dim=num_bins,
        n_layer=6,
        n_head=8,
        n_embd=256,
    ))
    return gpt

**Tesing the GPT**

In [277]:
observation_dim = 1
num_bins = 10
sequence_length = 5
gpt = make_gpt(observation_dim,num_bins,sequence_length)
batch_size = 1

observations_history = torch.rand((batch_size,sequence_length,observation_dim))
logits = gpt(observations_history)

print("Observation Shape:\t\t", observations_history.shape)
print("Action Bins Logits Shape:\t", logits.shape)


number of parameters: 4.74M
Observation Shape:		 torch.Size([1, 5, 1])
Action Bins Logits Shape:	 torch.Size([1, 5, 10])


# 3. Transformer Head for final Action Residuals $ \langle a \rangle $ and Bins $\lfloor a \rfloor$
The last component of BeT takes the action bins $\lfloor a \rfloor$ logits outputed by the transformer and embeds them into a representation that includes the final action bin $\lfloor a \rfloor$ and action residual $ \langle a \rangle $

In [278]:
class Head(nn.Module):

    def __init__(self, num_bins, action_dim, drop_out=0.2):
        super(Head, self).__init__()

        self.num_bins = num_bins
        self.action_dim = action_dim

        linear_in_dim = num_bins
        linear_out_dim = num_bins*(action_dim+1)
        
        self.layer = nn.Linear(linear_in_dim,linear_out_dim)
        self.dropout = nn.Dropout(drop_out)
       
    def forward(self, transformer_logits):
        
        # Transformer Logits: (Batch_Size , Sequence_Length , Number_Action_Bins)

        x = self.layer(transformer_logits)
        action_data = self.dropout(x)
        # Action Data: (Batch_Size , Sequence_Length , Number_Action_Bins*(Action_dim+1))

        seq_action_bins_logits, all_seq_action_residuals = torch.split(action_data, [num_bins, num_bins * action_dim], dim=-1)
        # Sequence Action Bins Logits:   (Batch_Size , Sequence_Length, Number_Action_Bins)
        # All Sequence Action Residuals: (Batch_Size , Sequence_Length, Number_Action_Bins*Action_dim)
        
        # Softmaxing the sequence action bins logits to get the probability per bin at every sequence index
        seq_action_bins_probs = torch.softmax(seq_action_bins_logits,dim=-1)
        seq_action_bins = torch.multinomial(seq_action_bins_probs.view(-1, num_bins), num_samples=1)
        seq_action_bins = seq_action_bins.reshape((batch_size,sequence_length,1))
        # Sequence Action Bins: (Batch_Size , Sequence_Length, 1)

        # Keeping the action resiudals for the selected action bin
        flat_all_seq_action_residuals = all_seq_action_residuals.reshape((batch_size*sequence_length, num_bins, action_dim))
        flat_seq_action_residuals = flat_all_seq_action_residuals[torch.arange(flat_all_seq_action_residuals.shape[0]),seq_action_bins.flatten()]
        seq_action_residuals = flat_seq_action_residuals.reshape((batch_size, sequence_length, action_dim))
        # Sequence Action Residuals: (Batch_Size , Sequence_Length, Action_dim)

        return {"seq_action_bins": seq_action_bins, "seq_action_residuals": seq_action_residuals, "seq_action_bins_logits": seq_action_bins_logits,}

**Tesing the Transformer Head**

In [279]:
batch_size = 5
num_bins = 2
sequence_length = 10
action_dim = 3
logits = torch.rand((batch_size,sequence_length,num_bins))

head = Head(num_bins,action_dim)
output = head(logits)
seq_action_bins, seq_action_residuals, seq_action_bins_logits = output["seq_action_bins"], output["seq_action_residuals"], output["seq_action_bins_logits"]

print("Action Residuals Shape:\t\t", seq_action_residuals.shape)
print("Actions Bins Shape:\t\t", seq_action_bins.shape)
print("Actions Bins Logits Shape:\t", seq_action_bins_logits.shape)

Action Residuals Shape:		 torch.Size([5, 10, 3])
Actions Bins Shape:		 torch.Size([5, 10, 1])
Actions Bins Logits Shape:	 torch.Size([5, 10, 2])


# 4. Focal Loss for Action Bins $\lfloor a \rfloor$

$L_{focal}(p_t) = -(1 - p_t)^\gamma \log(p_t)$

Focal loss is used to handle the class imbalance in action bins by assigning more weight for hard-to-classify examples inorder to be improve performance in predicting low-probability classes (important for multi-modal behavior distributions​​).

In [280]:
# Source: https://github.com/notmahi/miniBET/blob/main/behavior_transformer/bet.py

class FocalLoss(nn.Module):

    def __init__(self, gamma: float = 0, size_average: bool = True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.size_average = size_average

    def forward(self, input, target):
        logpt = F.log_softmax(input, dim=-1)
        logpt = logpt.gather(1, target.view(-1, 1)).view(-1)
        pt = logpt.exp()

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

# 5. BeT: Stitching all the Componenets Together 

In [281]:
class BeT(nn.Module):

    def __init__(self, observation_dim, action_dim, num_bins, sequence_length, actions, gamma=2.0):
        super(BeT, self).__init__()

        self.observation_dim = observation_dim
        self.action_dim = action_dim
        self.num_bins = num_bins
        self.sequence_length = sequence_length
        
        # Initializing the Encoder Decoder which is based on the K-Means
        self.encoderDecoder = EncoderDecoder(action_dim, num_bins, actions)

        # Initializing the GPT model to be used for sequence to sequence modeling
        self.gpt = make_gpt(observation_dim,num_bins,sequence_length)
        
        # Initializing the Head that takes the sequence output of the GPT to return the action bin and action residual
        self.head = Head(num_bins, action_dim)

        # Residual Loss Function 
        self.residual_criterion = nn.MSELoss()

        # Action Bins Loss Function 
        self.bin_criterion = FocalLoss(gamma)

    def forward(self, observations_history, train_data=False):
        
        # (Batch Size, Sequence Lenght, Number of Action Bins)
        gpt_logits = self.gpt(observations_history)

        head_output = self.head(gpt_logits)
        # Predicted Sequence Action Bins: (Batch_Size, Sequence_Length, 1)
        predicted_seq_action_bins = head_output["seq_action_bins"]
        # Predicted Sequence Action Residuals: (Batch_Size, Sequence_Length, Action_dim)
        predicted_seq_action_residuals = head_output["seq_action_residuals"]
        # Predicted Sequence Action Bins Logits: (Batch_Size, Sequence_Length, Number_Action_Bins)
        seq_action_bins_logits = head_output["seq_action_bins_logits"]

        # Predicted Action: (Batch_Size, Action_dim)
        predicted_action = self.encoderDecoder.decode(predicted_seq_action_bins[:,-1,0],predicted_seq_action_residuals[:,-1,:])

        # No Training: Inference Only (No Loss Calculation)
        if not train_data:
            return predicted_action
        # Training: Return Predicted Action and Training Required Data
        else:
            return {"seq_action_bins_logits":seq_action_bins_logits, "predicted_seq_action_residuals":predicted_seq_action_residuals}

    def learn(self, observations_history, actions_history, optimizer, residual_loss_scale=1e3):
        
        # Target Sequence Action Bins: (Batch_Size, Sequence_Length, Number_Action_Bins)
        # Target Sequence Action Residuals: (Batch_Size, Sequence_Length, Action_Dims)
        target_seq_action_bin, target_seq_action_residuals = encoder_decoder.encode(actions_history)

        training_data = self.forward(observations_history, train_data=True)
        # Predicted Sequence Action Residuals: (Batch_Size, Sequence_Length, Action_dim)
        predicted_seq_action_residuals = training_data["predicted_seq_action_residuals"]
        # Predicted Sequence Action Bins Logits: (Batch_Size, Sequence_Length, Number_Action_Bins)
        seq_action_bins_logits = training_data["seq_action_bins_logits"]

        # Residual Loss
        action_residual_loss = self.residual_criterion(predicted_seq_action_residuals, target_seq_action_residuals)

        # Actions Bin Loss
        seq_action_bins_logits = seq_action_bins_logits.reshape((-1,seq_action_bins_logits.shape[-1]))
        target_seq_action_bin = target_seq_action_bin.reshape((-1,1))
        action_bins_loss = self.bin_criterion(seq_action_bins_logits, target_seq_action_bin)

        # Total Loss = Actions Bin Loss + Residual Loss * Loss Scale
        total_loss = action_bins_loss + action_residual_loss*residual_loss_scale
        
        # Training Step
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        return {"total_loss": total_loss, "action_bins_loss": action_bins_loss, "action_residual_loss": action_residual_loss}
    
    # Source: https://github.com/notmahi/miniBET/blob/main/behavior_transformer/bet.py
    def create_optimizer(self, weight_decay, learning_rate, betas):
        optimizer = self.gpt.configure_optimizers(
            weight_decay=weight_decay,
            learning_rate=learning_rate,
            betas=betas,
        )
        optimizer.add_param_group({"params": self.head.parameters()})
        return optimizer

**Testing the BeT Combined Componenet**

In [282]:
num_bins = 2
batch_size = 2
total_action_batch_size = 5
sequence_length = 10
action_dim = 3
total_actions = 100

actions_collection = torch.rand((total_actions,action_dim))
bet = BeT(observation_dim, action_dim, num_bins, sequence_length, actions_collection)

observations_history = torch.rand((batch_size,sequence_length,observation_dim))
predicted_action = bet(observations_history)
print("Action Shape: ", predicted_action.shape)

number of parameters: 4.74M
Action Shape:  torch.Size([2, 3])


**Training Test**

In [284]:
bet_optimizer = bet.create_optimizer(weight_decay=0.0002,learning_rate=0.00001,betas=[0.9,0.999])

observations_history = torch.rand((batch_size,sequence_length,observation_dim))
actions_history = torch.rand((batch_size,sequence_length,action_dim))
bet.learn(observations_history, actions_history, bet_optimizer)

{'total_loss': tensor(295.8558, grad_fn=<AddBackward0>),
 'action_bins_loss': tensor(0.2179, grad_fn=<MeanBackward0>),
 'action_residual_loss': tensor(0.2956, grad_fn=<MseLossBackward0>)}