In [1]:
import sys
import torch
import wandb
from pathlib import Path
sys.path.append('../..')

from solutions.nca.nca_feedback_model import NCAFeedbackModel
from cgol.generator.uniform_density_generator import UniformDensityGenerator
from cgol.simulator.torch_simulator import TorchSimulator
from cgol.simulator.minimal_architecture_simulator import MinimalArchitectureSimulator
from cgol.simulator.derivable_minimal_architecture_simulator import DerivableMinimalArchitectureSimulator
from cgol.dataloader.dataloader_2 import Dataloader2
from cgol.loss.completion_loss2 import CompletionLoss2



In [2]:
seed = 0
width = 20
height = 20
batch_size = 500
dtype = torch.float
preprocess_device = 'cpu'
model_device = 'cuda'

dataloader_simulator = MinimalArchitectureSimulator(device=preprocess_device, dtype=dtype)
#train_simulator = DerivableMinimalArchitectureSimulator(device=model_device, dtype=dtype)
train_simulator = MinimalArchitectureSimulator(device=model_device, dtype=dtype)

generator = UniformDensityGenerator(seed, preprocess_device, dtype)
dataloader = Dataloader2(generator, dataloader_simulator, batch_size, width, height, preprocess_device, model_device, dtype, 0.1)

#model = Conv1Model(device=model_device, dtype=dtype, n_hidden_layers=4, n_channels=2000, 
#                   last_activation=lambda: lambda x: torch.nn.functional.sigmoid(x-0.5)*20)
model = NCAFeedbackModel(device=model_device, dtype=dtype)
model.initialize()
lr = 0.001
betas = (0.9,0.999)
eps = 0.00000001
#weight_decay = 0
weight_decay = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
loss_fn = torch.nn.MSELoss()
#loss_fn = CompletionLoss2()
n_solving_steps = 60
gradient_clipping = True

config = {
    "dataloader": dataloader.get_config(),
    "model": model.get_config(),
    "optimizer": {
        "type": type(optimizer).__name__,
        "lr": lr,
        "betas": betas,
        "eps": eps,
        "weight_decay": weight_decay
    },
    "loss": {
        "type": type(loss_fn).__name__
    },
    "gradient_clipping": gradient_clipping,
    "n_solving_steps": n_solving_steps
}

config

  init_weight_f(self.conv_start.weight)
  init_weight_f(self.conv_end.weight)


{'dataloader': {'type': 'Dataloader1',
  'generator': {'type': 'UniformDensityGenerator', 'seed': 0},
  'width': 20,
  'height': 20,
  'batch_size': 500,
  'min_change_threshold': 0.1,
  'max_sequence_age': 150},
 'model': {'type': 'NCAFeedbackModel',
  'is_toroidal': False,
  'perception_size': 3,
  'activation': 'type',
  'last_activation': 'type',
  'hidden_size': 100,
  'n_hidden_layers': 0,
  'n_channels': 100,
  'n_parameters': 103201,
  'weight_init': 'xavier_uniform',
  'bias_init': 'zeros_',
  'batch_norm': True},
 'optimizer': {'type': 'Adam',
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0.0001},
 'loss': {'type': 'MSELoss'},
 'gradient_clipping': True,
 'n_solving_steps': 60}

In [3]:
wandb_log = True

In [None]:
def safe_state(path: str, 
               model: NCAFeedbackModel, 
               optimizer: torch.nn.Module, 
               dataloader, 
               run_state: dict):
    save_state = {
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'dataloader_state': dataloader.get_state(),
        'run_state': run_state
    }
    torch.save(save_state, path)

def train(
        run_state: dict,
        model: NCAFeedbackModel, 
        dataloader, 
        optimizer: torch.optim.Optimizer, 
        simulator: TorchSimulator,
        loss_fn: torch.nn.Module, 
        checkpoint_path: Path, 
        max_steps: int=300000, 
        patience: int=75000) -> dict:
    run_state.update({
        'step': 0,
        'best_loss_step': 0,
        'best_accuracy_micro_step': 0,
        'best_accuracy_macro_step': 0,
        'loss': sys.float_info.max,
        'accuracy_micro': 0,
        'accuracy_macro': 0,
        'best_loss': sys.float_info.max,
        'best_accuracy_micro': 0,
        'best_accuracy_macro': 0,
        'best_end_state_accuracy_micro': 0,
        'best_end_state_accuracy_macro': 0
    } | run_state)

    safe_state(checkpoint_path/'initial.ckpt', model, optimizer, dataloader, run_state)
    
    if wandb_log:
        wandb.init('llottenbach', 'cgol', config={
            "config": config
        })

    model.train()
    try:
        with open('nca_feedback_train_log.txt', 'a') as log_file:
            for batch in dataloader:
                optimizer.zero_grad()

                target_end_state = batch[0]
                target_initial_state = batch[1]

                model_input: torch.Tensor = model.init_model_input(target_end_state)
                
                run_state['loss'] = 1.

                for i_solving_step in range(n_solving_steps):

                    model_output = model.forward(model_input)
                    predicted_initial_state: torch.Tensor = model.output_batch_from_model_output(model_output)
                    #predicted_end_state = simulator.step_batch_tensor(predicted_initial_state)
                    predicted_end_state = simulator.step_batch_tensor((predicted_initial_state >= 0.5) * 1.)
                    model_input = model.model_input_from_model_output(model_output, target_end_state, (predicted_end_state != target_end_state)*1.)

                    # metrics
                    #initial_state_loss = loss_fn(predicted_initial_state, target_initial_state)
                    #end_state_loss = loss_fn(predicted_end_state, target_end_state)
                    run_state['loss'] = run_state['loss'] + loss_fn(predicted_initial_state, target_initial_state)
                    
                    run_state['initial_state_accuracy_micro'] = (((predicted_initial_state >= 0.5) == target_initial_state).sum()
                        / (target_initial_state.shape[0] * target_initial_state.shape[1] * target_initial_state.shape[2]))
                    run_state['initial_state_accuracy_macro'] = (((predicted_initial_state >= 0.5) == target_initial_state).all((-1,-2)).sum()
                        / (target_initial_state.shape[0]))
                    
                    run_state['end_state_accuracy_micro'] = (((predicted_end_state >= 0.5) == target_end_state).sum()
                        / (target_end_state.shape[0] * target_end_state.shape[1] * target_end_state.shape[2]))
                    run_state['end_state_accuracy_macro'] = (((predicted_end_state >= 0.5) == target_end_state).all((-1,-2)).sum()
                        / (target_end_state.shape[0]))

                    run_state['accuracy_micro'] = run_state['initial_state_accuracy_micro']
                    run_state['accuracy_macro'] = run_state['initial_state_accuracy_macro']
                    
                    if run_state['best_loss'] > run_state['loss'] / (i_solving_step + 1):
                        run_state['best_loss'] = run_state['loss'].clone().detach() / (i_solving_step + 1)
                        run_state['best_loss_step'] = run_state['step']
                        safe_state(checkpoint_path/'best_loss.chkpt', model, optimizer, dataloader, run_state)
                    if run_state['best_accuracy_micro'] < run_state['accuracy_micro']:
                        run_state['best_accuracy_micro'] = run_state['accuracy_micro'].detach().clone()
                        run_state['best_accuracy_micro_step'] = run_state['step']
                        safe_state(checkpoint_path/'best_acc_micro.chkpt', model, optimizer, dataloader, run_state)
                    if run_state['best_accuracy_macro'] < run_state['accuracy_macro']:
                        run_state['best_accuracy_macro'] = run_state['accuracy_macro'].detach().clone()
                        run_state['best_accuracy_macro_step'] = run_state['step']
                        safe_state(checkpoint_path/'best_acc_macro.chkpt', model, optimizer, dataloader, run_state)
                    if run_state['best_end_state_accuracy_micro'] < run_state['end_state_accuracy_micro']:
                        run_state['best_end_state_accuracy_micro'] = run_state['end_state_accuracy_micro'].detach().clone()
                        run_state['best_end_state_accuracy_micro_step'] = run_state['step']
                        safe_state(checkpoint_path/'best_end_acc_micro.chkpt', model, optimizer, dataloader, run_state)
                    if run_state['best_end_state_accuracy_macro'] < run_state['end_state_accuracy_macro']:
                        run_state['best_end_state_accuracy_macro'] = run_state['end_state_accuracy_macro'].detach().clone()
                        run_state['best_end_state_accuracy_macro_step'] = run_state['step']
                        safe_state(checkpoint_path/'best_end_acc_macro.chkpt', model, optimizer, dataloader, run_state)
                                    
                    # log
                    print(run_state, file=log_file)
                    print(f'step: {run_state["step"]:7d},',
                        f'loss: {run_state["loss"]:1.16f},',
                        f'accuracy-micro: {run_state["accuracy_micro"]:1.16f},',
                        f'accuracy-macro: {run_state["accuracy_macro"]:1.16f}', file=log_file)
                    if wandb_log:
                        wandb.log({
                            'train/loss': run_state['loss'] / (i_solving_step + 1),
                            'train/accuracy-micro': run_state['accuracy_micro'],
                            'train/accuracy-macro': run_state['accuracy_macro'],
                            'train/initial-state/accuracy-micro': run_state['initial_state_accuracy_micro'],
                            'train/initial-state/accuracy-macro': run_state['initial_state_accuracy_macro'],
                            'train/end-state/accuracy-micro': run_state['end_state_accuracy_micro'],
                            'train/end-state/accuracy-macro': run_state['end_state_accuracy_macro'],

                            # dataloader
                            'dataloader/batch_age/min': dataloader.sequence_ages.min(),
                            'dataloader/batch_age/max': dataloader.sequence_ages.max(),
                            'dataloader/batch_age/median': dataloader.sequence_ages.float().median(),
                            'dataloader/batch_age/mean': dataloader.sequence_ages.float().mean(),
                            'dataloader/batch_age/std': dataloader.sequence_ages.float().std(),
                            'dataloader/batch_diffs_per_cell/min': dataloader.batch_diffs_per_cell.min(),
                            'dataloader/batch_diffs_per_cell/max': dataloader.batch_diffs_per_cell.max(),
                            'dataloader/batch_diffs_per_cell/median': dataloader.batch_diffs_per_cell.median(),
                            'dataloader/batch_diffs_per_cell/mean': dataloader.batch_diffs_per_cell.mean(),
                            'dataloader/batch_diffs_per_cell/std': dataloader.batch_diffs_per_cell.std(),
                        },
                        run_state['step'])

                    # stop condition
                    if (max_steps <= run_state['step']
                        or run_state['best_loss_step'] + patience <= run_state['step']
                        or run_state['best_accuracy_micro_step'] + patience <= run_state['step']
                        or run_state['best_accuracy_macro_step'] + patience <= run_state['step']):

                        safe_state(checkpoint_path/'stop.chkpt', model, optimizer, dataloader, run_state)
                        raise KeyboardInterrupt('Stop Condition reached')
                    
                    run_state['step'] += 1
                    
                # optimize
                run_state['loss'] = run_state['loss'] / n_solving_steps
                run_state['loss'].backward()
                if gradient_clipping:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
    except BaseException as e:
        safe_state(checkpoint_path/'error.chkpt', model, optimizer, dataloader, run_state)
        wandb.finish()
        raise e

    return run_state

In [5]:
if wandb_log:
    wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mlukas-lottenbach[0m ([33mllottenbach-nlp[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
run_checkpoint_path = Path('nca_feedback_train')
run_checkpoint_path.mkdir(exist_ok=True)
run_state = {}
train(run_state, model, dataloader, optimizer, train_simulator, loss_fn, run_checkpoint_path,
      max_steps=300000*n_solving_steps, patience=75000*n_solving_steps)
# Train with feedback
# https://wandb.ai/llottenbach/cgol/runs/3oxz77pn

0,1
dataloader/batch_age/max,▁▂█▁▁▂▂▂▂▂▂▁▂▂▁▃▂▄▂▂▂▃▂▁▂▂▂▃▂▂▃▃▁▃▃▂▃▂▁▄
dataloader/batch_age/mean,▆▄▅▅▅▄▄▂▆▂▄▄▅▂▃▆▄▃▅▇▁▆▅▃▄▅▇▄█▇▅▂▅▄▆▂▁▅▅▆
dataloader/batch_age/median,▃▃▆▃█▆▆▃▁▆▃▆▆▃▆▆▃▆▆▆▃▃▆▆▆▃▃▃▆▆▆▆▆▃▆▃▃▃▃▃
dataloader/batch_age/min,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dataloader/batch_age/std,▂▃▅▆▆▃▁▃▄▃▂▄▅▇▃▅▄▅█▅▆▄▁▄▄▃▇▄▄▆▁▅▂▆▃▂▄█▂▆
dataloader/batch_diffs_per_cell/max,█▇█▇█▇█▆███▇▇▇▇▇█▇▇█▇██▇▇█▆▃█▁█▄▄█▇█▇█▇▇
dataloader/batch_diffs_per_cell/mean,█▄▆▂▆▅▅▅▃▆▇▄▅█▁▅▅▄▅▅▃▄▄▄▅▇▅▃▄▅▄▆▅▅▆▅▂▆▆▇
dataloader/batch_diffs_per_cell/median,▃▄▅▂▂▇▄▂▄█▃▂▂▆▅▂▂▄▅▃▄▃▆▃▆▄▄▃▅▃▅▄▁▂▃▆▄▄▄▃
dataloader/batch_diffs_per_cell/min,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dataloader/batch_diffs_per_cell/std,▇▆▇▃▇▆▃▆▅▅▃▄▂▆▄▅▂▄▄▄▄▅▂▄▂▁▁▅▁▁▅█▁▄▂▅▃▆▃▁

0,1
dataloader/batch_age/max,68.0
dataloader/batch_age/mean,10.544
dataloader/batch_age/median,8.0
dataloader/batch_age/min,1.0
dataloader/batch_age/std,9.42517
dataloader/batch_diffs_per_cell/max,0.965
dataloader/batch_diffs_per_cell/mean,0.21707
dataloader/batch_diffs_per_cell/median,0.175
dataloader/batch_diffs_per_cell/min,0.1
dataloader/batch_diffs_per_cell/std,0.14751


KeyboardInterrupt: Stop Condition reached

In [None]:
wandb.finish()

In [None]:
safe_state(run_checkpoint_path/'finish.chkpt', model, optimizer, dataloader, run_state)