In [None]:

#@title 链接Google Drive
from google.colab import drive
drive.mount('/content/drive')

## 数据处理

In [None]:
#@title 读取一批次图片
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torchvision

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])

# 创建 DataLoader
batch_size = 64
dataloader = DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

# 获取一个批次的数据
data_iter = iter(dataloader)
images, labels = next(data_iter)

# 反转处理，转换为 NumPy 数组以便显示
def imshow(img):
    img = img / 2 + 0.5  # 反归一化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 显示前 8 张图像
imshow(torchvision.utils.make_grid(images[:8]))

# 打印标签
print('Labels:', labels[:8].numpy())

In [7]:
#@title 简单数据
# 数据集加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataloader = DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

##模型

In [3]:
#@title 初始化模型参数
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# 超参数
latent_dim = 100
batch_size = 64
lr = 0.0002
epochs = 12

In [42]:
#@title 1、线性模型
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import torchvision.utils as vutils

# 超参数
latent_dim = 100
batch_size = 64
lr = 0.0002
epochs = 1

# 生成器网络
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        # print("img shape:", img.shape)
        # print("z img shape:", z.shape)
        img = img.view(img.size(0), 1, 28, 28)
        return img

# 判别器网络
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        # print("validity shape:", validity.shape)
        return validity








In [None]:
#@title   1.1、第三方全连接

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity

In [97]:
#@title 2、conv2d模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 1024, kernel_size=2, stride=1, padding=0),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1),

            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), latent_dim, 1, 1)  # 调整输入的形状
        img = self.model(z)
        print("img shape:", img.shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=3, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        print("shape:", validity.shape)
        return validity.view(validity.size(0), -1)


In [105]:
#@title 2.1 自注意力

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = torch.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)

        out = self.gamma * out + x
        return out

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(512),  # 加入自注意力层
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(256),  # 加入自注意力层
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), latent_dim, 1, 1)  # 调整输入的形状
        img = self.model(z)
        print("img shape:", img.shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(64),  # 加入自注意力层
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=3, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        print("shape:", validity.shape)
        return validity.view(validity.size(0), -1)


In [4]:
#@title 2.2 自注意力stable diffusion
import math
from torch.nn import functional as F

class SelfAttention(nn.Module):
    def __init__(self, d_embed, in_proj_bias=True, out_proj_bias=True, n_heads=8):
        super().__init__()
        n_heads = 8
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)

    def forward(self, x):
        # x: (Batch_Size, Channels, Height, Width)
        batch_size, channels, height, width = x.shape

        # Reshape to (Batch_Size, Height * Width, Channels)
        x = x.view(batch_size, channels, height * width).transpose(1, 2)

        # (Batch_Size, Seq_Len, Dim)
        seq_len = x.size(1)
        d_embed = x.size(2)

        # (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
        q, k, v = self.in_proj(x).chunk(3, dim=-1)

        # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, H, Seq_Len, Dim / H)
        q = q.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)

        # Self-attention calculation
        weight = q @ k.transpose(-2, -1) / math.sqrt(self.d_head)
        weight = F.softmax(weight, dim=-1)
        output = weight @ v

        # Reshape back to (Batch_Size, Seq_Len, Dim)
        output = output.transpose(1, 2).reshape(batch_size, seq_len, d_embed)

        # Apply the final linear projection
        output = self.out_proj(output)

        # Reshape back to (Batch_Size, Channels, Height, Width)
        output = output.transpose(1, 2).view(batch_size, channels, height, width)

        return output

# Generator 和 Discriminator 类的实现保持不变，只是在适当的位置使用 SelfAttention



class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(512),  # 加入自注意力层
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(256),  # 加入自注意力层
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), latent_dim, 1, 1)  # 调整输入的形状
        img = self.model(z)
        print("img shape:", img.shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            SelfAttention(64),  # 加入自注意力层
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=3, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, img):
        validity = self.model(img)
        print("shape:", validity.shape)
        return validity.view(validity.size(0), -1)


In [17]:
#@title 损失器

# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()

# 损失函数
adversarial_loss = nn.BCELoss()

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

## 训练

In [None]:
#@title 训练GAN

import os
import torch

# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()

# 检查生成器权重文件是否存在
generator_weights_path = '/content/drive/MyDrive/generator_epoch.pth'
if os.path.exists(generator_weights_path):
    generator.load_state_dict(torch.load(generator_weights_path))
    print(f"加载生成器的权重成功：{generator_weights_path}")
else:
    print(f"生成器权重文件未找到：{generator_weights_path}")

# 检查判别器权重文件是否存在
discriminator_weights_path = '/content/drive/MyDrive/discriminator_epoch.pth'
if os.path.exists(discriminator_weights_path):
    discriminator.load_state_dict(torch.load(discriminator_weights_path))
    print(f"加载判别器的权重成功：{discriminator_weights_path}")
else:
    print(f"判别器权重文件未找到：{discriminator_weights_path}")



# 损失函数
adversarial_loss = nn.BCELoss()

# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # 标签
        # valid = torch.ones(imgs.size(0), 1)
        # fake = torch.zeros(imgs.size(0), 1)

        valid = torch.full((imgs.size(0), 1), 0.9)  # 真实标签设置为 0.9
        fake = torch.full((imgs.size(0), 1), 0.1)  # 生成标签设置为 0.1


        # -----------------
        # 训练生成器
        # -----------------
        optimizer_G.zero_grad()

        z = torch.randn(imgs.size(0), latent_dim)
        gen_imgs = generator(z)

        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        # 训练判别器
        # ---------------------
        optimizer_D.zero_grad()

        real_loss = adversarial_loss(discriminator(imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # 打印损失
        print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")

        # if i > 2:
        #   break

    # 每个epoch后生成一些图像
    with torch.no_grad():
        z = torch.randn(64, latent_dim)
        gen_imgs = generator(z)
        gen_imgs = gen_imgs.view(gen_imgs.size(0), 1, 28, 28)  # 批量图像应为 [batch_size, channels, height, width]

        # 将图像保存为单个文件
        grid_img = vutils.make_grid(gen_imgs, nrow=8, normalize=True, scale_each=True)
        # save_path = 'generated_images' + str(epoch) + '.png'
        save_path = f'generated_images_{epoch}.png'
        vutils.save_image(grid_img, save_path)

        print(f'生成的图像已保存到 {save_path}')

    # 保存模型权重
    torch.save(generator.state_dict(), f"/content/drive/MyDrive/generator_epoch.pth")
    torch.save(discriminator.state_dict(), f"/content/drive/MyDrive/discriminator_epoch.pth")

    print(f"模型权重已保存到 '/content/drive/MyDrive/generator_epoch.pth' 和 '/content/drive/MyDrive/discriminator_epoch.pth'")

##验证图片

In [None]:
#@title 显示文件夹png后缀的图片
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# 定义图片文件夹路径
folder_path = '/content'  # 替换为你的文件夹路径

# 获取文件夹中所有以.png结尾的文件名
png_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

# 遍历并显示每个PNG图片
for file_name in png_files:
    img_path = os.path.join(folder_path, file_name)
    img = mpimg.imread(img_path)
    plt.imshow(img)
    plt.axis('off')  # 隐藏坐标轴
    plt.title(file_name)  # 可选：显示文件名作为标题
    plt.show()


