In [1]:
from train import valid, set_seed
from apex import amp
import torch
import argparse
from addict import Dict
import logging
from utils.construct_tff import construct_real_tff
import matplotlib.pyplot as plt
import logging
from utils.data_utils import get_loader
from models.modeling import VisionTransformer, CONFIGS
import numpy as np
from torch.utils.data import Subset
!export CUDA_VISIBLE_DEVICES=4

In [26]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [3]:
def count_parameters(model):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return params/1000000

In [20]:
args = Dict()
args.model_type = 'ViT-B_16'
args.dataset = 'inet1k_cats'
args.img_size = 224
args.pretrained_dir = 'checkpoint/ViT-B_16-224.npz'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device
args.local_rank = -1
args.train_batch_size = 16
args.eval_batch_size = 16

# args.dataset = 'inet1k_cats'
# args.dataset_dir = 'data/inet1k_classes/cats'
# ckpt_path = 'output/inet1k_cats-2023-10-02-22-19-15/inet1k_cats_final_ckpt.bin' 
args.dataset = 'inet1k_birds'
args.dataset_dir = 'data/inet1k_classes/birds'
ckpt_path = 'output/inet1k_birds-2023-10-02-22-25-22/inet1k_birds_final_ckpt.bin'
# args.dataset = 'inet1k_dogs'
# args.dataset_dir = 'data/inet1k_classes/dogs'
# ckpt_path = 'output/inet1k_dogs-2023-09-24-21-00-17/inet1k_dogs_final_ckpt.bin'
# args.dataset = 'inet1k_snakes'
# args.dataset_dir = 'data/inet1k_classes/snakes'
# ckpt_path = 'output/inet1k_snakes-2023-10-02-22-28-06/inet1k_snakes_final_ckpt.bin'
# args.dataset = 'inet1k_trucks'
# args.dataset_dir = 'data/inet1k_classes/trucks'
# ckpt_path = 'output/inet1k_trucks-2023-09-24-20-47-28/inet1k_trucks_final_ckpt.bin'

In [27]:
# Prepare model
config = CONFIGS[args.model_type]

if args.dataset == "cifar10":
    num_classes = 10
elif args.dataset == "cifar100":
    num_classes = 100
elif 'inet' in args.dataset:
    num_classes = 1000

model = VisionTransformer(config, args.img_size, zero_head=False, num_classes=num_classes)
model.load_from(np.load(args.pretrained_dir))
model.to(args.device)
num_params = count_parameters(model)

logger.info("{}".format(config))
logger.info("Training parameters %s", args)
logger.info("Total Parameter: \t%2.1fM" % num_params)

2023-10-17 02:23:29,844 - __main__ - INFO - classifier: token
hidden_size: 768
patches:
  size: !!python/tuple
  - 16
  - 16
representation_size: null
transformer:
  attention_dropout_rate: 0.0
  dropout_rate: 0.1
  mlp_dim: 3072
  num_heads: 12
  num_layers: 12

2023-10-17 02:23:29,845 - __main__ - INFO - Training parameters {'model_type': 'ViT-B_16', 'dataset': 'inet1k_birds', 'img_size': 224, 'pretrained_dir': 'checkpoint/ViT-B_16-224.npz', 'device': device(type='cuda'), 'local_rank': -1, 'train_batch_size': 16, 'eval_batch_size': 16, 'dataset_dir': 'data/inet1k_classes/birds'}
2023-10-17 02:23:29,845 - __main__ - INFO - Total Parameter: 	86.6M


86.567656


In [22]:
train_loader, test_loader = get_loader(args)
classes = train_loader.dataset.dataset.classes

In [24]:
for data, label in train_loader:
    data = data.to(device)
    out_logits, _ = model(data)
    print(out_logits)
    print(classes[label])
    break

tensor([[-0.3387,  0.5969, -0.5674,  ..., -0.3600,  0.3715, -0.2212],
        [ 0.6837,  0.0188,  0.0290,  ..., -0.1652,  0.0880, -0.1276],
        [ 0.8962, -0.6325,  0.3972,  ...,  0.4540, -0.1371, -0.3120],
        ...,
        [ 0.0443, -0.0582, -0.0945,  ...,  0.1240,  0.0954, -0.1654],
        [ 0.5564,  0.5776,  0.4610,  ...,  0.1824,  0.0287, -1.5507],
        [ 0.5302, -0.3743, -0.5047,  ..., -0.5446, -0.4084, -0.1608]],
       device='cuda:0', grad_fn=<AddmmBackward0>)


TypeError: only integer tensors of a single element can be converted to an index

In [25]:
for i,j in zip(torch.amax(out_logits, dim=1).int(), label.detach()):
    print(i.item(), classes[j])

14 lorikeet
12 bald_eagle
15 prairie_chicken
13 vulture
13 African_grey
13 African_grey
13 lorikeet
14 hummingbird
12 vulture
7 bee_eater
15 coucal
13 lorikeet
11 bald_eagle
13 lorikeet
15 great_grey_owl
14 bald_eagle


torch.Size([16, 1000])
