In [116]:
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
import os
import argparse
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
from torch.optim import Adam,AdamW
from tqdm.auto import tqdm
import torchvision

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # 转换为单通道灰度图
    transforms.ToTensor(),

])
transform1 = transforms.Compose([

    transforms.ToTensor(),

])
path = "./MNIST/generated_images"
 
train_set = datasets.ImageFolder(path, transform=transform)
 
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)

test_set = torchvision.datasets.MNIST(root='./data',
                               train=False,
                               download=False,
                               transform=transform1)


test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)

print('Finish data loading')
print(f"Training data size: {len(train_set)}")
print(f"Testing data size: {len(test_set)}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

Finish data loading
Training data size: 10000
Testing data size: 10000
device: cuda


In [117]:
for images, labels in train_loader:
    print(f"Images shape: {images.shape}")  # (batch_size, channels, height, width)
    print(f"Batch size: {images.shape[0]}, Channels: {images.shape[1]}, Height: {images.shape[2]}, Width: {images.shape[3]}")
    break  # 只查看第一个 batch


Images shape: torch.Size([64, 1, 28, 28])
Batch size: 64, Channels: 1, Height: 28, Width: 28


In [118]:
for images, labels in train_loader:
    print(f"Batch size: {images.shape[0]}, Label batch size: {labels.shape[0]}")
    break


Batch size: 64, Label batch size: 64


In [119]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        return self.l5(x)#最后一层不做激活

model=Net()
model.to(device)

Net(
  (l1): Linear(in_features=784, out_features=512, bias=True)
  (l2): Linear(in_features=512, out_features=256, bias=True)
  (l3): Linear(in_features=256, out_features=128, bias=True)
  (l4): Linear(in_features=128, out_features=64, bias=True)
  (l5): Linear(in_features=64, out_features=10, bias=True)
)

In [120]:
def train(model,train_loader, epochs=10, device='cpu'):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, betas=(0.85, 0.98), weight_decay=5e-4, amsgrad=True)

    criterion = torch.nn.CrossEntropyLoss()# give y hat and y calculate the loss
    for i in range(epochs):

        loss_final = 0.0
        loop = tqdm(enumerate(train_loader), total=len(train_loader))
        loop.set_description(f'Epoch [{i}/{epochs}]')

        for step, (images, labels) in loop:

            
            images, labels = images.to(device), labels.to(device)
            # forward + backward + update
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_final += loss.item()
            loop.set_postfix(loss=loss_final)
        test(test_loader,model)


def test(test_loader,model):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:#包含batch size=64张图片和label
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)#predicted是一个tensor,torch.tensor([1, 2, 3, ..., 64])  # 共 64 个元素
            # 返回最大值和下标
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test set: %.3f %%' % (100 * correct / total))


In [121]:
train(model,train_loader,epochs=20,device=device)
test(test_loader,model)

  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 78.670 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 84.160 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 85.610 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 85.140 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 87.260 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 87.530 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 87.100 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 87.140 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 88.590 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 88.460 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 87.870 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 87.720 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 89.120 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 89.480 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 89.590 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 89.560 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 86.950 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 89.310 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 89.900 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 90.390 %
Accuracy on test set: 90.390 %


In [122]:
test(train_loader,model)

Accuracy on test set: 98.750 %


In [127]:

save_dir = './model_MNIST'  # 目录
save_path = f'{save_dir}/model_gen.pt'  # 文件路径

# 确保目录存在，但不要创建文件
os.makedirs(save_dir, exist_ok=True)  # ✅ 只创建目录，不影响文件

# 保存模型
torch.save(model.state_dict(), save_path)  # ✅ 保存文件
print(f"Model saved at: {save_path}")


Model saved at: ./model_MNIST/model_gen.pt


In [128]:
model_test=Net()
model_test.to(device)
model_test.load_state_dict(torch.load(f'./model_MNIST/model_gen.pt'))
test(test_loader,model_test)

  model_test.load_state_dict(torch.load(f'./model_MNIST/model_gen.pt'))


Accuracy on test set: 90.390 %


In [125]:
 #train(model,train_loader,epochs=10,device=device)