In [19]:
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, normalization):
        super(CNN, self).__init__()
        self.a = False
        if(normalization == nn.GroupNorm):
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
            self.n1 = normalization(4, 32)

            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
            self.n2 = normalization(8, 64)

            self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
            self.n3 = normalization(16, 128)

            self.pool = nn.MaxPool2d(2, 2)

            self.fc1 = nn.Linear(128 * 16 * 16, 128) 
            self.fc2 = nn.Linear(128, 10)
        elif(normalization == nn.LayerNorm):
            self.a = True
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)

            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)

            self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
            
            self.pool = nn.MaxPool2d(2, 2)

            self.fc1 = nn.Linear(128 * 16 * 16, 128) 
            self.n1 = normalization(128)
            self.fc2 = nn.Linear(128, 10)
        else:
            self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
            self.n1 = normalization(32)

            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
            self.n2 = normalization(64)

            self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
            self.n3 = normalization(128)

            self.pool = nn.MaxPool2d(2, 2)

            self.fc1 = nn.Linear(128 * 16 * 16, 128) 
            self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        if self.a == True:
            x = F.relu(self.conv1(x))
        
            x = F.relu(self.conv2(x))

            x = F.relu(self.conv3(x))
            
            x = self.pool(x)
            
            x = x.view(-1, 128 * 16 * 16)  

            x = F.relu(self.n1(self.fc1(x)))

            x = self.fc2(x)
        else:
            x = F.relu(self.n1(self.conv1(x)))
            
            x = F.relu(self.n2(self.conv2(x)))

            x = F.relu(self.n3(self.conv3(x)))
            
            x = self.pool(x)
            
            x = x.view(-1, 128 * 16 * 16)  

            x = F.relu(self.fc1(x))
            x = self.fc2(x)

        return x


In [16]:
# train
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 数据加载和预处理
transforms_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.226, 0.224, 0.225))
])

train_loader = DataLoader(
    datasets.CIFAR10('./data/CIFAR10', train=True, download=True, transform=transforms_train),
    batch_size=128, shuffle=True)

def train(normalization, name):

    model = CNN(normalization)
    optimizer = optim.Adamax(model.parameters(), lr=0.01)


    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    # 训练过程
    for epoch in range(1, 30):
        for batch_idx, (data, target) in enumerate(train_loader):
            
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = torch.nn.CrossEntropyLoss()(output, target)
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0:
                print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

    # 保存模型
    torch.save(model.state_dict(), f'./model/{name}_cnn.pt')


transforms_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.226, 0.224, 0.225))
])

test_loader = DataLoader(
    datasets.CIFAR10('./data/CIFAR10', train=False, download=True, transform=transforms_train),
    batch_size=1000, shuffle=True)

def test(normalization, name):
    model = CNN(normalization)
    model.load_state_dict(torch.load(f'./model/{name}_cnn.pt'))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: name:{name} Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.4f}%)\n')


Files already downloaded and verified
Files already downloaded and verified


In [20]:
train(nn.BatchNorm2d, "Batch")
test(nn.BatchNorm2d, "Batch")



In [None]:
train(nn.LayerNorm, "Layer")
test(nn.LayerNorm, "Layer")


Test set: name:Layer Average loss: 0.8150, Accuracy: 7556/10000 (75.5600%)



In [None]:

train(nn.InstanceNorm2d, "Instance")
test(nn.InstanceNorm2d, "Instance")


Test set: name:Instance Average loss: 1.7997, Accuracy: 2339/10000 (23.3900%)



In [None]:
train(nn.GroupNorm, "Group")
test(nn.GroupNorm, "Group")


Test set: name:Group Average loss: 0.9146, Accuracy: 6763/10000 (67.6300%)

