In [1]:
import torch
from transformers import ViTModel, ViTFeatureExtractor
from torch import nn

# Load pre-trained ViT model and feature extractor
model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

# Define a segmentation head
class SegmentationHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(SegmentationHead, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        return x

# Modify ViT model to include segmentation head
class ViTSegmentationModel(nn.Module):
    def __init__(self, vit, segmentation_head):
        super(ViTSegmentationModel, self).__init__()
        self.vit = vit
        self.segmentation_head = segmentation_head

    def forward(self, x):
        x = self.vit(x).last_hidden_state
        x = x.permute(0, 2, 1).reshape(x.size(0), -1, 14, 14)  # Reshape to 2D
        x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear')  # Upsample
        x = self.segmentation_head(x)
        return x

# Instantiate the segmentation model
segmentation_head = SegmentationHead(in_channels=768, num_classes=21)  # Example with 21 classes
segmentation_model = ViTSegmentationModel(vit=model, segmentation_head=segmentation_head)

# Example input
input_tensor = torch.randn(1, 3, 224, 224)  # Batch size 1, RGB image 224x224
output = segmentation_model(input_tensor)

print(output.shape)  # Should be (1, 21, 224, 224) for 21 segmentation classes


  from .autonotebook import tqdm as notebook_tqdm
Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 502/502 [00:00<00:00, 951kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 330M/330M [00:11<00:00, 30.3MB/s]
Downloading: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 160/160 [00:00<00:00, 389kB/s]


RuntimeError: shape '[1, -1, 14, 14]' is invalid for input of size 151296