-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #41 from minggnim/llm
Add LLM src and notebooks
- Loading branch information
Showing
21 changed files
with
2,419 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
341 changes: 341 additions & 0 deletions
341
notebooks/01_bert-classification-finetuning/01_a_classification_model_training_example.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,341 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"from pathlib import Path\n", | ||
"from torch import cuda\n", | ||
"from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n", | ||
"from nlp_models.base.data import create_label_dict, label_to_id_int, CustomDataset\n", | ||
"from nlp_models.base.metrics import classification_metrics\n", | ||
"from nlp_models.base.io import (\n", | ||
" get_pretrained_tokenizer,\n", | ||
" get_pretrained_model,\n", | ||
" save_model,\n", | ||
" save_label_dict\n", | ||
")\n", | ||
"from nlp_models.bert_classifier.bert import BertClass\n", | ||
"from nlp_models.bert_classifier.train import custom_trainer, validate, optimizer_obj" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"DEVICE = 'cuda' if cuda.is_available() else 'cpu'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"TRAINING_DATA = Path('../data/0_external/data.csv')\n", | ||
"MODEL_DIR = Path('../models/bert')\n", | ||
"LABEL_COL = 'label'\n", | ||
"DATA_COL = 'data'\n", | ||
"MAX_LEN = 512\n", | ||
"TRAIN_BATCH_SIZE = 8\n", | ||
"TEST_BATCH_SIZE = 4\n", | ||
"EPOCHS = 5\n", | ||
"LEARNING_RATE = 1e-05\n", | ||
"MODEL_NAME = 'bert-base-uncased'\n", | ||
"PRETRAINED_TOKENIZER = MODEL_DIR / 'pretrained/tokenizer-uncased'\n", | ||
"PRETRAINED_MODEL = MODEL_DIR / 'pretrained/bert-base-uncased'\n", | ||
"FINETUNED_DIR = MODEL_DIR / 'fine-tuned'\n", | ||
"FINETUNED_MODEL = FINETUNED_DIR / 'fine-tuned-uncased'\n", | ||
"FINETUNED_MODEL_STATE = FINETUNED_DIR / 'model-state-dict'\n", | ||
"FINETUNED_OPT_STATE = FINETUNED_DIR / 'opt-state-dict'\n", | ||
"CHECKPOINT_DIR = MODEL_DIR / 'checkpoint'\n", | ||
"pretrained_tokenizer = get_pretrained_tokenizer(MODEL_NAME, PRETRAINED_TOKENIZER)\n", | ||
"pretrained_model = get_pretrained_model(MODEL_NAME, PRETRAINED_MODEL)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"data = pd.read_csv(TRAINING_DATA)\n", | ||
"label_dict = create_label_dict(data, LABEL_COL)\n", | ||
"data[LABEL_COL] = data[LABEL_COL].apply(lambda c: label_to_id_int(c, label_dict))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"test_size = 0.2\n", | ||
"df_test = data.groupby(LABEL_COL, group_keys=False).apply(pd.DataFrame.sample, frac=test_size)\n", | ||
"df_train = data[~data.index.isin(df_test.index)]\n", | ||
"df_train.reset_index(drop=True, inplace=True)\n", | ||
"df_test.reset_index(drop=True, inplace=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"training_set = CustomDataset(df_train, DATA_COL, LABEL_COL, pretrained_tokenizer, MAX_LEN, int_labels=True)\n", | ||
"testing_set = CustomDataset(df_test, DATA_COL, LABEL_COL, pretrained_tokenizer, MAX_LEN, int_labels=True)\n", | ||
"\n", | ||
"train_dataloader = DataLoader(\n", | ||
" training_set, \n", | ||
" sampler=RandomSampler(training_set), \n", | ||
" batch_size=TRAIN_BATCH_SIZE \n", | ||
" )\n", | ||
"test_dataloader = DataLoader(\n", | ||
" testing_set,\n", | ||
" sampler=SequentialSampler(testing_set),\n", | ||
" batch_size=TEST_BATCH_SIZE\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 21, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"======== Epoch 1 / 5 ========\n", | ||
"Total steps: 1 || Training in progress...\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"0it [00:00, ?it/s]" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"1it [00:45, 45.88s/it]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
" Training time: 45.886006116867065 seconds ||\n", | ||
" Training loss: 0.6435834765434265 ||\n", | ||
" Training accuracy: 0.75\n", | ||
" \n", | ||
"Evaluation in progress...\n", | ||
"\n", | ||
" Validation time: 48.50241994857788 seconds\n", | ||
" Validation loss: 0.6544889211654663 ||\n", | ||
" Validation accuracy: 0.5\n", | ||
" \n", | ||
"\n", | ||
"======== Epoch 2 / 5 ========\n", | ||
"Total steps: 1 || Training in progress...\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"1it [00:40, 40.69s/it]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
" Training time: 40.690561056137085 seconds ||\n", | ||
" Training loss: 0.5998444557189941 ||\n", | ||
" Training accuracy: 0.875\n", | ||
" \n", | ||
"Evaluation in progress...\n", | ||
"\n", | ||
" Validation time: 42.543277978897095 seconds\n", | ||
" Validation loss: 0.6307095289230347 ||\n", | ||
" Validation accuracy: 0.5\n", | ||
" \n", | ||
"\n", | ||
"======== Epoch 3 / 5 ========\n", | ||
"Total steps: 1 || Training in progress...\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"1it [00:31, 31.09s/it]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
" Training time: 31.095382928848267 seconds ||\n", | ||
" Training loss: 0.5841808915138245 ||\n", | ||
" Training accuracy: 0.875\n", | ||
" \n", | ||
"Evaluation in progress...\n", | ||
"\n", | ||
" Validation time: 33.1079638004303 seconds\n", | ||
" Validation loss: 0.603857696056366 ||\n", | ||
" Validation accuracy: 1.0\n", | ||
" \n", | ||
"\n", | ||
"======== Epoch 4 / 5 ========\n", | ||
"Total steps: 1 || Training in progress...\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"1it [00:28, 28.89s/it]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
" Training time: 28.888659954071045 seconds ||\n", | ||
" Training loss: 0.4942058324813843 ||\n", | ||
" Training accuracy: 1.0\n", | ||
" \n", | ||
"Evaluation in progress...\n", | ||
"\n", | ||
" Validation time: 30.851701021194458 seconds\n", | ||
" Validation loss: 0.5770883560180664 ||\n", | ||
" Validation accuracy: 1.0\n", | ||
" \n", | ||
"\n", | ||
"======== Epoch 5 / 5 ========\n", | ||
"Total steps: 1 || Training in progress...\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"1it [00:26, 26.57s/it]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
" Training time: 26.574002027511597 seconds ||\n", | ||
" Training loss: 0.41605648398399353 ||\n", | ||
" Training accuracy: 1.0\n", | ||
" \n", | ||
"Evaluation in progress...\n", | ||
"\n", | ||
" Validation time: 28.43067193031311 seconds\n", | ||
" Validation loss: 0.5508862137794495 ||\n", | ||
" Validation accuracy: 1.0\n", | ||
" \n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model = BertClass(len(label_dict), pretrained_model)\n", | ||
"optimizer = optimizer_obj(model, LEARNING_RATE)\n", | ||
"custom_trainer(model, optimizer, train_dataloader, test_dataloader, EPOCHS, DEVICE)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'n': 2,\n", | ||
" 'accuracy': 1.0,\n", | ||
" 'precision': 1.0,\n", | ||
" 'recall': 1.0,\n", | ||
" 'f1': 1.0,\n", | ||
" 'confusion': array([[1, 0],\n", | ||
" [0, 1]])}" | ||
] | ||
}, | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"outputs, targets, loss = validate(model, test_dataloader, DEVICE)\n", | ||
"classification_metrics(outputs, targets)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"save_model(model, optimizer)\n", | ||
"save_label_dict(label_dict)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.9.7 ('ml')", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.0" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "a29fdb6f3e01a2d0ababb7a9ec88cb1ac07dcd7af98b4f6ed90db69885c2ce81" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.