In [None]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import pickle as pkl
import wandb
from pathlib import Path
sys.path.append('../..')

from solutions.convolution.conv1_model import Conv1Model
from cgol.generator.uniform_density_generator import UniformDensityGenerator
from cgol.simulator.minimal_architecture_simulator import MinimalArchitectureSimulator
from cgol.dataloader.dataloader_1 import Dataloader1

In [None]:
seed = 0
width = 20
height = 20
batch_size = 10000
dtype = torch.float
preprocess_device = 'cpu'
model_device = 'cuda'

simulator = MinimalArchitectureSimulator(device=preprocess_device, dtype=dtype)
generator = UniformDensityGenerator(seed, preprocess_device, dtype)
dataloader = Dataloader1(generator, simulator, batch_size, width, height, preprocess_device, model_device, dtype, 0.1)

model = Conv1Model(device=model_device, dtype=dtype)
model.initialize()
lr = 0.001
betas = (0.9,0.999)
eps = 0.00000001
weight_decay = 0
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
loss_fn = torch.nn.MSELoss()

config = {
    "dataloader": dataloader.get_config(),
    "model": model.get_config(),
    "optimizer": {
        "type": "Adam",
        "lr": lr,
        "betas": betas,
        "eps": eps,
        "weight_decay": weight_decay
    },
    "loss": {
        "type": "MSELoss"
    }
}

config

In [None]:
wandb_log = True

In [None]:
def safe_state(path: str, 
               model: torch.nn.Module, 
               optimizer: torch.nn.Module, 
               dataloader, 
               run_state: dict):
    save_state = {
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'dataloader_state': dataloader.get_state(),
        'run_state': run_state
    }
    torch.save(save_state, path)

def train(
        model: torch.nn.Module, 
        dataloader, 
        optimizer: torch.optim.Optimizer, 
        loss_fn: torch.nn.Module, 
        checkpoint_path: Path, 
        max_steps: int=300000, 
        patience: int=75000,
        run_state: dict = {}) -> dict:
    run_state = {
        'step': 0,
        'best_loss_step': 0,
        'best_accuracy_micro_step': 0,
        'best_accuracy_macro_step': 0,
        'loss': sys.float_info.max,
        'accuracy_micro': 0,
        'accuracy_macro': 0,
        'best_loss': sys.float_info.max,
        'best_accuracy_micro': 0,
        'best_accuracy_macro': 0
    } | run_state

    safe_state(checkpoint_path/'initial.ckpt', model, optimizer, dataloader, run_state)
    
    if wandb_log:
        wandb.init('llottenbach', 'cgol', config={
            "config": config
        })

    model.train()
    try:
        for batch in dataloader:
            optimizer.zero_grad()
            output = model.forward(batch[0])

            # metrics
            run_state['loss'] = loss_fn(output, batch[1])
            run_state['accuracy_micro'] = (((output >= 0.5) == batch[1]).sum()
                / (output.shape[0] * output.shape[1] * output.shape[2]))
            run_state['accuracy_macro'] = (((output >= 0.5) == batch[1]).all((-1,-2)).sum()
                / (output.shape[0]))
            
            if run_state['best_loss'] > run_state['loss']:
                run_state['best_loss'] = run_state['loss']
                run_state['best_loss_step'] = run_state['step']
                safe_state(checkpoint_path/'best_loss.chkpt', model, optimizer, dataloader, run_state)
            if run_state['best_accuracy_micro'] < run_state['accuracy_micro']:
                run_state['best_accuracy_micro'] = run_state['accuracy_micro']
                run_state['best_accuracy_micro_step'] = run_state['step']
                safe_state(checkpoint_path/'best_acc_micro.chkpt', model, optimizer, dataloader, run_state)
            if run_state['best_accuracy_macro'] < run_state['accuracy_macro']:
                run_state['best_accuracy_macro'] = run_state['accuracy_macro']
                run_state['best_accuracy_macro_step'] = run_state['step']
                safe_state(checkpoint_path/'best_acc_macro.chkpt', model, optimizer, dataloader, run_state)
            
            # optimize
            run_state['loss'].backward()
            optimizer.step()
            
            # log
            print(f'step: {run_state["step"]:7d},',
                f'loss: {run_state["loss"]:1.16f},',
                f'accuracy-micro: {run_state["accuracy_micro"]:1.16f},',
                f'accuracy-macro: {run_state["accuracy_macro"]:1.16f}')
            if wandb_log:
                wandb.log({
                    'train/loss': run_state['loss'],
                    'train/accuracy-micro': run_state['accuracy_micro'],
                    'train/accuracy-macro': run_state['accuracy_macro'],
                    # dataloader
                    'dataloader/batch_age/min': dataloader.batch_age.min(),
                    'dataloader/batch_age/max': dataloader.batch_age.max(),
                    'dataloader/batch_age/median': dataloader.batch_age.float().median(),
                    'dataloader/batch_age/mean': dataloader.batch_age.float().mean(),
                    'dataloader/batch_age/std': dataloader.batch_age.float().std(),
                    'dataloader/batch_diffs_per_cell/min': dataloader.batch_diffs_per_cell.min(),
                    'dataloader/batch_diffs_per_cell/max': dataloader.batch_diffs_per_cell.max(),
                    'dataloader/batch_diffs_per_cell/median': dataloader.batch_diffs_per_cell.median(),
                    'dataloader/batch_diffs_per_cell/mean': dataloader.batch_diffs_per_cell.mean(),
                    'dataloader/batch_diffs_per_cell/std': dataloader.batch_diffs_per_cell.std(),
                },
                run_state['step'])

            # stop condition
            if (max_steps <= run_state['step']
                or run_state['best_loss_step'] + patience <= run_state['step']
                or run_state['best_accuracy_micro_step'] + patience <= run_state['step']
                or run_state['best_accuracy_macro_step'] + patience <= run_state['step']):

                safe_state(checkpoint_path/'stop.chkpt', model, optimizer, dataloader, run_state)
                break
            
            run_state['step'] += 1
    except Exception as e:
        safe_state(checkpoint_path/'error.chkpt', model, optimizer, dataloader, run_state)
        raise e

    return run_state

In [None]:
if wandb_log:
    wandb_ent = 'llottenbach'
    wandb_proj = 'cgol'
    wandb.login()

In [None]:
run_checkpoint_path = Path('checkpoints')
run_checkpoint_path.mkdir(exist_ok=True)
run_state = train(model, dataloader, optimizer, loss_fn, run_checkpoint_path)