In [14]:
import torchvision.models.video.swin_transformer as swin
print(dir(swin))

['Any', 'Callable', 'F', 'List', 'Optional', 'PatchEmbed3d', 'PatchMerging', 'ShiftedWindowAttention3d', 'Swin3D_B_Weights', 'Swin3D_S_Weights', 'Swin3D_T_Weights', 'SwinTransformer3d', 'SwinTransformerBlock', 'Tensor', 'Tuple', 'VideoClassification', 'Weights', 'WeightsEnum', '_COMMON_META', '_KINETICS400_CATEGORIES', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', '_compute_attention_mask_3d', '_compute_pad_size_3d', '_get_relative_position_bias', '_get_window_and_shift_size', '_log_api_usage_once', '_ovewrite_named_param', '_swin_transformer3d', 'handle_legacy_interface', 'nn', 'partial', 'register_model', 'shifted_window_attention_3d', 'swin3d_b', 'swin3d_s', 'swin3d_t', 'torch']


In [None]:
import torch
from torchvision.models.video import swin_transformer

model = swin_transformer.swin3d_s(weights="KINETICS400_V1")

model.eval()

dummy_input = torch.randn(1, 3, 16, 224, 224)  # B, C, T, H, W
with torch.no_grad():
    output = model(dummy_input)

print(output.shape)  # Should print (1, 400)


Downloading: "https://download.pytorch.org/models/swin3d_s-da41c237.pth" to /mnt/efs/fs1/cache/torch/hub/checkpoints/swin3d_s-da41c237.pth


100%|██████████| 218M/218M [00:02<00:00, 89.7MB/s] 


torch.Size([1, 400])


In [25]:
import torch
from torchvision.models.video import swin_transformer

model = swin_transformer.swin3d_s()

model.eval()

dummy_input = torch.randn(1, 3, 32, 224, 224)  # B, C, T, H, W
with torch.no_grad():
    output = model(dummy_input)

print(output.shape)  # Should print (1, 400)

torch.Size([1, 400])


In [31]:
import torch

from torchvision.models.video import swin_transformer

model = swin_transformer.swin3d_s(weights="KINETICS400_V1")
model.eval()

features = {}

def hook_fn(module, input, output):
    features['feat'] = output

# Register hook on the layer before classification head
# Usually, this would be the last norm layer or last transformer block output.
# Since layers aren't exposed, try hooking the `norm` layer before head:

hook_handle = model.norm.register_forward_hook(hook_fn)

dummy_input = torch.randn(1, 3, 16, 224, 224)
with torch.no_grad():
    _ = model(dummy_input)

hook_handle.remove()

print(features['feat'].shape)  # Check shape of extracted features
features = features['feat']
features_2d = features.mean(dim=1)  # shape: (B, C, H_patch, W_patch)
print(features_2d.shape)


torch.Size([1, 8, 7, 7, 768])
torch.Size([1, 7, 7, 768])


In [32]:
import torch.nn as nn
import torch.nn.functional as F

class Decoder2D(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.up1 = nn.ConvTranspose2d(in_channels, 256, kernel_size=4, stride=4)
        self.bn1 = nn.BatchNorm2d(256)
        self.up2 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=4)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv_out = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        x = F.relu(self.bn1(self.up1(x)))  # upsample 7 -> 28
        x = F.relu(self.bn2(self.up2(x)))  # upsample 28 -> 112
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)  # 112 -> 224
        x = self.conv_out(x)
        return x


In [35]:
import torch
from torchvision.models.video import swin_transformer
import torch.nn.functional as F
import torch.nn as nn

# class Swin3DSegmentationModel(nn.Module):
#     def __init__(self, num_classes=2):
#         super().__init__()
#         # Load pretrained swin3d small
#         self.backbone = swin_transformer.swin3d_s()
        
#         # Remove classifier head
#         self.backbone.head = nn.Identity()
        
#         # Patch size from model config (e.g., 2x4x4)
#         self.patch_size = (2, 4, 4)
#         self.embed_dim = 96  # swin3d_s embed dim
        
#         # Decoder: simple conv to upsample features to input spatial size
#         # Adjust channels and upsampling to your needs
#         self.decoder = nn.Sequential(
#             nn.Conv3d(self.embed_dim, 128, kernel_size=3, padding=1),
#             nn.BatchNorm3d(128),
#             nn.ReLU(),
#             nn.Conv3d(128, num_classes, kernel_size=1)
#         )
    
#     def forward(self, x):
#         # x shape: (B, 3, T, H, W)
        
#         # Forward through patch embedding and transformer blocks:
#         # Instead of forward_features, replicate and grab patch tokens:
#         x = self.backbone.patch_embed(x)  # (B, embed_dim, T_patch, H_patch, W_patch)
#         features = model.forward_features(dummy_input)
#         print(features.shape)

        
#         # flatten spatial-temporal patches for transformer blocks
#         x = x.flatten(2).transpose(1, 2)  # (B, N_patches, embed_dim)
        
#         # Apply transformer blocks manually (simplified)
#         for blk in self.backbone.layers[0].blocks:
#             x = blk(x)
        
#         # Reshape back to 3D patch feature map
#         B, N, C = x.shape
#         T_patch = self.backbone.patch_embed.patch_size[0]
#         H_patch = self.backbone.patch_embed.patch_size[1]
#         W_patch = self.backbone.patch_embed.patch_size[2]
        
#         # Number of patches per dim (calculated from input size, approximate)
#         # You may need to get exact number of patches from input dims and patch size
#         # Here assume known patch grid dims
#         # For example, T_patch_grid, H_patch_grid, W_patch_grid:
        
#         # For demonstration, reshape assuming you know grid dims
#         # x = x.transpose(1, 2).reshape(B, C, T_patch_grid, H_patch_grid, W_patch_grid)
        
#         # Simplify: collapse temporal dim to get 2D map
#         x_2d = x.mean(dim=1)  # naive average over patches (for demo)
        
#         # Pass through decoder conv layers (expand dims)
#         # You would ideally upsample from patch-level to original spatial dims here
        
#         # For demo, just output x_2d shape (B, C)
#         return x_2d
import torch.nn as nn
import torch.nn.functional as F

class Decoder2D(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.up1 = nn.ConvTranspose2d(in_channels, 256, kernel_size=4, stride=4)
        self.bn1 = nn.BatchNorm2d(256)
        self.up2 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=4)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv_out = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        x = F.relu(self.bn1(self.up1(x)))  # upsample 7 -> 28
        x = F.relu(self.bn2(self.up2(x)))  # upsample 28 -> 112
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)  # 112 -> 224
        x = self.conv_out(x)
        return x
class Swin3DSegmenter(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.backbone = swin_transformer.swin3d_s(weights="KINETICS400_V1")
        self.backbone.head = nn.Identity()
        self.decoder = Decoder2D(in_channels=768, num_classes=num_classes)

        self.features = None
        self.hook_handle = self.backbone.norm.register_forward_hook(self._hook_fn)

    def _hook_fn(self, module, input, output):
        self.features = output

    def forward(self, x):
        _ = self.backbone(x)  # runs backbone, sets self.features
        feat = self.features  # (B, T_patch, H_patch, W_patch, C)
        feat = feat.permute(0, 4, 1, 2, 3)  # (B, C, T_patch, H_patch, W_patch)
        feat_2d = feat.mean(dim=2)  # average temporal patches: (B, C, H_patch, W_patch)
        seg_logits = self.decoder(feat_2d)  # (B, num_classes, 224, 224)
        return seg_logits

# Usage
model = Swin3DSegmenter(num_classes=1)
dummy_input = torch.randn(1, 3, 16, 224, 224)
out = model(dummy_input)
print(out.shape)


torch.Size([1, 1, 224, 224])


In [None]:
for i in range(8):
    print(i)


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
