In [1]:
import copy
import os
import time
from pprint import pprint

import torch
import torchvision

from sairg_utils import launch, \
                        get_data_loaders, \
                        train, \
                        test, \
                        set_random_seeds, \
                        define_finetune_model, \
                        train_process, \
                        builtin_model_initializer, \
                        get_model
set_random_seeds()
from meta_transformer_vanilla import MetaTransformer

device = torch.device("cuda")
EXPERIMENT_ROOT = os.path.expanduser("~/Developer/experiments")

In [2]:
base_model_weights = copy.deepcopy(torchvision.models.ResNet18_Weights)
dataset = torchvision.datasets.CIFAR10
transforms = base_model_weights.DEFAULT.transforms()

model_class = torchvision.models.resnet18
model_params = {'pos_params': [],
                'key_params': {'weights': base_model_weights},}
finetune_params = {'num_classes': 10,
                   'head_layer_name': 'fc'}
model_params['finetune_params'] = finetune_params
loss_fn_class = torch.nn.CrossEntropyLoss
optimizer_class = torch.optim.Adam
optimizer_params = {'lr': 1e-3}
checkpoint_dir = os.path.join(EXPERIMENT_ROOT, "model_ckpts")
checkpoint_prefix = "resnet18_cifar10"

training_args = {'dataset': dataset,
                 'transforms': transforms,
                 'num_epochs': 1,
                 'batch_size': 64,
                 'model_class': model_class,
                 'model_params': model_params,
                 'model_initializer': builtin_model_initializer,
                 'loss_fn_class': loss_fn_class,
                 'optimizer_class': optimizer_class,
                 'optimizer_params': optimizer_params,
                 'checkpoint_dir': checkpoint_dir,
                 'checkpoint_prefix': checkpoint_prefix
                }

In [3]:
model = get_model(training_args)
# model

(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU

In [4]:
# tok = time.time()
# launch(training_args, num_proc=2)
# tik = time.time()
# tik - tok

Train process rank 0
Running DDP on rank 0
Train process rank 1
Running DDP on rank 1
[1,    78] loss: 1.527
[1,    78] loss: 1.524
[1,   156] loss: 0.978
[1,   156] loss: 1.006
[1,   234] loss: 0.870
[1,   234] loss: 0.853
[1,   312] loss: 0.801
[1,   312] loss: 0.816
[1,   390] loss: 0.790
[1,   390] loss: 0.770
Accuracy after epoch 1: 75.37 %


79.73158383369446

In [5]:
# base_net_weights = torchvision.models.ResNet18_Weights.DEFAULT
# base_net = torchvision.models.resnet18(weights=base_net_weights)
# module_list = [k for k, v in base_net.named_modules() if 'layer' in k and '.' not in k]
# # module_list = [k for k, v in base_net.named_modules() if 'bn' in k]
# module_list += ['fc']
# print(len(module_list))
# module_list

In [6]:
# type(base_net_weights)
# base_net_weights == torchvision.models.ResNet18_Weights.DEFAULT
# type(base_net)

In [7]:
# base_net_weights = torchvision.models.ViT_B_32_Weights.DEFAULT
# base_net = torchvision.models.vit_b_32(weights=base_net_weights)
# # [k for k, v in base_net.named_modules() if 'encoder_layer' in k and len(k.split('.')) < 4]
# # [k for k, v in base_net.named_modules()]
# dict(base_net.named_modules())['heads']

In [8]:
# dataset = torchvision.datasets.CIFAR10
# transform = base_net_weights.transforms()
# train_loader, test_loader = get_data_loaders(dataset, transform)

In [9]:
# print(getattr(base_net, 'fc'))
# base_net = define_finetune_model(base_net, 10, 'fc', finetune_base=False)
# print(getattr(base_net, 'fc'))

In [10]:
# model_ckpt_dir = os.path.expanduser('~/Developer/experiments/meta-transformer/model_ckpts')
# base_model_filename = 'resnet18_cifar10.pth'
# base_model_path = os.path.join(model_ckpt_dir, base_model_filename)
# if os.path.exists(base_model_path):
#     base_net.load_state_dict(torch.load(base_model_path))
# else:
#     loss_fn = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(base_net.parameters(), lr=1e-3)
#     base_net.to(device)
#     tik = time.time()
#     loss = train(base_net, loss_fn, optimizer, device, train_loader)
#     tok = time.time()
#     print(loss)
#     print(tok - tik)
#     torch.save(base_net.state_dict(), base_model_path)

In [11]:
# base_net.to(device)
# acc = test(base_net, device, test_loader)
# print(acc)

In [12]:
# images, _ = iter(train_loader).next()
# images = images.to(device)
# num_layers = 2
# meta_net = MetaTransformer(base_net, module_list, images, num_transformer_layers=num_layers)
# meta_net.to(device)
# meta_net(images).shape

In [13]:
# do_reload = True
# base_model_filename = 'meta_transformer_resnet18_cifar10.pth'
# base_model_path = os.path.join(model_ckpt_dir, base_model_filename)
# if do_reload and os.path.exists(base_model_path):
#     meta_net.load_state_dict(torch.load(base_model_path))
# else:
#     loss_fn = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(meta_net.parameters(), lr=1e-4)
#     meta_net.to(device)
#     tik = time.time()
#     loss = train(meta_net, loss_fn, optimizer, device, train_loader)
#     tok = time.time()
#     print(loss)
#     print(tok - tik)
#     torch.save(meta_net.state_dict(), base_model_path)

In [14]:
# meta_net.to(device)
# acc = test(meta_net, device, test_loader)
# print(acc)

In [15]:
# meta_net.eval()
# meta_net(images)

# acts = meta_net._activations['fc']
# vals, inds = torch.max(acts, 1)
# vals = torch.unsqueeze(vals, 1)
# inds = torch.unsqueeze(inds, 1).repeat((1, 10))
# src = torch.tensor(range(100)).reshape((10, 10))
# src = src.unsqueeze(0).repeat((64, 1, 1))
# torch.scatter(torch.zeros_like(acts), 1, inds, vals)

# print(acts.shape)
# res = acts[inds == 6].reshape((-1, 10))
# print(res.shape)
# # print(acts.unsqueeze(-1)[inds == 6].shape)
# src = torch.tensor(range(10), dtype=acts.dtype, device=acts.device).repeat((6))
# print(src.shape)
# acts[inds == 6] = src
# print(acts.shape)
# acts
# vals.squeeze().shape
# print(f"{acts.shape}, {inds.shape}, {vals.shape}")

# a = torch.reshape(torch.tensor(range(40)), (4, 10))
# inds = torch.tensor([[2, 3], [6, 7], [0, 1], [4, 5]])
# inds == 2
# b = torch.gather(a, 1, inds)
# b[inds == 2]

# inds = inds.repeat((1, output.shape[-1]))
# output = torch.reshape(output[inds == 19], (-1, output.shape[-1]))
# avgs = torch.mean(output, 0, keepdim=True)
# avgs.shape
# output.shape

# meta_net._activations['fc'].shape

In [16]:
# meta_net.templates

In [17]:
# num_epochs = 10
# tik = time.time()
# launch(num_epochs=num_epochs, num_proc=2)
# tok = time.time()
# tok - tik

In [18]:
 # vit_net = SimpleViT(
 #     image_size = 32,
 #     patch_size = 8,
 #     num_classes = len(classes),
 #     dim = 1024,
 #     depth = 6,
 #     heads = 16,
 #     mlp_dim = 2048
 # )