In [1]:
import sys
import torch
import wandb
from pathlib import Path
sys.path.append('../..')

from solutions.convolution.conv1_model import Conv1Model
from cgol.generator.uniform_density_generator import UniformDensityGenerator
from cgol.simulator.minimal_architecture_simulator import MinimalArchitectureSimulator
from cgol.dataloader.dataloader_2 import Dataloader2



In [2]:
seed = 0
width = 20
height = 20
batch_size = 1000
dtype = torch.float
preprocess_device = 'cpu'
model_device = 'cuda'

simulator = MinimalArchitectureSimulator(device=preprocess_device, dtype=dtype)
generator = UniformDensityGenerator(seed, preprocess_device, dtype)
dataloader = Dataloader2(generator, simulator, 
                         batch_size, width, height, 
                         preprocess_device, model_device, dtype, 
                         0.1, 150)

model = Conv1Model(device=model_device, dtype=dtype, n_channels=1000, n_hidden_layers=4)
model.initialize()
lr = 0.001
betas = (0.9,0.999)
eps = 0.00000001
weight_decay = 0
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
loss_fn = torch.nn.MSELoss()

config = {
    "dataloader": dataloader.get_config(),
    "model": model.get_config(),
    "optimizer": {
        "type": "Adam",
        "lr": lr,
        "betas": betas,
        "eps": eps,
        "weight_decay": weight_decay
    },
    "loss": {
        "type": "MSELoss"
    }
}

config

  init_weight_f(self.conv_start.weight)
  init_weight_f(hidden_conf.weight)
  init_weight_f(self.conv_end.weight)


{'dataloader': {'type': 'Dataloader1',
  'generator': {'type': 'UniformDensityGenerator', 'seed': 0},
  'width': 20,
  'height': 20,
  'batch_size': 1000,
  'min_change_threshold': 0.1,
  'max_sequence_age': 150},
 'model': {'type': 'Conv1Model',
  'is_toroidal': False,
  'kernel_size': 5,
  'activation': 'type',
  'last_activation': 'type',
  'n_hidden_layers': 4,
  'n_channels': 1000,
  'n_parameters': 51001,
  'weight_init': 'xavier_uniform',
  'bias_init': 'zeros_'},
 'optimizer': {'type': 'Adam',
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0},
 'loss': {'type': 'MSELoss'}}

In [3]:
wandb_log = True

In [8]:
def safe_state(path: str, 
               model: torch.nn.Module, 
               optimizer: torch.nn.Module, 
               dataloader, 
               run_state: dict):
    save_state = {
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'dataloader_state': dataloader.get_state(),
        'run_state': run_state
    }
    torch.save(save_state, path)

def train(
        model: torch.nn.Module, 
        dataloader, 
        optimizer: torch.optim.Optimizer, 
        loss_fn: torch.nn.Module, 
        checkpoint_path: Path, 
        max_steps: int=300000, 
        patience: int=75000,
        run_state: dict = {},
        resume: bool = False) -> dict:
    run_state = {
        'step': 0,
        'best_loss_step': 0,
        'best_accuracy_micro_step': 0,
        'best_accuracy_macro_step': 0,
        'loss': sys.float_info.max,
        'accuracy_micro': 0,
        'accuracy_macro': 0,
        'best_loss': sys.float_info.max,
        'best_accuracy_micro': 0,
        'best_accuracy_macro': 0
    } | run_state
    do_train = True

    safe_state(checkpoint_path/'initial.ckpt', model, optimizer, dataloader, run_state)
    
    if wandb_log and not resume:
        wandb.init('llottenbach', 'cgol', config={
            "config": config
        })

    model.train()
    try:
        batch = next(dataloader)
        new_batch = True
        
        while do_train:
            optimizer.zero_grad()
            output = model.forward(batch[0])

            # metrics
            run_state['loss'] = loss_fn(output, batch[1])
            run_state['accuracy_micro'] = (((output >= 0.5) == batch[1]).sum()
                / (output.shape[0] * output.shape[1] * output.shape[2]))
            run_state['accuracy_macro'] = (((output >= 0.5) == batch[1]).all((-1,-2)).sum()
                / (output.shape[0]))
            
            # optimize
            run_state['loss'].backward()
            optimizer.step()
            
            if new_batch:
                # update highscore
                if run_state['best_loss'] > run_state['loss']:
                    run_state['best_loss'] = run_state['loss']
                    run_state['best_loss_step'] = run_state['step']
                    safe_state(checkpoint_path/'best_loss.chkpt', model, optimizer, dataloader, run_state)
                if run_state['best_accuracy_micro'] < run_state['accuracy_micro']:
                    run_state['best_accuracy_micro'] = run_state['accuracy_micro']
                    run_state['best_accuracy_micro_step'] = run_state['step']
                    safe_state(checkpoint_path/'best_acc_micro.chkpt', model, optimizer, dataloader, run_state)
                if run_state['best_accuracy_macro'] < run_state['accuracy_macro']:
                    run_state['best_accuracy_macro'] = run_state['accuracy_macro']
                    run_state['best_accuracy_macro_step'] = run_state['step']
                    safe_state(checkpoint_path/'best_acc_macro.chkpt', model, optimizer, dataloader, run_state)

                # log
                if wandb_log:
                    wandb.log({
                        'train/first_encounter/loss': run_state['loss'],
                        'train/first_encounter/accuracy_micro': run_state['accuracy_micro'],
                        'train/first_encounter/accuracy_macro': run_state['accuracy_macro'],
                        'dataloader/step': dataloader.step,
                        'dataloader/batch_age/min': dataloader.sequence_ages.min(),
                        'dataloader/batch_age/max': dataloader.sequence_ages.max(),
                        'dataloader/batch_age/median': dataloader.sequence_ages.float().median(),
                        'dataloader/batch_age/mean': dataloader.sequence_ages.float().mean(),
                        'dataloader/batch_age/std': dataloader.sequence_ages.float().std(),
                        'dataloader/batch_diffs_per_cell/min': dataloader.batch_diffs_per_cell.min(),
                        'dataloader/batch_diffs_per_cell/max': dataloader.batch_diffs_per_cell.max(),
                        'dataloader/batch_diffs_per_cell/median': dataloader.batch_diffs_per_cell.median(),
                        'dataloader/batch_diffs_per_cell/mean': dataloader.batch_diffs_per_cell.mean(),
                        'dataloader/batch_diffs_per_cell/std': dataloader.batch_diffs_per_cell.std(),
                    },
                    run_state['step'])
            
            # new batch condition
            new_batch = False
            if run_state['accuracy_micro'] >= 1.:
                batch = next(dataloader)
                new_batch = True

            # stop condition
            if (max_steps <= run_state['step']
                or run_state['best_loss_step'] + patience <= run_state['step']
                or run_state['best_accuracy_micro_step'] + patience <= run_state['step']
                or run_state['best_accuracy_macro_step'] + patience <= run_state['step']):
                do_train = False

            # SPLIT UP
            # log
            print(f'step: {run_state["step"]:7d},',
                f'loss: {run_state["loss"]:1.16f},',
                f'accuracy-micro: {run_state["accuracy_micro"]:1.16f},',
                f'accuracy-macro: {run_state["accuracy_macro"]:1.16f}')
            if wandb_log:
                wandb.log({
                    'train/loss': run_state['loss'],
                    'train/accuracy-micro': run_state['accuracy_micro'],
                    'train/accuracy-macro': run_state['accuracy_macro']
                },
                run_state['step'])

            if do_train:
                run_state['step'] += 1
            else:
                safe_state(checkpoint_path/'stop.chkpt', model, optimizer, dataloader, run_state)
    except Exception as e:
        safe_state(checkpoint_path/'error.chkpt', model, optimizer, dataloader, run_state)
        raise e

    return run_state

Error in callback <bound method _WandbInit._pre_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x76ad7c2dc7d0>> (for pre_run_cell):


BrokenPipeError: [Errno 32] Broken pipe

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x76ad7c2dc7d0>> (for post_run_cell):


BrokenPipeError: [Errno 32] Broken pipe

In [5]:
if wandb_log:
    wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mlukas-lottenbach[0m ([33mllottenbach-nlp[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
run_checkpoint_path = Path('checkpoints_grok2')
run_checkpoint_path.mkdir(exist_ok=True)
run_state = train(model, dataloader, optimizer, loss_fn, run_checkpoint_path)

step:       0, loss: 0.2499116957187653, accuracy-micro: 0.4718225002288818, accuracy-macro: 0.0000000000000000
step:       1, loss: 0.2487521022558212, accuracy-micro: 0.4506449997425079, accuracy-macro: 0.0000000000000000
step:       2, loss: 0.2473231554031372, accuracy-micro: 0.4503849744796753, accuracy-macro: 0.0000000000000000
step:       3, loss: 0.2462428659200668, accuracy-micro: 0.4502199888229370, accuracy-macro: 0.0000000000000000
step:       4, loss: 0.2463601231575012, accuracy-micro: 0.4510324895381927, accuracy-macro: 0.0000000000000000
step:       5, loss: 0.2465206384658813, accuracy-micro: 0.4529999792575836, accuracy-macro: 0.0000000000000000
step:       6, loss: 0.2459297925233841, accuracy-micro: 0.4646099805831909, accuracy-macro: 0.0000000000000000
step:       7, loss: 0.2453355789184570, accuracy-micro: 0.4957149922847748, accuracy-macro: 0.0000000000000000
step:       8, loss: 0.2450833171606064, accuracy-micro: 0.5320425033569336, accuracy-macro: 0.000000000

KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x76ad7c2dc7d0>> (for post_run_cell):


BrokenPipeError: [Errno 32] Broken pipe

In [17]:
wandb.finish()

Error in callback <bound method _WandbInit._pre_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x76ad7c2dc7d0>> (for pre_run_cell):


BrokenPipeError: [Errno 32] Broken pipe

BrokenPipeError: [Errno 32] Broken pipe

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x76ad7c2dc7d0>> (for post_run_cell):


BrokenPipeError: [Errno 32] Broken pipe