# 从Hugging Face上拉取ViT

pip install huggingface_hub

In [52]:
import torch
import torch.nn as nn
from transformers import ViTModel, ViTConfig

class ViTForInterpolation(nn.Module):
    def __init__(self):
        super(ViTForInterpolation, self).__init__()
        config = ViTConfig.from_pretrained('google/vit-base-patch16-224-in21k')
        config.num_channels = 1  # 修改通道数为1
        self.vit = ViTModel(config)
        
        # 根据输入图像的大小定义卷积层
        self.conv = nn.Conv2d(config.hidden_size, 1, kernel_size=1)
        self.upsample = nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False)

    def forward(self, x):
        outputs = self.vit(pixel_values=x).last_hidden_state
        batch_size, num_patches, hidden_size = outputs.shape
        
        # 去掉分类 token
        num_patches = num_patches - 1
        outputs = outputs[:, 1:, :]  # [batch_size, num_patches, hidden_size]
        
        # 计算patch数量
        patch_size = 16  # 假设patch大小是16x16
        num_patches_per_dim = int((num_patches ** 0.5))
        
        # 调整输出维度
        outputs = outputs.view(batch_size, num_patches_per_dim, num_patches_per_dim, hidden_size)
        outputs = outputs.permute(0, 3, 1, 2)  # [batch_size, hidden_size, num_patches_per_dim, num_patches_per_dim]
        
        # 使用卷积层将输出映射到目标图像尺寸
        outputs = self.conv(outputs)  # [batch_size, 1, num_patches_per_dim, num_patches_per_dim]
        
        # 上采样到原始图像尺寸
        outputs = self.upsample(outputs)  # [batch_size, 1, 224, 224]
        
        return outputs

# 初始化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViTForInterpolation().to(device)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
def masked_mse_loss(output, target, mask):
    loss = (output - target) ** 2
    loss = loss * mask  # 应用掩码
    return torch.sum(loss) / torch.sum(mask)  # 只对非掩码区域求平均


# 训练模型
model.train()
epochs = 10

for epoch in range(epochs):
    epoch_loss = 0
    for data, mask in dataloader:
        if torch.isnan(data).any() or torch.isinf(data).any():
            print("Input data contains NaN or inf values")
        if torch.isnan(mask).any() or torch.isinf(mask).any():
            print("Mask contains NaN or inf values")
        data, mask = data.to(device), mask.to(device)
        
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(data)
        
        # 只对掩码为1的区域计算损失
        loss = masked_mse_loss(outputs, data, mask)
                
        # 反向传播和优化
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(dataloader)}")


Epoch [1/10], Loss: 1666595566.845855
Epoch [2/10], Loss: 1666386644.4509685
Epoch [3/10], Loss: 1666340013.930888
Epoch [4/10], Loss: 1666314923.534682
Epoch [5/10], Loss: 1666293931.3936207
Epoch [6/10], Loss: 1666273793.1288376
Epoch [7/10], Loss: 1666252629.3849704
Epoch [8/10], Loss: 1666230784.4244547
Epoch [9/10], Loss: 1666209109.7028522
Epoch [10/10], Loss: 1666187264.015358


In [62]:
print("outputs shape:", outputs.shape)
print("data shape:", data.shape)
print("mask shape:", mask.shape)

outputs shape: torch.Size([8, 1, 224, 224])
data shape: torch.Size([8, 1, 224, 224])
mask shape: torch.Size([8, 1, 224, 224])
