In [1]:
import sys
import torch
import wandb
from pathlib import Path
sys.path.append('../..')

from solutions.convolution.conv1_model import Conv1Model
from cgol.generator.uniform_density_generator import UniformDensityGenerator
from cgol.simulator.torch_simulator import TorchSimulator
from cgol.simulator.minimal_architecture_simulator import MinimalArchitectureSimulator
from cgol.simulator.derivable_minimal_architecture_simulator import DerivableMinimalArchitectureSimulator
from cgol.dataloader.dataloader_2 import Dataloader2
from cgol.loss.completion_loss import CompletionLoss



In [2]:
seed = 0
width = 20
height = 20
batch_size = 500
dtype = torch.float
preprocess_device = 'cpu'
model_device = 'cuda'

dataloader_simulator = MinimalArchitectureSimulator(device=preprocess_device, dtype=dtype)
train_simulator = DerivableMinimalArchitectureSimulator(device=model_device, dtype=dtype)

generator = UniformDensityGenerator(seed, preprocess_device, dtype)
dataloader = Dataloader2(generator, dataloader_simulator, batch_size, width, height, preprocess_device, model_device, dtype, 0.1)

#model = Conv1Model(device=model_device, dtype=dtype, n_hidden_layers=4, n_channels=2000, 
#                   last_activation=lambda: lambda x: torch.nn.functional.sigmoid(x-0.5)*20)
model = Conv1Model(device=model_device, dtype=dtype, n_hidden_layers=4, n_channels=2000)
model.initialize()
lr = 0.001
betas = (0.9,0.999)
eps = 0.00000001
weight_decay = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
#loss_fn = CompletionLoss()
loss_fn = torch.nn.MSELoss()

config = {
    "dataloader": dataloader.get_config(),
    "model": model.get_config(),
    "optimizer": {
        "type": "Adam",
        "lr": lr,
        "betas": betas,
        "eps": eps,
        "weight_decay": weight_decay
    },
    "loss": {
        "type": "CompletionLoss"
    }
}

config

  init_weight_f(self.conv_start.weight)
  init_weight_f(hidden_conf.weight)
  init_weight_f(self.conv_end.weight)


{'dataloader': {'type': 'Dataloader1',
  'generator': {'type': 'UniformDensityGenerator', 'seed': 0},
  'width': 20,
  'height': 20,
  'batch_size': 500,
  'min_change_threshold': 0.1,
  'max_sequence_age': 150},
 'model': {'type': 'Conv1Model',
  'is_toroidal': False,
  'kernel_size': 5,
  'activation': 'type',
  'last_activation': 'type',
  'n_hidden_layers': 4,
  'n_channels': 2000,
  'n_parameters': 102001,
  'weight_init': 'xavier_uniform',
  'bias_init': 'zeros_'},
 'optimizer': {'type': 'Adam',
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0.0001},
 'loss': {'type': 'CompletionLoss'}}

In [3]:
wandb_log = True

In [4]:
def safe_state(path: str, 
               model: torch.nn.Module, 
               optimizer: torch.nn.Module, 
               dataloader, 
               run_state: dict):
    save_state = {
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'dataloader_state': dataloader.get_state(),
        'run_state': run_state
    }
    torch.save(save_state, path)

def train(
        run_state: dict,
        model: torch.nn.Module, 
        dataloader, 
        optimizer: torch.optim.Optimizer, 
        derivable_simulator: TorchSimulator,
        loss_fn: torch.nn.Module, 
        checkpoint_path: Path, 
        max_steps: int=300000, 
        patience: int=75000) -> dict:
    run_state.update({
        'step': 0,
        'best_loss_step': 0,
        'best_accuracy_micro_step': 0,
        'best_accuracy_macro_step': 0,
        'loss': sys.float_info.max,
        'accuracy_micro': 0,
        'accuracy_macro': 0,
        'best_loss': sys.float_info.max,
        'best_accuracy_micro': 0,
        'best_accuracy_macro': 0
    } | run_state)

    safe_state(checkpoint_path/'initial.ckpt', model, optimizer, dataloader, run_state)
    
    if wandb_log:
        wandb.init('llottenbach', 'cgol', config={
            "config": config
        })

    model.train()
    try:
        for batch in dataloader:
            optimizer.zero_grad()
            output = model.forward(batch[0])
            output = derivable_simulator.step_batch_tensor(output)

            # metrics
            run_state['loss'] = loss_fn(output, batch[0])
            run_state['accuracy_micro'] = (((output >= 0.5) == batch[1]).sum()
                / (output.shape[0] * output.shape[1] * output.shape[2]))
            run_state['accuracy_macro'] = (((output >= 0.5) == batch[1]).all((-1,-2)).sum()
                / (output.shape[0]))
            
            if run_state['best_loss'] > run_state['loss']:
                run_state['best_loss'] = run_state['loss'].clone().detach()
                run_state['best_loss_step'] = run_state['step']
                safe_state(checkpoint_path/'best_loss.chkpt', model, optimizer, dataloader, run_state)
            if run_state['best_accuracy_micro'] < run_state['accuracy_micro']:
                run_state['best_accuracy_micro'] = run_state['accuracy_micro'].detach().clone()
                run_state['best_accuracy_micro_step'] = run_state['step']
                safe_state(checkpoint_path/'best_acc_micro.chkpt', model, optimizer, dataloader, run_state)
            if run_state['best_accuracy_macro'] < run_state['accuracy_macro']:
                run_state['best_accuracy_macro'] = run_state['accuracy_macro'].detach().clone()
                run_state['best_accuracy_macro_step'] = run_state['step']
                safe_state(checkpoint_path/'best_acc_macro.chkpt', model, optimizer, dataloader, run_state)
            
            # optimize
            run_state['loss'].backward()
            optimizer.step()
            
            # log
            print(run_state)
            print(f'step: {run_state["step"]:7d},',
                f'loss: {run_state["loss"]:1.16f},',
                f'accuracy-micro: {run_state["accuracy_micro"]:1.16f},',
                f'accuracy-macro: {run_state["accuracy_macro"]:1.16f}')
            if wandb_log:
                wandb.log({
                    'train/loss': run_state['loss'],
                    'train/accuracy-micro': run_state['accuracy_micro'],
                    'train/accuracy-macro': run_state['accuracy_macro'],
                    # dataloader
                    'dataloader/batch_age/min': dataloader.sequence_ages.min(),
                    'dataloader/batch_age/max': dataloader.sequence_ages.max(),
                    'dataloader/batch_age/median': dataloader.sequence_ages.float().median(),
                    'dataloader/batch_age/mean': dataloader.sequence_ages.float().mean(),
                    'dataloader/batch_age/std': dataloader.sequence_ages.float().std(),
                    'dataloader/batch_diffs_per_cell/min': dataloader.batch_diffs_per_cell.min(),
                    'dataloader/batch_diffs_per_cell/max': dataloader.batch_diffs_per_cell.max(),
                    'dataloader/batch_diffs_per_cell/median': dataloader.batch_diffs_per_cell.median(),
                    'dataloader/batch_diffs_per_cell/mean': dataloader.batch_diffs_per_cell.mean(),
                    'dataloader/batch_diffs_per_cell/std': dataloader.batch_diffs_per_cell.std(),
                },
                run_state['step'])

            # stop condition
            if (max_steps <= run_state['step']
                or run_state['best_loss_step'] + patience <= run_state['step']
                or run_state['best_accuracy_micro_step'] + patience <= run_state['step']
                or run_state['best_accuracy_macro_step'] + patience <= run_state['step']):

                safe_state(checkpoint_path/'stop.chkpt', model, optimizer, dataloader, run_state)
                break
            
            run_state['step'] += 1
    except BaseException as e:
        safe_state(checkpoint_path/'error.chkpt', model, optimizer, dataloader, run_state)
        raise e

    return run_state

In [5]:
if wandb_log:
    wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mlukas-lottenbach[0m ([33mllottenbach-nlp[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
run_checkpoint_path = Path('conv1_train_on_endstate_3')
run_checkpoint_path.mkdir(exist_ok=True)
run_state = {}
train(run_state, model, dataloader, optimizer, train_simulator, loss_fn, run_checkpoint_path)
# https://wandb.ai/llottenbach/cgol/runs/pd9ohjsb
# changed back activation function
# use MSE loss

{'step': 0, 'best_loss_step': 0, 'best_accuracy_micro_step': 0, 'best_accuracy_macro_step': 0, 'loss': tensor(0.2633, device='cuda:0', grad_fn=<MseLossBackward0>), 'accuracy_micro': tensor(0.4724, device='cuda:0'), 'accuracy_macro': tensor(0., device='cuda:0'), 'best_loss': tensor(0.2633, device='cuda:0'), 'best_accuracy_micro': tensor(0.4724, device='cuda:0'), 'best_accuracy_macro': 0}
step:       0, loss: 0.2632910907268524, accuracy-micro: 0.4723699986934662, accuracy-macro: 0.0000000000000000
{'step': 1, 'best_loss_step': 0, 'best_accuracy_micro_step': 1, 'best_accuracy_macro_step': 0, 'loss': tensor(0.2742, device='cuda:0', grad_fn=<MseLossBackward0>), 'accuracy_micro': tensor(0.6029, device='cuda:0'), 'accuracy_macro': tensor(0., device='cuda:0'), 'best_loss': tensor(0.2633, device='cuda:0'), 'best_accuracy_micro': tensor(0.6029, device='cuda:0'), 'best_accuracy_macro': 0}
step:       1, loss: 0.2742233872413635, accuracy-micro: 0.6029300093650818, accuracy-macro: 0.0000000000000

KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x733d16a57010>> (for post_run_cell):


BrokenPipeError: [Errno 32] Broken pipe

In [None]:
wandb.finish()

In [7]:
safe_state(run_checkpoint_path/'finish.chkpt', model, optimizer, dataloader, run_state)

Error in callback <bound method _WandbInit._pre_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x733d16a57010>> (for pre_run_cell):


BrokenPipeError: [Errno 32] Broken pipe

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x733d16a57010>> (for post_run_cell):


BrokenPipeError: [Errno 32] Broken pipe