In [None]:
# libs
import sys
import csv
import time
from glob import glob

from tqdm import tqdm_notebook as tqdm

import libs.bag_of_worder as bag_of_worder
import libs.preprocessor as tweet_preproc

# Keras
from keras.models import load_model
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

## Load Models

In [None]:
from joblib import dump, load

# --- CNN ---
try:
    cnn_clf = load_model('model/dl_cnn.h5')
    print("CNN classifier loaded!")
except:
    print("ERROR: CNN not loaded")

## Load Objects

In [None]:
# Init Preprocessor
twitterPreprocessor = tweet_preproc.TwitterPreprocessor()

# Keras Tokenizer
with open('model/tokenizer.joblib', 'rb') as handle:
    tokenizer = load(handle)

In [None]:
def transformTweets(tweets):
    
    # Convert to sequence
    tweets_t = tokenizer.texts_to_sequences(tweets)
    
    # Adding 1 because of reserved 0 index
    vocab_size = len(tokenizer.word_index) + 1

    # Max sequence length
    maxlen = 100

    # Pad
    tweets_t = pad_sequences(tweets_t, padding='post', maxlen=maxlen)
    
    return tweets_t
    

## Predict Functions

In [None]:
def predict_cnn(tweets, max_thresh=0.5, min_thresh=0.5):
    
    # Transform tweets 
    tweets = transformTweets(tweets)

    # Predict
    preds = cnn_clf.predict(tweets)
    
    # Make sure they are the same size
    assert len(preds) == len(tweets)

    # Convert to binary
    binary_preds = []
    for pred in preds:
        if(pred > max_thresh):
            binary_preds.append(1)
            
        elif(pred < min_thresh):
            binary_preds.append(0)
            
    return binary_preds

## Load tweets

In [None]:
def file_len(fname):
    
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    
    nbrOfLines = i + 1
    print("Nbr of lines : " + str(nbrOfLines))
    
    return nbrOfLines

In [None]:
def predictFile(src_path,out_path):
    
    # Count number of tweets
    nbr_tweets = file_len(src_path)

    with open(out_path, 'w+', newline='', encoding="utf-8") as outfile:
        with open(src_path, 'r', newline='', encoding="utf-8") as csvfile:

            # init reader
            reader = csv.reader(csvfile, quotechar='"', delimiter=',')

            # Taking the header of the file + the index of useful columns:
            header = next(reader)
            ind_createdAt = header.index('created_at')
            ind_text = header.index('text')
            ind_description = header.index('description')
            ind_location = header.index('location')

            # convert tweets file to list                
            tweets = []
            all_tweets = []
            for row in reader:
                    
                # get data
                created_at = row[ind_createdAt]
                text = row[ind_text]
                #description = row[ind_description]
                location = row[ind_location]
                
                # append to lists
                all_tweets.append([created_at,text,location])
                tweets.append(text)

            
            
            # Predict all tweets
            preds = predict_cnn(tweets, max_thresh=0.95, min_thresh=0.05)
            
            # Write headers for first row
            outfile.write('"label","created_at","text","location"\n')

            # Init counter
            tweet_counter = 0
            
            # Write to file
            for i in tqdm(range(0,len(preds))):
                
                # Get pred
                pred = preds[i]

                # If failed skip
                if(pred < 0):
                    continue
                    
                # Get data                
                created_at = all_tweets[i][0]
                text = all_tweets[i][1]
                #description = row[ind_description]
                location = all_tweets[i][2]

                # Write to file
                rowData = [str(pred),created_at,text,location]
                rowData = '"' + '","'.join(rowData) + '"\n'
                outfile.write(rowData)

                # increment counter
                tweet_counter = tweet_counter + 1
                    
                
            print("Nbr of tweets labeled: " + str(tweet_counter))

In [None]:
# Glob all the tweets csv
filenames = glob("data/general/*/tweets.csv")
for fname in filenames:
    
    print(fname)
    outpath = "/".join(fname.split("/")[:-1]) + "/predictions.csv"
    predictFile(fname, outpath)