In [370]:
from __future__ import unicode_literals, print_function, division

import os
from io import open
import sys
import math
import random
import argparse
import operator
import pdb

import torch
import torch.autograd as autograd
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


from collections import defaultdict
from collections import Counter

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score


# Kyle's attempt
import faker
from faker import Faker
import pandas as pd
import numpy as np
import re
from string import punctuation
import glob
import unicodedata
import string
import random
import time

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

fake = Faker()

In [371]:
# Need to have the class of the model in local memory to load a saved model in pytorch
class LSTMClassifier(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_size):
        super(LSTMClassifier, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1)

        self.hidden2out = nn.Linear(hidden_dim, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

        self.dropout_layer = nn.Dropout(p=0.2)


    def init_hidden(self, batch_size):
        return(autograd.Variable(torch.randn(1, batch_size, self.hidden_dim)),
                    autograd.Variable(torch.randn(1, batch_size, self.hidden_dim)))


    def forward(self, batch, lengths):

        self.hidden = self.init_hidden(batch.size(-1))

        embeds = self.embedding(batch)
        packed_input = pack_padded_sequence(embeds, lengths)
        outputs, (ht, ct) = self.lstm(packed_input, self.hidden)
        # ht is the last hidden state of the sequences
        # ht = (1 x batch_size x hidden_dim)
        # ht[-1] = (batch_size x hidden_dim)
        output = self.dropout_layer(ht[-1])
        output = self.hidden2out(output)
        output = self.softmax(output)

        return output

In [372]:
        
class PaddedTensorDataset(Dataset):
#     """Dataset wrapping data, target and length tensors.

#     Each sample will be retrieved by indexing both tensors along the first
#     dimension.

#     Arguments:
#         data_tensor (Tensor): contains sample data.
#         target_tensor (Tensor): contains sample targets (labels).
#         length (Tensor): contains sample lengths.
#         raw_data (Any): The data that has been transformed into tensor, useful for debugging
#     """

    def __init__(self, data_tensor, target_tensor, length_tensor, raw_data):
        assert data_tensor.size(0) == target_tensor.size(0) == length_tensor.size(0)
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
        self.length_tensor = length_tensor
        self.raw_data = raw_data

    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index], self.length_tensor[index], self.raw_data[index]

    def __len__(self):
        return self.data_tensor.size(0)

In [390]:
class DF_To_Tensors():
    def __init__(self):
        self.number_of_random_samples=20
        self.country_lookup= pd.read_csv('datasets/country_lookup.csv')
        self.tag2id = defaultdict(int,
                        {'city': 0,
                         'first_name': 1,
                         'geo': 2,
                         'percent': 3,
                         'year': 4,
                         'ssn': 5,
                         'language_name': 6,
                         'country_name': 7,
                         'phone_number': 8,
                         'month': 9,
                         'zipcode': 10,
                         'iso8601': 11,
                         'paragraph': 12,
                         'pyfloat': 13,
                         'email': 14,
                         'prefix': 15,
                         'pystr': 16,
                         'isbn': 17,
                         'boolean': 18,
                        'country_code':19,
                         'country_GID':20,
                        'continent':21,
                        'date': 22,
                        'day_of_month':23,
                        'day_of_week':24,
                        'date_long_dmdy':25,
                        'date_long_mdy': 26,
                        'date_long_dmdyt':27,
                        'date_long_mdyt_m':28,
                         'city_suffix':29,
                         'month_name':30

                         })
        self.n_categories = len(self.tag2id)
        self.token_set={'a','b','c','d','e',
                        'f','g','h','i','j','k','l',
                        'm','n','o','p','q','r','s',
                        't','u','v','w','x','y','z',
                        'A','B','C','D','E','F','G',
                        'H','I','J','K','L','M','N',
                        'O','P','Q','R','S','T','U',
                        'V','W','X','Y','Z','1','2',
                        '3','4','5','6','7','8','9','0',
                        "'",',','.',';','*','!','@',
                        '#','$','%','^','&','(',')',
                        '_','=','-',':','+','/',"\\", '*'}
        self.token2id = defaultdict(int,
            {'PAD': 0,
             'UNK': 1,
             'a':2,
             'b':3,
             'c': 4,
             'd': 5,
             'e': 6,
             'f': 7,
             'g':8,
             'h': 9,
             'i': 10,
             'j':11,
             'k':12,
             'l':13,
             'm':14,
             'n':15,
             'o':16,
             'p':17,
             'q':18,
             'r':19,
             's':20,
             't':21,
             'u':22,
             'v':23,
             'w':24,
             'x':25,
             'y':26,
             'z':27,
             'A':28,
             'B':29,
             'C':30,
             'D':31,
             'E':32,
             'F':33,
             'G':34,
             'H':35,
             'I':36,
             'J':37,
             'K':38,
             'L':39,
             'N':40,
             'O':41,
             'P':42,
             'Q':43,
             'R':44,
             'S':45,
             'T':46,
             'U':47,
             'V':48,
             'W':49,
             'X':50,
             'Y':51,
             'Z':52,
             '1':53,
             '2':54,
             '3':55,
             '4':56,
             '5':57,
             '6':58,
             '7':59,
             '8':60,
             '9':61,
             '0':62,
             "'":63,
             ',':64,
             '.':65,
             ';':66,
             '*':67,
             '!':68,
             '@':68,
             '#':70,
             '$':71,
             '%':72,
             '^':73,
             '&':74,
             '(':75,
             ')':76,
             '_':77,
             '=':78,
             '-':79,
             ':':80,
             '+':81,
             '/':82,
             '\\':83,
             '*': 84})
    
    def vectorized_string(self, string):
            return [self.token2id[token] if token in self.token2id else self.token2id['UNK'] for token in str(string)]
        
    def vectorized_array(self, array):
        vecorized_array=[]
        for stringValue in array:
            vecorized_array.append(self.vectorized_string(str(stringValue)))
        return vecorized_array
    
    def pad_sequences(self, vectorized_seqs, seq_lengths):
        # create a zero matrix
        seq_tensor = torch.zeros((len(vectorized_seqs), seq_lengths.max())).long()

        # fill the index
        for idx, (seq, seqlen) in enumerate(zip(vectorized_seqs, seq_lengths)):
            seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
        return seq_tensor


    def create_dataset(self, data, batch_size=1):
        
        vectorized_seqs = self.vectorized_array(data)
        seq_lengths = torch.LongTensor([len(s) for s in vectorized_seqs])
        seq_tensor = self.pad_sequences(vectorized_seqs, seq_lengths)
        target_tensor = torch.LongTensor([self.tag2id[y] for  y in data])
        raw_data = [x for x in data]
        
        return DataLoader(PaddedTensorDataset(seq_tensor, target_tensor, seq_lengths, raw_data), batch_size=batch_size)

    def sort_batch(self,batch, targets, lengths):
        seq_lengths, perm_idx = lengths.sort(0, descending=True)
        seq_tensor = batch[perm_idx]
        target_tensor = targets[perm_idx]

        return seq_tensor.transpose(0, 1), target_tensor, seq_lengths


    def evaluate_test_set(self, model, test):
        y_pred = list()
        all_predictionsforValue=[]

        for batch, targets, lengths, raw_data in self.create_dataset(test, batch_size=1):
            batch, targets, lengths = self.sort_batch(batch, targets, lengths)
            pred = model(torch.autograd.Variable(batch), lengths.cpu().numpy())
            pred_idx = torch.max(pred, 1)[1]
            def get_key(val):
                for key, value in self.tag2id.items():
                     if val == value:
                            return {'top_pred':key, 'tensor':pred, 'pred_idx':pred_idx}
#                             all_predictionsforValue.append({'key':key, 'tensor':pred, 'pred_idx':pred_idx})

            all_predictionsforValue.append(get_key(pred_idx[0]))
        return all_predictionsforValue
        
    def read_in_csv(self,path):
        df = pd.read_csv(path)
        return df
#         print(self.df.head())
    
    def get_arrayOfValues_df(self, df):
        column_value_object={}

        for column in df.columns:
            guesses=[]
            column_value_object[column]=[]
            for _ in range(1,self.number_of_random_samples):
                random_values = str(np.random.choice(df[column]))
                random_col = column
                column_value_object[column].append(random_values)

        return column_value_object
    def averaged_predictions(self, all_predictions):
        all_arrays=[]
        for pred in all_predictions:
                all_arrays.append(pred['tensor'].detach().numpy())
        
        out = np.mean(all_arrays, axis=0)
        def get_key(val):
                for key, value in self.tag2id.items():
                     if val == value:
                        return key

        return {
            "averaged_tensor":out,
            'averaged_top_category':get_key(np.argmax(out))
        }
    
    def predictions(self, model, path_to_csv):
        df = self.read_in_csv(path=path_to_csv)
        column_value_object = self.get_arrayOfValues_df(df)
        self.column_value_object = column_value_object
        predictions=[]
        for column in column_value_object:
#             print(column_value_object[column])
#             print(column)
#             print(self.evaluate_test_set(model, column_value_object[column]))
            all_predictions = self.evaluate_test_set(model, column_value_object[column])
            avg_predictions = self.averaged_predictions(all_predictions)
            predictions.append({
                    'column': column,
                    'values': column_value_object[column],
                    'avg_predictions': avg_predictions,
                    'model_predictions': all_predictions
                   
                })
    
        
#         for pred in predictions:
#                 if len(pred['values']) == len(pred['model_predictions']):
#                         for i, v in enumerate(pred['values']):
# #                             print(v, i)
#                             pred['model_predictions'][i]['original']=pred['values'][i]
# #                         print('values', pred['values'])
# #                         print(pred['model_predictions'])
        self.predictions = predictions
        return self.predictions
    
    
    def assign_heuristic_function(self, predictions):
        def none_f(values):
            print('none')
            return 'none'
            
        def city_country_f(values):
            for city_country in values:
                print('city_country', city_country)
            
            return 'city_country'
        
        def country_GID_f(values):
            GID_in_lookup=[]
            for gid in values:
                GID_in_lookup.append(self.country_lookup.isin([gid]).any())
                           
            print(len(GID_in_lookup))

            if len(GID_in_lookup) >= (len(values) * .85):
                print(len(values)*.85)
                return 'GID'
                
                
        def continent_f(values):
            for cont in values:
                print('continent',cont)
            return 'cont'
                
        def geo_f(values):
            geo_valid=[]
            for geo in values:
                if float(geo) <=180 and float(geo) >= -180:
                    if float(geo) <=90 and float(geo) >= -90:
                        print('geo_lat_lng', geo)
                        geo_valid.append('latlng')
                        
                    else:
                        print('lng', geo)
                        geo_valid.append('lng')
                else:
                    geo_valid.append('failed')
                        
            (unique, counts) = np.unique(geo_valid, return_counts=True)
            print(unique, counts)
            return 'geo'
                
        
        def iso_f(values):
            for iso in values:
                print('iso',iso)
            return 'iso'
                
        def year_f(values):
            year_values_valid=[]
            for year in values:
                if str.isdigit(str(year)):
                    if int(year) > 1300 and int(year)<2500:
                        print('probably valid year')
                        year_values_valid.append('true')
                        
                    else:
                        print('not a normal year')
                        year_values_valid.append('maybe')
                else:
                    print('not a digit')
                    year_values_valid.append('failed')
                    
            (unique, counts) = np.unique(year_values_valid, return_counts=True)
            print('unique', unique, 'counts', counts)
            if "failed" in year_values_valid:
                
                return 'failed'
            else:
            
                return 'year'
        
        def bool_f(values):
            for bools in values:
                print('boolian',bools)
            return 'boolian'
        
        def date_f(values):
            for date in values:
                print('date',date)
            return 'date'
                
        def month_day_f(values):
            month_day_results=[]
            for i, md in enumerate(values):
#                 print('monthday',md)
                if str.isdigit(md):
                    if int(md) <=12 and int(md) >= 1:
#                         print("month", md)
                        month_day_results.append('month_day')
                    elif int(md) >12 and int(md) <= 31:
#                         print('day', md)
                        month_day_results.append('day')
                    else:
                        month_day_results.append('failed')
#                         print('not month or day'
                else:
                    print('not a valid digit')
            
            if 'failed' in month_day_results:
                return 'none_by_failed'
            elif 'day' in month_day_results:
                return 'day'
            elif 'month_day' in month_day_results:
                return 'month'
            else:
                return 'none'
        
        def month_day_name_f(values):
            for month_day in values:
                print('month_day_name', month_day)
                
        functionlist = defaultdict(int,
            {'city': city_country_f,
             'language_name': city_country_f,
             'country_name': city_country_f,
             'country_code':city_country_f,
             'country_GID':country_GID_f,
             'city_suffix':city_country_f,

             'continent':continent_f,

             'geo': geo_f,


             'first_name': none_f,
             'percent': none_f,

             'ssn': none_f,
             'phone_number': none_f,
             'zipcode': none_f,
             'paragraph': none_f,
             'pyfloat': none_f,
             'email': none_f,
             'prefix': none_f,
             'pystr': none_f,
             'isbn': none_f,

             'boolean': bool_f,

             'iso8601': iso_f,


             'year': year_f,

             'day_of_month':month_day_f,
             'month': month_day_f,

             'month_name':month_day_name_f,
             'day_of_week':month_day_name_f,

             'date': date_f,
             'date_long_dmdy':date_f,
             'date_long_mdy': date_f,
             'date_long_dmdyt':date_f,
             'date_long_mdyt_m':date_f,
                 })
        final_column_classification=[]

        for pred in predictions:
            print('pred')
#             print('column', pred['column'])
#             print('pred',pred['avg_predictions'])
            fun = functionlist[pred['avg_predictions']['averaged_top_category']](self.column_value_object[pred['column']])
#             print(fun)
            final_column_classification.append({'column': pred['column'], 'classifcation': fun})
#         print('f_response', fun)
        
        return final_column_classification



In [361]:
c_lookup= pd.read_csv('datasets/country_lookup.csv')

In [364]:
c_lookup.head()

Unnamed: 0,iso3,Country.x
0,AFG,Afghanistan
1,ALB,Albania
2,ANT,Netherlands Antilles
3,ARB,Aruba
4,ARE,United Arab Emirates


In [391]:
# create our class and load the saved model
dft_tensor=DF_To_Tensors()
model2 = torch.load('./models/LSTM_RNN_Geotime_Classify_v_0.03.pth')
model2.eval()

LSTMClassifier(
  (embedding): Embedding(84, 128)
  (lstm): LSTM(128, 32)
  (hidden2out): Linear(in_features=32, out_features=31, bias=True)
  (softmax): LogSoftmax(dim=1)
  (dropout_layer): Dropout(p=0.2, inplace=False)
)

In [392]:
preds4=dft_tensor.predictions(model=model2, path_to_csv='datasets/data/africa_test.csv')


In [393]:
preds4

[{'column': 'fid',
  'values': ['31',
   '20',
   '71',
   '59',
   '50',
   '41',
   '58',
   '77',
   '82',
   '73',
   '47',
   '33',
   '19',
   '75',
   '15',
   '62',
   '46',
   '78',
   '19'],
  'avg_predictions': {'averaged_tensor': array([[-10.25605   , -10.747086  ,  -9.734103  ,  -7.4590816 ,
            -4.8392344 , -11.851789  ,  -8.677875  ,  -9.47764   ,
            -9.674136  ,  -1.8404397 ,  -7.30422   ,  -9.29116   ,
           -15.417233  , -10.907297  ,  -9.796656  , -15.877197  ,
           -10.626991  , -11.269174  , -18.436161  ,  -9.762193  ,
           -12.689939  , -12.317076  ,  -8.548533  ,  -0.33501458,
           -15.414353  ,  -8.343185  ,  -9.918771  ,  -9.879967  ,
           -11.040152  , -13.485345  , -13.765615  ]], dtype=float32),
   'averaged_top_category': 'day_of_month'},
  'model_predictions': [{'top_pred': 'day_of_month',
    'tensor': tensor([[-10.6928, -10.7004, -11.4096,  -5.2742,  -6.2837, -12.7637, -10.5388,
             -10.6554, -12.634

In [394]:
dft_tensor.assign_heuristic_function(preds4)

pred
pred
city_country Yemen
city_country Egypt
city_country Somalia
city_country Bahrain
city_country Malawi
city_country Seychelles
city_country Comoros
city_country Egypt
city_country Botswana
city_country Morocco
city_country Namibia
city_country Tanzania
city_country Yemen
city_country Tanzania
city_country Angola
city_country Mauritania
city_country Yemen
city_country Seychelles
city_country Egypt
pred
none
pred
city_country Freetown
city_country Windhoek
city_country Lilongwe
city_country Jerusalem
city_country Pretoria
city_country Abu Dhabi
city_country Addis Ababa
city_country Lome
city_country N'Djamena
city_country Yaounde
city_country Masqat
city_country Yaounde
city_country N'Djamena
city_country N'Djamena
city_country Sana'a
city_country Rabat
city_country Juba
city_country Jerusalem
city_country Tripoli
pred
geo_lat_lng 30.06056
geo_lat_lng 39.18333
geo_lat_lng 13.23444
geo_lat_lng 13.18
geo_lat_lng 35.22361
geo_lat_lng 35.22361
geo_lat_lng -17.5
geo_lat_lng -1.524722
g

[{'column': 'fid', 'classifcation': 'none_by_failed'},
 {'column': 'CNTRY_NAME', 'classifcation': 'city_country'},
 {'column': 'AREA', 'classifcation': 'none'},
 {'column': 'CAPNAME', 'classifcation': 'city_country'},
 {'column': 'CAPLONG', 'classifcation': 'geo'},
 {'column': 'CAPLAT', 'classifcation': 'geo'},
 {'column': 'FEATUREID', 'classifcation': 'none_by_failed'},
 {'column': 'COWCODE', 'classifcation': 'year'},
 {'column': 'COWSYEAR', 'classifcation': 'year'},
 {'column': 'COWSMONTH', 'classifcation': 'month'},
 {'column': 'COWSDAY', 'classifcation': 'day'},
 {'column': 'COWEYEAR', 'classifcation': 'year'},
 {'column': 'COWEMONTH', 'classifcation': 'month'},
 {'column': 'COWEDAY', 'classifcation': 'day'},
 {'column': 'GWCODE', 'classifcation': 'year'},
 {'column': 'GWSYEAR', 'classifcation': 'failed'},
 {'column': 'GWSMONTH', 'classifcation': 'month'},
 {'column': 'GWSDAY', 'classifcation': 'day'},
 {'column': 'GWEYEAR', 'classifcation': 'year'},
 {'column': 'GWEMONTH', 'classi

In [359]:
preds3=dft_tensor.predictions(model=model2, path_to_csv='datasets/data/four_col_test.csv')
preds3

[{'column': 'iso3',
  'values': ['SDN',
   'MDA',
   'GAB',
   'TGO',
   'GMB',
   'SWZ',
   'PRY',
   'SWE',
   'JPN',
   'IRN',
   'NER',
   'SDN',
   'KHM',
   'MLT',
   'LTU',
   'CZE',
   'AUS',
   'CHL',
   'CHN'],
  'avg_predictions': {'averaged_tensor': array([[-14.255122 ,  -9.381731 , -13.233103 , -13.732901 , -13.661171 ,
           -12.358272 , -12.251612 , -11.431745 , -11.204177 , -11.405403 ,
           -11.240194 , -12.564417 , -15.554382 , -15.704672 , -14.209358 ,
           -11.197466 ,  -6.1207824, -14.951247 , -12.574597 ,  -2.172541 ,
            -0.1859478,  -7.381302 , -12.568318 , -12.071199 , -11.897299 ,
           -11.070905 , -13.114594 , -14.236744 , -12.013137 , -16.62177  ,
           -10.656276 ]], dtype=float32),
   'averaged_top_category': 'country_GID'},
  'model_predictions': [{'top_pred': 'country_GID',
    'tensor': tensor([[-13.3551,  -8.5325, -11.2893, -12.9909, -11.7928, -12.2919, -11.1005,
             -10.1040, -10.5430,  -9.6620, -11.0159, -

In [360]:
dft_tensor.assign_heuristic_function(preds3)

pred
city_country SDN
city_country MDA
city_country GAB
city_country TGO
city_country GMB
city_country SWZ
city_country PRY
city_country SWE
city_country JPN
city_country IRN
city_country NER
city_country SDN
city_country KHM
city_country MLT
city_country LTU
city_country CZE
city_country AUS
city_country CHL
city_country CHN
pred
city_country Cyprus
city_country Cuba
city_country Madagascar
city_country Benin
city_country Sweden
city_country Sierra Leone
city_country Singapore
city_country Lesotho
city_country Argentina
city_country Somalia
city_country Taiwan
city_country Paraguay
city_country Mauritania
city_country Bermuda
city_country Turkey
city_country Chile
city_country Solomon Islands
city_country Nepal
city_country Dominican Republic
pred
pred
none
pred
none
pred
none
pred
geo_lat_lng -0.00395954
geo_lat_lng 0.0
geo_lat_lng 0.0
geo_lat_lng -0.034388804
geo_lat_lng -9.34e-05
geo_lat_lng -0.027486997
geo_lat_lng -0.114451442
geo_lat_lng -0.007441719
geo_lat_lng -0.005390241
geo

[{'column': 'iso3', 'classifcation': 'city_country'},
 {'column': 'Country.x', 'classifcation': 'city_country'},
 {'column': 'Year', 'classifcation': 'none_by_failed'},
 {'column': 'Value_Reserve_TimeSeries', 'classifcation': 'none'},
 {'column': 'Value_Shortage_TimeSeries', 'classifcation': 'none'},
 {'column': 'Value_ConsumptiontoC0_TimeSeries', 'classifcation': 'none'},
 {'column': 'Value_ReserveChangetoC0_TimeSeries', 'classifcation': 'geo'},
 {'column': 'Value_Production_TimeSeries', 'classifcation': 'none'}]

In [188]:
preds4=dft_tensor.predictions(model=model, path_to_csv='datasets/data/cshape_africa.csv')
preds4

[{'column': 'country', 'values': ['Zambia', 'Liberia', 'Rwanda', 'Egypt', 'Saudi Arabia', 'Guinea-Bissau', 'The Gambia', 'Israel', 'Equatorial Guinea'], 'avg_predictions': {'averaged_tensor': array([[ -4.430699  ,  -3.4931073 , -18.022495  , -14.277381  ,
        -12.642521  , -14.933206  ,  -2.9969385 ,  -0.55397224,
        -15.657098  ,  -9.320038  , -17.297201  , -13.390749  ,
        -10.028104  , -18.216972  , -11.068612  , -16.259323  ,
         -8.840366  , -14.575265  , -10.401575  , -12.39828   ,
        -15.327362  ,  -9.153301  , -15.053763  ,  -8.945912  ,
         -9.090069  ,  -9.435516  , -14.266641  , -13.459404  ,
        -14.361956  , -10.552912  , -10.768927  ]], dtype=float32), 'averaged_top_category': 'country_name'}, 'model_predictions': [{'top_pred': 'country_name', 'tensor': tensor([[ -6.0386,  -2.8219, -16.5380, -14.3044, -13.9985, -14.2122,  -2.0446,
          -0.2139, -14.9084, -10.2137, -17.6548, -12.0147,  -9.5135, -16.7261,
         -10.3072, -15.3954,  -

In [189]:
# Just test a random array by itself.
dft_tensor.evaluate_test_set(model=model,test=['01/02/2020', '09', 'USA', 'WOWOOWOWOWOWOWOWO', '1999', '1200', '402', '.3934', 'iiii', '12', 'USA',"ETH",'South America'])

[{'top_pred': 'date',
  'tensor': tensor([[-1.8405e+01, -1.3475e+01, -1.0843e+01, -1.3926e+01, -9.8927e+00,
           -4.1174e+00, -1.3473e+01, -1.5847e+01, -8.4487e+00, -7.4543e+00,
           -1.0893e+01, -6.7751e+00, -2.0338e+01, -1.3331e+01, -1.9767e+01,
           -1.5002e+01, -1.4928e+01, -8.4395e+00, -1.4448e+01, -8.6970e+00,
           -1.0094e+01, -1.1057e+01, -1.9654e-02, -8.0456e+00, -1.5556e+01,
           -9.3909e+00, -9.4819e+00, -8.5697e+00, -1.0287e+01, -1.7160e+01,
           -1.4445e+01]], grad_fn=<LogSoftmaxBackward>),
  'pred_idx': tensor([22])},
 {'top_pred': 'month',
  'tensor': tensor([[-10.4091, -12.7464, -12.3732,  -7.0635,  -5.8465, -16.1444,  -7.0775,
            -9.0475, -13.9970,  -0.5013, -11.1098, -11.0008, -14.9622, -12.5449,
           -10.6254, -17.3869, -10.4383, -12.3830, -18.6079, -12.5750, -15.1814,
           -11.7973, -10.2138,  -0.9437, -16.1655,  -9.1353, -11.2784,  -9.8058,
           -12.1363, -15.2351, -16.6754]], grad_fn=<LogSoftmaxBackwar

In [41]:

# for later use
def randomChoice(self, values):
    return values[random.randint(0, len(values) - 1)]

def getRandomSet(self):
    category = self.randomChoice(self.all_categories)
#         print(category)
    line = self.randomChoice(list(self.category_values[category]['obj']))
#         print('line', line)
    return (line, category)