In [None]:
import os
os.environ['HF_HOME'] = './'
os.environ['HF_DATASETS_CACHE'] = './'
MGIT_PATH=os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))

In [None]:
import sys
sys.path.append(MGIT_PATH)
from utils.lineage.graph import *
from utils import meta_functions

In [None]:
success_condition_path=os.getcwd()+'/user_functions.py'
success_condition_name='test_success_condition'

In [None]:
compute_metrics_path=os.getcwd()+'/user_functions.py'
compute_metrics_name='compute_metrics'

In [None]:
lineage_eval_dataset = LineageDataset("glue", "sst2", split="validation", feature_keys=['sentence'])

In [None]:
#lineage_train_dataset = LineageDataset("glue", "sst2", split="train", feature_keys=['sentence'])
lineage_train_dataset = LineageDataset("glue", "sst2", split="validation", feature_keys=['sentence'])

In [None]:
preprocess_file = os.path.join(MGIT_PATH,'utils/preprocess_utils.py')
preprocess_function = 'glue_preprocess_function'

In [None]:
perturbation_file = os.path.join(MGIT_PATH,'utils/perturbations/perturbation_utils.py')
perturbation_name = 'perturb_char_misspelledword'

In [None]:
!rm -rf tmp_sst2_node1
!rm -rf tmp_sst2_node1_v2
!rm -rf tmp_sst2_node2
!rm -rf tmp_sst2_node2_versioned
!rm -rf parameter_store
!rm -rf tmp_trainer
!rm -rf tmp_trainer_args

In [None]:
g = LineageGraph()

In [None]:
test1 = LineageTest(
        preprocess_function_path=preprocess_file,
        preprocess_function_name=preprocess_function,        
        eval_dataset=lineage_eval_dataset,
        test_success_condition_path=success_condition_path,
        test_success_condition_name=success_condition_name,
        compute_metrics_path=compute_metrics_path,
        compute_metrics_name=compute_metrics_name,
        metric_for_best_model='accuracy',
        name='test1',
)
g.register_test_to_type(test1,'sst2')

# Create root sst-2 node 1

In [None]:
lineage_train = LineageTrain(
    preprocess_function_path=preprocess_file,
    preprocess_function_name=preprocess_function,
    train_dataset=lineage_train_dataset,
    eval_dataset=lineage_eval_dataset,
    num_train_epochs=2,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
node1 = LineageNode(
    init_checkpoint='roberta-base',
    lineage_train=lineage_train,
    output_dir='tmp_sst2_node1',
    model_type='sst2'
)

g.add(node1)

In [None]:
node1.is_training_finished()

In [None]:
node1.train()

In [None]:
for node in g.nodes.values():
    node.run_all_tests()
meta_functions.show_result_table(g,show_metrics=True)

# Create sst-2 node 2 from node 1

In [None]:
lineage_train = LineageTrain(
    preprocess_function_path=preprocess_file,
    preprocess_function_name=preprocess_function,
    train_dataset=lineage_train_dataset,
    eval_dataset=lineage_eval_dataset,
    num_train_epochs=1,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
node2 = LineageNode(
    init_checkpoint='roberta-base',
    model_init_function_path=success_condition_path,
    model_init_function_name='vanilla_finetune_init_function',
    lineage_train=lineage_train,
    output_dir='tmp_sst2_node2',
    model_type='sst2',
    #is_delta=True,
)

g.add(node2,etype='adapted',parent='tmp_sst2_node1')

In [None]:
node2.train()

In [None]:
for node in g.nodes.values():
    node.run_all_tests()
meta_functions.show_result_table(g,show_metrics=True)

# Create sst-2 node 1 v2 from node 1

In [None]:
lineage_train = LineageTrain(
    preprocess_function_path=preprocess_file,
    preprocess_function_name=preprocess_function,
    train_dataset=lineage_train_dataset,
    eval_dataset=lineage_eval_dataset,
    num_train_epochs=2,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
node1_v2 = LineageNode(
    init_checkpoint='roberta-base',
    model_init_function_path=success_condition_path,
    model_init_function_name='vanilla_finetune_init_function',
    lineage_train=lineage_train,
    output_dir='tmp_sst2_node1_v2',
    model_type='sst2',
    #is_delta=True,
)

In [None]:
g.add(node1_v2,etype='adapted',parent='tmp_sst2_node1')
g.add(node1_v2,etype='versioned',parent='tmp_sst2_node1')

In [None]:
node1_v2.train()

In [None]:
for node in g.nodes.values():
    node.run_all_tests()
meta_functions.show_result_table(g,show_metrics=True)

# update node1_v2 and all adapted children referencing node 1 

In [None]:
g.run_update_cascade(old_node=node1,updated_node=node1_v2)

In [None]:
for node in g.nodes.values():
    node.run_all_tests()

In [None]:
meta_functions.show_result_table(g,show_metrics=True)

In [None]:
#g.show(etype="adapted")