-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'melhora_acuracia' of https://github.com/fga-eps-mds/202…
…2-2-IsItKbs into melhora_acuracia
- Loading branch information
Showing
3 changed files
with
162 additions
and
350 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"from sklearn.feature_extraction.text import TfidfVectorizer\n", | ||
"import pickle\n", | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"## Reading and filtering data:\n", | ||
"\n", | ||
"txt = \"\"\n", | ||
"with open(\"data/raw/mashing.txt\", \"r\", encoding=\"utf-8\") as g:\n", | ||
" txt = g.read()\n", | ||
"\n", | ||
"sentences = txt.split(\"\\n\")\n", | ||
"for i in range(len(sentences)):\n", | ||
" sentences[i] = sentences[i].strip()\n", | ||
"\n", | ||
"sentences = [x for x in sentences if x != \"\"]\n", | ||
"\n", | ||
"texto = \"\"\n", | ||
"with open(\"data/raw/large-2014.txt\", \"r\", encoding=\"utf-8\") as k:\n", | ||
" texto = k.read()\n", | ||
" \n", | ||
"texto = texto.replace(\"?\",\".\")\n", | ||
"texto = texto.replace(\"!\",\".\")\n", | ||
"texto = texto.replace(\"»\",\"\")\n", | ||
"texto = texto.replace(\"«\",\"\")\n", | ||
"texto = texto.replace(\":\",\"\")\n", | ||
"texto = texto.replace(\";\",\"\")\n", | ||
"texto = texto.replace(\"...\",\".\")\n", | ||
"texto = texto.replace(\"…\",\".\")\n", | ||
"texto = texto.replace(\"\\n\",\".\")\n", | ||
"texto = texto.replace(\" \",\" \")\n", | ||
"texto = texto.replace(\"—\", \"\")\n", | ||
"texto = texto.replace(\"\\\"\",\"\")\n", | ||
"texto = texto.replace(\"„\",\"\")\n", | ||
"texto = texto.replace(\"eKGWB\", \"\")\n", | ||
"texto = texto.replace(\"Studia Nietzscheana (2014), www.nietzschesource.org/SN/large-2014.\",\"\")\n", | ||
"sentencas = texto.split(\" \")\n", | ||
"for i in range(len(sentencas)):\n", | ||
" sentencas[i] = sentencas[i].strip()\n", | ||
" \n", | ||
"sentencas = [x for x in sentencas if x != \"\"]\n", | ||
"\n", | ||
"## Splitting into training and testing data:\n", | ||
"\n", | ||
"X = np.array(sentences + sentencas)\n", | ||
"Y = np.array(['Y']*len(sentences) + ['N']*len(sentencas))\n", | ||
"df = pd.DataFrame({'sentence':X,'mashing':Y})\n", | ||
"df.to_csv(\"data/processed/dataframe.csv\") ##saving the filtered data as csv\n", | ||
"\n", | ||
"from sklearn.model_selection import train_test_split\n", | ||
"X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.33, random_state=2)\n", | ||
"\n", | ||
"## Training the model:\n", | ||
"\n", | ||
"from sklearn.linear_model import LogisticRegression\n", | ||
"from sklearn.feature_extraction.text import CountVectorizer, HashingVectorizer\n", | ||
"\n", | ||
"vectorizer = TfidfVectorizer(ngram_range=(1, 4),\n", | ||
" lowercase=True,\n", | ||
" analyzer='char', binary=True,\n", | ||
" strip_accents=\"unicode\")\n", | ||
"vectorizer = CountVectorizer (ngram_range=(1,4),\n", | ||
" lowercase = True,\n", | ||
" analyzer = 'char_wb',\n", | ||
" binary = False,\n", | ||
" strip_accents = \"unicode\")\n", | ||
"vectorizer = HashingVectorizer (ngram_range=(1,4),\n", | ||
" lowercase = True,\n", | ||
" analyzer = 'char',\n", | ||
" binary = True,\n", | ||
" strip_accents = \"unicode\")\n", | ||
"vectorizer.fit(X_train)\n", | ||
"vectorizer.fit(X_train)\n", | ||
"model = LogisticRegression()\n", | ||
"\n", | ||
"model.fit(vectorizer.transform(X_train),Y_train)\n", | ||
"\n", | ||
"from sklearn.metrics import confusion_matrix, accuracy_score\n", | ||
"\n", | ||
"Y_pred = model.predict(vectorizer.transform(X_test))\n", | ||
"\n", | ||
"confusion_matrix(Y_test, Y_pred)\n", | ||
"\n", | ||
"## Accuracy score:\n", | ||
"\n", | ||
"from sklearn.metrics import accuracy_score\n", | ||
"\n", | ||
"Y_pred = model.predict(X_test)\n", | ||
"\n", | ||
"acc = accuracy_score(Y_test, Y_pred) ##0.985\n", | ||
"print('Acc:', acc)\n", | ||
"\n", | ||
"## Evaluating the model with new metrics\n", | ||
"\n", | ||
"from sklearn.metrics import balanced_accuracy_score\n", | ||
"\n", | ||
"score1 = balanced_accuracy_score(Y_test, Y_pred)\n", | ||
"\n", | ||
"print('Acc:', score1)\n", | ||
"\n", | ||
"from sklearn.metrics import f1_score\n", | ||
"\n", | ||
"score2 = f1_score(Y_test, Y_pred, labels=None, pos_label='0', average='binary', sample_weight=None, zero_division='warn')\n", | ||
"\n", | ||
"print('Acc:', score2)\n", | ||
"\n", | ||
"from sklearn.metrics import recall_score\n", | ||
"\n", | ||
"score3 = recall_score(Y_test, Y_pred, labels=None, pos_label='0', average='binary', sample_weight=None, zero_division='warn')\n", | ||
"\n", | ||
"print('Acc:', score3)\n", | ||
"\n", | ||
"## Compressing Model\n", | ||
"\n", | ||
"import pickle\n", | ||
"\n", | ||
"pickle.dump(model, open(\"models/logistic-reg.sav\", 'wb'))\n", | ||
"\n", | ||
"# Input in case you want to test it:\n", | ||
"input_data = [(input(\"\"))]\n", | ||
"pred = model.predict(vectorizer.transform([x[0] for x in input_data]))\n", | ||
"print(pred)\n", | ||
"\n", | ||
"print('Acc:', acc)\n", | ||
"\n", | ||
"## Compressing Model\n", | ||
"\n", | ||
"import pickle\n", | ||
"\n", | ||
"pickle.dump(model, open(\"models/logistic-reg.pkl\", 'wb'))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python", | ||
"version": "3.11.0" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "a39d81b61f9ce6a24f8f526befdd921f4c6667ad35f1e729c3a4277aa5c26a6d" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.