In [1]:
! pip install -e gym-autotrain

Defaulting to user installation because normal site-packages is not writeable
Obtaining file:///home/jupyter-skenjeye%40broadinst-05974/AutoTrain/gym-autotrain
Installing collected packages: gym-autotrain
  Attempting uninstall: gym-autotrain
    Found existing installation: gym-autotrain 0.0.1
    Uninstalling gym-autotrain-0.0.1:
      Successfully uninstalled gym-autotrain-0.0.1
  Running setup.py develop for gym-autotrain
Successfully installed gym-autotrain


In [1]:
import torch
import torchvision
import torchvision.transforms as T
import torchvision.models as models

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.laplace import Laplace
import torch.utils.data as torchdata

import numpy as np
import matplotlib.pyplot as plt

import seaborn as sns
import pickle as pkl
from pathlib import Path
from functools import partial
import pandas as pd

from tqdm.notebook import tqdm

import gym
from gym import error, spaces, utils
from gym.utils import seeding

import gym_autotrain.envs.utils as utils

from gym_autotrain.envs.thresholdout import Thresholdout

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as  torchdata

import pandas as pd
import numpy as np

from pathlib import Path
from functools import partial



# Data

In [2]:
DATA_ROOT = Path('./data')
DATA_SPLIT = 0.6

ENV_PATH = Path('./autotrain-run')
ENV_PATH.mkdir(exist_ok=True)

DEVICE = torch.device("cuda:3")

In [3]:

CLASSES = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def splitds(train, test, no_signal=False, pct_cap=None):
    X = np.concatenate((train.data,test.data), axis=0)
    Y = train.targets + test.targets
    
    if pct_cap:
        cap = int(pct_cap*len(X))
        X, Y = X[:cap], Y[:cap]
        
    
    if no_signal:
        print('suffling labels')
        np.random.shuffle(Y)
    
    split_id = int(len(X) * DATA_SPLIT)
    train.data, train.targets = X[:split_id], Y[:split_id]
    test.data, test.targets = X[split_id:], Y[split_id:]

def get_dataset(tfms, no_signal=False, pct_cap=None):
    train = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=True,
                                        download=True, transform=tfms)

    holdout = torchvision.datasets.CIFAR10(root=DATA_ROOT / 'cifar-10-data', train=False,
                                           download=True, transform=tfms)
    
    splitds(train, holdout, no_signal, pct_cap)
    
    print(f'length of trainset: [{len(train)}]; len of holdout: [{len(holdout)}]')
    
    return train, holdout


In [4]:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])

TFMS = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor(), normalize])

train, holdout = get_dataset(TFMS, pct_cap=0.1)

Files already downloaded and verified
Files already downloaded and verified
length of trainset: [3600]; len of holdout: [2400]


In [5]:
def accuracy(data: torchdata.DataLoader, model: nn.Module): # phi
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(data,total=len(data)):
            images, labels = batch[0].to(DEVICE), batch[1]
            outputs = model(images).cpu()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return correct / total

# Model

In [6]:
backbone = models.resnet18(pretrained=False)
backbone.fc = nn.Linear(512, len(CLASSES))
backbone.to(DEVICE)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

# Env

In [44]:
def make_o(loss_vec:np.array, lr:float, phi_val:float):
      return np.concatenate((loss_vec, [lr, phi_val]), axis=0)


class AutoTrainEnvironment(gym.Env):
      metadata = {'render.modes': ['human']}

      def __init__(self):
            pass

      def __repr__(self):
            return f"""AutoTrainEnvironment with the following parameters:
                        lr_init={self.lr_init}, inter_reward={self._inter_reward}, H={self.H}, K={self.K}, T={self.T}"""
      
      def log(self, s):
            if self.v: print(f'[time_step:{self.time_step}] ',s)

      def init(self, backbone: nn.Module,  phi: callable, savedir:Path,
             trnds:torchdata.Dataset, valds:torchdata.Dataset, 
             T=3, H=5, S=2, lr_init=3e-4, inter_reward=0.05,
             num_workers=4, bs=16, v=False, device=None):
            """
            params:
                  - backbone: nn.Module, neural network architecture for training
                  - trainds, valds: Dataset, train and validation datasets
                  - phi: callable, function to be optimised  

                  - T: num batch updates that constitutes one time step for the environment
                  - H: length of rewind vector
                  - S: sampling interval, determines K which is the loss vector
            """
            # experiment paramter setup
            
            self.T, self.H, self.sampling_interval = T, H, S
            self.K = self.T // self.sampling_interval
            
            self._inter_reward = inter_reward
            self.v = v # is verbose
            self.device = device
            self.savedir = savedir
            # rewind actions * lr_scale actions [decrease  10%, keep, increase 10%] + reinit + stop 
            self.action_space_dim = self.H*3 + 2 
            # loss_vec of size K + lr + phi_val
            self.observation_space_dim = self.K + 2
            
            self.time_step = 0

            # model 
            self.ll = StateLinkedList(savedir=savedir, dim=self.observation_space_dim)
            self.backbone = backbone

            self.criterion = nn.CrossEntropyLoss()
            self.lr_init = self._curr_lr = lr_init

            self._init_backbone()

            # package data
            self.trnds, self.valds = trnds, valds

            self.trndl, self.fixdl, self.valdl =  utils.create_dls(self.trnds, self.valds, bs=bs, num_workers=num_workers)


            # Thresholdout & statistic of interest
            self.thresholdout = Thresholdout(self.trndl, self.valdl)
            self.phi = phi 
            self._phi_func = partial(self.phi, model=self.backbone) #  does partial work on this

            self._init_phi()

            # calculate the sampling interval

            self.logdf = pd.DataFrame(columns=['t', 'reward', 'is_stop', 'action_id'])

            self._add_observation(np.zeros(K), self._get_phi_val())

            self.log(f'environment initialised : {self.__repr__()}')
            
            return self.ll.get_observations(self.H)

      def  _init_phi(self):
            self.log('initialised phi value: started ...')
            self._prev_phi_val = 0
            self._cur_phi_val = self.thresholdout.verify(self._phi_func)
            self._last_phi_update = self.time_step
            self.log('initialised phi value: done')

      def _init_backbone(self):

            if self.device:
                  self.backbone.to(self.device) # check

            utils.init_params(self.backbone)

            self.opt = optim.Adam(self.backbone.parameters(), lr=self.lr_init)
            self.log(f'initialised backbone parameters & optimizer')


      def _get_phi_val(self) -> float:
            if self._last_phi_update != self.time_step:
                  self._prev_phi_val = self._cur_phi_val
                  self._cur_phi_val = self.thresholdout.verify(self._phi_func)
                  self._last_phi_update = self.time_step

            return self._cur_phi_val

      def _add_observation(self,  loss_vec: np.array,  phi_val: float):
            o_state = ObservationAndState(
                  param_dict=self.backbone.state_dict(),
                  o=make_o(loss_vec, self._curr_lr, phi_val)
            )
            self.ll.append(o_state)
            self.log(f'added observation')

      def visualise_data(self):
            """
            two plots: 
                  - class distribution of training data
                  - class distribution of validation data
            """
            pass 

      def step(self, action: int): 
            """
            step(self, action: int):
                @action: index of the max value of the action probability vector 
            
            """

            self.log(f'action [{action}] recieved')

            is_stop = action == self.action_space_dim
            is_reinit = action == self.action_space_dim - 1

            if is_stop:
                  final_reward = self._compute_final_reward()
                  self.log(f'recieved STOP signal, final reward is: [{final_reward}]')
                  return None, final_reward, True, {}


            # lr and rewind steps
            if action < 5:
                  self._scale_lr(0.9)
                  rewind_steps = action
                  self.log(f'decreased lr by 10% -> [lr:{self._curr_lr}]')
            elif action >= 5 and action < 10:
                  rewind_steps = action - 5 
            else:
                  self._scale_lr(1.1)
                  self.log(f'increased lr by 10% -> [lr:{self._curr_lr}]')
                  rewind_steps = action - 10
            

            if rewind_steps >= self.ll.len or is_reinit:
                  self.log(f'recieved RE-INIT signal or rewind_steps[{rewind_steps}] > len(ll)')

                  self._init_backbone()

                  self._init_phi()
                  self._add_observation(np.zeros(self.K), self._get_phi_val()) # do we add here or no

                  self.ll = StateLinkedList(savedir=self.savedir, dim=self.observation_space_dim)

            elif rewind_steps != 0:
                  self.log(f'rewind weights [{rewind_steps}] steps back')
                  self.ll.rewind(rewind_steps)
                  o = self.ll[self.ll.len-1]  # get the latest state after rewind
                  self.backbone.load_state_dict(o['param_dict'])

            # do training 
            loss_vec = self._train_one_cycle()
            # set current observation
            self._add_observation(loss_vec, self._get_phi_val()) # whats this for then

            # get last H observations
            o_history = self.ll.get_observations(self.H)

            # compute intermediate reward
            step_reward = self._compute_intermediate_reward()
            self.log(f'reward at the end of time step is [{step_reward}]')

            return o_history, step_reward, False, {}


      def _scale_lr(self, scale_factor):

            for g in self.opt.param_groups:
                  g['lr'] *= scale_factor

            self._curr_lr *= scale_factor


      def _train_one_cycle(self, loss_vec=None, steps=0):
            print('train called, loss_vec: ', loss_vec, ' steps ',steps)
            if loss_vec is None:
                  loss_vec = np.zeros(self.K)

            self.backbone.train()
            for i, batch in enumerate(self.trndl):
                  inputs, labels = batch[0], batch[1]

                  if self.device:
                      inputs, labels = inputs.to(self.device), labels.to(self.device)
            
                  self.opt.zero_grad()

                  outputs = self.backbone(inputs)
                  loss = self.criterion(outputs, labels)
                  loss.backward()
                  self.opt.step()

                  steps += 1
                    
                  if steps >= self.T:
                        self.time_step += 1 # this defines the time step
                        return loss_vec
                  
                  if steps % self.sampling_interval == 0:
                        loss_vec[(steps // self.sampling_interval) - 1] = loss.item()
                
            return self._train_one_cycle(loss_vec=loss_vec, steps=steps)

      def _compute_final_reward(self):
            return self._get_phi_val()

      def _compute_intermediate_reward(self):
            delta = self._get_phi_val() - self._prev_phi_val
            if  delta > 0:
                  return self._inter_reward
            else:
                  return -self._inter_reward


      def reset(self):

            return self.init(self.backbone, self.phi, self.savedir, self.trnds, self.valds,
            T=self.T, H=self.H, K=self.K, lr_init=self.lr_init, inter_reward=self._inter_reward,
            num_workers=self.num_workers, bs=self.bs, v=self.v, device=self.device)

      def render(self, mode='human', close=False):
            """render observation; for that need to add logs"""
            pass



class ObservationAndState:
      def __init__(self, param_dict: dict, o: np.array):
            self.param_dict = param_dict
            self.o = o

            self.dim = self.o.size

      def to_dict(self):
            return {
                  'param_dict': self.param_dict,
                  'o': self.o
            }

      def __repr__(self):
            return f"ObservationAndState Object --> o={self.o} param_dict={self.param_dict}"

      
            

class StateLinkedList:

      def __init__(self, savedir: Path, dim: int):

            if type(savedir) == str:
                  savedir = Path(savedir) 

            assert savedir.exists() and savedir.is_dir(), "please make sure save path exists and is directory"
            assert dim > 0

            self.savedir = savedir
            self.dim = dim # observation dimension 

            self.len = 0 # the id of the next node

      def get_observations(self, size):
            os = []
            if size > self.len:
                  for _ in range(size - self.len):
                        os.append(np.zeros(self.dim))
                  size = self.len

            os += [self[i].o for i in range(size)]
            return np.vstack(os)
            

      def append(self, state: ObservationAndState):

            new_node_path = self.node_path(self.len)
            torch.save(state, new_node_path)
            self.len += 1

      def node_path(self, idx) -> Path:
            return self.savedir / f'state_{idx}.ckpt'


      def __len__(self):
            return self.len

      def __getitem__(self, idx):

            if idx >= self.len:
                  raise ValueError('idx too large')

            return torch.load(self.node_path(idx))
            

      def rewind(self, steps: int):

            steps = min(steps, self.len)
            
            for i in range(1, steps+1):
                  print('rewind')
                  nodepath = self.node_path(self.len - i)
                  nodepath.unlink()

            self.len -= steps





# Running Experiments

In [45]:
H = 5
BS = 16
S = 2 # sampling interval
T = len(train) // BS * 3 # three epochs
K = T // S
K

337

In [46]:
env = AutoTrainEnvironment()



ob = env.init(backbone=backbone,  phi=accuracy, savedir=ENV_PATH,
         trnds=train, valds=holdout, 
         T=T, H=H, S=2, lr_init=3e-4, inter_reward=0.05,
         num_workers=4, bs=BS, v=True, device=DEVICE)

[time_step:0]  initialised backbone parameters & optimizer
[time_step:0]  initialised phi value: started ...


HBox(children=(FloatProgress(value=0.0, max=225.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=150.0), HTML(value='')))


[time_step:0]  initialised phi value: done
[time_step:0]  added observation
[time_step:0]  environment initialised : AutoTrainEnvironment with the following parameters:
                        lr_init=0.0003, inter_reward=0.05, H=5, K=337, T=675


In [41]:
def agent():
    return np.random.choice(range(env.action_space_dim))
    

In [42]:
agent()

8

In [47]:
env.step(0)

[time_step:0]  action [0] recieved
[time_step:0]  decreased lr by 10% -> [lr:0.00027]
train called, loss_vec:  None  steps  0
train called, loss_vec:  [2.6716814  1.89849043 1.46394598 1.46715999 1.51975882 1.01550984
 1.1220268  0.67024863 0.5065791  0.5564186  0.40669262 0.51554328
 0.35622099 0.44355494 0.48641109 0.35580549 0.43212035 0.26500952
 0.27012148 0.35450253 0.14052968 0.18816264 0.42199397 0.09489837
 0.19831866 0.21077883 0.11365804 0.07457419 0.11297897 0.07429682
 0.24430957 0.06534055 0.24004477 0.08660294 0.41303706 0.07469225
 0.09402582 0.22449659 0.12264282 0.06622043 0.13399404 0.06270418
 0.05542123 0.2226918  0.21362366 0.21102114 0.1198792  0.1962316
 0.11414631 0.25017709 0.24400733 0.09989141 0.12945461 0.12918201
 0.21318395 0.23703092 0.4249813  0.33737153 0.23304154 0.22955856
 0.16239452 0.23685353 0.35459965 0.08655542 0.25610399 0.28174251
 0.15637863 0.13442664 0.20358257 0.26980224 0.06227034 0.29633158
 0.30740952 0.10634193 0.25232834 0.10382569 0

HBox(children=(FloatProgress(value=0.0, max=225.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=150.0), HTML(value='')))


[2.6716814  1.89849043 1.46394598 1.46715999 1.51975882 1.01550984
 1.1220268  0.67024863 0.5065791  0.5564186  0.40669262 0.51554328
 0.35622099 0.44355494 0.48641109 0.35580549 0.43212035 0.26500952
 0.27012148 0.35450253 0.14052968 0.18816264 0.42199397 0.09489837
 0.19831866 0.21077883 0.11365804 0.07457419 0.11297897 0.07429682
 0.24430957 0.06534055 0.24004477 0.08660294 0.41303706 0.07469225
 0.09402582 0.22449659 0.12264282 0.06622043 0.13399404 0.06270418
 0.05542123 0.2226918  0.21362366 0.21102114 0.1198792  0.1962316
 0.11414631 0.25017709 0.24400733 0.09989141 0.12945461 0.12918201
 0.21318395 0.23703092 0.4249813  0.33737153 0.23304154 0.22955856
 0.16239452 0.23685353 0.35459965 0.08655542 0.25610399 0.28174251
 0.15637863 0.13442664 0.20358257 0.26980224 0.06227034 0.29633158
 0.30740952 0.10634193 0.25232834 0.10382569 0.07704481 0.2559174
 0.61242193 0.2016139  0.26684785 0.33359915 0.26172924 0.13693061
 0.35444772 0.20887946 0.20910668 0.02807483 0.16731872 0.41783

(array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 3.00000000e-04, 8.47065400e-02],
        [2.67168140e+00, 1.89849043e+00, 1.46394598e+00, ...,
         4.30972904e-01, 2.70000000e-04, 5.36689875e-01]]),
 0.05,
 False,
 {})