## **1. 进一步优化 VAE 与 CVAE 模型**

### **VAE 模型改进**

1. **Batch Normalization 和 Dropout**：增强泛化能力。
2. **改进的损失函数**：考虑 KL 散度与重构误差的平衡。
3. **多层网络**：增加模型复杂性以提升表示能力。

In [37]:
import torch
import torch.nn as nn

class ConditionalVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, cond_dim):
        super(ConditionalVAE, self).__init__()
        
        # 编码器
        self.fc1 = nn.Linear(input_dim + cond_dim, hidden_dim)  # 注意这里的输入维度要匹配拼接后的维度
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)  # 均值
        self.fc22 = nn.Linear(hidden_dim, latent_dim)  # 方差

        # 解码器
        self.fc3 = nn.Linear(latent_dim + cond_dim, hidden_dim)  # 解码时也要拼接条件
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)
        self.dropout = nn.Dropout(0.3)

    def encode(self, x, cond):
        # 拼接输入和条件
        x_cond = torch.cat([x, cond], dim=-1)
        h1 = torch.relu(self.bn1(self.fc1(x_cond)))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, cond):
        # 拼接潜在向量和条件
        z_cond = torch.cat([z, cond], dim=-1)
        h3 = torch.relu(self.bn2(self.fc3(z_cond)))
        h3 = self.dropout(h3)
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x, cond):
        mu, logvar = self.encode(x, cond)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, cond), mu, logvar


In [38]:
# 数据准备
data = torch.randn(32, 100)  # 32 条样本，每条 100 维
labels = torch.nn.functional.one_hot(torch.randint(0, 5, (32,)), num_classes=5).float()  # 生成 one-hot 标签

# 创建数据集和 DataLoader
from torch.utils.data import DataLoader, TensorDataset

dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 初始化模型
cvae_model = ConditionalVAE(input_dim=100, hidden_dim=50, latent_dim=10, cond_dim=5)


1. 

### **CVAE 模型实现**

1. **条件输入**：嵌入特定标签以生成具有条件约束的序列。
2. **标签与特征连接**：将标签与原始输入连接以进行编码。

In [39]:
import torch.optim as optim

def train(model, dataloader, epochs=100):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, cond in dataloader:
            optimizer.zero_grad()
            recon_x, mu, logvar = model(x, cond)  # 正确传入条件
            loss = ((x - recon_x) ** 2).sum() + 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")


## **2. 训练与比较 VAE 和 CVAE**

### **训练函数**

In [40]:
def generate_sequences(model, num_sequences=100, cond=None):
    sequences = []
    with torch.no_grad():
        for _ in range(num_sequences):
            z = torch.randn(1, model.fc21.out_features)  # 使用潜在向量生成
            generated = model.decode(z, cond).numpy().flatten()
            sequence = ''.join(random.choices("ACDEFGHIKLMNPQRSTVWY", k=len(generated)))
            sequences.append(sequence)
    return sequences

def evaluate_sequences(sequences):
    diversity = len(set(sequences)) / len(sequences)
    stability = np.mean([sum(1 for aa in seq if aa in "AILMFWV") / len(seq) for seq in sequences])
    return diversity, stability


In [41]:
# 训练 CVAE
train(cvae_model, dataloader)

# 生成与评估序列
cvae_sequences = generate_sequences(cvae_model, num_sequences=100, cond=torch.tensor([[1, 0, 0, 0, 0]]))

# 评估结果
cvae_diversity, cvae_stability = evaluate_sequences(cvae_sequences)
print(f"CVAE - Diversity: {cvae_diversity:.2f}, Stability: {cvae_stability:.2f}")


Epoch 1, Loss: 3780.9712
Epoch 2, Loss: 3760.2236
Epoch 3, Loss: 3758.2026
Epoch 4, Loss: 3732.9578
Epoch 5, Loss: 3707.5842
Epoch 6, Loss: 3688.3037
Epoch 7, Loss: 3680.5059
Epoch 8, Loss: 3644.3115
Epoch 9, Loss: 3634.7539
Epoch 10, Loss: 3613.0684
Epoch 11, Loss: 3587.0449
Epoch 12, Loss: 3574.7756
Epoch 13, Loss: 3531.6206
Epoch 14, Loss: 3541.4749
Epoch 15, Loss: 3512.5715
Epoch 16, Loss: 3471.8289
Epoch 17, Loss: 3441.1858
Epoch 18, Loss: 3451.1194
Epoch 19, Loss: 3414.0679
Epoch 20, Loss: 3382.2461
Epoch 21, Loss: 3346.7485
Epoch 22, Loss: 3349.8340
Epoch 23, Loss: 3298.3862
Epoch 24, Loss: 3279.8489
Epoch 25, Loss: 3247.9692
Epoch 26, Loss: 3181.2300
Epoch 27, Loss: 3133.9131
Epoch 28, Loss: 3087.9595
Epoch 29, Loss: 3020.1404
Epoch 30, Loss: 2889.3647
Epoch 31, Loss: 2828.0889
Epoch 32, Loss: 2617.7615
Epoch 33, Loss: 2403.8254
Epoch 34, Loss: 2061.0054
Epoch 35, Loss: 1576.6433
Epoch 36, Loss: 864.0876
Epoch 37, Loss: -114.1929
Epoch 38, Loss: -1463.3604
Epoch 39, Loss: -3265

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 50])

生成与评估函数