In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader

# 定义用于加载两张图片的自定义数据集类
class TwoMNISTDataset(Dataset):
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        if(idx%2):
            img1, label1 = self.mnist_dataset[idx]
            img2, label2 = self.mnist_dataset[idx]
        else:
            img1, label1 = self.mnist_dataset[abs(len(self.mnist_dataset)-idx-1)]
            img2, label2 = self.mnist_dataset[idx]
        return img1, label1, img2, label2, torch.tensor((label1==label2), dtype=torch.float32)


mnist_train = MNIST('./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
mnist_test = MNIST('./data', train=False, download=True, transform=torchvision.transforms.ToTensor())

train_dataset = TwoMNISTDataset(mnist_train)
test_dataset = TwoMNISTDataset(mnist_test)

# 训练和测试模型
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [22]:
# 定义卷积神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()


    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x1 = nn.functional.relu(x1)
        x1 = nn.functional.max_pool2d(x1, 2)
        x1 = self.conv2(x1)
        x1 = nn.functional.relu(x1)
        x1 = nn.functional.max_pool2d(x1, 2)
        x1 = x1.view(-1, 1024)

        x2 = self.conv1(x2)
        x2 = nn.functional.relu(x2)
        x2 = nn.functional.max_pool2d(x2, 2)
        x2 = self.conv2(x2)
        x2 = nn.functional.relu(x2)
        x2 = nn.functional.max_pool2d(x2, 2)
        x2 = x2.view(-1, 1024)

        x = torch.cat((x1, x2), dim=1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        x = nn.functional.relu(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        return x

# 实例化模型和定义优化器
model = CNN()
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), momentum=0.9,lr=0.002)

In [23]:
# 定义训练和测试函数
def train(model, criterion, optimizer, train_loader):
    model.train()
    train_loss = 0.0
    for data in train_loader:
        img1, _, img2, _, label = data
        optimizer.zero_grad()
        output = model(img1, img2)
        loss = criterion(output,label.unsqueeze(1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    return train_loss / len(train_loader)

def test(model, criterion, test_loader):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0
    with torch.no_grad():
        for data in test_loader:
            img1, _, img2, _, label = data
            outputs = model(img1, img2)
            loss = criterion(outputs,label.unsqueeze(1))
            predicted = torch.round(outputs)
            total += label.size(0)
            correct += (predicted == label.unsqueeze(1)).sum().item()
            test_loss += loss.item()
    return test_loss/len(test_loader), correct / total

In [24]:
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, criterion, optimizer, train_loader)
    test_loss, test_acc = test(model, criterion, test_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

Epoch [1/10], Train Loss: 0.5863, Test Loss: 0.3941, Test Accuracy: 0.8446
Epoch [2/10], Train Loss: 0.2028, Test Loss: 0.1507, Test Accuracy: 0.9525
Epoch [3/10], Train Loss: 0.1346, Test Loss: 0.1245, Test Accuracy: 0.9570
Epoch [4/10], Train Loss: 0.1128, Test Loss: 0.1186, Test Accuracy: 0.9555
Epoch [5/10], Train Loss: 0.1004, Test Loss: 0.1058, Test Accuracy: 0.9644
Epoch [6/10], Train Loss: 0.0910, Test Loss: 0.1010, Test Accuracy: 0.9638
Epoch [7/10], Train Loss: 0.0835, Test Loss: 0.0938, Test Accuracy: 0.9652
Epoch [8/10], Train Loss: 0.0759, Test Loss: 0.0940, Test Accuracy: 0.9680
Epoch [9/10], Train Loss: 0.0679, Test Loss: 0.0864, Test Accuracy: 0.9666
Epoch [10/10], Train Loss: 0.0618, Test Loss: 0.0786, Test Accuracy: 0.9701
