In [1]:
import torch
import torch.nn as nn

In [2]:
class PatchMerging(nn.Module):
    
    def __init__(self, hw, emb_dim):
        super().__init__()
        self.dim = emb_dim
        self.reduction = nn.Linear(4 * emb_dim, 2 * emb_dim, bias=False)
        self.norm = nn.LayerNorm(4 * emb_dim)

    def forward(self, x):
        B, L, C = x.shape
        H, W = int(L**0.5), int(L**0.5) # Assuming square
        
        x = x.view(B, H, W, C)
        
        # Padding if size is odd
        if (H % 2 == 1) or (W % 2 == 1):
            x = nn.functional.pad(x, (0, 0, 0, W % 2, 0, H % 2))
        
        # Slice the 2x2 patches
        x0 = x[:, 0::2, 0::2, :]  # Top-Left
        x1 = x[:, 1::2, 0::2, :]  # Bottom-Left
        x2 = x[:, 0::2, 1::2, :]  # Top-Right
        x3 = x[:, 1::2, 1::2, :]  # Bottom-Right
        
        x = torch.cat([x0, x1, x2, x3], -1)  # B, H/2, W/2, 4*C
        x = x.view(B, -1, 4 * C)  # Flatten
        x = self.norm(x)
        x = self.reduction(x)
        return x

class SwinTransformerLayer(nn.Module):
    """
    A simplified Swin Layer to make the demo run.
    (In a full implementation, this contains the Window Attention & Shift logic)
    """
    def __init__(self, hw, window_size, num_heads, emb_dim, shift_size, 
                 output_dropout_p, mlp_drop_p, mlp_expansion, drop_path_p):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_dim)
        # Using standard MultiheadAttention as a placeholder for WindowAttention
        self.attn = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=num_heads, 
                                          dropout=output_dropout_p, batch_first=True)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, int(emb_dim * mlp_expansion)),
            nn.GELU(),
            nn.Dropout(mlp_drop_p),
            nn.Linear(int(emb_dim * mlp_expansion), emb_dim),
            nn.Dropout(mlp_drop_p)
        )

    def forward(self, x):
        # 1. Attention Block
        res = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x) # Standard attention for demo purposes
        x = res + x 
        
        # 2. MLP Block
        res = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = res + x
        return x

# --- 2. Your Provided Components ---

class PatchPartition(nn.Module):
    def __init__(self, in_channels, emb_dim, patch_size):
        super().__init__()
        self.patcher = nn.Conv2d(in_channels=in_channels, out_channels=emb_dim, kernel_size=patch_size, stride=patch_size)
        self.flatter = nn.Flatten(start_dim=-2, end_dim=-1)
        self.norm = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = self.patcher(x)
        x = self.flatter(x)
        x = x.permute(0, 2, 1)
        x = self.norm(x)
        return x

class SwinTransformerBlock(nn.Module):
    def __init__(self, hw, window_size, num_heads, emb_dim, shift_size, n_layers, output_dropout_p, mlp_drop_p, mlp_expansion, drop_path_p):
        super().__init__()
        self.layers = nn.ModuleList()
        # Adjusted loop: usually Swin blocks come in pairs (Shifted + Non-Shifted)
        # Your code iterates n_layers // 2, so we add 2 layers per loop to respect `n_layers`
        for i in range(n_layers // 2):
            # Layer 1 (Regular)
            self.layers.append(SwinTransformerLayer(
                hw=hw, window_size=window_size, num_heads=num_heads, emb_dim=emb_dim,
                shift_size=0, output_dropout_p=output_dropout_p, mlp_drop_p=mlp_drop_p, 
                mlp_expansion=mlp_expansion, drop_path_p=drop_path_p if isinstance(drop_path_p, float) else 0.0
            ))
            # Layer 2 (Shifted)
            self.layers.append(SwinTransformerLayer(
                hw=hw, window_size=window_size, num_heads=num_heads, emb_dim=emb_dim,
                shift_size=shift_size, output_dropout_p=output_dropout_p, mlp_drop_p=mlp_drop_p, 
                mlp_expansion=mlp_expansion, drop_path_p=drop_path_p if isinstance(drop_path_p, float) else 0.0
            ))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class SwinTransformer(nn.Module):
    def __init__(self, img_size=(224, 224), in_channels=3, emb_dims=[96, 192, 384, 768], patch_size=4, depth=4,
                 num_classes=1000, n_layers=[2, 2, 6, 2], window_size=7, num_heads=[3, 6, 12, 24],
                 output_dropout_p=0, mlp_drop_p=[0, 0, 0, 0], mlp_expansion=4, drop_path_p=0.1):
        super().__init__()
        img_h, img_w = img_size[0], img_size[1]
        drop_path_rates = torch.linspace(0, drop_path_p, sum(n_layers)//2).tolist()
        self.layers = nn.ModuleList()
        for i in range(depth):
            if i == 0:
                patch_partition = PatchPartition(in_channels=in_channels, emb_dim=emb_dims[i], patch_size=patch_size)
                # Calculate new HW after patch partition
                h, w = img_h // patch_size, img_w // patch_size
                swin_block = SwinTransformerBlock(hw=(h, w), window_size=window_size,
                                                  num_heads=num_heads[i], emb_dim=emb_dims[i],
                                                  shift_size=window_size // 2, n_layers=n_layers[i], output_dropout_p=output_dropout_p,
                                                  mlp_drop_p=mlp_drop_p[i], mlp_expansion=mlp_expansion,
                                                  drop_path_p=drop_path_rates[:n_layers[0]])
                self.layers.append(patch_partition)
                self.layers.append(swin_block)
            else:
                # Calculate HW for this stage
                factor = 2 ** (i + 1)
                h, w = img_h // factor, img_w // factor
                patch_merging = PatchMerging(hw=(h, w), emb_dim=emb_dims[i-1])
                swin_block = SwinTransformerBlock(hw=(h, w), window_size=window_size,
                                                  num_heads=num_heads[i], emb_dim=emb_dims[i],
                                                  shift_size=window_size // 2, n_layers=n_layers[i], output_dropout_p=output_dropout_p,
                                                  mlp_drop_p=mlp_drop_p[i], mlp_expansion=mlp_expansion,
                                                  drop_path_p=drop_path_rates[sum(n_layers[:i])//2:sum(n_layers[:i+1])//2])
                self.layers.append(patch_merging)
                self.layers.append(swin_block)
        self.norm = nn.LayerNorm(emb_dims[-1])
        self.avg_pool = nn.AdaptiveAvgPool1d(output_size=1)
        self.fc = nn.Linear(in_features=emb_dims[-1], out_features=num_classes) if num_classes > 0 else nn.Identity()
        self.flatten = nn.Flatten(start_dim=-2, end_dim=-1)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.weight, 1.0)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, x):
        print(f"Input: {x.shape}")
        for i, layer in enumerate(self.layers):
            x = layer(x)
            print(f" -> After {layer.__class__.__name__}: {x.shape}")
        
        x = self.norm(x).transpose(1, 2)
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        print(f"Output: {x.shape}")
        return x

# --- 3. The Demo Execution ---

if __name__ == "__main__":
    # Create the model
    model = SwinTransformer(num_classes=10) # 10 classes for example

    # Create dummy image data (Batch Size=1, Channels=3, H=224, W=224)
    dummy_input = torch.randn(1, 3, 224, 224)

    print("-" * 30)
    print("Starting Forward Pass Demo")
    print("-" * 30)

    # Run the model
    output = model(dummy_input)

    print("-" * 30)
    print("Demo Complete")

------------------------------
Starting Forward Pass Demo
------------------------------
Input: torch.Size([1, 3, 224, 224])
 -> After PatchPartition: torch.Size([1, 3136, 96])
 -> After SwinTransformerBlock: torch.Size([1, 3136, 96])
 -> After PatchMerging: torch.Size([1, 784, 192])
 -> After SwinTransformerBlock: torch.Size([1, 784, 192])
 -> After PatchMerging: torch.Size([1, 196, 384])
 -> After SwinTransformerBlock: torch.Size([1, 196, 384])
 -> After PatchMerging: torch.Size([1, 49, 768])
 -> After SwinTransformerBlock: torch.Size([1, 49, 768])
Output: torch.Size([1, 10])
------------------------------
Demo Complete


In [3]:
class PatchPartition(nn.Module):

    def __init__(self, in_channels, emb_dim, patch_size):
        super().__init__()
        self.patcher = nn.Conv2d(in_channels=in_channels, out_channels=emb_dim, kernel_size=patch_size, stride=patch_size)  # Similar to nn.Unfold + nn.Linear
        self.flatter = nn.Flatten(start_dim=-2, end_dim=-1)
        self.norm = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = self.patcher(x)
        x = self.flatter(x)
        x = x.permute(0, 2, 1)
        x = self.norm(x)
        return x

In [4]:
def compute_attn_mask(H, W, window_size, shift_size):
    img_mask = torch.zeros((1, H, W, 1))
    h_slices = (slice(0, window_size), slice(window_size, shift_size), slice(shift_size, None))
    w_slices = (slice(0, window_size), slice(window_size, shift_size), slice(shift_size, None))
    cnt = 0
    for h in h_slices:
        for w in w_slices:
            img_mask[:, h, w, :] = cnt
            cnt += 1

    B, H, W, C = img_mask.shape
    img_mask = img_mask.view(B, H // window_size, window_size, W // window_size, window_size, C)
    mask_windows = img_mask.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    mask_windows = mask_windows.view(-1, window_size * window_size)
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    return attn_mask

In [10]:
class WindowedAttention(nn.Module):

    def __init__(self, emb_dim, num_heads, window_size, attn_dropout_p, output_dropout_p, shift_size):
        super().__init__()
        self.shift_size = shift_size
        self.num_heads = num_heads
        self.scale = (emb_dim // num_heads) ** -0.5
        self.qkv_proj = nn.Linear(in_features=emb_dim, out_features=emb_dim * 3)
        self.attn_dropout = nn.Dropout(p=attn_dropout_p)
        self.softmax = nn.Softmax(dim=-1)
        self.output_projection = nn.Linear(in_features=emb_dim, out_features=emb_dim)
        self.output_dropout = nn.Dropout(p=output_dropout_p)

        self.relative_position_bias_table = nn.Parameter(torch.zeros(size=(num_heads, (2 * window_size - 1), (2 * window_size - 1)))
                                                         , requires_grad=True).unsqueeze(1)
        self.bias_unfold = torch.nn.Unfold(kernel_size=(window_size, window_size), stride=1)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x, mask=None):
        B, N, C = x.shape
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.bias_unfold(self.relative_position_bias_table).flip(dims=(1,)).T.permute(2, 0, 1)
        attn += relative_position_bias.unsqueeze(0)
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
        attn = self.attn_dropout(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.output_projection(x)
        x = self.output_dropout(x)
        return x

In [7]:
class SwinTransformerBlock(nn.Module):

    def __init__(self, hw, window_size, num_heads, emb_dim, shift_size, n_layers, output_dropout_p, mlp_drop_p, mlp_expansion, drop_path_p):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(n_layers//2):
            layer = SwinTransformerLayer(hw=hw, window_size=window_size, num_heads=num_heads, emb_dim=emb_dim,
                                         shift_size=shift_size, output_dropout_p=output_dropout_p, mlp_drop_p=mlp_drop_p, mlp_expansion=mlp_expansion,
                                         drop_path_p=drop_path_p[i])
            self.layers.append(layer)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [9]:
class SwinTransformerLayer(nn.Module):

    def __init__(self, hw, window_size, num_heads, emb_dim, shift_size, mlp_expansion, output_dropout_p, mlp_drop_p, drop_path_p):
        super().__init__()
        self.shift_size = shift_size
        self.window_size = window_size
        self.H, self.W = hw[0], hw[1]
        if (self.H <= self.window_size) or (self.W <= self.window_size):
            self.shift_size = 0
        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.norm3 = nn.LayerNorm(emb_dim)
        self.norm4 = nn.LayerNorm(emb_dim)
        self.w_msa = WindowedAttention(emb_dim=emb_dim, num_heads=num_heads, window_size=window_size, output_dropout_p=output_dropout_p, attn_dropout_p=0, shift_size=shift_size)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=emb_dim, out_features=emb_dim * mlp_expansion),
            nn.GELU(),
            nn.Dropout(p=mlp_drop_p),
            nn.Linear(in_features=emb_dim * mlp_expansion, out_features=emb_dim),
            nn.Dropout(p=mlp_drop_p),
        )
        self.drop_path = DropPath(drop_prob=drop_path_p) if drop_path_p > 0.0 else nn.Identity()
        if self.shift_size > 0:
            self.attn_mask = compute_attn_mask(self.H, self.W, window_size, self.shift_size)

    def forward(self, x):
        B, N, C = x.shape
        residual = x
        x = self.norm1(x)

        x = x.reshape(B, self.H, self.W, C)
        x_windows = x.view(B, self.H // self.window_size, self.window_size, self.W // self.window_size, self.window_size, C)
        x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size * self.window_size, C)

        x = self.w_msa(x_windows, mask=None)  # First self-attention is regular window based self-attention

        x = x.view(B, self.H // self.window_size, self.W // self.window_size, self.window_size, self.window_size, C).permute(0, 1, 3, 2, 4, 5)
        x = x.reshape(B, N, C)

        x = residual + self.drop_path(x)
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = residual + self.drop_path(x)

        residual = x
        x = self.norm3(x)

        x = x.reshape(B, self.H, self.W, C)
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))  # cyclic shift
            x_windows = shifted_x.view(B, self.H // self.window_size, self.window_size, self.W // self.window_size, self.window_size, C)
            x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size * self.window_size, C)
            x = self.w_msa(x_windows, mask=self.attn_mask)  # Shifted window based self-attention
            x = x.view(B, self.H // self.window_size, self.W // self.window_size, self.window_size, self.window_size, C).permute(0, 1, 3, 2, 4, 5)
            shifted_x = x.reshape(B, self.H, self.W, C)
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)).reshape(B, N, C)  # reverse cyclic shift
        else:
            x_windows = x.view(B, self.H // self.window_size, self.window_size, self.W // self.window_size, self.window_size, C)
            x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size * self.window_size, C)
            x = self.w_msa(x_windows, mask=None)
            x = x.view(B, self.H // self.window_size, self.W // self.window_size, self.window_size, self.window_size, C).permute(0, 1, 3, 2, 4, 5)
            x = x.reshape(B, N, C)
        x = residual + self.drop_path(x)
        residual = x
        x = self.norm4(x)
        x = self.mlp(x)
        x = residual + self.drop_path(x)
        return x

In [11]:
from PIL import Image
import gradio as gr
import torch
from transformers import AutoImageProcessor, SwinForImageClassification

model_name = "microsoft/swin-tiny-patch4-window7-224"
processor = AutoImageProcessor.from_pretrained(model_name)
model = SwinForImageClassification.from_pretrained(model_name)

def classify_image(image):
    if image is None:
        return None
    
    inputs = processor(images=image, return_tensors="pt")
    

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    predicted_label = logits.argmax(-1).item()
    confidence = torch.nn.functional.softmax(logits, dim=-1)
    

    label_name = model.config.id2label[predicted_label]
    score = confidence[0][predicted_label].item()
    
    return {label_name: score}

demo = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=gr.Label(num_top_classes=1, label="Swin Prediction"),
    title="Swin Transformer to Classify Images - 1BM23AI109 & 1BM23AI100",
    description="Upload an image to see how the Swin Transformer architecture sees the world."
)

if __name__ == "__main__":
    demo.launch()

ModuleNotFoundError: No module named 'gradio'

In [None]:
class PatchMerging(nn.Module):

    def __init__(self, hw, emb_dim):
        super().__init__()
        self.H, self.W = hw[0], hw[1]
        self.unfold = nn.Unfold(kernel_size=(2, 2), stride=2)
        self.norm = nn.LayerNorm(4 * emb_dim)
        self.proj = nn.Linear(in_features=4 * emb_dim, out_features=2 * emb_dim)

    def forward(self, x):
        B, N, C = x.shape
        x = x.view(B, self.H, self.W, C).permute(0, 3, 1, 2)
        x = self.unfold(x).permute(0, 2, 1)
        x = self.norm(x)  # Different from PatchPartition, LayerNorm is before projection as in official implementation
        x = self.proj(x)
        return x