In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer
import cv2
import numpy as np
import os
import json
import glob
import matplotlib.pyplot as plt

In [6]:

class VAE64(nn.Module):
    def __init__(self, latent_dim=128):
        super(VAE64, self).__init__()
        
        # Encoder: 64x64 -> 32x32 -> 16x16 -> 8x8
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1), 
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1), 
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1), 
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 512),
            nn.GELU()
        )
        
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)
        
        # Decoder: 1x1 -> 8x8 -> 16x16 -> 32x32 -> 64x64
        self.decoder_input = nn.Linear(latent_dim, 512)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (512, 1, 1)),
            nn.ConvTranspose2d(512, 128, 8, stride=1, padding=0), # 8x8
            nn.BatchNorm2d(128),
            nn.GELU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 16x16
            nn.BatchNorm2d(64),
            nn.GELU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),   # 32x32
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),    # 64x64
            nn.Sigmoid() 
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


class VLA_WorldModel(nn.Module):
    def __init__(self, latent_dim=128, action_dim=3, text_embed_dim=512):
        super(VLA_WorldModel, self).__init__()
        
        # Combined Input: Latent (128) + Action (3) + Text (512) = 643
        self.network = nn.Sequential(
            nn.Linear(latent_dim + action_dim + text_embed_dim, 512),
            nn.GELU(),
            nn.Linear(512, 512),
            nn.GELU(),
            nn.Linear(512, 256),
            nn.GELU(),
            nn.Linear(256, latent_dim) # Predicts next Z
        )

    def forward(self, z, action, text_embedding):
        # Concatenate everything into one long vector
        # z: (batch, 128), action: (batch, 3), text_embedding: (batch, 512)
        x = torch.cat([z, action, text_embedding], dim=-1)
        return self.network(x)

class CLIPHandler:
    def __init__(self, device):
        self.device = device
        self.model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
        self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        # Freeze CLIP weights (we only want to use it, not train it)
        for param in self.model.parameters():
            param.requires_grad = False

    def embed_text(self, text_list):
        inputs = self.tokenizer(text_list, padding=True, return_tensors="pt").to(self.device)
        outputs = self.model(**inputs)
        # We use the pooled output (final representation of the sentence)
        return outputs.pooler_output

# --- 2. The Tester Engine ---

class VLATester:
    def __init__(self, vae_path, wm_path, device):
        self.device = device
        self.vae = VAE64().to(device)
        self.vae.load_state_dict(torch.load(vae_path, map_location=device))
        self.vae.eval()

        self.wm = VLA_WorldModel().to(device)
        self.wm.load_state_dict(torch.load(wm_path, map_location=device))
        self.wm.eval()

        self.clip = CLIPHandler(device)
        self.transform = transforms.Compose([
            transforms.ToPILImage(), transforms.Resize((64, 64)), transforms.ToTensor()
        ])

    def test_sequence(self, episode_path, num_steps=5):
        """ Tests if the model can predict a sequence of actions accurately. """
        
        # Load the meta data (Instruction and Actions)
        with open(os.path.join(episode_path, "meta.json"), "r") as f:
            meta = json.load(f)
        
        instruction = meta["instruction"]
        actions = meta["actions"][:num_steps]
        
        # Load the initial frame (T=0)
        frames = sorted(glob.glob(os.path.join(episode_path, "*.png")))
        img_start = cv2.imread(frames[0], cv2.IMREAD_GRAYSCALE)
        img_tensor = self.transform(img_start).unsqueeze(0).to(self.device)

        # Get Text Embedding
        text_embed = self.clip.embed_text(instruction)

        # Storage for results
        predicted_images = []
        actual_images = []

        with torch.no_grad():
            # Initial Latent
            z, _ = self.vae.encode(img_tensor)

            for i in range(num_steps):
                # 1. Prediction step
                action_tensor = torch.tensor([actions[i]], dtype=torch.float32).to(self.device)
                z = self.wm(z, action_tensor, text_embed)
                
                # 2. Decode Predicted Latent
                pred_img = self.vae.decode(z).squeeze().cpu().numpy()
                predicted_images.append(pred_img)

                # 3. Load Actual Frame for Comparison
                actual_img = cv2.imread(frames[i+1], cv2.IMREAD_GRAYSCALE)
                actual_img = cv2.resize(actual_img, (64, 64)) / 255.0
                actual_images.append(actual_img)

        self.visualize(instruction, actual_images, predicted_images)

    def visualize(self, instruction, actuals, preds):
        n = len(actuals)
        fig, axes = plt.subplots(2, n, figsize=(15, 6))
        for i in range(n):
            axes[0, i].imshow(actuals[i], cmap='gray')
            axes[0, i].set_title(f"Target T+{i+1}")
            axes[0, i].axis('off')

            axes[1, i].imshow(preds[i], cmap='gray')
            axes[1, i].set_title(f"Dream T+{i+1}")
            axes[1, i].axis('off')
        
        plt.suptitle(f"Instruction: {instruction}", fontsize=16)
        plt.tight_layout()
        plt.show()

# --- 3. Run Test Cases ---

if __name__ == "__main__":
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tester = VLATester("vae_vla_base.pth", "world_model.pth", DEVICE)

    # Test Case 1: Load a specific episode from your data
    # Change 'ep_0' to an episode that exists in your 'drawing_data' folder
    try:
        tester.test_sequence("drawing_data/ep_0", num_steps=5)
    except Exception as e:
        print(f"Test failed: {e}. Check if 'drawing_data/ep_0' exists.")

RuntimeError: Error(s) in loading state_dict for VLA_WorldModel:
	Missing key(s) in state_dict: "network.6.weight", "network.6.bias". 
	Unexpected key(s) in state_dict: "latent_norm.weight", "latent_norm.bias", "action_norm.weight", "action_norm.bias", "text_norm.weight", "text_norm.bias". 
	size mismatch for network.4.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for network.4.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).