In [1]:
import torchvision
import os

torchvision_dir = os.path.dirname(torchvision.__file__)
print("torchvision is installed in:", torchvision_dir)

torchvision is installed in: /home/hoon2/anaconda3/envs/pda/lib/python3.9/site-packages/torchvision


In [19]:
import torch
from torchvision.models.video import swin3d_s, swin3d_b

model = swin3d_s(weights=None)
model2 = swin3d_b(weights=None)

dummy_input = torch.randn(1, 30, 3, 256, 256)
dummy_input = dummy_input.permute(0, 2, 1, 3, 4)

output_s = model(dummy_input)
output_b = model2(dummy_input)

print("swint_s output shape:", output_s.shape)
print("swint_b output shape:", output_b.shape)

swint_s output shape: torch.Size([1, 400])
swint_b output shape: torch.Size([1, 400])


In [20]:
model.head

Linear(in_features=768, out_features=400, bias=True)

In [21]:
import torch.nn as nn

in_features = model.head.in_features  
model.head = nn.Linear(in_features, 4)

in_features = model2.head.in_features 
model2.head = nn.Linear(in_features, 4)

In [22]:
model.head

Linear(in_features=768, out_features=4, bias=True)

In [23]:
from torchinfo import summary
from fvcore.nn import FlopCountAnalysis, parameter_count_table

print(parameter_count_table(model))

| name                       | #elements or shape   |
|:---------------------------|:---------------------|
| model                      | 49.5M                |
|  patch_embed               |  9.5K                |
|   patch_embed.proj         |   9.3K               |
|    patch_embed.proj.weight |    (96, 3, 2, 4, 4)  |
|    patch_embed.proj.bias   |    (96,)             |
|   patch_embed.norm         |   0.2K               |
|    patch_embed.norm.weight |    (96,)             |
|    patch_embed.norm.bias   |    (96,)             |
|  features                  |  49.5M               |
|   features.0               |   0.2M               |
|    features.0.0            |    0.1M              |
|    features.0.1            |    0.1M              |
|   features.1               |   74.5K              |
|    features.1.reduction    |    73.7K             |
|    features.1.norm         |    0.8K              |
|   features.2               |   0.9M               |
|    features.2.0           

In [24]:
print(parameter_count_table(model2))

| name                       | #elements or shape   |
|:---------------------------|:---------------------|
| model                      | 87.6M                |
|  patch_embed               |  12.7K               |
|   patch_embed.proj         |   12.4K              |
|    patch_embed.proj.weight |    (128, 3, 2, 4, 4) |
|    patch_embed.proj.bias   |    (128,)            |
|   patch_embed.norm         |   0.3K               |
|    patch_embed.norm.weight |    (128,)            |
|    patch_embed.norm.bias   |    (128,)            |
|  features                  |  87.6M               |
|   features.0               |   0.4M               |
|    features.0.0            |    0.2M              |
|    features.0.1            |    0.2M              |
|   features.1               |   0.1M               |
|    features.1.reduction    |    0.1M              |
|    features.1.norm         |    1.0K              |
|   features.2               |   1.6M               |
|    features.2.0           

In [2]:
import torch
from torchvision.models.video import mvit_v2_s, mvit_v1_b

model = mvit_v2_s(weights=None)
model2 = mvit_v1_b(weights=None)

dummy_input = torch.randn(1, 30, 3, 256, 256)
dummy_input = dummy_input.permute(0, 2, 1, 3, 4)

output_v2_s = model(dummy_input)
output_v1_b = model2(dummy_input)

print("mvit_v2_s output shape:", output_v2_s.shape)
print("mvit_v2_b output shape:", output_v1_b.shape)

mvit_v2_s output shape: torch.Size([1, 400])
mvit_v2_b output shape: torch.Size([1, 400])


In [4]:
import torch.nn as nn

in_features = model.head[1].in_features  # 기존 in_features 유지
model.head[1] = nn.Linear(in_features, 4)

in_features = model2.head[1].in_features  # 기존 in_features 유지
model2.head[1] = nn.Linear(in_features, 4)

In [5]:
model2

MViT(
  (conv_proj): Conv3d(3, 96, kernel_size=(3, 7, 7), stride=(2, 4, 4), padding=(1, 3, 3))
  (pos_encoding): PositionalEncoding()
  (blocks): ModuleList(
    (0): MultiscaleBlock(
      (norm1): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (norm2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
      (attn): MultiscaleAttention(
        (qkv): Linear(in_features=96, out_features=288, bias=True)
        (project): Sequential(
          (0): Linear(in_features=96, out_features=96, bias=True)
        )
        (pool_k): Pool(
          (pool): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 8, 8), padding=(1, 1, 1), groups=96, bias=False)
          (norm_act): Sequential(
            (0): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          )
        )
        (pool_v): Pool(
          (pool): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 8, 8), padding=(1, 1, 1), groups=96, bias=False)
          (norm_act): Sequential(
            (0): LayerNorm((96,

In [8]:
from torchinfo import summary
from fvcore.nn import FlopCountAnalysis, parameter_count_table

print(parameter_count_table(model))

| name                       | #elements or shape   |
|:---------------------------|:---------------------|
| model                      | 34.3M                |
|  conv_proj                 |  42.4K               |
|   conv_proj.weight         |   (96, 3, 3, 7, 7)   |
|   conv_proj.bias           |   (96,)              |
|  pos_encoding              |  96                  |
|   pos_encoding.class_token |   (96,)              |
|  blocks                    |  34.2M               |
|   blocks.0                 |   0.1M               |
|    blocks.0.norm1          |    0.2K              |
|    blocks.0.norm2          |    0.2K              |
|    blocks.0.attn           |    72.8K             |
|    blocks.0.mlp            |    74.2K             |
|   blocks.1                 |   0.4M               |
|    blocks.1.norm1          |    0.2K              |
|    blocks.1.norm2          |    0.4K              |
|    blocks.1.attn           |    0.1M              |
|    blocks.1.mlp           

In [9]:
print(parameter_count_table(model2))

| name                        | #elements or shape   |
|:----------------------------|:---------------------|
| model                       | 36.4M                |
|  conv_proj                  |  42.4K               |
|   conv_proj.weight          |   (96, 3, 3, 7, 7)   |
|   conv_proj.bias            |   (96,)              |
|  pos_encoding               |  0.4M                |
|   pos_encoding.class_token  |   (96,)              |
|   pos_encoding.spatial_pos  |   (4096, 96)         |
|   pos_encoding.temporal_pos |   (15, 96)           |
|   pos_encoding.class_pos    |   (96,)              |
|  blocks                     |  36.0M               |
|   blocks.0                  |   0.2M               |
|    blocks.0.norm1           |    0.2K              |
|    blocks.0.norm2           |    0.2K              |
|    blocks.0.attn            |    42.8K             |
|    blocks.0.mlp             |    0.1M              |
|    blocks.0.project         |    18.6K             |
|   blocks

In [4]:
from vivit import *

img = torch.ones([1, 16, 3, 256, 256])  # 수정: 256x256 크기의 입력 이미지 생성

model = ViViT(256, 16, 100, 16)  # 수정: image_size를 256으로 설정
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Parameters: %.3fM' % parameters)

out = model(img)

print("Shape of out :", out.shape)  # [B, num_classes]

Trainable Parameters: 4.512M
Shape of out : torch.Size([1, 100])


In [5]:
parameters

np.float64(4.512292)