In [None]:
import os
#GPU_NO = "3"
#os.environ["CUDA_VISIBLE_DEVICES"] = GPU_NO
TRANSFORMERS_CACHE='/workspace/HF_cache/transformers_cache/'
import sys
MGIT_PATH=os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))
sys.path.append(MGIT_PATH)

In [None]:
from utils.lineage.graph import *
from utils import meta_functions

In [None]:
success_condition_path=os.getcwd()+'/user_functions.py'
success_condition_name='test_success_condition'

In [None]:
compute_metrics_path=os.getcwd()+'/user_functions.py'
compute_metrics_name='compute_metrics'

In [None]:
lineage_eval_dataset = LineageDataset("glue", "sst2", split="validation", cache_dir=TRANSFORMERS_CACHE, feature_keys=['sentence'])

In [None]:
#lineage_train_dataset = LineageDataset("glue", "sst2", split="train", cache_dir=TRANSFORMERS_CACHE, feature_keys=['premise','hypothesis'])
lineage_train_dataset = LineageDataset("glue", "sst2", split="validation", cache_dir=TRANSFORMERS_CACHE, feature_keys=['sentence'])

In [None]:
preprocess_file = os.path.join(MGIT_PATH,'utils/preprocess_utils.py')
preprocess_function = 'glue_preprocess_function'

In [None]:
perturbation_file = os.path.join(MGIT_PATH,'utils/perturbations/perturbation_utils.py')
perturbation_name = 'perturb_char_misspelledword'

In [None]:
!rm -rf tmp_sst2_node1
!rm -rf tmp_sst2_node2
!rm -rf tmp_sst2_node3
!rm -rf parameter_store

In [None]:
g = LineageGraph()

In [None]:
test1 = LineageTest(
        preprocess_function_path=preprocess_file,
        preprocess_function_name=preprocess_function,        
        eval_dataset=lineage_eval_dataset,
        test_success_condition_path=success_condition_path,
        test_success_condition_name=success_condition_name,
        compute_metrics_path=compute_metrics_path,
        compute_metrics_name=compute_metrics_name,
        metric_for_best_model='accuracy',
        name='test1',
)
g.register_test_to_type(test1,'sst2')

In [None]:
lineage_train = LineageTrain(
    preprocess_function_path=preprocess_file,
    preprocess_function_name=preprocess_function,
    train_dataset=lineage_train_dataset,
    eval_dataset=lineage_eval_dataset,
    num_train_epochs=2,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
node1 = LineageNode(
    init_checkpoint='roberta-base',
    lineage_train=lineage_train,
    output_dir='tmp_sst2_node1',
    model_type='sst2'
)

g.add(node1)

In [None]:
lineage_train = LineageTrain(
    preprocess_function_path=preprocess_file,
    preprocess_function_name=preprocess_function,
    train_dataset=lineage_train_dataset,
    eval_dataset=lineage_eval_dataset,
    num_train_epochs=1,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
node2 = LineageNode(
    init_checkpoint='roberta-base',
    model_init_function_path=success_condition_path,
    model_init_function_name='vanilla_finetune_init_function',
    lineage_train=lineage_train,
    output_dir='tmp_sst2_node2',
    model_type='sst2',
    is_delta=True,
)

g.add(node2,etype='adapted',parent='tmp_sst2_node1')

In [None]:
lineage_train = LineageTrain(
    preprocess_function_path=preprocess_file,
    preprocess_function_name=preprocess_function,
    train_dataset=lineage_train_dataset,
    eval_dataset=lineage_eval_dataset,
    num_train_epochs=2,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
node3 = LineageNode(
    init_checkpoint='roberta-base',
    model_init_function_path=success_condition_path,
    model_init_function_name='vanilla_finetune_init_function',
    lineage_train=lineage_train,
    output_dir='tmp_sst2_node3',
    model_type='sst2',
    is_delta=True,
)

g.add(node3,etype='adapted',parent='tmp_sst2_node1')

In [None]:
g.show(etype="adapted")

In [None]:
for node in g.nodes.values():
    node.train()
    node.run_all_tests()

In [None]:
meta_functions.show_result_table(g)

In [None]:
meta_functions.show_result_table(g,show_metrics=True)

In [None]:
for node in g.nodes.values():
    assert node.is_test_failure() is False

In [None]:
for node in g.nodes.values():
    node.unload_model()
    node.run_all_tests()

In [None]:
meta_functions.show_result_table(g,show_metrics=True)

In [None]:
for k, v in node1.get_pt_model().state_dict().items():
    delta = node2.get_pt_model().state_dict()[k] - v
    print(delta)

In [None]:
res = !du -s tmp_sst2_node1
node1_store = int(res[0].split('\t')[0])

res = !du -s tmp_sst2_node2
node2_store = int(res[0].split('\t')[0])

res = !du -s tmp_sst2_node3
node3_store = int(res[0].split('\t')[0])

res = !du -s parameter_store
global_store = int(res[0].split('\t')[0])

print('storage savings:',(node1_store+node2_store+node3_store)/global_store)