In [1]:
import os

# Restrict PyTorch to only see GPU 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch

if torch.cuda.is_available():
    print(f"Number of GPUs available: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("CUDA is not available, using CPU.")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Number of GPUs available: 1
GPU 0: NVIDIA L40S
Using device: cuda:0


In [9]:
from spatialvla.mobilevlm.model.diffusion_heads import DiffusionActionHead
import torch
import torch.optim as optim
from tqdm import tqdm
DiffusionHead = {
    'head_type': 'Diffusion',
    'hidden_projection': 'pass',
    'use_map' : True,
    'max_action': 5.0,
    'loss_type': 'mse',
    'time_dim':  32,
    'num_blocks': 3,
    'dropout_rate': 0.0,
    'hidden_dim': 1024,
    'use_layer_norm':True,
    'diffusion_steps': 20,
    'n_diffusion_samples': 1,
}

In [11]:
model = DiffusionActionHead(in_dim=32, head_args=DiffusionHead).cuda()

In [12]:
import numpy as np
x = np.linspace(0, 20, 7)
y = np.sin(x)
y = torch.Tensor(y).unsqueeze(0).unsqueeze(0).cuda()

In [13]:
epochs = 10000
learning_rate = 1e-3
batch_size = 16
# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [16]:
model.train()
# Training loop
for epoch in tqdm(range(epochs)):
    model.train()     
    optimizer.zero_grad()
    with torch.autocast(device_type='cuda', dtype=torch.float16):
        labels_encoded = torch.zeros(16, 1, 32).cuda()
        loss = model.loss(labels_encoded, y.expand(batch_size, -1, -1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:21<00:00, 122.29it/s]


In [18]:
loss.item()

0.005590737797319889

In [38]:
p, h = model.predict_action(torch.zeros(1, 1, 32).cuda(), 5, return_history=True)

In [39]:
len(h)

6

In [41]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# Assuming `h` is a list of tensors
frames = []

# Loop through each element in `h` to create and save each frame
for h_elem in h:
    # Plot the data
    fig, ax = plt.subplots()
    ax.plot(range(7), h_elem[0].detach().cpu().numpy())

    # Save the plot as an image in memory
    fig.canvas.draw()
    image = np.array(fig.canvas.renderer.buffer_rgba())
    frames.append(Image.fromarray(image))

    # Close the figure to avoid display overlap and memory issues
    plt.close(fig)

# Save frames as an animated GIF
frames[0].save("animation.gif", save_all=True, append_images=frames[1:], loop=0, duration=200)
