# Test

In [1]:
%load_ext autoreload
%autoreload 2

Import datasets using functions from src/data/data.py. Datasets are downloaded from huggingface and stored in /data. Once downloaded, datasets are loaded locally.

Run ```pip install -e .``` if module importing isn't working.

In [2]:
from src.data.data import get_in_domain, get_out_domain
from src.data.utils import get_random_subsets

in_domain = get_in_domain()
out_domain = get_out_domain()

print(f"In domain:\n{in_domain}")
print(in_domain[0])

print(f"Out of domain:\n{out_domain}")
print(out_domain[10])

# get_random_subsets(in_domain)

  from .autonotebook import tqdm as notebook_tqdm


In domain:
Dataset({
    features: ['premise', 'hypothesis', 'label', 'idx'],
    num_rows: 261802
})
{'premise': 'you know during the season and i guess at at your level uh you lose them to the next level if if they decide to recall the the parent team the Braves decide to call to recall a guy from triple A then a double A guy goes up to replace him and a single A guy goes up to replace him', 'hypothesis': 'You lose the things to the following level if the people recall.', 'label': 0, 'idx': 1}
Out of domain:
Dataset({
    features: ['premise', 'hypothesis', 'label', 'parse_premise', 'parse_hypothesis', 'binary_parse_premise', 'binary_parse_hypothesis', 'heuristic', 'subcase', 'template'],
    num_rows: 10000
})
{'premise': 'The president avoided the athlete .', 'hypothesis': 'The athlete avoided the president .', 'label': 1, 'parse_premise': '(ROOT (S (NP (DT The) (NN president)) (VP (VBD avoided) (NP (DT the) (NN athlete))) (. .)))', 'parse_hypothesis': '(ROOT (S (NP (DT The) (NN at

Import models using methods from src/models/opt.py. Models are downloaded from huggingface and stored in /models/pretrained. Once downloaded, models are loaded locally.

In [3]:
from src.model.model import get_model

model_opt125, tokenizer_opt125 = get_model('opt-125m')
model_opt350, tokenizer_opt350 = get_model('opt-350m')

Few-shot finetuning.

In [4]:
from src.finetuners.fewshot import fine_tune

train_dataset = in_domain.select(range(2))
eval_dataset = out_domain.select(range(2))

fine_tune(model=model_opt125, tokenizer=tokenizer_opt125, train_dataset=train_dataset, eval_dataset=eval_dataset)

100%|██████████| 40/40 [00:09<00:00,  4.06it/s]


{'train_runtime': 9.8602, 'train_samples_per_second': 8.113, 'train_steps_per_second': 4.057, 'train_loss': 0.061122357845306396, 'epoch': 40.0}


100%|██████████| 1/1 [00:42<00:00, 42.00s/it]


{'accuracy': 0.0,
 'total_inference_time': 42.0685,
 'average_inference_time_per_sample': 21.03425,
 'peak_memory_usage_gb': 1.5613317489624023}

Batch few-shot finetuning.

In [10]:
from src.finetuners.fewshot import batch_fine_tune
import json

eval_dataset = out_domain.select(range(10))

results, avg_results = batch_fine_tune(model_name='opt-125m', train_dataset=in_domain, eval_dataset=eval_dataset, sample_sizes=[2, 4], num_trials=5)

print(json.dumps(avg_results, indent=4))

  0%|          | 0/2 [00:00<?, ?it/s]

{'train_runtime': 8.6874, 'train_samples_per_second': 9.209, 'train_steps_per_second': 4.604, 'train_loss': 0.05942646861076355, 'epoch': 40.0}
{'train_runtime': 8.7788, 'train_samples_per_second': 9.113, 'train_steps_per_second': 4.556, 'train_loss': 0.08801689147949218, 'epoch': 40.0}
{'train_runtime': 8.8401, 'train_samples_per_second': 9.05, 'train_steps_per_second': 4.525, 'train_loss': 0.05300096273422241, 'epoch': 40.0}
{'train_runtime': 8.9181, 'train_samples_per_second': 8.971, 'train_steps_per_second': 4.485, 'train_loss': 0.09598073959350586, 'epoch': 40.0}
{'train_runtime': 8.7003, 'train_samples_per_second': 9.195, 'train_steps_per_second': 4.598, 'train_loss': 0.08078636527061463, 'epoch': 40.0}


 50%|█████     | 1/2 [02:05<02:05, 125.48s/it]

{'train_runtime': 86.5902, 'train_samples_per_second': 1.848, 'train_steps_per_second': 0.462, 'train_loss': 0.11859711408615112, 'epoch': 40.0}
{'train_runtime': 87.0571, 'train_samples_per_second': 1.838, 'train_steps_per_second': 0.459, 'train_loss': 0.09925844073295594, 'epoch': 40.0}
{'train_runtime': 88.1669, 'train_samples_per_second': 1.815, 'train_steps_per_second': 0.454, 'train_loss': 0.06952434778213501, 'epoch': 40.0}
{'train_runtime': 86.2525, 'train_samples_per_second': 1.855, 'train_steps_per_second': 0.464, 'train_loss': 0.11014838218688965, 'epoch': 40.0}
{'train_runtime': 92.4881, 'train_samples_per_second': 1.73, 'train_steps_per_second': 0.432, 'train_loss': 0.12173117399215698, 'epoch': 40.0}


100%|██████████| 2/2 [11:00<00:00, 330.15s/it]

{
    "2": {
        "accuracy": 0.8800000000000001,
        "total_inference_time": 14.52936,
        "average_inference_time_per_sample": 1.452936,
        "peak_memory_usage_gb": 3.8645865440368654
    },
    "4": {
        "accuracy": 0.62,
        "total_inference_time": 17.0671,
        "average_inference_time_per_sample": 1.70671,
        "peak_memory_usage_gb": 3.872313976287842
    }
}



