-
Notifications
You must be signed in to change notification settings - Fork 0
/
8d-verification-apply-keras-alternative.R
35 lines (22 loc) · 1.33 KB
/
8d-verification-apply-keras-alternative.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# verifikace modelu - načte Keras model z adresáře keras a uplatní ho na data o nově stažených tweetech
library(tidyverse)
library(keras)
model <- load_model_hdf5("./models/bi-ltsm.h5") # načíst...
summary(model) # vytisknout shrnutí
pred_data <- readr::read_csv('./data/new-keras-matrix.csv') %>%
select(id, name, text)
train_data <- readr::read_csv('./data/new-keras-matrix.csv') %>%
mutate(pravy_tomio = ifelse(name == 'tomio_cz', 1,0)) %>% # klasifikace = potřebuju binární výstup
select(-id, -name, -text)
x_pred <- data.matrix(train_data %>% select(-pravy_tomio)) # všechno kromě targetu jako matice
pred <- model %>% # vektor pravděpodobnosti - inverval nula až jedna
predict_proba(x_pred)
verifikace <- pred_data %>% # doplnit podle pořadi id tweetu, pravděpodobnost + vlastní tweet
cbind(pred)
write_csv(verifikace, './data/verifikace.csv') # uložit pro budoucí použití
print(paste('Do /data/verifikace.csv uloženo', nrow(x_pred), 'namodelovaných řádků.'))
verifikace <- verifikace %>%
mutate(name_pred = ifelse(pred>0.5, 'tomio_cz', 'Tomio_Okamura'))
conf_mtx <- table(verifikace$name, verifikace$name_pred)
print(paste('Správně předpovězeno',sum(diag(conf_mtx)), 'z',sum(conf_mtx), 'tweetů, což představuje', round(100 * sum(diag(conf_mtx))/sum(conf_mtx), 2), 'procent.'))
print(conf_mtx)