In [1]:
import torch
import torch.nn as nn
import timm  # Hugging Face timm library supports ViTs


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Modify ViViT for 96 Frames
def modify_vivit_for_96_frames(model, num_frames=96, num_classes=8, device="cuda"):
    # Expand the positional embeddings to match 96 frames
    pretrained_pos_embed = model.pos_embed  # Original embedding
    
    # Create new embedding for 96 frames
    new_pos_embed = torch.nn.Parameter(torch.randn(1, num_frames + 1, pretrained_pos_embed.shape[-1]))  # +1 for CLS token
    
    # Replace the existing positional embeddings
    model.pos_embed = new_pos_embed

    # 🔹 Modify the classifier head for 8 classes
    model.head = nn.Linear(model.head.in_features, num_classes).to(device)

    # 🔹 Enable Fine-Tuning (Unfreeze All Layers)
    for param in model.parameters():
        param.requires_grad = True

    print(f"Model modified for {num_frames} frames and {num_classes} classes.")
    return model


In [3]:

# Load Pretrained ViViT Model
def load_vivit_model(pretrained_path, num_classes=8, num_frames=96, device="cuda"):
    print(f"Loading ViViT model fine-tuned on Kinetics-600 from: {pretrained_path}")

    # Load Pretrained ViViT Model from `timm`
    model = timm.create_model("vivit_large_patch16_224", pretrained=True, num_classes=600)  # K600 has 600 classes
    
    # Load checkpoint (Modify this based on the format of your .pth file)
    checkpoint = torch.load(pretrained_path, map_location=device)
    
    if "model_state" in checkpoint:
        model.load_state_dict(checkpoint["model_state"], strict=False)
    else:
        model.load_state_dict(checkpoint, strict=False)

    print("Pretrained ViViT model loaded successfully!")

    # 🔹 Modify the Model for 96 Frames
    model = modify_vivit_for_96_frames(model, num_frames, num_classes, device)

    return model



In [4]:


# Example Usage
pretrained_model_path = "path/to/your/ViViT_K600.pth"  # Change this to the actual model path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vivit_model = load_vivit_model(pretrained_model_path, num_classes=8, num_frames=96, device=device)
vivit_model.to(device)


Loading ViViT model fine-tuned on Kinetics-600 from: path/to/your/ViViT_K600.pth


RuntimeError: Unknown model (vivit_large_patch16_224)

In [7]:
import timm

all_models = timm.list_models(pretrained=True)
for model_name in all_models:
    print(model_name)

aimv2_1b_patch14_224.apple_pt
aimv2_1b_patch14_336.apple_pt
aimv2_1b_patch14_448.apple_pt
aimv2_3b_patch14_224.apple_pt
aimv2_3b_patch14_336.apple_pt
aimv2_3b_patch14_448.apple_pt
aimv2_huge_patch14_224.apple_pt
aimv2_huge_patch14_336.apple_pt
aimv2_huge_patch14_448.apple_pt
aimv2_large_patch14_224.apple_pt
aimv2_large_patch14_224.apple_pt_dist
aimv2_large_patch14_336.apple_pt
aimv2_large_patch14_336.apple_pt_dist
aimv2_large_patch14_448.apple_pt
bat_resnext26ts.ch_in1k
beit_base_patch16_224.in22k_ft_in22k
beit_base_patch16_224.in22k_ft_in22k_in1k
beit_base_patch16_384.in22k_ft_in22k_in1k
beit_large_patch16_224.in22k_ft_in22k
beit_large_patch16_224.in22k_ft_in22k_in1k
beit_large_patch16_384.in22k_ft_in22k_in1k
beit_large_patch16_512.in22k_ft_in22k_in1k
beitv2_base_patch16_224.in1k_ft_in1k
beitv2_base_patch16_224.in1k_ft_in22k
beitv2_base_patch16_224.in1k_ft_in22k_in1k
beitv2_large_patch16_224.in1k_ft_in1k
beitv2_large_patch16_224.in1k_ft_in22k
beitv2_large_patch16_224.in1k_ft_in22k_in1

In [13]:
if "vivit_large_patch16" in all_models:
    print(True)
else:
    print(False)

False


In [17]:
for model_name in all_models:
    if "vit" in model_name:
        print(model_name)

convit_base.fb_in1k
convit_small.fb_in1k
convit_tiny.fb_in1k
crossvit_9_240.in1k
crossvit_9_dagger_240.in1k
crossvit_15_240.in1k
crossvit_15_dagger_240.in1k
crossvit_15_dagger_408.in1k
crossvit_18_240.in1k
crossvit_18_dagger_240.in1k
crossvit_18_dagger_408.in1k
crossvit_base_240.in1k
crossvit_small_240.in1k
crossvit_tiny_240.in1k
davit_base.msft_in1k
davit_base_fl.msft_florence2
davit_huge_fl.msft_florence2
davit_small.msft_in1k
davit_tiny.msft_in1k
efficientvit_b0.r224_in1k
efficientvit_b1.r224_in1k
efficientvit_b1.r256_in1k
efficientvit_b1.r288_in1k
efficientvit_b2.r224_in1k
efficientvit_b2.r256_in1k
efficientvit_b2.r288_in1k
efficientvit_b3.r224_in1k
efficientvit_b3.r256_in1k
efficientvit_b3.r288_in1k
efficientvit_l1.r224_in1k
efficientvit_l2.r224_in1k
efficientvit_l2.r256_in1k
efficientvit_l2.r288_in1k
efficientvit_l2.r384_in1k
efficientvit_l3.r224_in1k
efficientvit_l3.r256_in1k
efficientvit_l3.r320_in1k
efficientvit_l3.r384_in1k
efficientvit_m0.r224_in1k
efficientvit_m1.r224_in1k
