In [34]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
import os
import argparse
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
from torch.optim import Adam,AdamW
from tqdm.auto import tqdm


# 指定需要的类别


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


# 加载 CIFAR-100 数据集
train_set = torchvision.datasets.MNIST(root="./data", train=True, download=False,
                                    transform=transform)
# 筛选出所需类别的索引

    


# 创建子集数据集
cnt=[0]*10
selected_indices_train=[]
for i in range(len(train_set)):
    if cnt[train_set.targets[i].item()]<1000:
        selected_indices_train.append(i)
        cnt[train_set.targets[i].item()]+=1


filtered_train_set = Subset(train_set, selected_indices_train)

train_loader = DataLoader(filtered_train_set, batch_size=64, shuffle=True, num_workers=4)




# 加载  测试集
test_set = torchvision.datasets.MNIST(root='./data',
                               train=False,
                               download=False,
                               transform=transform)


test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)


print('Finish data loading')
print(f"Training data size: {len(filtered_train_set)}")
print(f"Testing data size: {len(test_set)}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

Finish data loading
Training data size: 10000
Testing data size: 10000
device: cuda


In [35]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = F.relu(self.l4(x))
        return self.l5(x)#最后一层不做激活

model=Net()
model.to(device)



Net(
  (l1): Linear(in_features=784, out_features=512, bias=True)
  (l2): Linear(in_features=512, out_features=256, bias=True)
  (l3): Linear(in_features=256, out_features=128, bias=True)
  (l4): Linear(in_features=128, out_features=64, bias=True)
  (l5): Linear(in_features=64, out_features=10, bias=True)
)

In [36]:
def train(mdoel, train_loader, epochs=10, device='cpu'):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, betas=(0.85, 0.98), weight_decay=5e-4, amsgrad=True)
    criterion = torch.nn.CrossEntropyLoss()# give y hat and y calculate the loss
    for i in range(epochs):

        loss_final = 0.0
        loop = tqdm(enumerate(train_loader), total=len(train_loader))
        loop.set_description(f'Epoch [{i}/{epochs}]')

        for step, (images, labels) in loop:

            
            images, labels = images.to(device), labels.to(device)
            # forward + backward + update
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_final += loss.item()
            loop.set_postfix(loss=loss_final)
        test(test_loader,model)

def test(test_loader,model):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:#包含batch size=64张图片和label
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)#predicted是一个tensor,torch.tensor([1, 2, 3, ..., 64])  # 共 64 个元素
            # 返回最大值和下标
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test set: %.3f %%' % (100 * correct / total))


In [37]:
train(model, train_loader,epochs=25,device=device)
test(test_loader,model)

  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 88.410 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 91.090 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 92.330 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 93.140 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 93.540 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 94.010 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 94.350 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 94.620 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 94.680 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 94.820 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 94.930 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.270 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.040 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 94.760 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.240 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.200 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.500 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.530 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.470 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.410 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.350 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.420 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.430 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.580 %


  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy on test set: 95.430 %
Accuracy on test set: 95.430 %


In [38]:
test(train_loader,model)

Accuracy on test set: 99.990 %


In [39]:
save_dir = './model_MNIST'  # 目录
save_path = f'{save_dir}/model_org.pt'  # 文件路径

# 确保目录存在，但不要创建文件
os.makedirs(save_dir, exist_ok=True)  # ✅ 只创建目录，不影响文件

# 保存模型
torch.save(model.state_dict(), save_path)  # ✅ 保存文件
print(f"Model saved at: {save_path}")

Model saved at: ./model_MNIST/model_org.pt
