In [1]:
import copy
import os
import time
from pprint import pprint

import torch
import torchvision

from sairg_utils import launch, \
                        get_data_loaders, \
                        train, \
                        test, \
                        set_random_seeds, \
                        define_finetune_model, \
                        train_process, \
                        builtin_model_initializer, \
                        get_model
set_random_seeds()
from meta_transformer_vanilla import MetaTransformer, \
                                     metatransformer_model_initializer

device = torch.device("cuda")

In [2]:
# Define ResNet base model using SAIRG infrastructure

base_model_weights = torchvision.models.ResNet18_Weights
dataset = torchvision.datasets.CIFAR10
transforms = base_model_weights.DEFAULT.transforms()

model_class = torchvision.models.resnet18
model_params = {'pos_params': [],
                'key_params': {'weights': base_model_weights},}
finetune_params = {'num_classes': 10,
                   'head_layer_name': 'fc'}
model_params['finetune_params'] = finetune_params
loss_fn_class = torch.nn.CrossEntropyLoss
optimizer_class = torch.optim.Adam
optimizer_params = {'lr': 1e-3}
checkpoint_dir = os.path.join(os.getcwd(), "model_ckpts")
checkpoint_prefix = "resnet18_cifar10"

base_training_args = {'dataset': dataset,
                     'transforms': transforms,
                     'num_epochs': 10,
                     'batch_size': 64,
                     'model_class': model_class,
                     'model_params': model_params,
                     'model_initializer': builtin_model_initializer,
                     'loss_fn_class': loss_fn_class,
                     'optimizer_class': optimizer_class,
                     'optimizer_params': optimizer_params,
                     'checkpoint_dir': checkpoint_dir,
                     'checkpoint_prefix': checkpoint_prefix
                    }

# base_model = get_model(base_training_args, train=True)

# train_loader, test_loader = get_data_loaders(dataset, transforms)

# base_model.to(device)
# acc = test(base_model, device, test_loader)
# print(acc)

In [3]:
# Train MetaNet
# 725s, 86%, ZeRO + 2 layers
# 191s, 85%, 1 thr + 2 layers
base_model_weights = torchvision.models.ResNet18_Weights
dataset = torchvision.datasets.CIFAR10
transforms = base_model_weights.DEFAULT.transforms()

model_class = MetaTransformer
base_model = get_model(base_training_args, train=True)
train_loader, test_loader = get_data_loaders(dataset, transforms)
layer_names = [k for k, v in base_model.named_modules() if 'layer' in k and '.' not in k] + ['fc']
del base_model
input_batch, _ = iter(train_loader).next()
base_model_checkpoint_path = os.path.join(base_training_args['checkpoint_dir'],
                                          base_training_args['checkpoint_prefix'])
model_params = {'layer_names': layer_names,
                'input_batch': input_batch.shape,
                'base_training_args': base_training_args,
                'kwargs': {'num_transformer_layers': 2}}
loss_fn_class = torch.nn.CrossEntropyLoss
optimizer_class = torch.optim.Adam
optimizer_params = {'lr': 1e-4}
checkpoint_dir = os.path.join(os.getcwd(), "model_ckpts")
checkpoint_prefix = "meta_transformer_resnet18_cifar10"

meta_training_args = {'dataset': dataset,
                     'transforms': transforms,
                     'num_epochs': 1,
                     'batch_size': 64,
                     'model_class': model_class,
                     'model_params': model_params,
                     'model_initializer': metatransformer_model_initializer,
                     'loss_fn_class': loss_fn_class,
                     'optimizer_class': optimizer_class,
                     'optimizer_params': optimizer_params,
                     'checkpoint_dir': checkpoint_dir,
                     'checkpoint_prefix': checkpoint_prefix
                    }

# meta_model = get_model(meta_training_args, train=True)
meta_model = get_model(meta_training_args)

In [4]:
tok = time.time()
# launch(meta_training_args, num_proc=2)
# train_process(0, meta_training_args)
tik = time.time()
tik - tok

Train process rank 0
base_training_args: {'dataset': <class 'torchvision.datasets.cifar.CIFAR10'>, 'transforms': ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
), 'num_epochs': 10, 'batch_size': 64, 'model_class': <function resnet18 at 0x7f8ee5133dd0>, 'model_params': {'pos_params': [], 'key_params': {'weights': <enum 'ResNet18_Weights'>}, 'finetune_params': {'num_classes': 10, 'head_layer_name': 'fc'}}, 'model_initializer': <function builtin_model_initializer at 0x7f8ee4f81950>, 'loss_fn_class': <class 'torch.nn.modules.loss.CrossEntropyLoss'>, 'optimizer_class': <class 'torch.optim.adam.Adam'>, 'optimizer_params': {'lr': 0.001}, 'checkpoint_dir': '/home/francis/Developer/experiments/meta-transformer/model_ckpts', 'checkpoint_prefix': 'resnet18_cifar10'}
[1,   156] loss: 0.831
[1,   312] loss: 0.637
[1,   468] loss: 0.581
[1,   624] loss: 0.578
[1,   780] loss: 0.5

184.67360830307007

In [6]:
# base_net_weights = torchvision.models.ResNet18_Weights.DEFAULT
# base_net = torchvision.models.resnet18(weights=base_net_weights)
# module_list = [k for k, v in base_net.named_modules() if 'layer' in k and '.' not in k]
# # module_list = [k for k, v in base_net.named_modules() if 'bn' in k]
# module_list += ['fc']
# print(len(module_list))
# module_list

In [7]:
# type(base_net_weights)
# base_net_weights == torchvision.models.ResNet18_Weights.DEFAULT
# type(base_net)

In [8]:
# base_net_weights = torchvision.models.ViT_B_32_Weights.DEFAULT
# base_net = torchvision.models.vit_b_32(weights=base_net_weights)
# # [k for k, v in base_net.named_modules() if 'encoder_layer' in k and len(k.split('.')) < 4]
# # [k for k, v in base_net.named_modules()]
# dict(base_net.named_modules())['heads']

In [9]:
# dataset = torchvision.datasets.CIFAR10
# transform = base_net_weights.transforms()
# train_loader, test_loader = get_data_loaders(dataset, transform)

In [10]:
# print(getattr(base_net, 'fc'))
# base_net = define_finetune_model(base_net, 10, 'fc', finetune_base=False)
# print(getattr(base_net, 'fc'))

In [11]:
# model_ckpt_dir = os.path.expanduser('~/Developer/experiments/meta-transformer/model_ckpts')
# base_model_filename = 'resnet18_cifar10.pth'
# base_model_path = os.path.join(model_ckpt_dir, base_model_filename)
# if os.path.exists(base_model_path):
#     base_net.load_state_dict(torch.load(base_model_path))
# else:
#     loss_fn = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(base_net.parameters(), lr=1e-3)
#     base_net.to(device)
#     tik = time.time()
#     loss = train(base_net, loss_fn, optimizer, device, train_loader)
#     tok = time.time()
#     print(loss)
#     print(tok - tik)
#     torch.save(base_net.state_dict(), base_model_path)

In [12]:
# base_net.to(device)
# acc = test(base_net, device, test_loader)
# print(acc)

In [13]:
# images, _ = iter(train_loader).next()
# images = images.to(device)
# num_layers = 2
# meta_net = MetaTransformer(base_net, module_list, images, num_transformer_layers=num_layers)
# meta_net.to(device)
# meta_net(images).shape

In [14]:
# do_reload = True
# base_model_filename = 'meta_transformer_resnet18_cifar10.pth'
# base_model_path = os.path.join(model_ckpt_dir, base_model_filename)
# if do_reload and os.path.exists(base_model_path):
#     meta_net.load_state_dict(torch.load(base_model_path))
# else:
#     loss_fn = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(meta_net.parameters(), lr=1e-4)
#     meta_net.to(device)
#     tik = time.time()
#     loss = train(meta_net, loss_fn, optimizer, device, train_loader)
#     tok = time.time()
#     print(loss)
#     print(tok - tik)
#     torch.save(meta_net.state_dict(), base_model_path)

In [15]:
# meta_net.to(device)
# acc = test(meta_net, device, test_loader)
# print(acc)

In [16]:
# meta_net.eval()
# meta_net(images)

# acts = meta_net._activations['fc']
# vals, inds = torch.max(acts, 1)
# vals = torch.unsqueeze(vals, 1)
# inds = torch.unsqueeze(inds, 1).repeat((1, 10))
# src = torch.tensor(range(100)).reshape((10, 10))
# src = src.unsqueeze(0).repeat((64, 1, 1))
# torch.scatter(torch.zeros_like(acts), 1, inds, vals)

# print(acts.shape)
# res = acts[inds == 6].reshape((-1, 10))
# print(res.shape)
# # print(acts.unsqueeze(-1)[inds == 6].shape)
# src = torch.tensor(range(10), dtype=acts.dtype, device=acts.device).repeat((6))
# print(src.shape)
# acts[inds == 6] = src
# print(acts.shape)
# acts
# vals.squeeze().shape
# print(f"{acts.shape}, {inds.shape}, {vals.shape}")

# a = torch.reshape(torch.tensor(range(40)), (4, 10))
# inds = torch.tensor([[2, 3], [6, 7], [0, 1], [4, 5]])
# inds == 2
# b = torch.gather(a, 1, inds)
# b[inds == 2]

# inds = inds.repeat((1, output.shape[-1]))
# output = torch.reshape(output[inds == 19], (-1, output.shape[-1]))
# avgs = torch.mean(output, 0, keepdim=True)
# avgs.shape
# output.shape

# meta_net._activations['fc'].shape

In [17]:
# meta_net.templates

In [18]:
# num_epochs = 10
# tik = time.time()
# launch(num_epochs=num_epochs, num_proc=2)
# tok = time.time()
# tok - tik

In [19]:
 # vit_net = SimpleViT(
 #     image_size = 32,
 #     patch_size = 8,
 #     num_classes = len(classes),
 #     dim = 1024,
 #     depth = 6,
 #     heads = 16,
 #     mlp_dim = 2048
 # )