In [18]:
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import torch.nn.functional as F
import os
import argparse
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
from torch.optim import Adam,AdamW
from tqdm.auto import tqdm
import torchvision

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])



bt_size=128
path = "./CIFAR10/generated_images"
 
train_set = datasets.ImageFolder(path, transform=transform_train)
 
train_loader = DataLoader(train_set, batch_size=bt_size, shuffle=True, num_workers=2)


selected_classes = [0, 1, 2, 3, 4]
label_mapping = {orig_label: idx for idx, orig_label in enumerate(selected_classes)}

test_set = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=False, transform=transform_test
)


selected_indices_test = [
    idx for idx, (_, label) in enumerate(test_set)
    if label in selected_classes
]


for i in selected_indices_test:
    test_set.targets[i]=label_mapping[test_set.targets[i]]

# 创建子集数据集
filtered_test_set = Subset(test_set, selected_indices_test)
test_loader = DataLoader(filtered_test_set, batch_size=bt_size, shuffle=False, num_workers=2)

print('Finish data loading')
print(f"Training data size: {len(train_set)}")
print(f"Testing data size: {len(filtered_test_set)}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

Finish data loading
Training data size: 5000
Testing data size: 5000
device: cuda


In [19]:

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_mobilenetv2_x0_5", pretrained=False)
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()# give y hat and y calculate the loss
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

Using cache found in /home/chunjie/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [20]:
def test(test_loader,model,device='cpu'):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:#包含batch size=64张图片和label
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)#predicted是一个tensor,torch.tensor([1, 2, 3, ..., 64])  # 共 64 个元素
            # 返回最大值和下标
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test set: %.3f %%' % (100 * correct / total))

def train(epoch):
    running_loss = 0.0 

    
    for batch_idx, data in enumerate(train_loader, 0):

        inputs, target = data
        inputs, target = inputs.to(device) , target.to(device)
        optimizer.zero_grad()
        #print(batch_idx)

        # forward + backward + update
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print('epoch: %d loss:%.3f ' % (epoch,running_loss), end=' ')

In [21]:

if __name__ == "__main__":
    epoch=100

    for i in range(epoch):
        train(i)
        scheduler.step()
        test(test_loader,model,device)

epoch: 0 loss:159.172  Accuracy on test set: 33.440 %
epoch: 1 loss:59.465  Accuracy on test set: 41.720 %
epoch: 2 loss:48.843  Accuracy on test set: 43.100 %
epoch: 3 loss:44.655  Accuracy on test set: 47.800 %
epoch: 4 loss:43.382  Accuracy on test set: 48.900 %
epoch: 5 loss:40.535  Accuracy on test set: 52.220 %
epoch: 6 loss:38.553  Accuracy on test set: 53.920 %
epoch: 7 loss:38.380  Accuracy on test set: 52.560 %
epoch: 8 loss:37.534  Accuracy on test set: 54.700 %
epoch: 9 loss:37.588  Accuracy on test set: 56.040 %
epoch: 10 loss:36.176  Accuracy on test set: 56.880 %
epoch: 11 loss:36.266  Accuracy on test set: 57.260 %
epoch: 12 loss:33.876  Accuracy on test set: 58.380 %
epoch: 13 loss:32.780  Accuracy on test set: 60.140 %
epoch: 14 loss:34.098  Accuracy on test set: 59.240 %
epoch: 15 loss:32.820  Accuracy on test set: 59.420 %
epoch: 16 loss:31.443  Accuracy on test set: 59.940 %
epoch: 17 loss:28.738  Accuracy on test set: 61.960 %
epoch: 18 loss:28.734  Accuracy on te

In [22]:
model.eval()
test(test_loader,model,device)
test(train_loader,model,device)

Accuracy on test set: 81.420 %
Accuracy on test set: 97.320 %


In [23]:
import os
save_dir = './model_CIFAR10'  # 目录
save_path = f'{save_dir}/model_gen.pt'  # 文件路径

# 确保目录存在，但不要创建文件
os.makedirs(save_dir, exist_ok=True)  # ✅ 只创建目录，不影响文件

# 保存模型
torch.save(model.state_dict(), save_path)  # ✅ 保存文件
print(f"Model saved at: {save_path}")

Model saved at: ./model_CIFAR10/model_gen.pt


In [25]:
model_test=torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_mobilenetv2_x0_5", pretrained=False)
model_test.to(device)
model_test.load_state_dict(torch.load(f'./model_CIFAR10/model_gen.pt'))
model_test.eval()
test(test_loader,model_test,device)
test(train_loader,model_test,device)

Using cache found in /home/chunjie/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master
  model_test.load_state_dict(torch.load(f'./model_CIFAR10/model_gen.pt'))


Accuracy on test set: 81.420 %
Accuracy on test set: 97.040 %
