In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader

# 定义用于加载两张图片的自定义数据集类
class TwoMNISTDataset(Dataset):
    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        img1, label1 = self.mnist_dataset[idx]
        img2, label2 = self.mnist_dataset[idx]
        return img1, label1, img2, label2, torch.tensor(label1 == label2, dtype=torch.float32)

# 加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

mnist_train = MNIST('./data', train=True, download=True, transform=transform)
mnist_test = MNIST('./data', train=False, download=True, transform=transform)

train_dataset = TwoMNISTDataset(mnist_train)
test_dataset = TwoMNISTDataset(mnist_test)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:04<00:00, 2052819.35it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 570735.21it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 2478802.57it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 11737848.90it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [None]:
# 定义卷积神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 256)
        self.fc2 = nn.Linear(256, 1)

    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x1 = nn.functional.relu(x1)
        x1 = nn.functional.max_pool2d(x1, 2)
        x1 = self.conv2(x1)
        x1 = nn.functional.relu(x1)
        x1 = nn.functional.max_pool2d(x1, 2)
        x1 = x1.view(-1, 1024)

        x2 = self.conv1(x2)
        x2 = nn.functional.relu(x2)
        x2 = nn.functional.max_pool2d(x2, 2)
        x2 = self.conv2(x2)
        x2 = nn.functional.relu(x2)
        x2 = nn.functional.max_pool2d(x2, 2)
        x2 = x2.view(-1, 1024)

        x = torch.cat((x1, x2), dim=1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return torch.sigmoid(x)

# 实例化模型和定义优化器
model = CNN()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# 定义训练和测试函数
def train(model, criterion, optimizer, train_loader):
    model.train()
    running_loss = 0.0
    for data in train_loader:
        img1, _, img2, _, label = data
        optimizer.zero_grad()
        output = model(img1, img2)
        loss = criterion(output, label.unsqueeze(1))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(train_loader)

def test(model, criterion, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            img1, _, img2, _, label = data
            outputs = model(img1, img2)
            predicted = torch.round(outputs)
            total += label.size(0)
            correct += (predicted == label.unsqueeze(1)).sum().item()
    return correct / total

In [None]:
# 训练和测试模型
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

num_epochs = 5
for epoch in range(num_epochs):
    train_loss = train(model, criterion, optimizer, train_loader)
    test_acc = test(model, criterion, test_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Test Accuracy: {test_acc:.4f}")
