In [None]:
%env TOKENIZERS_PARALLELISM=false

In [None]:
from freegroup.tools import from_string, flatten, normalize

fdim = 4
max_length = 400

representatives = {
    3: '[[x, y], [x, yz]]',
    4: '[[[x, y], [x, yz]], [[x, y], [x, yzp]]]'
}

singles = [
    int(''.join(['0' if i == idx else '1' for i in range(fdim + 1)]), 2)
    for idx in range(fdim + 1)
]

total_steps = int(5e5)
sched_steps = 200

config = {
    'freegroup_dimension': fdim,
    'seed': 42,
    'notebook_name': 'train.ipynb',
    'device': 'cuda:1',
    'project': 'homotopy-groups',
    
    'method': 'mask',
    
    'tokenizer': {
        'pretrained_model_name_or_path': f'tokenizer/word-level-tokenizer-{fdim}',
    },
    
    'model': {
        'model_type': 'gpt2',
        'n_positions': 1024,
        'n_embd': 10 * 12,
        'n_layer': 12,
        'n_head': 10,
    },
    
    'train': {
        'batch_size': 16,
        
        'steps': total_steps,
        'log_steps': 400,
        'save_steps': 25000,
        
        'optimizer': {
            'name': 'AdamW',
            'args': {
                'lr': 1e-5,
            },
        },
        
        'scheduler': {
            'name': 'Linear',
            'args': {
                'start_factor': 1.,
                'end_factor': 0.5,
                'total_iters': total_steps // sched_steps,
            },
            'steps': sched_steps,
        },
        
        'dataset': {
            'dist': {
                'name': 'incomplete-intersection',
                'args': {
                    'freegroup_dimension': fdim,
                    'zero_closure_parameters': {
                        'method': 'brackets',
                        'depth_method': 'u',
                        'depth_parameters': {'radius': 10},
                    },
                    'non_zero_closure_parameters': {
                        'method': 'brackets',
                        'depth_method': 'u',
                        'depth_parameters': {'radius': 30},
                    },
                    'max_freegroups': 1,
                    'freegroup_parameters': {
                        'length_method': 'u',
                        'length_parameters': {'radius': 10},
                    },
                    'total_max_length': max_length,
                    'probas': {k: 1 for k in singles},
                },
            },
            'type': {
                'name': 'refillable',
                'args': {
                    'size': 64 * 100,
                    'max_calls': 3,
                },
            },
        },
    },
    
    'eval': {
        'steps': 2000,
        'batch_size': 16,
        'dataset': {
            'validation': {
                'dist': {
                    'name': 'incomplete-intersection',
                    'args': {
                        'freegroup_dimension': fdim,
                        'zero_closure_parameters': {
                            'method': 'brackets',
                            'depth_method': 'u',
                            'depth_parameters': {'radius': 10},
                        },
                        'non_zero_closure_parameters': {
                            'method': 'brackets',
                            'depth_method': 'u',
                            'depth_parameters': {'radius': 30},
                        },
                        'max_freegroups': 1,
                        'freegroup_parameters': {
                            'length_method': 'u',
                            'length_parameters': {'radius': 10},
                        },
                        'total_max_length': max_length,
                        'probas': {k: 1 for k in singles},
                    },
                },
                'type': {
                    'name': 'fixed',
                    'args': {'size': 64 * 10},
                },
            },
            'trivial': {
                'dist': {
                    'name': 'complete-intersection',
                    'args': {
                        'freegroup_dimension': fdim,
                        'zero_closure_parameters': {
                            'method': 'brackets',
                            'depth_method': 'u',
                            'depth_parameters': {'radius': 8},
                        },
                        'non_zero_closure_parameters': {
                            'method': 'brackets',
                            'depth_method': 'u',
                            'depth_parameters': {'radius': 15},
                        },
                        'max_freegroups': 1,
                        'freegroup_parameters': {
                            'length_method': 'u',
                            'length_parameters': {'radius': 10},
                        },
                        'total_max_length': max_length,
                        'max_multipliers': 1,
                    },
                },
                'type': {
                    'name': 'fixed',
                    'args': {'size': 64 * 10},
                },
            },
            'non-trivial': {
                'dist': {
                    'name': 'generator-permutations',
                    'args': {
                        'freegroup_dimension': fdim, 
                        'word': normalize(flatten(from_string(representatives[fdim], method = 'lu')))
                    },
                },
                'type': {
                    'name': 'fixed',
                    'args': {'size': fdim},
                },
            },
        },
    },
    
    'gen': {
        'steps': 2500,
        'batch_size': 16,
        'dataset': {
            'prefix_5': {
                'dist': {
                    'name': 'freegroup',
                    'args': {
                        'freegroup_dimension': fdim,
                        'length_method': 'c',
                        'length_parameters': {'radius': 5},
                    },
                },
                'type': {
                    'name': 'refillable',
                    'args': {
                        'size': 64 * 5,
                        'max_calls': 1,
                    },
                },
            },
            'prefix_7': {
                'dist': {
                    'name': 'freegroup',
                    'args': {
                        'freegroup_dimension': fdim,
                         'length_method': 'c',
                        'length_parameters': {'radius': 7},
                    },
                },
                'type': {
                    'name': 'refillable',
                    'args': {
                        'size': 64 * 5,
                        'max_calls': 1,
                    },
                },
            },
            'prefix_10': {
                'dist': {
                    'name': 'freegroup',
                    'args': {
                        'freegroup_dimension': fdim,
                        'length_method': 'c',
                        'length_parameters': {'radius': 10},
                    },
                },
                'type': {
                    'name': 'refillable',
                    'args': {
                        'size': 64 * 5,
                        'max_calls': 1,
                    },
                },
            },
        },
        'methods': {
            'beam': {
                'num_beams': 5,
                'num_return_sequences': 5,
                'max_length': max_length,
                'repetition_penalty': 1.2,
            },
            'sample': {
                'do_sample': True,
                'num_return_sequences': 5,
                'max_length': max_length,
                'top_p': 0.9,
            },
        },
    },
}

# DATASET

In [109]:
from freegroup.tools import (
    is_from_singleton_normal_closure, wu_closure,
    flatten, normalize, Mult
)
from freegroup.sampling import (
    normal_closure, freegroup,
    random_tree
)
from iteration_utilities import repeatfunc, unique_everseen
from itertools import islice
from tqdm.notebook import tqdm
from copy import deepcopy
from numpy.random import choice, shuffle, randint, geometric, binomial

from contextlib import contextmanager
from multiprocess import Pool
from torch.utils.data import Dataset
from copy import deepcopy

def compute_multi_label(word, freegroup_dimension):
    return [is_from_singleton_normal_closure(word, wu_closure(freegroup_dimension, idx)) for idx in range(0, freegroup_dimension + 1)]


def remove_gens_in_a_row(word, gen, mode = 'crop', mode_params = {'max': 2}):
    cur, res = 0, []
    
    idx = 0
    while idx < len(word):
        cur, cur_gen = 0, None
        while idx < len(word) and word[idx] in [gen, -gen]:
            cur_gen = word[idx]
            cur, idx = cur + 1, idx + 1
            
        if cur > 0:
            if mode == 'crop':
                res.extend([cur_gen] * min(cur, mode_params['max']))
            elif mode in ['geom', 'geometric']:
                res.extend([cur_gen] * min(cur, random.geometric(**mode_params, size = None)))
            elif mode in ['binomial', 'bin']:
                res.extend([cur_gen] * max(1, random.binomial(**mode_params, n = cur, size = None)))
        if idx < len(word): res.append(word[idx])
        idx += 1
        
    return normalize(res)

def sample_dataset(
    dist, size, dist_kwargs = {},
    unique = True, unique_key = None,
    progress = True, tqdm_kwargs = {}
):
    iterator = dist(**dist_kwargs)
    
    key = lambda x: tuple(x['word']) if unique_key is None else unique_key
    if unique: iterator = unique_everseen(iterator, key = key)
        
    iterator = islice(iterator, size)
    
    if progress: iterator = tqdm(iterator, total = size, **tqdm_kwargs)
    
    return list(iterator)

class RefillableDataset(Dataset):
    def __init__(self, dist, dist_kwargs, size, max_calls):
        self.size = size
        self.calls, self.max_calls = 0, max_calls
        
        self.pool = Pool(1)
        
        self.next_batch_fn = lambda: self.pool.apply_async(
            sample_dataset,
            kwds = {
                'dist': dist, 'dist_kwargs': dist_kwargs,
                'size': size, 'progress': False,
            })
        
        self.curr_batch = None
        self.next_batch = self.next_batch_fn()
        
    def __len__(self): return self.max_calls * self.size
    
    def __getitem__(self, idx):
        self.calls += 1
        if self.curr_batch is None or self.calls >= self.max_calls * len(self.curr_batch):
            self.curr_batch = self.next_batch.get()
            self.next_batch = self.next_batch_fn()
            self.calls = 0
            
        return self.curr_batch[idx % self.size]
    
    def stop(self): self.pool.terminate()
        
@contextmanager
def refillable_dataset(dist, dist_kwargs, size, max_calls):
    dataset = RefillableDataset(dataset_fn, preprocess_fn, dataset_config, size, max_calls)
    yield dataset
    dataset.stop()
    
def incomplete_intersection_dist(
    freegroup_dimension = 3,
    zero_closure_parameters = {'method': 'brackets', 'depth_method': 'c', 'depth_parameters': {'radius': 5}},
    non_zero_closure_parameters = {'method': 'brackets', 'depth_method': 'c', 'depth_parameters': {'radius': 30}},
    max_freegroups = 1,
    freegroup_parameters = {'length_method': 'c', 'length_parameters': {'radius': 5}},
    max_multipliers = 1,
    probas = {k: 1 for k in range(1, 2 ** 4 - 1)},
    total_max_length = 200,
):
    
    probas = deepcopy(probas)
    for k in range(1, 2 ** freegroup_dimension - 1):
        if not k in probas: probas[k] = 0.
    _sum = sum(probas.values())
    for k, v in probas.items():
        probas[k] /= _sum
    
    multi_labels = list(probas.keys())
    probas = list(probas.values())
    
    def _init():
        leaves = []
        multi_label = choice(multi_labels, p = probas)
        for idx in range(freegroup_dimension + 1):
            if multi_label & (1 << idx) > 0:
                params = deepcopy(zero_closure_parameters if idx == 0 else non_zero_closure_parameters)
                leaves.append(normal_closure(
                    freegroup_dimension = freegroup_dimension,
                    closure = wu_closure(freegroup_dimension, idx),
                    **(zero_closure_parameters if idx == 0 else non_zero_closure_parameters),
                ))
        for _ in range(randint(low = 0, high = max_freegroups)):
            leaves.append(freegroup(
                freegroup_dimension = freegroup_dimension,
                **freegroups_parameters,
            ))
            
        shuffle(leaves)
        return random_tree(leaves)
    
    def init(): return normalize(flatten(Mult([_init() for _ in range(randint(low = 1, high = max_multipliers + 1))])))
    
    def features(word):
        return {
            'word': word[::],
            'multi_label': compute_multi_label(word, freegroup_dimension)
        }
        
    def condition(entry):
        if 0 == len(entry['word']) or len(entry['word']) >= total_max_length:
            return False
        
        if sum(entry['multi_label']) >= freegroup_dimension + 1:
            return False
        return True

    return filter(condition, map(features, repeatfunc(init)))



def complete_intersection_dist(
    freegroup_dimension = 3,
    zero_closure_parameters = {'method': 'brackets', 'depth_method': 'c', 'depth_parameters': {'radius': 5}},
    non_zero_closure_parameters = {'method': 'brackets', 'depth_method': 'c', 'depth_parameters': {'radius': 30}},
    max_freegroups = 1,
    freegroup_parameters = {'length_method': 'c', 'length_parameters': {'radius': 5}},
    max_multipliers = 1,
    total_max_length = 200,
):
    def _init():
        leaves = []
        for idx in range(freegroup_dimension + 1):
            leaves.append(normal_closure(
                freegroup_dimension = freegroup_dimension,
                closure = wu_closure(freegroup_dimension, idx),
                **(zero_closure_parameters if idx == 0 else non_zero_closure_parameters),
            ))
        for _ in range(randint(low = 0, high = max_freegroups)):
            leaves.append(freegroup(
                freegroup_dimension = freegroup_dimension,
                **freegroups_parameters,
            ))
            
        shuffle(leaves)
        return random_tree(leaves)
    
    def init(): return normalize(flatten(Mult([_init() for _ in range(randint(low = 1, high = max_multipliers + 1))])))
    
    def features(word):
        return {
            'word': word[::],
            'multi_label': compute_multi_label(word, freegroup_dimension)
        }
        
    def condition(entry):
        if 0 == len(entry['word']) or len(entry['word']) >= total_max_length:
            return False
        
        if sum(entry['multi_label']) != freegroup_dimension + 1:
            return False
        return True

    return filter(condition, map(features, repeatfunc(init)))


def generator_permutations_dist(freegroup_dimension, word):
    for s in range(freegroup_dimension):
        yield {
            'word': [-1 if f < 0 else 1 * (1 + (abs(f) - 1 + s) % freegroup_dimension) for f in word],
            'multi_label': [1] * (freegroup_dimension + 1),
        }
        
def wrapper_dist(freegroup_dimension, word_generator, word_generator_kwargs):
    return map(lambda x: {'word': x, 'multi_label': compute_multi_label(x, freegroup_dimension)}, word_generator(**word_generator_kwargs))


In [None]:
from matplotlib import pyplot as plt

visualize = sample_dataset(
    incomplete_intersection_dist, size = 1000,
    dist_kwargs = config['train']['dataset']['dist']['args']
)

words = list(map(lambda x: x['word'], visualize))
masks = list(map(lambda x: x['multi_label'], visualize))

plt.hist(list(map(len, words)))
plt.show()

visualize = sample_dataset(
    complete_intersection_dist, size = 1000,
    dist_kwargs = config['eval']['dataset']['trivial']['dist']['args'],
)

words = list(map(lambda x: x['word'], visualize))
masks = list(map(lambda x: x['multi_label'], visualize))

plt.hist(list(map(len, words)))
plt.show()


# TRAIN

In [None]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(**config['tokenizer'])

model = AutoModelForCausalLM.from_config(AutoConfig.for_model(
        bos_token_id = tokenizer.bos_token_id,
        eos_token_id = tokenizer.eos_token_id,
        pad_token_id = tokenizer.pad_token_id,
        **config['model']
))

sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
from typing import List, Dict, Any
from freegroup.tools import batch_to_string
from torch import tensor

def data_collator(
    model,
    tokenizer,
    freegroup_dimension,
    method = 'ignore', # choice ['ignore', 'prompt', 'mask']
    mode = 'predict', # choice: ['predict', 'generate']
):
    y, n, sep = tokenizer.additional_special_tokens
    
    def prompt(multi_label):
        return ' '.join([y if f else n for f in multi_label] + [sep])
    
    def predict_collate_fn(batch: List[Dict[str, Any]]):
        words = batch_to_string(list(map(lambda x: x['word'], batch)))
        multi_labels = list(map(lambda x: x['multi_label'], batch))
        
        if method == 'prompt':
            words = [prompt(label) + w for w, label in zip(words, multi_labels)]
                        
        inputs = tokenizer(words, padding = True, return_tensors = 'pt')
        
        input_ids = inputs.input_ids.clone()
        input_ids = input_ids[:, :-1]
        
        attention_mask = inputs.attention_mask.clone()
        attention_mask = attention_mask[:, :-1]
        
        labels = inputs.input_ids.clone()
        labels[labels == tokenizer.pad_token_id] = -100
        labels = labels[:, 1:]
        
        if method in ['ignore', 'prompt'] :
            return {
                'input_ids': input_ids,
                'labels': labels,
                'attention_mask': attention_mask,
            }
                
        if method == 'mask':
            head_mask = tensor(multi_labels)
            # must be [num_layers, batch, num_heads, seq_len, seq_len]
            head_mask = head_mask.unsqueeze(0)\
                    .unsqueeze(-1).unsqueeze(-1)\
                    .repeat(1, 1, model.config.n_head // (freegroup_dimension + 1), 1, 1)\
                    .expand(model.config.n_layer, -1, -1, -1, -1)\
                    .clone()
                        
            return {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'labels': labels,
                'head_mask': head_mask,
            }
    
    
    def generate_collate_fn(batch: List[Dict[str, Any]]):
        words = list(map(lambda x: x['word'], batch))
        words = batch_to_string(words)
        
        if method == 'prompt':
            words = [prompt([1] * (freegroup_dimension + 1)) + w for w, label in zip(words, multi_labels)]
        
        inputs = tokenizer(words, padding = True, return_tensors = 'pt')
        
        return {
            'inputs': inputs.input_ids[:, :-1] # exclude `eos` token
        }
    
    return predict_collate_fn if mode == 'predict' else generate_collate_fn

In [None]:
from typing import Dict
from transformers import LogitsProcessor
from torch import tensor, inf

class NoLastTokenReductionProcessor(LogitsProcessor):
    def __init__(self, reciprocal_tokens: Dict[int, int]):
        self.reciprocal_tokens = reciprocal_tokens

    @staticmethod
    def from_tokenizer(freegroup_dimension, tokenizer):
        from itertools import chain
        
        reciprocal_tokens = dict()
        for x in chain(range(-freegroup_dimension, 0), range(1, freegroup_dimension + 1)):
            idx, _idx = tokenizer.convert_tokens_to_ids([str(x), str(-x)])
            reciprocal_tokens[idx] = _idx

        return NoLastTokenReductionProcessor(reciprocal_tokens)

    def __call__(self, input_ids, scores):
        
        id_with_reciprocals = [
            (i, self.reciprocal_tokens[token_idx.item()]) if token_idx.item() in self.reciprocal_tokens else None
            for i, token_idx in enumerate(input_ids[:, -1])
        ]
        id_with_reciprocals = filter(lambda x: not x is None, id_with_reciprocals)


        try:    
            batch_ids, last_reciprocal_ids = map(
                lambda x: tensor(list(x), dtype = int, device = scores.device),
                zip(*id_with_reciprocals)
            )
            scores[batch_ids, last_reciprocal_ids] = -inf
        except ValueError:
            pass
        
        return scores    
    
# https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/generation/logits_process.py#L892
class SuppressTokensLogitsProcessor(LogitsProcessor):
    def __init__(self, tokens):
        self.tokens = list(tokens)
        
    def __call__(self, input_ids, scores):
        scores[:, self.tokens] = -float('inf')
        return scores
    

from freegroup.tools import (
    batch_normalize, batch_is_from_singleton_normal_closure, wu_closure,
    batch_reduce_modulo_singleton_normal_closure, batch_to_string,
    batch_from_string,
)

from freegroup.sampling import (
    freegroup_generator
)
from copy import deepcopy
from itertools import islice

import numpy as np

def completion_ratio(outputs, references, freegroup_dimension):
    per_reference = len(outputs) // len(references)
    is_from_closures = [batch_is_from_singleton_normal_closure(outputs, wu_closure(freegroup_dimension, idx)) for idx in range(freegroup_dimension + 1)]
    is_from_closures = np.array(is_from_closures)
    is_from_closures = np.transpose(is_from_closures)
    is_from_closures = np.all(is_from_closures, axis=-1)
    is_from_closures = is_from_closures & (np.array(list(map(len, outputs))) > 0)
    is_from_closures = is_from_closures.reshape(len(references), -1)
    is_from_closures = is_from_closures.any(axis = -1)
    return {'completion_ratio': is_from_closures.mean()}

def reduction_ratio(outputs, references, freegroup_dimension):
    closures = \
        [
            batch_reduce_modulo_singleton_normal_closure(outputs, wu_closure(freegroup_dimension, idx))
            for idx in range(freegroup_dimension + 1)
        ]
    
    closures_lenghts = np.array(list(map(lambda cls: list(map(len, cls)), closures)))
    lengths = np.array(list(map(len, outputs)))
    
    reduced = (lengths[None, :] - closures_lenghts) / lengths[None, :]
    reduced = np.nan_to_num(reduced)
    reduced = reduced.mean(axis=1)
    metrics = dict()
    for idx in range(freegroup_dimension + 1):
        metrics[f'reduction_ratio_{idx}'] = reduced[idx]
    metrics[f'reduction_ratio'] = reduced.mean()
      
    return metrics
        


In [None]:
from IPython.display import display, Javascript
display(Javascript('IPython.notebook.save_checkpoint();'))

In [None]:
import wandb
from tqdm.auto import tqdm

import numpy as np

from os.path import exists
from os import makedirs
from shutil import rmtree

from freegroup.sampling import freegroup_generator
from freegroup.tools import batch_from_string

from transformers import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, LogitsProcessorList
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from torch import no_grad, save
from transformers import set_seed

set_seed(config['seed'])

metrics, report = {}, {}
def add_metric(name, value):
    if not name in metrics:
        metrics[name] = []
    metrics[name].append(value)
    
def report_metric(name):
    report[f'{name}_mean'] = np.mean(metrics[name])
    report[f'{name}_max'] = np.max(metrics[name])
    report[f'{name}_min'] = np.min(metrics[name])
    metrics[name] = []
    
def construct_dataset(conf):
    dist_kwargs = conf['dist']['args']
    if conf['dist']['name'] == 'incomplete-intersection':
        dist = incomplete_intersection_dist
    elif conf['dist']['name'] == 'complete-intersection':
        dist = complete_intersection_dist
    elif conf['dist']['name'] == 'generator-permutations':
        dist = generator_permutations_dist
    elif conf['dist']['name'] == 'freegroup':
        dist = wrapper_dist
        dist_kwargs = {
            'freegroup_dimension': config['freegroup_dimension'],
            'word_generator': freegroup_generator,
            'word_generator_kwargs': conf['dist']['args'],
        }
    
    type_kwargs = conf['type']['args']
    if conf['type']['name'] == 'refillable':
        return RefillableDataset(dist, dist_kwargs, **type_kwargs)
    elif conf['type']['name'] == 'fixed':
        return sample_dataset(dist, dist_kwargs = dist_kwargs, **type_kwargs)
        
train_dataset = construct_dataset(config['train']['dataset'])
eval_dataset = {k: construct_dataset(v) for k, v in config['eval']['dataset'].items()}
gen_dataset = {k: construct_dataset(v) for k, v in config['gen']['dataset'].items()}

gen_methods = {}
for k, v in config['gen']['methods'].items():
    logits_processors = []
    
    if 'top_p' in v:
        logits_processors.append(TopPLogitsWarper(v.pop('top_p')))
    if 'repetition_penalty' in v:
        logits_processors.append(RepetitionPenaltyLogitsProcessor(v.pop('repetition_penalty')))
    
    logits_processors.append(
        SuppressTokensLogitsProcessor(tokenizer.convert_tokens_to_ids(['[', ']', ',', 'y', 'n', ':']))
    )
    logits_processors.append(
        NoLastTokenReductionProcessor.from_tokenizer(config['freegroup_dimension'], tokenizer)
    )
    
    gen_methods[k] = {'logits_processor': LogitsProcessorList(logits_processors), **v}
    

with    wandb.init(project = config['project'], notes = '', config = config) as run,\
        tqdm(total = config['train']['steps']) as progress:
    
    model.to(config['device'])
    
    artifact = wandb.Artifact(f'notebook', type='notebook')
    artifact.add_file(config['notebook_name'])
    run.log_artifact(artifact)
    
    if config['train']['optimizer']['name'] == 'AdamW':
        optimizer = AdamW(params = model.parameters(), **config['train']['optimizer']['args'])
        
    if config['train']['scheduler']['name'] == 'Linear':
        scheduler = LinearLR(optimizer,**config['train']['scheduler']['args'])
        
    train_dataloader_fn = lambda: DataLoader(
        train_dataset,
        batch_size = config['train']['batch_size'],
        collate_fn = data_collator(
            model, tokenizer, config['freegroup_dimension'],
            method = config['method'], mode = 'predict',
        )
    )
    
    dataloader = [].__iter__()
    
    while progress.n < config['train']['steps']:
        progress.update()
        model.train()
        
        try:
            batch = next(dataloader)
        except StopIteration:
            dataloader = train_dataloader_fn().__iter__()
            batch = next(dataloader)
            
        for k, v in batch.items():
            batch[k] = v.to(model.device)
            
        outputs = model(**batch)
        optimizer.zero_grad()
        outputs[0].backward()
        optimizer.step()
        
        add_metric('train/loss', outputs[0].item())
        
        if progress.n % config['train']['scheduler']['steps'] == 0:
            scheduler.step()
            
        if progress.n % config['eval']['steps'] == 0:
            model.eval()
            for key, dataset in eval_dataset.items():
                for batch in DataLoader(
                    dataset, shuffle=True,
                    batch_size = config['eval']['batch_size'],
                    collate_fn = data_collator(
                        model, tokenizer, config['freegroup_dimension'],
                        method = config['method'], mode = 'predict',
                    )
                ):
                    for k, v in batch.items():
                        batch[k] = v.to(model.device)
                    with no_grad():
                        outputs = model(**batch)
                    add_metric(f'eval/{key}_loss', outputs[0].item())
                report_metric(f'eval/{key}_loss')
                
        if progress.n % config['gen']['steps'] == 0:
            model.eval()
            
            gen_metric_names = ['completion_ratio'] + ['reduction_ratio'] +\
                [f'reduction_ratio_{idx}' for idx in range(config['freegroup_dimension'] + 1)]
            
            for dataset_key, dataset in gen_dataset.items():
                for method_key, method in gen_methods.items():
                    
                    for batch in DataLoader(
                        dataset, shuffle = True,
                        batch_size = config['gen']['batch_size'],
                        collate_fn = data_collator(
                            model, tokenizer, config['freegroup_dimension'],
                            method = config['method'], mode = 'generate',
                        )
                    ):
                        for k, v in batch.items():
                            batch[k] = v.to(model.device)
                            
                        outputs = model.generate(**batch, **method)
                        outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                        outputs = batch_from_string(outputs)
                        
                        for k, v in completion_ratio(
                            outputs, batch['inputs'],
                            freegroup_dimension = config['freegroup_dimension']
                        ).items():
                            add_metric(f'gen/{method_key}_{dataset_key}/{k}', v)
                            
                        for k, v in reduction_ratio(
                            outputs, batch['inputs'],
                            freegroup_dimension = config['freegroup_dimension']
                        ).items():
                            add_metric(f'gen/{method_key}_{dataset_key}/{k}', v)
                            
                    for name in gen_metric_names:
                        report_metric(f'gen/{method_key}_{dataset_key}/{name}')
                        
        if progress.n % config['train']['log_steps'] == 0:
            report_metric('train/loss')
        
        if progress.n % config['train']['save_steps'] == 0:
            model_artifact = wandb.Artifact(run.id, type='model')
            checkpoint = f'{run.dir}/checkpoint'
            if exists(checkpoint): rmtree(checkpoint)
            makedirs(checkpoint)
            model.save_pretrained(checkpoint)
            tokenizer.save_pretrained(checkpoint)
            save(optimizer.state_dict(), f'{checkpoint}/optimizer.pt')
            save(scheduler.state_dict(), f'{checkpoint}/scheduler.pt')
            model_artifact.add_dir(checkpoint)
            run.log_artifact(model_artifact)
            
        if report: wandb.log(report, commit = True, step = progress.n)
            
for dataset in [train_dataset, *eval_dataset.values(), *gen_dataset.values()]:
    if isinstance(dataset, RefillableDataset): dataset.stop()