In [1]:
%load_ext autoreload
%autoreload 2

# Small-LLM (Locomotion Language Model)

### Q: Can textual language models understand / reason about physics and locomotion?

Animals have knowledge regarding physics and locomotion


<div style="display: flex; justify-content: center; align-items: center;">
  <div style="text-align: center;">
    <img src="media/leg-giraffe.gif" alt="Giraffe walking" style="max-width: 45%; height: auto;">
    <p>Giraffe walking</p>
  </div>
  <div style="text-align: center;">
    <img src="media/half_cheetah.png" alt="Cheetah" style="max-width: 45%; height: auto;">
    <p>RL Gym Cheetah</p>
  </div>
</div>

### (1) RAG In Context Learning + Closed Loop Control

System prompt: You are an expert Mujoco Half Cheetah V0 environment controller.

```
Dynamic Prompt: 
"Time step {t}. HalfCheetah-v0 state vector has dimension 17. 
Current state: {state_list}.
Here are similar states to the current state and their corresonding actions to take you should use as a reference:
Similar state: {near_state} : action {near_action}...

Respond in strict JSON: {\"action\": [f1, f2, f3, f4, f5, f6]}. No extra text."
```


<div style="text-align: center;">
  <img src="media/closed_loop.png" alt="Closed Loop">
</div>

In [4]:
from gpt_wrapper.rag import rag_with_gpt
from visualize import replay_offscreen
import os 
import random 
import warnings
warnings.filterwarnings("ignore")

# generate some trajectories
np_actions = rag_with_gpt(max_steps=40)

# visualize as video
replay_offscreen('mujoco/halfcheetah/expert-v0', np_actions, out_path=os.path.join("/home/ubuntu/small-llm/test-decision-transformer/saved_vids", f"ragwrapper_cheetah_{random.randint(0,100000)}.mp4"))

Step 40/40
Current state: [ -0.189   3.261  -0.334   0.322   0.167  -0.483  -0.105   0.186   1.081
  -2.492   6.21   -8.33   -5.76  -10.103  -0.018 -10.015   2.661]
GPT generated action: [ 0.644  0.527 -0.183 -0.775 -0.737  0.712]

<<< Saved video to /home/ubuntu/small-llm/test-decision-transformer/saved_vids/ragwrapper_cheetah_99059.mp4 >>>



Demo videos

<div style="display: flex; justify-content: space-around; align-items: flex-start; flex-wrap: nowrap; overflow-x: auto;">
<video width="640" height="480" controls>
  <source src="saved_vids/ragwrapper_cheetah_35441.mp4" type="video/mp4">
  Your browser does not support the video tag.
</video>

<video width="640" height="480" controls>
  <source src="saved_vids/ragwrapper_cheetah_55262.mp4" type="video/mp4">
  Your browser does not support the video tag.
</video>

<video width="640" height="480" controls>
  <source src="saved_vids/ragwrapper_cheetah_99059.mp4" type="video/mp4">
  Your browser does not support the video tag.
</video>
</div>



## (2) Fine-tuned small LLM (Pythia-410M)

We freeze the entire model and only train linear encoder and decoder layers (4M trainable params)

<div style="text-align: center;">
  <img src="media/pythia_finetune.png" alt="Closed Loop">
</div>

In [5]:
from visualize import viz_driver

# Note: we can condition on our target reward 
viz_driver("pythia", target_rew=300)


<<< Saved video to /home/ubuntu/small-llm/test-decision-transformer/saved_vids/pythia_targetreward_300_cheetah_20298.mp4 >>>



Demo videos with different reward conditions


<div style="display: flex; justify-content: space-around; align-items: flex-start; flex-wrap: nowrap; overflow-x: auto;">
  <div style="text-align: center; min-width: 300px; margin: 0 10px;">
    <h4>Reward Target: 600</h4>
    <video width="640" height="480" controls>
      <source src="saved_vids/pythia_targetreward_600_cheetah_81075.mp4" type="video/mp4"> 
      Your browser does not support the video tag.
    </video>
  </div>
  
  <div style="text-align: center; min-width: 300px; margin: 0 10px;">
    <h4>Reward Target: 1200</h4>
    <video width="640" height="480" controls>
      <source src="saved_vids/pythia_targetreward_1200_cheetah_68985.mp4" type="video/mp4">
      Your browser does not support the video tag.
    </video>
  </div>
  
  <div style="text-align: center; min-width: 300px; margin: 0 10px;">
    <h4>Reward Target: 2400</h4>
    <video width="640" height="480" controls>
      <source src="saved_vids/pythia_targetreward_2400_cheetah_31850.mp4" type="video/mp4">
      Your browser does not support the video tag.
    </video>
  </div>
</div>

## (3) Train GPT2 from scratch

Following *Decision Transformer (Chen et al. 2021)*, train GPT2 decoder model (700K params)

<div style="text-align: center;">
  <img src="media/decision_transformer.png" alt="Closed Loop" width="800">
</div>

In [7]:
from visualize import viz_driver

# Note: we can condition on our target reward 
viz_driver("dt", target_rew=300)


<<< Saved video to /home/ubuntu/small-llm/test-decision-transformer/saved_vids/dt_targetreward_300_cheetah_18256.mp4 >>>



Demo videos with different reward conditions

<div style="display: flex; justify-content: space-around; align-items: flex-start; flex-wrap: nowrap; overflow-x: auto;">
  <div style="text-align: center; min-width: 320px; margin: 0 10px;">
    <h4>Reward Target: 300</h4>
      <video width="640" height="480" controls>
      <source src="saved_vids/dt_targetreward_300_cheetah_56626.mp4" type="video/mp4">
      Your browser does not support the video tag.
    </video>
  </div>
  
  <div style="text-align: center; min-width: 320px; margin: 0 10px;">
    <h4>Reward Target: 600</h4>
      <video width="640" height="480" controls>
      <source src="saved_vids/dt_targetreward_600_cheetah_58199.mp4" type="video/mp4">
      Your browser does not support the video tag.
    </video>
  </div>
  
  <div style="text-align: center; min-width: 320px; margin: 0 10px;">
    <h4>Reward Target: 1200</h4>
      <video width="640" height="480" controls>
      <source src="saved_vids/dt_targetreward_1200_cheetah_44888.mp4" type="video/mp4">
      Your browser does not support the video tag.
    </video>
  </div>
</div>

## Evaluation comparison

Comparison on fine-tuned frozen LLM with GPT trained from scratch


![Model comparison](media/model_comparison.png)

## Conclusion


<div style="display: flex; justify-content: center; align-items: center;">
  <div style="text-align: center;">
    <img src="media/leg-giraffe.gif" alt="Giraffe walking" style="max-width: 45%; height: auto;">
    <p>Giraffe walking</p>
  </div>
  <div style="text-align: center;">
    <img src="media/dt_cheetah_1-3s.gif" alt="Cheetah" style="max-width: 45%; height: auto;">
    <p>RL Gym Cheetah</p>
  </div>
</div>