In [2]:
from End2End_Model import End2End_Model

In [10]:
from typing import Literal, Union
import torch
from torchinfo import summary

In [4]:
torch.cuda.empty_cache()

In [5]:
## HYPERPARAMETERS:
MODEL_NR:int = 34
WANDB_PROJECT:str = "3DViT_E2E"
MODEL_TYPE:Literal["dino_vits8", "dino_vitb8", "dino_vits16", "dino_vitb16", "vit_b_16", "vit_l_16", "3Dvit_8", "3Dvit_16"]="3Dvit_16"
FULL_VOLUME:bool = True
BOOTSTRAP_METHOD:Literal["centering", "inflation", None] = "centering"
EPOCHS:int = 60
BATCH_SIZE:int = 16
MAX_LR:float = 1e-5
DIV_FACTOR:int = 100 # Base LR is computed as MAX_LR/DIV_FACTOR.
N_CYCLES:int = 6
TRAINABLE_LAYERS:Union[int, Literal["all"]] = "all"
BCKB_DROPOUT:float = 0.12
LOCAL:bool = True
SAVE_TOP_CKPTS:int = 0

In [7]:
model = End2End_Model(
        model_type=MODEL_TYPE,
        trainable_layers=TRAINABLE_LAYERS,
        backbone_dropout=BCKB_DROPOUT,
        max_lr=MAX_LR,
        div_factor=DIV_FACTOR,
        steps_per_epoch=50,
        epochs=EPOCHS,
        n_cycles=N_CYCLES,
        bootstrap_method=BOOTSTRAP_METHOD
    )

<All keys matched successfully>


In [11]:
batch_size = 16
summary(model, input_size=(batch_size, 1, 32, 224, 224),
        col_names=["input_size",
                "output_size",
                "num_params",
                "trainable"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
End2End_Model                            [16, 1, 32, 224, 224]     [16, 1]                   --                        True
├─VisionTransformer3D: 1-1               [16, 1, 32, 224, 224]     [16, 768]                 152,064                   True
│    └─PatchEmbed: 2-1                   [16, 1, 32, 224, 224]     [16, 196, 768]            --                        True
│    │    └─Conv3d: 3-1                  [16, 1, 32, 224, 224]     [16, 768, 1, 14, 14]      6,292,224                 True
│    └─Dropout: 2-2                      [16, 197, 768]            [16, 197, 768]            --                        --
│    └─ModuleList: 2-3                   --                        --                        --                        True
│    │    └─Block: 3-2                   [16, 197, 768]            [16, 197, 768]            7,087,872                 True
│    

In [1]:
from LIDC_Dataset import LIDC_Dataset

In [2]:
datapath="/home/jbinda/INFORM/LIDC_ViTs/dataset/"

In [3]:
train_ds = LIDC_Dataset(
            datapath=datapath,
            fold=1,
            transform=None,
            label_transform=None,
            mode="train",
            task="Classification",
            full_volume=True
        )

In [5]:
train_ds[0][0].shape

torch.Size([32, 32, 32])

# ViT

In [24]:
import torchvision
from torchinfo import summary

In [34]:
model = torchvision.models.vit_h_14(weights='IMAGENET1K_SWAG_E2E_V1')

Downloading: "https://download.pytorch.org/models/vit_h_14_swag-80465313.pth" to /home/jbinda/.cache/torch/hub/checkpoints/vit_h_14_swag-80465313.pth
100%|██████████████████████████████████████████████████████████████████████████████| 2.36G/2.36G [23:39<00:00, 1.79MB/s]


In [37]:
batch_size = 4
summary(model, input_size=(batch_size, 3, 518, 518),
        col_names=["input_size",
                "output_size",
                "num_params",
                "trainable"])

Layer (type:depth-idx)                             Input Shape               Output Shape              Param #                   Trainable
VisionTransformer                                  [4, 3, 518, 518]          [4, 1000]                 1,280                     True
├─Conv2d: 1-1                                      [4, 3, 518, 518]          [4, 1280, 37, 37]         753,920                   True
├─Encoder: 1-2                                     [4, 1370, 1280]           [4, 1370, 1280]           1,753,600                 True
│    └─Dropout: 2-1                                [4, 1370, 1280]           [4, 1370, 1280]           --                        --
│    └─Sequential: 2-2                             [4, 1370, 1280]           [4, 1370, 1280]           --                        True
│    │    └─EncoderBlock: 3-1                      [4, 1370, 1280]           [4, 1370, 1280]           19,677,440                True
│    │    └─EncoderBlock: 3-2                      [4, 1370

In [38]:
model.heads

Sequential(
  (head): Linear(in_features=1280, out_features=1000, bias=True)
)

In [31]:
model.hidden_dim

1024

In [21]:
model.heads = torch.nn.Identity(768)

In [22]:
print(model)

End2End_Model(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.15, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.15, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.15, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.15, inplace=False)
        )
      )
    )
    (norm)