In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules
IN_LOCAL = not IN_COLAB
USE_GITHUB = True
USE_DRIVE = False
device = 'cuda'

assert not (USE_GITHUB and USE_DRIVE)

In [None]:
if IN_COLAB:
  !pip install wandb -qU
  from google.colab import runtime
  if USE_GITHUB:
    !git clone https://github.com/kejeon/in_dev_RN20Q.git
    %cd '/content/in_dev_RN20Q'
  elif USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    %cd '/content/drive/MyDrive/GitHub/in_dev_RN20Q'

import wandb
import torch
from model.resnet4c10q import ResNet20_Q
from model.resnet4c10 import resnet20
from mylib.WCResNetTrainer import ResNetTrainer

wandb.login(key='e0c11d3ff2bee4c8775ba05863038fdac671c043')

In [None]:
def train():
    run = wandb.init()

    api = wandb.Api()
    artifact = api.artifact('jke1994/ResNet20_WC_L1/model:v129')
    artifact.download(root='./pretrained_model')
    state_dict = torch.load('./pretrained_model/ckpt.pth', map_location=torch.device(device))

    lr = wandb.config.lr
    arch_tag = "ResNet20_Q"
    batch_size = 128
    dataset = "CIFAR10"
    lambda_l1 = wandb.config.lambda_l1
    lambda_kl = wandb.config.lambda_kl
    a_bit = wandb.config.a_bit
    w_bit = wandb.config.w_bit
    train_epoch = wandb.config.train_epoch

    model = ResNet20_Q(a_bit=a_bit, w_bit=w_bit)

    try:
        model.load_state_dict(state_dict['net'])
    except:
        # if the model is wrapped in a module, update all keys in state_dict to remove module.
        state_dict['net'] = {k.replace('module.', ''): v for k, v in state_dict['net'].items()}
        model.load_state_dict(state_dict['net'])

    my_trainer = ResNetTrainer(dataset=dataset,
                           arch_tag=arch_tag,
                           lambda_l1=lambda_l1,
                           lambda_kl=lambda_kl,
                           model=model,
                           device = device,
                           batch_size=batch_size,
                           lr=lr)

    my_trainer.train_script(train_epoch)

    return

In [None]:
torch.manual_seed(0)

sweep_configuration = {
    'method': 'grid',
    'name': 'sweep_1',
    'metric': {'goal': 'maximize', 'name': 'test_acc'},
    'parameters': 
    {
        'lr': {'values': [0.001]},
        'lambda_l1': {'values': [0.000002, 0.000004, 0.000006, 0.000008, 0.00001]},
        'lambda_kl': {'values': [100, 150, 200, 250, 300]},
        'a_bit': {'values': [4]},
        'w_bit': {'values': [4]},
        'train_epoch': {'values': [100]}
     }
}

sweep_id = wandb.sweep(
  sweep=sweep_configuration, 
  project='ResNet20_WC_L1_Sweep'
  )

In [None]:
wandb.agent(sweep_id, function=train, count=25)

wandb.finish()

runtime.unassign()