In [None]:
import copy
import os
import time
from pprint import pprint

import torch
import torchvision

import sairg_utils

sairg_utils.set_random_seeds()

device = torch.device("cuda")

In [None]:
# Define ResNet base model using SAIRG infrastructure

base_model_weights = 'torchvision.models.ResNet18_Weights'
dataset = 'torchvision.datasets.CIFAR10'
transforms = 'torchvision.models.ResNet18_Weights.DEFAULT.transforms'

model_class = 'torchvision.models.resnet18'
model_params = {'pos_params': [],
                'key_params': {'weights': base_model_weights},}
finetune_params = {'num_classes': 10,
                   'head_layer_name': 'fc'}
model_params['finetune_params'] = finetune_params
model_initializer = 'sairg_utils.builtin_model_initializer'
loss_fn_class = 'torch.nn.CrossEntropyLoss'
optimizer_class = 'torch.optim.Adam'
optimizer_params = {'lr': 1e-3}
checkpoint_dir = os.path.join(os.getcwd(), "model_ckpts")
checkpoint_prefix = "resnet18_cifar10"

base_training_args = {'dataset': dataset,
                     'transforms': transforms,
                     'num_epochs': 10,
                     'batch_size': 64,
                     'model_class': model_class,
                     'model_params': model_params,
                     'model_initializer': model_initializer,
                     'loss_fn_class': loss_fn_class,
                     'optimizer_class': optimizer_class,
                     'optimizer_params': optimizer_params,
                     'checkpoint_dir': checkpoint_dir,
                     'checkpoint_prefix': checkpoint_prefix
                    }

# sairg_utils.train_process(0, base_training_args)
# base_model = sairg_utils.get_model(base_training_args, train=True)

# train_loader, test_loader = get_data_loaders(dataset, transforms)

# base_model.to(device)
# acc = test(base_model, device, test_loader)
# print(acc)

In [None]:
# Train MetaNet
# 725s, 86%, ZeRO + 2 layers
# 191s, 85%, 1 thr + 2 layers

dataset = 'torchvision.datasets.CIFAR10'
transforms = 'torchvision.models.ResNet18_Weights.DEFAULT.transforms'

model_class = 'meta_transformer_vanilla.MetaTransformer'
base_model = sairg_utils.get_model(base_training_args)
dataset_class = sairg_utils.get_class_by_path(dataset)
transform_class = sairg_utils.get_class_by_path(transforms)()
train_loader, test_loader = sairg_utils.get_data_loaders(dataset_class, transform_class)
layer_names = [k for k, v in base_model.named_modules() if 'layer' in k and '.' not in k] + ['fc']
del base_model
input_batch, _ = iter(train_loader).next()
base_model_checkpoint_path = os.path.join(base_training_args['checkpoint_dir'],
                                          base_training_args['checkpoint_prefix'])
model_params = {'layer_names': layer_names,
                'input_batch': input_batch.shape,
                'base_training_args': base_training_args,
                'kwargs': {'num_transformer_layers': 2}}
model_initializer = 'meta_transformer_vanilla.metatransformer_model_initializer'
loss_fn_class = 'torch.nn.CrossEntropyLoss'
optimizer_class = 'torch.optim.Adam'
optimizer_params = {'lr': 1e-4}
checkpoint_dir = os.path.join(os.getcwd(), "model_ckpts")
checkpoint_prefix = "meta_transformer_resnet18_cifar10"

meta_training_args = {'dataset': dataset,
                     'transforms': transforms,
                     'num_epochs': 1,
                     'batch_size': 64,
                     'model_class': model_class,
                     'model_params': model_params,
                     'model_initializer': model_initializer,
                     'loss_fn_class': loss_fn_class,
                     'optimizer_class': optimizer_class,
                     'optimizer_params': optimizer_params,
                     'checkpoint_dir': checkpoint_dir,
                     'checkpoint_prefix': checkpoint_prefix
                    }
# meta_model = sairg_utils.get_model(meta_training_args, train=True, force=True)
meta_model = sairg_utils.get_model(meta_training_args)

In [None]:
tok = time.time()
# launch(meta_training_args, num_proc=2)
# train_process(0, meta_training_args)
tik = time.time()
tik - tok

In [None]:
# base_net_weights = torchvision.models.ResNet18_Weights.DEFAULT
# base_net = torchvision.models.resnet18(weights=base_net_weights)
# module_list = [k for k, v in base_net.named_modules() if 'layer' in k and '.' not in k]
# # module_list = [k for k, v in base_net.named_modules() if 'bn' in k]
# module_list += ['fc']
# print(len(module_list))
# module_list

In [None]:
# type(base_net_weights)
# base_net_weights == torchvision.models.ResNet18_Weights.DEFAULT
# type(base_net)

In [None]:
# base_net_weights = torchvision.models.ViT_B_32_Weights.DEFAULT
# base_net = torchvision.models.vit_b_32(weights=base_net_weights)
# # [k for k, v in base_net.named_modules() if 'encoder_layer' in k and len(k.split('.')) < 4]
# # [k for k, v in base_net.named_modules()]
# dict(base_net.named_modules())['heads']

In [None]:
# dataset = torchvision.datasets.CIFAR10
# transform = base_net_weights.transforms()
# train_loader, test_loader = get_data_loaders(dataset, transform)

In [None]:
# print(getattr(base_net, 'fc'))
# base_net = define_finetune_model(base_net, 10, 'fc', finetune_base=False)
# print(getattr(base_net, 'fc'))

In [None]:
# model_ckpt_dir = os.path.expanduser('~/Developer/experiments/meta-transformer/model_ckpts')
# base_model_filename = 'resnet18_cifar10.pth'
# base_model_path = os.path.join(model_ckpt_dir, base_model_filename)
# if os.path.exists(base_model_path):
#     base_net.load_state_dict(torch.load(base_model_path))
# else:
#     loss_fn = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(base_net.parameters(), lr=1e-3)
#     base_net.to(device)
#     tik = time.time()
#     loss = train(base_net, loss_fn, optimizer, device, train_loader)
#     tok = time.time()
#     print(loss)
#     print(tok - tik)
#     torch.save(base_net.state_dict(), base_model_path)

In [None]:
# base_net.to(device)
# acc = test(base_net, device, test_loader)
# print(acc)

In [None]:
# images, _ = iter(train_loader).next()
# images = images.to(device)
# num_layers = 2
# meta_net = MetaTransformer(base_net, module_list, images, num_transformer_layers=num_layers)
# meta_net.to(device)
# meta_net(images).shape

In [None]:
# do_reload = True
# base_model_filename = 'meta_transformer_resnet18_cifar10.pth'
# base_model_path = os.path.join(model_ckpt_dir, base_model_filename)
# if do_reload and os.path.exists(base_model_path):
#     meta_net.load_state_dict(torch.load(base_model_path))
# else:
#     loss_fn = torch.nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(meta_net.parameters(), lr=1e-4)
#     meta_net.to(device)
#     tik = time.time()
#     loss = train(meta_net, loss_fn, optimizer, device, train_loader)
#     tok = time.time()
#     print(loss)
#     print(tok - tik)
#     torch.save(meta_net.state_dict(), base_model_path)

In [None]:
# meta_net.to(device)
# acc = test(meta_net, device, test_loader)
# print(acc)

In [None]:
# meta_net.eval()
# meta_net(images)

# acts = meta_net._activations['fc']
# vals, inds = torch.max(acts, 1)
# vals = torch.unsqueeze(vals, 1)
# inds = torch.unsqueeze(inds, 1).repeat((1, 10))
# src = torch.tensor(range(100)).reshape((10, 10))
# src = src.unsqueeze(0).repeat((64, 1, 1))
# torch.scatter(torch.zeros_like(acts), 1, inds, vals)

# print(acts.shape)
# res = acts[inds == 6].reshape((-1, 10))
# print(res.shape)
# # print(acts.unsqueeze(-1)[inds == 6].shape)
# src = torch.tensor(range(10), dtype=acts.dtype, device=acts.device).repeat((6))
# print(src.shape)
# acts[inds == 6] = src
# print(acts.shape)
# acts
# vals.squeeze().shape
# print(f"{acts.shape}, {inds.shape}, {vals.shape}")

# a = torch.reshape(torch.tensor(range(40)), (4, 10))
# inds = torch.tensor([[2, 3], [6, 7], [0, 1], [4, 5]])
# inds == 2
# b = torch.gather(a, 1, inds)
# b[inds == 2]

# inds = inds.repeat((1, output.shape[-1]))
# output = torch.reshape(output[inds == 19], (-1, output.shape[-1]))
# avgs = torch.mean(output, 0, keepdim=True)
# avgs.shape
# output.shape

# meta_net._activations['fc'].shape

In [None]:
# meta_net.templates

In [None]:
# num_epochs = 10
# tik = time.time()
# launch(num_epochs=num_epochs, num_proc=2)
# tok = time.time()
# tok - tik

In [None]:
 # vit_net = SimpleViT(
 #     image_size = 32,
 #     patch_size = 8,
 #     num_classes = len(classes),
 #     dim = 1024,
 #     depth = 6,
 #     heads = 16,
 #     mlp_dim = 2048
 # )