In [2]:
from functools import partial
import os

import torch.nn as nn
from torchvision.models import vit_b_16
import torch

import argparse

from wb_data import WaterBirdsDataset, get_loader, get_transform_cub, log_data
from utils import evaluate, get_y_p

In [6]:
NUM_CLASSES = 2

In [13]:
model = vit_b_16(weights='DEFAULT')
model.heads.head = nn.Linear(model.heads.head.in_features, NUM_CLASSES)


In [37]:
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [35]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)

class_token False
conv_proj.weight False
conv_proj.bias False
encoder.pos_embedding False
encoder.layers.encoder_layer_0.ln_1.weight False
encoder.layers.encoder_layer_0.ln_1.bias False
encoder.layers.encoder_layer_0.self_attention.in_proj_weight False
encoder.layers.encoder_layer_0.self_attention.in_proj_bias False
encoder.layers.encoder_layer_0.self_attention.out_proj.weight False
encoder.layers.encoder_layer_0.self_attention.out_proj.bias False
encoder.layers.encoder_layer_0.ln_2.weight False
encoder.layers.encoder_layer_0.ln_2.bias False
encoder.layers.encoder_layer_0.mlp.0.weight True
encoder.layers.encoder_layer_0.mlp.0.bias True
encoder.layers.encoder_layer_0.mlp.3.weight True
encoder.layers.encoder_layer_0.mlp.3.bias True
encoder.layers.encoder_layer_1.ln_1.weight False
encoder.layers.encoder_layer_1.ln_1.bias False
encoder.layers.encoder_layer_1.self_attention.in_proj_weight False
encoder.layers.encoder_layer_1.self_attention.in_proj_bias False
encoder.layers.encoder_layer_1.s

In [65]:
for param in model.parameters():
    param.requires_grad = False

for param in model.heads.head.parameters():
    param.requires_grad = True

for name, module in model.named_modules():
    if "mlp" in name and isinstance(module, torch.nn.Linear):
        for param in module.parameters():
            param.requires_grad = True

 
for name, module in model.named_modules():
    if "encoder.ln" in name and isinstance(module, torch.nn.LayerNorm):
        for param in module.parameters():
            param.requires_grad = True
            
print("Model parameters and their trainable status:")
for name, module in model.named_modules():
    is_trainable = any(param.requires_grad for param in module.parameters())
    print(f"{name}: {'Trainable' if is_trainable else 'Frozen'}")
    

Model parameters and their trainable status:
: Trainable
conv_proj: Frozen
encoder: Trainable
encoder.dropout: Frozen
encoder.layers: Trainable
encoder.layers.encoder_layer_0: Trainable
encoder.layers.encoder_layer_0.ln_1: Frozen
encoder.layers.encoder_layer_0.self_attention: Frozen
encoder.layers.encoder_layer_0.self_attention.out_proj: Frozen
encoder.layers.encoder_layer_0.dropout: Frozen
encoder.layers.encoder_layer_0.ln_2: Frozen
encoder.layers.encoder_layer_0.mlp: Trainable
encoder.layers.encoder_layer_0.mlp.0: Trainable
encoder.layers.encoder_layer_0.mlp.1: Frozen
encoder.layers.encoder_layer_0.mlp.2: Frozen
encoder.layers.encoder_layer_0.mlp.3: Trainable
encoder.layers.encoder_layer_0.mlp.4: Frozen
encoder.layers.encoder_layer_1: Trainable
encoder.layers.encoder_layer_1.ln_1: Frozen
encoder.layers.encoder_layer_1.self_attention: Frozen
encoder.layers.encoder_layer_1.self_attention.out_proj: Frozen
encoder.layers.encoder_layer_1.dropout: Frozen
encoder.layers.encoder_layer_1.ln_2