In [None]:

#@title 链接Google Drive
from google.colab import drive
drive.mount('/content/drive')

## 全连接

In [None]:
#@title vae全链接训练

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

# 定义VAE模型
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        # 编码器网络
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # 均值和对数方差
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # 解码器网络
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 计算VAE的损失函数
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# 参数设置
input_dim = 28 * 28  # MNIST数据集的输入维度
hidden_dim = 400
latent_dim = 20
lr = 1e-3
batch_size = 128
epochs = 10

# 检查是否可以使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 加载MNIST数据集
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 初始化模型、优化器
model = VAE(input_dim, hidden_dim, latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# 训练VAE
model.train()
for epoch in range(epochs):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, input_dim).to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_loader.dataset):.4f}')

# 生成样本
model.eval()
with torch.no_grad():
    z = torch.randn(64, latent_dim).to(device)
    sample = model.decode(z).cpu()
    sample = sample.view(64, 1, 28, 28)


In [None]:
#@title 验证模型

import matplotlib.pyplot as plt

def show_images(original, reconstructed, num_images=10):
    plt.figure(figsize=(20, 4))
    for i in range(num_images):
        # 展示原始图像
        ax = plt.subplot(2, num_images, i + 1)
        plt.imshow(original[i].reshape(28, 28), cmap="gray")
        ax.axis('off')

        # 展示重构图像
        ax = plt.subplot(2, num_images, i + 1 + num_images)
        plt.imshow(reconstructed[i].reshape(28, 28), cmap="gray")
        ax.axis('off')

    plt.show()

# 从测试集加载数据
test_loader = DataLoader(datasets.MNIST('./data', train=False, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True)

# 验证重构效果
model.eval()
with torch.no_grad():
    for data, _ in test_loader:
        data = data.view(-1, input_dim).to(device)
        recon_data, _, _ = model(data)
        show_images(data.cpu(), recon_data.cpu())
        break

## 卷积


In [None]:
#@title vae卷积训练


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

# 参数设置
input_dim = 28 * 28  # MNIST数据集的输入维度
hidden_dim = 400
latent_dim = 20
lr = 1e-3
batch_size = 128
epochs = 10

# 检查是否可以使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 定义卷积VAE模型
class ConvVAE(nn.Module):
    def __init__(self, latent_dim):
        super(ConvVAE, self).__init__()

        # 编码器网络
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # [batch, 1, 28, 28] -> [batch, 32, 14, 14]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # [batch, 32, 14, 14] -> [batch, 64, 7, 7]
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # [batch, 64, 7, 7] -> [batch, 128, 4, 4]
            nn.ReLU(),
            nn.Flatten()  # [batch, 128, 4, 4] -> [batch, 128*4*4]
        )

        self.fc_mu = nn.Linear(128 * 3 * 3, latent_dim)
        self.fc_logvar = nn.Linear(128 * 3 * 3, latent_dim)

        # 解码器网络
        self.decoder_input = nn.Linear(latent_dim, 128 * 3 * 3)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(72, 64, kernel_size=3, stride=2, padding=1),  # [batch, 128, 4, 4] -> [batch, 64, 8, 8]
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # [batch, 64, 8, 8] -> [batch, 32, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),  # [batch, 32, 16, 16] -> [batch, 1, 28, 28]
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        # print("x shape:", x.shape)
        # print("h shape:", h.shape)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        # print("hh shape:", h.shape)
        h = h.view(128, -1, 4, 4)  # 重新reshape为解码器的输入维度
        # print("hhh shape:", h.shape)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        tmp = self.decode(z)
        # print("z shape:", z.shape)
        # print("tmp shape:", tmp.shape)
        return self.decode(z), mu, logvar

# 计算VAE的损失函数
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# 参数设置
latent_dim = 20
lr = 1e-3
batch_size = 128
epochs = 10

# 加载MNIST数据集
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 初始化模型、优化器
model = ConvVAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# 训练VAE
model.train()
for epoch in range(epochs):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        if (data.shape[0] < 128 ):
          continue
        recon_batch, mu, logvar = model(data.to(device))
        loss = loss_function(recon_batch, data.to(device), mu, logvar)
        loss.backward()
        train_loss += loss.item()
        # print("recon_batch: ", recon_batch.shape)

        # if (recon_batch.shape[1] < 128 ):
        #   continue
        optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_loader.dataset):.4f}')

# # 生成样本
# model.eval()
# with torch.no_grad():
#     z = torch.randn(64, latent_dim).to(device)
#     sample = model.decode(z).to(device).cpu()
#     sample = sample.view(64, 1, 28, 28)


In [None]:
#@title 验证模型

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def show_images(original, reconstructed, num_images=10):
    plt.figure(figsize=(20, 4))
    for i in range(num_images):
        # 展示原始图像
        ax = plt.subplot(2, num_images, i + 1)
        plt.imshow(original[i].reshape(28, 28), cmap="gray")
        ax.axis('off')

        # 展示重构图像
        ax = plt.subplot(2, num_images, i + 1 + num_images)
        plt.imshow(reconstructed[i].reshape(28, 28), cmap="gray")
        ax.axis('off')

    plt.show()

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载MNIST测试集数据
test_loader = DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

# 验证重构效果
model.eval()
model.to(device)  # 确保模型在正确的设备上
with torch.no_grad():
    for data, _ in test_loader:
        data = data.to(device)  # 不需要展平数据
        recon_data, _, _ = model(data)
        show_images(data.cpu(), recon_data.cpu())
        break  # 仅显示一个批次的数据


## vae和gan

In [None]:
#@title 训练

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 参数设置
latent_dim = 20
lr = 1e-3
batch_size = 128
epochs = 10


# 加载MNIST数据集
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),  # [batch, 1, 28, 28] -> [batch, 32, 14, 14]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # [batch, 32, 14, 14] -> [batch, 64, 7, 7]
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # [batch, 64, 7, 7] -> [batch, 128, 3, 3]
            nn.ReLU(),
            nn.Flatten()  # [batch, 128, 3, 3] -> [batch, 128*3*3]
        )

        self.fc_mu = nn.Linear(128 * 3 * 3, latent_dim)
        self.fc_logvar = nn.Linear(128 * 3 * 3, latent_dim)

    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.decoder_input = nn.Linear(latent_dim, 128 * 3 * 3)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=1),  # [batch, 128, 3, 3] -> [batch, 64, 6, 6]
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # [batch, 64, 6, 6] -> [batch, 32, 12, 12]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),  # [batch, 32, 12, 12] -> [batch, 1, 28, 28]
            nn.Sigmoid()
        )

    def forward(self, z):
        h = self.decoder_input(z)
        h = h.view(z.size(0), 128, 3, 3)  # 调整为解码器输入维度
        return self.decoder(h)



class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128 * 3 * 3, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.discriminator(x)

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

def vae_gan_loss(recon_x, x, mu, logvar, D_real, D_fake):
    # VAE损失
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # GAN损失
    D_loss = nn.functional.binary_cross_entropy(D_real, torch.ones_like(D_real)) + \
             nn.functional.binary_cross_entropy(D_fake, torch.zeros_like(D_fake))

    G_loss = nn.functional.binary_cross_entropy(D_fake, torch.ones_like(D_fake))

    return BCE + KLD + G_loss, D_loss

# 初始化模型
latent_dim = 20
encoder = Encoder(latent_dim).to(device)
decoder = Decoder(latent_dim).to(device)
discriminator = Discriminator().to(device)

optimizer_vae = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)

# 训练循环
for epoch in range(epochs):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)

        # 训练判别器
        optimizer_d.zero_grad()

        # VAE前向传播
        mu, logvar = encoder(data)
        z = reparameterize(mu, logvar)
        recon_batch = decoder(z)

        D_real = discriminator(data)
        D_fake = discriminator(recon_batch.detach())  # 注意这里使用了detach

        # 计算判别器的损失并进行反向传播
        loss_d = nn.functional.binary_cross_entropy(D_real, torch.ones_like(D_real)) + \
                 nn.functional.binary_cross_entropy(D_fake, torch.zeros_like(D_fake))
        loss_d.backward()
        optimizer_d.step()

        # 训练生成器（VAE）
        optimizer_vae.zero_grad()

        # 为生成器重新计算D_fake_for_g
        D_fake_for_g = discriminator(recon_batch)
        loss_g = nn.functional.binary_cross_entropy(D_fake_for_g, torch.ones_like(D_fake))

        # 计算VAE的损失并进行反向传播
        loss_vae = nn.functional.binary_cross_entropy(recon_batch, data, reduction='sum') + \
                   -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + \
                   loss_g
        loss_vae.backward()
        optimizer_vae.step()

        train_loss += loss_vae.item()

    print(f'Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset):.4f}')



In [None]:
#@title 验证

import matplotlib.pyplot as plt

with torch.no_grad():
    z = torch.randn(64, latent_dim).to(device)
    generated_images = decoder(z).cpu()

# 展示生成的图像
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(generated_images[i].view(28, 28), cmap='gray')
    ax.axis('off')
plt.show()
