In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import sys

PROJECT_ROOT = "/content/drive/MyDrive/fish_attention_HydraViT_test"
sys.path.insert(0, PROJECT_ROOT)

# pip timm 캐시 완전 제거
for k in list(sys.modules.keys()):
    if k == "timm" or k.startswith("timm."):
        del sys.modules[k]


In [3]:
import timm
print("timm module path:", getattr(timm, "__file__", None))
print("timm package paths:", list(getattr(timm, "__path__", [])))


timm module path: /content/drive/MyDrive/fish_attention_HydraViT_test/timm/__init__.py
timm package paths: ['/content/drive/MyDrive/fish_attention_HydraViT_test/timm']


In [4]:
from timm.models.hydravit import HydraViT, Block, FiSHBlock

In [5]:
import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader

from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD

def get_dataloaders(data_dir: str, batch_size: int, num_workers: int):
    transform_train = T.Compose([
        T.Resize(256),
        T.RandomResizedCrop(224),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD),
    ])

    transform_test = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD),
    ])

    train_set = torchvision.datasets.CIFAR10(
        root=data_dir,
        train=True,
        download=True,
        transform=transform_train,
    )
    test_set = torchvision.datasets.CIFAR10(
        root=data_dir,
        train=False,
        download=True,
        transform=transform_test,
    )

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

NUM_CLASSES = 10

def create_baseline_hydravit(
    num_classes: int = NUM_CLASSES,
    dim: int = 384,
    depth: int = 8,
    num_heads: int = 6,
):
    model = HydraViT(
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=num_classes,
        embed_dim=dim,
        depth=depth,
        num_heads=num_heads,
        block_fn=Block,
    )
    return model.to(device)


def create_fish_hydravit(
    num_classes: int = NUM_CLASSES,
    dim: int = 384,
    depth: int = 8,
    num_heads: int = 6,
):
    fish_model = HydraViT(
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=num_classes,
        embed_dim=dim,
        depth=depth,
        num_heads=num_heads,
        block_fn=FiSHBlock,
    )
    return fish_model.to(device)


In [7]:
train_loader, test_loader = get_dataloaders(
    data_dir="/content/data",
    batch_size=64,
    num_workers=2,
)

criterion = nn.CrossEntropyLoss()

#  Baseline HydraViT
baseline_model = create_baseline_hydravit()
optimizer_baseline = optim.AdamW(baseline_model.parameters(), lr=3e-4)

# FiSH-HydraViT
fish_model = create_fish_hydravit()
optimizer_fish = optim.AdamW(fish_model.parameters(), lr=3e-4)

def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    total = 0
    correct = 0
    for imgs, labels in tqdm(loader):
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, pred = outputs.max(1)
        correct += pred.eq(labels).sum().item()
        total += labels.size(0)

    return correct / total


def evaluate(model, loader):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, pred = outputs.max(1)
            correct += pred.eq(labels).sum().item()
            total += labels.size(0)
    return correct / total

EPOCHS = 3

print("=== Training Baseline HydraViT ===")
for epoch in range(EPOCHS):
   acc = train_one_epoch(baseline_model, train_loader, criterion, optimizer_baseline)
   print(f"[Baseline] Epoch {epoch+1}: Train Acc = {acc:.4f}")

baseline_test_acc = evaluate(baseline_model, test_loader)
print(f"[Baseline] Test Acc = {baseline_test_acc:.4f}")

print("\n=== Training FiSH-HydraViT ===")
for epoch in range(EPOCHS):
    acc = train_one_epoch(fish_model, train_loader, criterion, optimizer_fish)
    print(f"[FiSH] Epoch {epoch+1}: Train Acc = {acc:.4f}")

fish_test_acc = evaluate(fish_model, test_loader)
print(f"[FiSH] Test Acc = {fish_test_acc:.4f}")


=== Training Baseline HydraViT ===


100%|██████████| 782/782 [05:42<00:00,  2.28it/s]


[Baseline] Epoch 1: Train Acc = 0.2818


100%|██████████| 782/782 [05:41<00:00,  2.29it/s]


[Baseline] Epoch 2: Train Acc = 0.3709


100%|██████████| 782/782 [05:42<00:00,  2.28it/s]

[Baseline] Epoch 3: Train Acc = 0.4067





[Baseline] Test Acc = 0.4650

=== Training FiSH-HydraViT ===


100%|██████████| 782/782 [05:52<00:00,  2.22it/s]


[FiSH] Epoch 1: Train Acc = 0.2600


100%|██████████| 782/782 [05:53<00:00,  2.21it/s]


[FiSH] Epoch 2: Train Acc = 0.3345


100%|██████████| 782/782 [05:54<00:00,  2.21it/s]

[FiSH] Epoch 3: Train Acc = 0.3632





[FiSH] Test Acc = 0.4171
