In [5]:
from mynet import MAE_ViT

import torch
import torch.nn as nn
from get_dataset import UnlabeledDataset, labeledDataset
from torchvision import transforms
from matplotlib import pyplot as plt
from torchvision.utils import make_grid

In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])])

unlabeled_dataset = UnlabeledDataset(root='D:/FERexperiments/datasets/AffectNet', phase='val', transform=transform)
train_dataset = labeledDataset(root='D:/FERexperiments/datasets/RAF-DB', phase='train', transform=transform)
test_dataset = labeledDataset(root='D:/FERexperiments/datasets/RAF-DB', phase='test', transform=transform)

pretrain_dataloader = torch.utils.data.DataLoader(unlabeled_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
finetune_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
finetune_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

# 初始化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MAE_ViT().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

print("Initializing...")

Initializing...


In [7]:
num_epochs_1 = 20

print("stage 1")
# 阶段1：无监督预训练
model.stage = 1
for epoch in range(num_epochs_1):
    model.train()
    running_loss = 0.0
    for batch_idx, imgs in enumerate(pretrain_dataloader):
        imgs = imgs.to(device)  # 确保数据在GPU
        
        # 前向传播
        recon = model(imgs)
        loss = nn.MSELoss()(recon, model.patch_embed(imgs))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs_1}]| Loss: {running_loss / (batch_idx + 1):.4f}")

stage 1
Epoch [1/20]| Loss: 0.4892
Epoch [2/20]| Loss: 0.0944
Epoch [3/20]| Loss: 0.0482
Epoch [4/20]| Loss: 0.0335
Epoch [5/20]| Loss: 0.0258
Epoch [6/20]| Loss: 0.0214
Epoch [7/20]| Loss: 0.0181
Epoch [8/20]| Loss: 0.0157
Epoch [9/20]| Loss: 0.0140
Epoch [10/20]| Loss: 0.0124
Epoch [11/20]| Loss: 0.0114
Epoch [12/20]| Loss: 0.0106
Epoch [13/20]| Loss: 0.0096
Epoch [14/20]| Loss: 0.0093
Epoch [15/20]| Loss: 0.0083
Epoch [16/20]| Loss: 0.0081
Epoch [17/20]| Loss: 0.0074
Epoch [18/20]| Loss: 0.0071
Epoch [19/20]| Loss: 0.0067
Epoch [20/20]| Loss: 0.0066


In [9]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

model.eval()
device = torch.device('cuda')
with torch.no_grad():
    sample_images = next(iter(pretrain_dataloader))
    print(sample_images.shape)
    # sample_images = sample_images.to(device)
    sample_images = sample_images[:3]
    masked_images, _, ids_restore = model.random_masking(model.patch_embed(sample_images))
    recon_images = model.decoder(masked_images)
    recon_images = torch.gather(
        recon_images, 
        dim=1,
        index=ids_restore.unsqueeze(-1).expand(-1, -1, recon_images.shape[-1])
    )
    recon_images = model.patch_embed.proj(recon_images.permute(0, 2, 1).reshape(-1, 768, 14, 14))  # 反转patch embedding
    recon_images = recon_images.permute(0, 2, 3, 1)  # (B, H, W, C)
    sample_images = sample_images.permute(0, 2, 3, 1)  # (B, H, W, C)

    # 反标准化
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 1, 3).to(device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 1, 3).to(device)
    sample_images = sample_images * std + mean
    recon_images = recon_images * std + mean

    # 展示图像
    fig, axes = plt.subplots(3, 2, figsize=(10, 15))
    for i in range(3):
        axes[i, 0].imshow(sample_images[i].cpu().numpy())
        axes[i, 0].set_title('Original Image')
        axes[i, 0].axis('off')
        axes[i, 1].imshow(recon_images[i].cpu().numpy())
        axes[i, 1].set_title('Reconstructed Image')
        axes[i, 1].axis('off')
    plt.tight_layout()
    plt.show()

torch.Size([32, 3, 224, 224])


RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

In [6]:
# 阶段2：监督微调
model.stage = 2
num_epochs_2 = 40
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)
criterion = nn.CrossEntropyLoss()

from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs_2, eta_min=1e-6, verbose=True)
for param in model.encoder.parameters():  
    param.requires_grad = False  

print("stage 2")



for epoch in range(num_epochs_2):
    model.train()
    train_loss = 0.0
    for imgs, labels in finetune_train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs)
        loss = criterion(preds, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    best_acc = 0.0
    with torch.no_grad():
        for imgs, labels in finetune_test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs)
            loss = criterion(preds, labels)
            
            test_loss += loss.item()
            _, predicted = torch.max(preds.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    # 打印统计信息
    train_loss /= len(finetune_train_loader)
    test_loss /= len(finetune_test_loader)
    test_acc = 100 * correct / total
    
    print(f"Epoch [{epoch+1}/{num_epochs_2}] | "
            f"Train Loss: {train_loss:.4f} | "
            f"Test Loss: {test_loss:.4f} | "
            f"Accuracy: {test_acc:.2f}%")



stage 2
Epoch [1/40] | Train Loss: 2.7208 | Test Loss: 3.2608 | Accuracy: 35.40%
Epoch [2/40] | Train Loss: 2.8735 | Test Loss: 2.3418 | Accuracy: 35.27%
Epoch [3/40] | Train Loss: 2.5489 | Test Loss: 2.9469 | Accuracy: 38.53%
Epoch [4/40] | Train Loss: 2.2587 | Test Loss: 2.0061 | Accuracy: 33.05%
Epoch [5/40] | Train Loss: 2.2525 | Test Loss: 2.4739 | Accuracy: 36.57%
Epoch [6/40] | Train Loss: 3.2903 | Test Loss: 4.2051 | Accuracy: 30.28%
Epoch [7/40] | Train Loss: 3.3389 | Test Loss: 2.3670 | Accuracy: 33.02%
Epoch [8/40] | Train Loss: 1.9236 | Test Loss: 2.9628 | Accuracy: 34.35%
Epoch [9/40] | Train Loss: 2.7805 | Test Loss: 2.0044 | Accuracy: 36.31%
Epoch [10/40] | Train Loss: 2.0849 | Test Loss: 3.2670 | Accuracy: 32.79%
Epoch [11/40] | Train Loss: 2.3218 | Test Loss: 3.4057 | Accuracy: 29.50%
Epoch [12/40] | Train Loss: 2.5438 | Test Loss: 2.1034 | Accuracy: 36.57%
Epoch [13/40] | Train Loss: 2.8695 | Test Loss: 4.1046 | Accuracy: 36.18%


KeyboardInterrupt: 