In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from Lookup_ViT import LookupViT

In [3]:
model = LookupViT(
    in_channels=1,
    dim=64,
    heads=4,
    depth=6,
    num_classes=10,
    image_size=128,
    lookup_patch_size=16
)

In [4]:
print(model)

LookupViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=16, p2=16)
    (1): Linear(in_features=256, out_features=64, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): ModuleList(
    (0-5): 6 x LookupTransformerBlock(
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList()
      (cross_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (self_attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ffn): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GE

In [11]:
# 导入minist
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 配置参数
batch_size = 64  # 每个批次的样本数
shuffle = True   # 是否在每个 epoch 后打乱数据

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为 PyTorch 张量
    transforms.Normalize((0.5,), (0.5,)),# 标准化到 [-1, 1]
    transforms.Resize((128, 128))  # 调整图像大小
])

# 加载 MNIST 数据集
train_dataset = datasets.MNIST(
    root='./data',         # 数据存储路径
    train=True,            # 加载训练集
    transform=transform,   # 数据预处理
    download=True          # 如果数据不存在，下载数据
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False, 
    transform=transform,
    download=True
)

# 创建数据加载器
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=shuffle
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False
)



In [12]:
epochs = 50
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = criterion.to(device)


In [14]:

def train_epoch(model, data_loader, criterion, optimizer, device):
    
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    
    for data, target in data_loader:
        data = data.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
    return train_loss / len(data_loader), correct / total

In [15]:
def test_epoch(model, data_loader, criterion, device):
    
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in data_loader:
            data = data.to(device)
            target = target.to(device)
            
            output = model(data)
            loss = criterion(output, target)
            
            test_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
    return test_loss / len(data_loader), correct / total

In [16]:
for epoch in range(epochs):
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = test_epoch(model, test_loader, criterion, device)
    
    print(f'Epoch [{epoch}/{epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

Epoch [0/50], Train Loss: 0.8425, Train Acc: 0.7149, Test Loss: 0.2465, Test Acc: 0.9280
Epoch [1/50], Train Loss: 0.2087, Train Acc: 0.9375, Test Loss: 0.1664, Test Acc: 0.9481
Epoch [2/50], Train Loss: 0.1494, Train Acc: 0.9538, Test Loss: 0.1158, Test Acc: 0.9652
Epoch [3/50], Train Loss: 0.1257, Train Acc: 0.9608, Test Loss: 0.1173, Test Acc: 0.9634
Epoch [4/50], Train Loss: 0.1052, Train Acc: 0.9668, Test Loss: 0.0992, Test Acc: 0.9704
Epoch [5/50], Train Loss: 0.0941, Train Acc: 0.9703, Test Loss: 0.0926, Test Acc: 0.9689
Epoch [6/50], Train Loss: 0.0834, Train Acc: 0.9737, Test Loss: 0.0937, Test Acc: 0.9694
Epoch [7/50], Train Loss: 0.0731, Train Acc: 0.9756, Test Loss: 0.0913, Test Acc: 0.9706
Epoch [8/50], Train Loss: 0.0681, Train Acc: 0.9779, Test Loss: 0.0742, Test Acc: 0.9766


KeyboardInterrupt: 