In [None]:
import os
import time
from pprint import pprint

import torch
import torchvision

from sairg_utils import launch, get_data_loaders, train, test, set_random_seeds, define_finetune_model
set_random_seeds()
from meta_transformer import MetaTransformer

device = torch.device("cuda")

In [None]:
base_net_weights = torchvision.models.ResNet18_Weights.DEFAULT
base_net = torchvision.models.resnet18(weights=base_net_weights)
module_list = [k for k, v in base_net.named_modules() if 'layer' in k and '.' not in k]
# module_list = [k for k, v in base_net.named_modules() if 'bn' in k]
module_list += ['fc']
print(len(module_list))
module_list

In [None]:
# base_net_weights = torchvision.models.ViT_B_32_Weights.DEFAULT
# base_net = torchvision.models.vit_b_32(weights=base_net_weights)
# # [k for k, v in base_net.named_modules() if 'encoder_layer' in k and len(k.split('.')) < 4]
# # [k for k, v in base_net.named_modules()]
# dict(base_net.named_modules())['heads']

In [None]:
dataset = torchvision.datasets.CIFAR10
transform = base_net_weights.transforms()
train_loader, test_loader = get_data_loaders(dataset, transform)

# im, labels = iter(train_loader).next()

In [None]:
print(getattr(base_net, 'fc'))
base_net = define_finetune_model(base_net, 10, 'fc', finetune_base=False)
print(getattr(base_net, 'fc'))

In [None]:
model_ckpt_dir = os.path.expanduser('~/Developer/experiments/meta-transformer/model_ckpts')
base_model_filename = 'resnet18_cifar10.pth'
base_model_path = os.path.join(model_ckpt_dir, base_model_filename)
if os.path.exists(base_model_path):
    base_net.load_state_dict(torch.load(base_model_path))
else:
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(base_net.parameters(), lr=1e-3)
    base_net.to(device)
    tik = time.time()
    loss = train(base_net, loss_fn, optimizer, device, train_loader)
    tok = time.time()
    print(loss)
    print(tok - tik)
    torch.save(base_net.state_dict(), base_model_path)

In [None]:
base_net.to(device)
acc = test(base_net, device, test_loader)
print(acc)

In [None]:
images, _ = iter(train_loader).next()
images = images.to(device)
num_layers = 2
meta_net = MetaTransformer(base_net, module_list, images, num_transformer_layers=num_layers)
meta_net.to(device)
meta_net(images).shape

In [None]:
do_reload = True
base_model_filename = 'meta_transformer_resnet18_cifar10.pth'
base_model_path = os.path.join(model_ckpt_dir, base_model_filename)
if do_reload and os.path.exists(base_model_path):
    meta_net.load_state_dict(torch.load(base_model_path))
else:
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(meta_net.parameters(), lr=1e-4)
    meta_net.to(device)
    tik = time.time()
    loss = train(meta_net, loss_fn, optimizer, device, train_loader)
    tok = time.time()
    print(loss)
    print(tok - tik)
    torch.save(meta_net.state_dict(), base_model_path)

In [None]:
meta_net.to(device)
acc = test(meta_net, device, test_loader)
print(acc)

In [None]:
# num_epochs = 10
# tik = time.time()
# launch(num_epochs=num_epochs, num_proc=2)
# tok = time.time()
# tok - tik

In [None]:
 # vit_net = SimpleViT(
 #     image_size = 32,
 #     patch_size = 8,
 #     num_classes = len(classes),
 #     dim = 1024,
 #     depth = 6,
 #     heads = 16,
 #     mlp_dim = 2048
 # )