In [None]:
from hydra import initialize_config_dir, compose
from gaze_av_aloha.configs import Config
from omegaconf import OmegaConf
import gaze_av_aloha
from gaze_av_aloha.policies.gaze_policy.gaze_policy import GazePolicy
from gym_av_aloha.datasets.av_aloha_dataset import AVAlohaDataset, AVAlohaDatasetMeta
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
# GFLOPs Calculation

from fvcore.nn import FlopCountAnalysis
import torch
import time

os.environ["CUDA_VISIBLE_DEVICES"] = "3" 

# Path to your config directory (adjust as needed)
config_dir = os.path.abspath("../configs")

# overrides = [
#     "policy=vit_policy",
#     "task=av_aloha_sim_thread_needle",
#     "policy.visualize=False",
# ]
# overrides = [
#     "policy=low_res_vit_policy",
#     "task=av_aloha_sim_thread_needle",
#     "policy.visualize=False",
# ]
overrides = [
    "policy=foveated_vit_policy",
    "task=av_aloha_sim_thread_needle",
    "policy.visualize=False",
]
# overrides = [
#     "policy=foveated_vit_policy",
#     "task=av_aloha_sim_thread_needle",
#     "policy.visualize=False",
#     "policy.use_gaze_as_action=false", 
#     "policy.gaze_model_repo_id=iantc104/gaze_model_av_aloha_sim_thread_needle",
# ]

device = torch.device("cuda")



with initialize_config_dir(config_dir=config_dir, job_name="my_app"):
    cfg: Config = compose(config_name="default", overrides=overrides)

dataset_meta = AVAlohaDatasetMeta(repo_id=cfg.task.dataset_repo_id, root=cfg.task.dataset_root)
policy = GazePolicy(cfg.policy, cfg.task, dataset_meta.stats).to(device)
policy = policy.eval()

dataset = AVAlohaDataset(
    repo_id=cfg.task.dataset_repo_id,
    delta_timestamps=policy.get_delta_timestamps(),
)



The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize_config_dir(config_dir=config_dir, job_name="my_app"):


In [2]:
# Example model and input
model = policy.flow.backbone.backbone
patch_size = policy.flow.backbone.tokenizer.token_size
x = torch.randn(1, policy.flow.backbone.get_num_tokens(*cfg.policy.input_shape, device=device), 3, patch_size, patch_size).to(device)
flops = FlopCountAnalysis(model, x)
print(f"FLOPs: {flops.total()}")  # total in float
print(f"GFLOPs: {flops.total() / 1e9:.2f}")

n = 100
with torch.inference_mode():
    start_time = time.perf_counter()
    for _ in range(n):
        model(x)
    end_time = time.perf_counter()
inference_time = (end_time - start_time) / n 
print(f" inference time: {inference_time:.6f} seconds")

del x
del flops
del model


Unsupported operator aten::add encountered 25 time(s)
Unsupported operator aten::div encountered 12 time(s)
Unsupported operator aten::unflatten encountered 12 time(s)
Unsupported operator aten::mul encountered 60 time(s)
Unsupported operator aten::softmax encountered 12 time(s)
Unsupported operator aten::gelu encountered 12 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
blocks.0.attn.out_proj, blocks.0.drop_path, blocks.1.attn.out_proj, blocks.1.drop_path, blocks.10.attn.out_proj, blocks.10.drop_path, blocks.11.attn.out_proj, blocks.11.drop_path, blocks.2.attn.out_proj, blocks.2.drop_path, blocks.3.attn.out_proj, blocks.3.drop_path, blocks.4.attn.out_proj, blocks.4.drop_path, blocks.5.attn.out_proj, blocks.

FLOPs: 1982435328
GFLOPs: 1.98
 inference time: 0.003605 seconds


In [4]:
# Inference Time Calc
del policy
torch.cuda.empty_cache()

policy = GazePolicy(cfg.policy, cfg.task, dataset_meta.stats).to(device)
policy = policy.eval()

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
)
batch = next(iter(dataloader))

for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        try:
            batch[k] = v.to(device)[:, :cfg.policy.n_obs_steps]
        except:
            print(f"Skipping {k} due to shape mismatch")

if torch.cuda.is_available():
    torch.cuda.synchronize()
policy.eval()
n = 100
with torch.inference_mode():
    start_time = time.perf_counter()
    for _ in range(n):
        _ = policy.generate_actions(batch)
    end_time = time.perf_counter()
inference_time = (end_time - start_time) / n 
print(f"Policy inference time: {inference_time:.6f} seconds")

# Print memory used
allocated = torch.cuda.memory_allocated() / 1024**2  # in MB
reserved = torch.cuda.memory_reserved() / 1024**2  # in MB

print(f"Memory allocated: {allocated:.2f} MB")
print(f"Memory reserved:  {reserved:.2f} MB")

del batch

Skipping episode_index due to shape mismatch
Policy inference time: 0.087952 seconds
Memory allocated: 671.56 MB
Memory reserved:  986.00 MB


In [6]:
# training time calc

# Inference Time Calc
del policy
torch.cuda.empty_cache()

policy = GazePolicy(cfg.policy, cfg.task, dataset_meta.stats).to(device)
policy = policy.train()

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
)
batch = next(iter(dataloader))

for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        batch[k] = v.to(device)

if torch.cuda.is_available():
    torch.cuda.synchronize()

policy.train()
# Optimizer and loss
optimizer = torch.optim.AdamW(policy.parameters(), lr=1e-4)

n = 100
start_time = time.perf_counter()
for _ in range(n):
    optimizer.zero_grad()
    loss, _ = policy.forward(batch)
    loss.backward()
    optimizer.step()
end_time = time.perf_counter()
inference_time = (end_time - start_time) / n 
print(f"Policy train time: {inference_time:.6f} seconds")

del batch

Policy train time: 0.123751 seconds


In [11]:
num_total_params = sum(p.numel() for p in policy.flow.pool.parameters())
print(num_total_params)

15774720
