In [3]:
from peft import LoraConfig, get_peft_model
import torch
import torch.nn as nn
from actor_critic import ActorCritic
from world_model import MiniWorldModel
from vq_vae import Encoder, vq_vae
from safetensors.torch import load_file
import gymnasium as gym

### Load Models


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained weights
encoder_state = load_file("pretrained/encoder.safetensors")
vq_state = load_file("pretrained/vq.safetensors")
world_model_state = torch.load("checkpoints/best_model_step_46500.pth")

encoder = Encoder().to(device)
encoder.load_state_dict(encoder_state)

vq_vae = vq_vae.to(device)
vq_vae.load_state_dict(vq_state)

world_model = MiniWorldModel(num_actions=18, num_games=6).to(device)
world_model.load_state_dict(world_model_state)

actor_critic = ActorCritic(num_actions=18, max_seq_len=32).to(device)



In [5]:
for name, module in actor_critic.named_modules():
    print(name)


embed
blocks
blocks.layers
blocks.layers.0
blocks.layers.0.self_attn
blocks.layers.0.self_attn.out_proj
blocks.layers.0.linear1
blocks.layers.0.dropout
blocks.layers.0.linear2
blocks.layers.0.norm1
blocks.layers.0.norm2
blocks.layers.0.dropout1
blocks.layers.0.dropout2
blocks.layers.1
blocks.layers.1.self_attn
blocks.layers.1.self_attn.out_proj
blocks.layers.1.linear1
blocks.layers.1.dropout
blocks.layers.1.linear2
blocks.layers.1.norm1
blocks.layers.1.norm2
blocks.layers.1.dropout1
blocks.layers.1.dropout2
blocks.layers.2
blocks.layers.2.self_attn
blocks.layers.2.self_attn.out_proj
blocks.layers.2.linear1
blocks.layers.2.dropout
blocks.layers.2.linear2
blocks.layers.2.norm1
blocks.layers.2.norm2
blocks.layers.2.dropout1
blocks.layers.2.dropout2
blocks.layers.3
blocks.layers.3.self_attn
blocks.layers.3.self_attn.out_proj
blocks.layers.3.linear1
blocks.layers.3.dropout
blocks.layers.3.linear2
blocks.layers.3.norm1
blocks.layers.3.norm2
blocks.layers.3.dropout1
blocks.layers.3.dropout2


### Add lora


In [20]:
lora_config = LoraConfig(
    r=16,  # Low-rank dimension (adjust 8-32 based on compute)
    lora_alpha=32,  # Scaling factor
    target_modules=[
        "out_proj",  # Attention output Linear
        "linear1",  # FFN first Linear
        "linear2",  # FFN second Linear
        "obs_head.1",  # obs_head Linear
        "reward_head.1",  # reward_head Linear
    ],  # Use '*' wildcard for layers 0-5
    lora_dropout=0.1,
    bias="none",  # Don't adapt biases (optional, but efficient)
    modules_to_save=[
        "obs_embed",
        "action_embed",
        "game_embed",
    ],  # Fully train/save embeddings
)

world_model = get_peft_model(world_model, lora_config)
world_model.print_trainable_parameters()

trainable params: 448,528 || all params: 5,719,313 || trainable%: 7.8423
