In [1]:
import sys
import torch
import wandb
from pathlib import Path
sys.path.append('../..')

from solutions.convolution.conv1_model import Conv1Model
from cgol.generator.uniform_density_generator import UniformDensityGenerator
from cgol.simulator.minimal_architecture_simulator import MinimalArchitectureSimulator
from cgol.dataloader.dataloader_2 import Dataloader2
from cgol.loss.completion_loss import CompletionLoss



In [2]:
checkpoint = torch.load('checkpoints_train_grok3/final.chkpt')

In [3]:
seed = 0
width = 20
height = 20
batch_size = 100
dtype = torch.float
preprocess_device = 'cpu'
model_device = 'cuda'

simulator = MinimalArchitectureSimulator(device=preprocess_device, dtype=dtype)
generator = UniformDensityGenerator(seed, preprocess_device, dtype)
generator.rng.set_state(checkpoint['dataloader_state']['generator']['rng'])

dataloader = Dataloader2(generator, simulator, 
                         batch_size, width, height, 
                         preprocess_device, model_device, dtype, 
                         0.1, 150)
dataloader.last_batch = checkpoint['dataloader_state']['last_batch']
dataloader.sequence_ages = checkpoint['dataloader_state']['sequence_ages']
dataloader.step = checkpoint['dataloader_state']['step']

model = Conv1Model(device=model_device, dtype=dtype, n_channels=2000, n_hidden_layers=4)
model.load_state_dict(checkpoint['model_state'])

lr = 0.001
betas = (0.9,0.999)
eps = 0.00000001
weight_decay = 0
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
optimizer.load_state_dict(checkpoint['optimizer_state'])

loss_fn = torch.nn.MSELoss()

config = {
    "dataloader": dataloader.get_config(),
    "model": model.get_config(),
    "optimizer": {
        "type": "Adam",
        "lr": lr,
        "betas": betas,
        "eps": eps,
        "weight_decay": weight_decay
    },
    "loss": {
        "type": "MSELoss"
    }
}

config

{'dataloader': {'type': 'Dataloader1',
  'generator': {'type': 'UniformDensityGenerator', 'seed': 0},
  'width': 20,
  'height': 20,
  'batch_size': 100,
  'min_change_threshold': 0.1,
  'max_sequence_age': 150},
 'model': {'type': 'Conv1Model',
  'is_toroidal': False,
  'kernel_size': 5,
  'activation': 'type',
  'last_activation': 'type',
  'n_hidden_layers': 4,
  'n_channels': 2000,
  'n_parameters': 102001,
  'weight_init': 'xavier_uniform',
  'bias_init': 'zeros_'},
 'optimizer': {'type': 'Adam',
  'lr': 0.001,
  'betas': (0.9, 0.999),
  'eps': 1e-08,
  'weight_decay': 0},
 'loss': {'type': 'MSELoss'}}

In [4]:
wandb_log = True

In [5]:
def safe_state(path: str, 
               model: torch.nn.Module, 
               optimizer: torch.nn.Module, 
               dataloader, 
               run_state: dict):
    save_state = {
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'dataloader_state': dataloader.get_state(),
        'run_state': run_state
    }
    torch.save(save_state, path)

def train(
        run_state: dict,
        model: torch.nn.Module, 
        dataloader, 
        optimizer: torch.optim.Optimizer, 
        loss_fn: torch.nn.Module, 
        checkpoint_path: Path, 
        max_steps: int=300000, 
        patience: int=75000) -> dict:
    run_state.update({
        'step': 0,
        'best_loss_step': 0,
        'best_accuracy_micro_step': 0,
        'best_accuracy_macro_step': 0,
        'loss': sys.float_info.max,
        'accuracy_micro': 0,
        'accuracy_macro': 0,
        'best_loss': sys.float_info.max,
        'best_accuracy_micro': 0,
        'best_accuracy_macro': 0
    } | run_state)
    do_train = True

    safe_state(checkpoint_path/'initial.ckpt', model, optimizer, dataloader, run_state)
    
    if wandb_log:
        wandb.init('llottenbach', 'cgol', resume='allow', resume_from=run_state['step'], id='lh6vi61e')

    model.train()
    try:
        batch = next(dataloader)
        new_batch = True
        
        while do_train:
            logs = {}
            optimizer.zero_grad()
            output = model.forward(batch[0])

            # metrics
            run_state['loss'] = loss_fn(output, batch[1])
            run_state['accuracy_micro'] = (((output >= 0.5) == batch[1]).sum()
                / (output.shape[0] * output.shape[1] * output.shape[2]))
            run_state['accuracy_macro'] = (((output >= 0.5) == batch[1]).all((-1,-2)).sum()
                / (output.shape[0]))
            
            # optimize
            run_state['loss'].backward()
            optimizer.step()
            
            if new_batch:
                # update highscore
                if run_state['best_loss'] > run_state['loss']:
                    run_state['best_loss'] = run_state['loss'].detach().clone()
                    run_state['best_loss_step'] = run_state['step']
                    safe_state(checkpoint_path/'best_loss.chkpt', model, optimizer, dataloader, run_state)
                if run_state['best_accuracy_micro'] < run_state['accuracy_micro']:
                    run_state['best_accuracy_micro'] = run_state['accuracy_micro'].detach().clone()
                    run_state['best_accuracy_micro_step'] = run_state['step']
                    safe_state(checkpoint_path/'best_acc_micro.chkpt', model, optimizer, dataloader, run_state)
                if run_state['best_accuracy_macro'] < run_state['accuracy_macro']:
                    run_state['best_accuracy_macro'] = run_state['accuracy_macro'].detach().clone()
                    run_state['best_accuracy_macro_step'] = run_state['step']
                    safe_state(checkpoint_path/'best_acc_macro.chkpt', model, optimizer, dataloader, run_state)

                # log
                logs = logs | {
                    'train/first_encounter/loss': run_state['loss'],
                    'train/first_encounter/accuracy_micro': run_state['accuracy_micro'],
                    'train/first_encounter/accuracy_macro': run_state['accuracy_macro'],
                    'dataloader/step': dataloader.step,
                    'dataloader/batch_age/min': dataloader.sequence_ages.min(),
                    'dataloader/batch_age/max': dataloader.sequence_ages.max(),
                    'dataloader/batch_age/median': dataloader.sequence_ages.float().median(),
                    'dataloader/batch_age/mean': dataloader.sequence_ages.float().mean(),
                    'dataloader/batch_age/std': dataloader.sequence_ages.float().std(),
                    'dataloader/batch_diffs_per_cell/min': dataloader.batch_diffs_per_cell.min(),
                    'dataloader/batch_diffs_per_cell/max': dataloader.batch_diffs_per_cell.max(),
                    'dataloader/batch_diffs_per_cell/median': dataloader.batch_diffs_per_cell.median(),
                    'dataloader/batch_diffs_per_cell/mean': dataloader.batch_diffs_per_cell.mean(),
                    'dataloader/batch_diffs_per_cell/std': dataloader.batch_diffs_per_cell.std(),
                }
            
            # new batch condition
            new_batch = False
            if run_state['accuracy_macro'] >= .5:
                batch = next(dataloader)
                new_batch = True

            # stop condition
            if (max_steps <= run_state['step']
                or run_state['best_loss_step'] + patience <= run_state['step']
                or run_state['best_accuracy_micro_step'] + patience <= run_state['step']
                or run_state['best_accuracy_macro_step'] + patience <= run_state['step']):
                do_train = False

            # SPLIT UP
            # log
            print(run_state)
            print(f'step: {run_state["step"]:7d},',
                f'loss: {run_state["loss"]:1.16f},',
                f'accuracy-micro: {run_state["accuracy_micro"]:1.16f},',
                f'accuracy-macro: {run_state["accuracy_macro"]:1.16f}')
            logs = logs | {
                'train/loss': run_state['loss'],
                'train/accuracy-micro': run_state['accuracy_micro'],
                'train/accuracy-macro': run_state['accuracy_macro']
            }
            
            if wandb_log:
                wandb.log(logs, run_state['step'])

            if do_train:
                run_state['step'] += 1
            else:
                safe_state(checkpoint_path/'stop.chkpt', model, optimizer, dataloader, run_state)
    except BaseException as e:
        safe_state(checkpoint_path/'error.chkpt', model, optimizer, dataloader, run_state)
        raise e

    return run_state

In [6]:
if wandb_log:
    wandb.login()
# bigger model
# slight weight decay

# new loss function
# 

[34m[1mwandb[0m: Currently logged in as: [33mlukas-lottenbach[0m ([33mllottenbach-nlp[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
run_checkpoint_path = Path('checkpoints_train_grok3_resume')
run_checkpoint_path.mkdir(exist_ok=True)
run_state = checkpoint['run_state']
run_state = train(run_state, model, dataloader, optimizer, loss_fn, run_checkpoint_path)

{'step': 9911, 'best_loss_step': 0, 'best_accuracy_micro_step': 9911, 'best_accuracy_macro_step': 0, 'loss': tensor(0.3163, device='cuda:0', grad_fn=<MseLossBackward0>), 'accuracy_micro': tensor(0.6681, device='cuda:0'), 'accuracy_macro': tensor(0., device='cuda:0'), 'best_loss': tensor(0.2501, device='cuda:0'), 'best_accuracy_micro': tensor(0.6681, device='cuda:0'), 'best_accuracy_macro': 0}
step:    9911, loss: 0.3162914216518402, accuracy-micro: 0.6680999994277954, accuracy-macro: 0.0000000000000000
{'step': 9912, 'best_loss_step': 0, 'best_accuracy_micro_step': 9911, 'best_accuracy_macro_step': 0, 'loss': tensor(0.3160, device='cuda:0', grad_fn=<MseLossBackward0>), 'accuracy_micro': tensor(0.6678, device='cuda:0'), 'accuracy_macro': tensor(0., device='cuda:0'), 'best_loss': tensor(0.2501, device='cuda:0'), 'best_accuracy_micro': tensor(0.6681, device='cuda:0'), 'best_accuracy_macro': 0}
step:    9912, loss: 0.3160276710987091, accuracy-micro: 0.6678000092506409, accuracy-macro: 0.0

KeyboardInterrupt: 

In [9]:
safe_state(run_checkpoint_path/'final.chkpt', model, optimizer, dataloader, run_state)

In [None]:
print(run_state)

{'step': 9911, 'best_loss_step': 0, 'best_accuracy_micro_step': 0, 'best_accuracy_macro_step': 0, 'loss': tensor(0.0030, device='cuda:0', grad_fn=<MseLossBackward0>), 'accuracy_micro': tensor(0.9968, device='cuda:0'), 'accuracy_macro': tensor(0.6100, device='cuda:0'), 'best_loss': tensor(0.2501, device='cuda:0'), 'best_accuracy_micro': tensor(0.5182, device='cuda:0'), 'best_accuracy_macro': 0}


In [None]:
torch.save(model.state_dict(), 'train_grok3_final_backup.chkpt')

In [8]:
wandb.finish()

0,1
dataloader/batch_age/max,▁
dataloader/batch_age/mean,▁
dataloader/batch_age/median,▁
dataloader/batch_age/min,▁
dataloader/batch_age/std,▁
dataloader/batch_diffs_per_cell/max,▁
dataloader/batch_diffs_per_cell/mean,▁
dataloader/batch_diffs_per_cell/median,▁
dataloader/batch_diffs_per_cell/min,▁
dataloader/batch_diffs_per_cell/std,▁

0,1
dataloader/batch_age/max,2.0
dataloader/batch_age/mean,1.66
dataloader/batch_age/median,2.0
dataloader/batch_age/min,1.0
dataloader/batch_age/std,0.4761
dataloader/batch_diffs_per_cell/max,0.9675
dataloader/batch_diffs_per_cell/mean,0.30602
dataloader/batch_diffs_per_cell/median,0.26
dataloader/batch_diffs_per_cell/min,0.1
dataloader/batch_diffs_per_cell/std,0.20512
