In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import gym
from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from ncps.datasets.torch import AtariCloningDataset
from ncps.torch import CfC, CfCCell
from ncps.wirings.wirings import Wiring
import torch
from torch import nn
from typing import Optional, Union
import ncps
from ncps.torch.lstm import LSTMCell 

# 定义卷积块 (ConvBlock)
class ConvBlock(nn.Module):
    def __init__(self):
        super().__init__()
        # 定义第一个卷积层，输入通道为4，输出通道为64，卷积核大小为5，填充为2，步幅为2
        self.conv1 = nn.Conv2d(4, 64, 5, padding=2, stride=2)
        # 定义第二个卷积层，输入通道为64，输出通道为128，卷积核大小为5，填充为2，步幅为2
        self.conv2 = nn.Conv2d(64, 128, 5, padding=2, stride=2)
        # 定义第二个批量归一化层
        self.bn2 = nn.BatchNorm2d(128)
        # 定义第三个卷积层，输入通道为128，输出通道为128，卷积核大小为5，填充为2，步幅为2
        self.conv3 = nn.Conv2d(128, 128, 5, padding=2, stride=2)
        # 定义第四个卷积层，输入通道为128，输出通道为256，卷积核大小为5，填充为2，步幅为2
        self.conv4 = nn.Conv2d(128, 256, 5, padding=2, stride=2)
        # 定义第四个批量归一化层
        self.bn4 = nn.BatchNorm2d(256)

    def forward(self, x):
        # 通过第一个卷积层，并应用ReLU激活函数
        x = F.relu(self.conv1(x))
        # 通过第二个卷积层和批量归一化层，并应用ReLU激活函数
        x = F.relu(self.bn2(self.conv2(x)))
        # 通过第三个卷积层，并应用ReLU激活函数
        x = F.relu(self.conv3(x))
        # 通过第四个卷积层和批量归一化层，并应用ReLU激活函数
        x = F.relu(self.bn4(self.conv4(x)))
        # 全局平均池化，将每个通道的特征平均化
        x = x.mean((-1, -2))
        return x

# 定义包含卷积块和CfC层的网络 (ConvCfC)
class ConvCfC(nn.Module):
    def __init__(self, n_actions):
        super().__init__()
        # 初始化卷积块
        self.conv_block = ConvBlock()
        # 初始化CfC层，输入尺寸为256，隐藏单元为64，批量优先，输出尺寸为n_actions
        self.rnn = CfC(256, 64, batch_first=True, proj_size=n_actions)

    def forward(self, x, hx=None):
        batch_size = x.size(0)  # 获取批量大小
        seq_len = x.size(1)  # 获取序列长度
        # 合并时间维度和批量维度（卷积层要求这种格式）
        x = x.view(batch_size * seq_len, *x.shape[2:])
        # 应用卷积块
        x = self.conv_block(x)
        # 分离时间维度和批量维度
        x = x.view(batch_size, seq_len, *x.shape[1:])
        # 通过CfC层，hx是RNN的隐藏状态
        x, hx = self.rnn(x, hx)
        return x, hx



In [None]:
# 初始化环境
env = gym.make("ALE/Breakout-v5")
# 包装环境以使用深度强化学习的标准预处理
env = wrap_deepmind(env)

# 准备数据集
# 加载用于克隆训练的Atari数据集，使用“train”分割
train_ds = AtariCloningDataset("breakout", split="train")
# 加载用于克隆验证的Atari数据集，使用“val”分割
val_ds = AtariCloningDataset("breakout", split="val")

# 创建用于训练的数据加载器，每批次32个样本，使用4个子进程，并随机打乱数据
trainloader = DataLoader(train_ds, batch_size=32, num_workers=4, shuffle=True)
# 创建用于验证的数据加载器，每批次32个样本，使用4个子进程
valloader = DataLoader(val_ds, batch_size=32, num_workers=4)


In [None]:
# 训练和评估函数

def train_one_epoch(model, criterion, optimizer, trainloader):
    running_loss = 0.0
    pbar = tqdm(total=len(trainloader))  # 进度条
    model.train()  # 将模型设置为训练模式
    device = next(model.parameters()).device  # 获取模型所在的设备
    for i, (inputs, labels) in enumerate(trainloader):
        inputs = inputs.to(device)  # 将数据移动到模型所在的设备
        labels = labels.to(device)

        optimizer.zero_grad()  # 清零梯度
        outputs, hx = model(inputs)  # 前向传播
        labels = labels.view(-1, *labels.shape[2:])  # 展平标签
        outputs = outputs.reshape(-1, *outputs.shape[2:])  # 展平输出
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 优化器更新参数

        running_loss += loss.item()
        pbar.set_description(f"loss={running_loss / (i + 1):0.4g}")  # 更新进度条描述
        pbar.update(1)  # 进度条更新
    pbar.close()

def eval(model, valloader):
    losses, accs = [], []
    model.eval()  # 将模型设置为评估模式
    device = next(model.parameters()).device  # 获取模型所在的设备
    with torch.no_grad():  # 不计算梯度
        for inputs, labels in valloader:
            inputs = inputs.to(device)  # 将数据移动到模型所在的设备
            labels = labels.to(device)

            outputs, _ = model(inputs)
            outputs = outputs.reshape(-1, *outputs.shape[2:])  # 展平输出
            labels = labels.view(-1, *labels.shape[2:])  # 展平标签
            loss = criterion(outputs, labels)  # 计算损失
            acc = (outputs.argmax(-1) == labels).float().mean()  # 计算准确率
            losses.append(loss.item())
            accs.append(acc.item())
    return np.mean(losses), np.mean(accs)  # 返回平均损失和准确率


In [None]:
# Visualize Atari game and play endlessly

def run_closed_loop(model, env, num_episodes=None):
    obs = env.reset()
    if isinstance(obs, tuple):
        obs = obs[0]
    device = next(model.parameters()).device
    hx = None  # Hidden state of the RNN
    returns = []
    total_reward = 0
    with torch.no_grad():
        while True:
            #print(f"Original obs shape: {obs.shape}")  # 打印原始 obs 形狀
            obs = np.asarray(obs)
            if len(obs.shape) == 3:  # 確保 obs 是 3 維
                obs = np.transpose(obs, [2, 0, 1]).astype(np.float32) / 255.0
            else:
                raise ValueError(f"Unexpected obs shape: {obs.shape}")
            obs = torch.from_numpy(obs).unsqueeze(0).unsqueeze(0).to(device)
            pred, hx = model(obs, hx)
            action = pred.squeeze(0).squeeze(0).argmax().item()
            
            result = env.step(action)
            #print(result[1])
            if len(result) == 4:
                obs, r, done, info = result
            else:
                obs, r, done, info, _ = result
            
            if isinstance(obs, tuple):
                obs = obs[0]
            total_reward += r
            if done:
                obs = env.reset()
                if isinstance(obs, tuple):
                    obs = obs[0]
                hx = None  # Reset hidden state of the RNN
                returns.append(total_reward)
                total_reward = 0
                if num_episodes is not None:
                    num_episodes = num_episodes - 1
                    if num_episodes == 0:
                        return returns

# 设备设置
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model = ConvCfC(n_actions=env.action_space.n).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# 打开文件以追加模式写入日志
log_file = open("training_log_wiring.txt", "a")

for epoch in range(20):  # 多次遍历数据集
    
    train_one_epoch(model, criterion, optimizer, trainloader)
    # 在验证集上评估模型
    val_loss, val_acc = eval(model, valloader)
    
    log_message = f"Epoch {epoch+1}, val_loss={val_loss:0.4g}, val_acc={100*val_acc:0.2f}%\n"

    print(f"Epoch {epoch+1}, val_loss={val_loss:0.4g}, val_acc={100*val_acc:0.2f}%")
    # 打印到控制台
    print(log_message)
    
    # 将日志写入文件
    log_file.write(log_message)

# 关闭文件
log_file.close()


In [None]:
# 保存模型
torch.save(model.state_dict(), 'cfc_model_wiring.pt')
print("Model saved as cfc_model_wiring.pt")

In [None]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
loaded_model = ConvCfC(n_actions=env.action_space.n,).to(device)
loaded_model.load_state_dict(torch.load('cfc_model_wiring.pt'))
print("Model loaded from cfc_model_wiring.pt")

# 確保模型處於評估模式
loaded_model.eval()

# 再次運行閉環測試
returns = run_closed_loop(loaded_model, env, num_episodes=10)
print(f"Mean return {returns} (n={len(returns)})")
print(np.mean(returns))