In [1]:
# Install the required modules
!pip install transformers
!pip install pandas
!pip install numpy
!pip install tqdm
!pip install sklearn

from google.colab import drive
drive.mount('/content/gdrive')

# !nvidia-smi

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/70/1a/364556102943cacde1ee00fdcae3b1615b39e52649eddbf54953e5b144c9/transformers-2.2.1-py3-none-any.whl (364kB)
[K     |█                               | 10kB 26.0MB/s eta 0:00:01[K     |█▉                              | 20kB 2.2MB/s eta 0:00:01[K     |██▊                             | 30kB 3.2MB/s eta 0:00:01[K     |███▋                            | 40kB 2.1MB/s eta 0:00:01[K     |████▌                           | 51kB 2.6MB/s eta 0:00:01[K     |█████▍                          | 61kB 3.1MB/s eta 0:00:01[K     |██████▎                         | 71kB 3.6MB/s eta 0:00:01[K     |███████▏                        | 81kB 4.1MB/s eta 0:00:01[K     |████████                        | 92kB 4.6MB/s eta 0:00:01[K     |█████████                       | 102kB 3.5MB/s eta 0:00:01[K     |█████████▉                      | 112kB 3.5MB/s eta 0:00:01[K     |██████████▊                     | 122kB 3.5M

In [2]:
import pandas as pd
from tqdm import tqdm
from sklearn import preprocessing
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import confusion_matrix
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import *

In [0]:
# Embeddings can be derived from the last 1 or 4 layers, to reduce the computational cost, we used only the last layer.

class Embeddings:
    LAST_LAYER = 1
    LAST_4_LAYERS = 2
    def __init__(self):
        self._tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self._bert_model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
        self._bert_model.eval()

    def tokenize(self, sentence):
        """

        :param sentence: input sentence ['str']
        :return: tokenized sentence based on word piece model ['List']
        """
        marked_sentence = "[CLS] " + sentence + " [SEP]"
        tokenized_text = self._tokenizer.tokenize(marked_sentence)
        return tokenized_text

    def get_bert_embeddings(self, sentence):
        """

        :param sentence: input sentence ['str']
        :return: BERT pre-trained hidden states (list of torch tensors) ['List']
        """
        # Predict hidden states features for each layer

        tokenized_text = self.tokenize(sentence)
        indexed_tokens = self._tokenizer.convert_tokens_to_ids(tokenized_text)

        segments_ids = [1] * len(tokenized_text)

        # Convert inputs to PyTorch tensors
        tokens_tensor = torch.tensor([indexed_tokens])
        segments_tensors = torch.tensor([segments_ids])

        with torch.no_grad():
            encoded_layers = self._bert_model(tokens_tensor, segments_tensors)

        return encoded_layers[-1][0:12]

    def sentence2vec(self, sentence, layers):
        """

        :param sentence: input sentence ['str']
        :param layers: parameter to decide how word embeddings are obtained ['str]
            1. 'last' : last hidden state used to obtain word embeddings for sentence tokens
            2. 'last_4' : last 4 hidden states used to obtain word embeddings for sentence tokens

        :return: sentence vector [List]
        """
        encoded_layers = self.get_bert_embeddings(sentence)
        
        if layers == 1:
            # using the last layer embeddings
            token_embeddings = encoded_layers[-1]
            # summing the last layer vectors for each token
            sentence_embedding = torch.mean(token_embeddings, 1)
            return sentence_embedding.view(-1).tolist()

        elif layers == 2:
            token_embeddings = []
            tokenized_text = self.tokenize(sentence)

            batch_i = 0
            # For each token in the sentence...
            for token_i in range(len(tokenized_text)):

                # Holds 12 layers of hidden states for each token
                hidden_layers = []

                # For each of the 12 layers...
                for layer_i in range(len(encoded_layers)):
                    # Lookup the vector for `token_i` in `layer_i`
                    vec = encoded_layers[layer_i][batch_i][token_i]

                    hidden_layers.append(list(vec.numpy()))

                token_embeddings.append(hidden_layers)

            # using the last 4 layer embeddings
            token_vecs_sum = []

            # For each token in the sentence...
            for token in token_embeddings:
                # Sum the vectors from the last four layers.
                sum_vec = np.sum(token[-4:], axis=0)

                # Use `sum_vec` to represent `token`.
                token_vecs_sum.append(list(sum_vec))

            # summing the last layer vectors for each token
            sentence_embedding = np.mean(token_vecs_sum, axis=0)
            return sentence_embedding.ravel().tolist()

In [0]:
# Dataset: 3000 chunks * 3 authors, without masking

url = 'https://raw.githubusercontent.com/fy164251/text_style_transfer/master/Datasets/raw_text_3000.csv'
df = pd.read_csv(url)

X = df.text.astype('str')
y = df.author.astype('category')

# lbl_enc = preprocessing.LabelEncoder()
# y = lbl_enc.fit_transform(y.values)

y = np.asarray(y)
onehot_encoder = preprocessing.OneHotEncoder(sparse=False)
encoded = y.reshape(len(y), 1)
y = onehot_encoder.fit_transform(encoded)

In case you used a LabelEncoder before this OneHotEncoder to convert the categories to integers, then you can now use the OneHotEncoder directly.


In [0]:
# Dataset: 3000 chunks * 3 authors, with masking

url = 'https://raw.githubusercontent.com/fy164251/text_style_transfer/master/Datasets/masked_text_3000.csv'
df = pd.read_csv(url)

X = df.text.astype('str')
y = df.author.astype('category')

y = np.asarray(y)
onehot_encoder = preprocessing.OneHotEncoder(sparse=False)
encoded = y.reshape(len(y), 1)
y = onehot_encoder.fit_transform(encoded)

In [5]:
model = Embeddings()

X_text = []
for sentence in tqdm(X):
    X_text.append(model.sentence2vec(sentence, layers=model.LAST_LAYER))

100%|██████████| 231508/231508 [00:00<00:00, 2670653.60B/s]
100%|██████████| 313/313 [00:00<00:00, 76004.00B/s]
100%|██████████| 440473133/440473133 [00:08<00:00, 50072121.61B/s]
100%|██████████| 9000/9000 [1:19:40<00:00,  1.82it/s]


In [7]:
# X_df = pd.DataFrame(X_text)
# X_df.to_csv('./gdrive/My Drive/DL/Style/DistilBert_Embedding_3000_2.csv')
X_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767
0,0.089255,0.474704,0.171598,-0.445951,-0.139031,-0.095083,0.398022,0.262700,0.213528,-0.004183,0.576964,-0.121572,-0.239475,0.293347,-0.084161,0.894064,0.147231,0.401929,0.129791,0.258735,0.685502,0.079295,0.334323,1.034874,0.093368,0.347640,0.212769,-0.384932,-0.288673,-0.026614,0.530695,-0.004031,0.158224,-0.237739,0.388905,-0.054335,0.544653,-0.439034,0.523372,0.383997,...,0.037189,-0.414413,-0.182625,-0.996050,-0.053274,-0.579955,0.225314,0.360031,0.002425,0.005650,0.022793,0.268636,0.278417,-0.433382,0.136677,-0.082261,-0.230011,-0.054747,0.223735,0.856815,-0.240603,0.567785,-0.011137,-0.196050,0.296334,0.254625,-0.015677,-0.149860,-0.363567,-0.522562,-0.720226,-0.107708,0.265608,-0.207027,0.666353,-0.288064,-0.198666,0.036834,0.211623,0.179178
1,-0.340932,0.221050,0.782805,-0.581847,0.461820,-0.240045,0.203510,0.163061,0.228730,-0.404142,-0.079556,-0.300071,-0.163420,0.303376,0.099836,1.423623,0.011883,0.281750,0.211207,0.114722,1.000818,-0.080286,0.418081,0.432044,0.134941,0.329491,0.002067,-0.246614,-0.725172,0.290682,0.589696,-0.288656,0.250666,-0.276592,0.094576,-0.686130,0.035607,-0.318936,-0.025617,0.274376,...,-0.250307,-0.626264,-0.209951,-0.134688,-0.218366,-0.712746,-0.106775,-0.393539,-0.257189,0.242331,0.257851,-0.012872,0.346436,0.140753,0.183796,-0.812609,0.096131,0.087510,0.073986,0.639240,-0.432956,0.641596,0.018408,-0.177756,-0.095648,0.193993,0.135293,-0.477258,-0.690213,-0.110558,-0.294730,-0.006305,-0.257251,-0.614889,0.277610,-0.254113,-0.314011,-0.177580,0.440142,0.300442
2,-0.335891,0.215750,0.134236,-0.398985,0.100357,-0.276152,0.012297,0.451921,-0.033935,-0.082510,0.433583,-0.274968,0.025606,0.025651,-0.230578,0.716468,0.404485,0.144631,0.428208,0.019562,0.691490,0.314824,0.169805,0.501763,0.271473,0.199180,0.172446,-0.116188,-0.452280,0.228855,0.681822,0.008252,-0.115546,-0.551180,0.012454,-0.525111,0.252605,-0.403585,0.586612,0.120770,...,-0.129724,-0.507280,-0.156877,-0.334767,-0.140801,-0.423644,0.266033,0.174512,0.037891,0.047548,0.276495,0.265085,0.240179,0.121982,0.240448,-0.172127,-0.204381,0.216384,0.181148,0.653867,-0.144304,0.558028,-0.290160,-0.089322,-0.047649,-0.118058,0.032081,-0.104082,0.046842,-0.029964,-0.514549,0.057987,-0.115259,0.224609,0.359204,-0.294915,-0.363483,-0.051290,0.363466,0.152344
3,-0.156222,0.466047,0.407739,-0.488406,-0.134691,0.180689,0.119333,0.555957,-0.380613,0.106727,0.350529,-0.594996,-0.018465,0.219168,-0.017954,0.668042,0.372010,0.225265,-0.009484,0.251160,0.478516,0.029257,0.119113,0.528693,0.074440,0.295118,-0.032307,-0.149913,-0.470186,0.174785,0.256214,0.150042,0.303352,-0.476231,0.053985,-0.321143,0.158450,-0.342911,0.354510,0.329658,...,-0.254499,-0.511102,-0.289797,-0.663957,-0.208240,-0.269668,-0.000019,0.476862,-0.435070,-0.319236,-0.132682,0.485799,0.169192,-0.452174,0.311044,-0.247674,-0.211432,0.119443,0.363887,0.758694,-0.099674,0.347741,-0.247450,-0.375582,-0.018978,-0.099293,0.056441,-0.230396,-0.302153,-0.292718,-0.558637,-0.064437,0.064220,-0.150074,0.565909,-0.179520,-0.388975,0.025525,0.498361,0.313834
4,-0.187530,0.268600,0.311052,-0.200505,0.015567,-0.144803,-0.014919,0.373741,0.093312,-0.321745,0.163310,-0.396472,-0.016497,-0.006123,-0.377307,0.720480,0.312520,0.613831,0.084854,-0.234721,1.083672,-0.332469,0.131527,0.448986,0.228013,-0.003657,0.131388,-0.161173,-0.393312,0.170627,0.593013,-0.103269,0.474441,-0.234904,0.048970,-0.413106,0.394481,-0.358634,0.347066,-0.008929,...,0.117894,-0.644592,-0.153507,-0.261847,-0.348074,-0.581157,-0.168289,0.061962,0.115857,0.043849,-0.130765,0.537401,0.136393,-0.616925,-0.219375,-0.171489,0.045843,0.463768,0.320474,0.664438,-0.564877,0.214166,0.121975,-0.381541,0.032417,0.066512,0.479395,-0.364574,-0.432845,-0.817666,-0.713633,-0.180551,-0.376038,0.114734,0.641779,-0.507370,-0.481888,-0.259569,0.059434,0.451487
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8995,0.091006,0.561350,0.220087,-0.757771,0.133639,-0.109520,0.180277,0.367699,-0.224234,-0.332058,-0.089731,-0.218915,-0.109006,0.048059,-0.263213,0.863571,0.212466,0.389487,0.614161,-0.086711,0.788468,-0.060691,-0.076992,0.318297,0.327359,0.172609,0.494178,-0.093250,-0.642839,-0.063243,0.781365,0.175113,-0.152739,-0.280006,-0.169909,-0.313993,0.450466,-0.352787,-0.029192,0.211810,...,-0.026387,-0.291474,-0.140296,-0.030668,-0.095701,-0.493935,0.110218,-0.414005,-0.178922,0.224465,-0.055236,0.047705,0.480967,-0.190057,0.224529,-0.089975,0.056657,0.022135,-0.075964,0.397923,-0.350319,0.295857,0.259457,-0.614115,0.262725,-0.269345,-0.008568,-0.338333,-0.507381,-0.184507,-0.567601,0.110983,-0.044716,-0.272419,0.226073,-0.190810,-0.389146,-0.013880,0.462799,0.254994
8996,0.106377,0.364273,0.648710,-0.651089,0.579585,-0.194366,-0.166921,0.816909,-0.206333,-0.141815,0.311151,-0.382003,0.178676,0.188375,-0.377538,0.937112,0.124876,0.456146,0.147312,-0.224139,0.760303,0.072268,0.033559,0.086394,0.094563,0.491350,0.366786,0.298446,-0.476386,-0.132563,0.493288,0.055220,-0.000392,-0.430928,0.127056,-0.686848,0.537682,-0.260890,-0.077890,0.505422,...,0.073203,-0.474479,-0.314126,-0.324037,-0.047044,-0.735415,0.272512,-0.134246,-0.338725,0.124791,0.085217,0.354539,-0.002648,-0.183943,0.093347,-0.081096,-0.133964,0.267154,0.396371,0.388802,-0.493829,0.509024,0.112676,-0.273665,0.272412,0.059530,0.241940,-0.275573,-0.493048,-0.192752,-0.309807,-0.018762,0.060558,-0.009681,0.366008,-0.486010,-0.402157,0.202666,0.450038,0.294662
8997,0.200587,0.632470,0.452327,-0.333864,0.105404,0.116012,0.251767,0.328868,-0.012447,-0.314254,0.149082,-0.242813,-0.076636,0.593439,0.213088,0.834758,0.320840,0.100508,0.091595,-0.229268,1.150250,-0.167489,0.320666,0.302161,0.288873,0.288023,0.329725,-0.048066,-0.749308,-0.145215,0.506286,-0.147755,0.375724,-0.401739,-0.195842,-0.297437,-0.012209,-0.290802,-0.012745,-0.086297,...,0.340454,-0.293618,-0.314073,-0.433923,0.122052,-0.360461,-0.107032,-0.139771,-0.065770,0.072193,-0.287180,0.174737,0.141031,0.121455,0.405234,0.022682,0.163822,0.264478,0.184069,0.484682,-0.139608,0.799220,-0.029043,-0.470777,-0.035307,0.199051,0.209762,-0.564824,-0.887952,-0.082044,-0.386754,0.013043,0.367012,-0.454167,0.118667,-0.097619,-0.270266,0.187779,0.342686,-0.030291
8998,-0.275840,0.431549,0.353404,-0.389437,-0.153087,-0.148158,0.149882,0.477485,-0.354752,-0.184685,0.010818,-0.363396,0.134015,0.318329,-0.007664,0.815095,0.340491,0.229355,0.201636,0.155619,0.826316,0.026976,0.108195,0.706576,0.184612,0.110016,-0.107746,-0.405019,-0.477389,0.105523,0.329453,0.043848,0.128855,-0.314362,-0.489506,-0.226411,0.420743,-0.313750,0.194902,0.510837,...,-0.182181,-0.530322,-0.110891,-0.312774,-0.071164,-0.630336,-0.043317,0.257743,-0.055170,0.000216,-0.172036,0.161880,0.196169,-0.337393,0.292819,-0.008794,-0.377003,0.001030,-0.163677,0.576054,-0.423863,0.678619,0.131139,-0.296466,-0.155286,-0.280105,-0.025960,-0.580991,-0.219401,-0.463358,-0.647055,0.039960,0.287399,-0.079697,-0.009502,-0.003524,-0.030494,0.129787,0.480699,0.319759


In [8]:
# X_df = pd.DataFrame(X_text)
# X_df.to_csv('./gdrive/My Drive/DL/Style/DistilBert_Embedding_3000.csv')

# X_df = pd.read_csv('./gdrive/My Drive/DL/Style/DistilBert_Embedding_3000_2.csv').set_index('Unnamed: 0')

X_train, X_val, y_train, y_val = train_test_split(X_df, y, stratify=y, random_state=1, test_size=0.2, shuffle=True)
X_val, X_test, y_val, y_test = train_test_split(X_val, y_val, test_size=0.5, random_state=1, shuffle=True)

print(X_train.shape, X_val.shape)

(7200, 768) (900, 768)


In [9]:
X_train

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,...,728,729,730,731,732,733,734,735,736,737,738,739,740,741,742,743,744,745,746,747,748,749,750,751,752,753,754,755,756,757,758,759,760,761,762,763,764,765,766,767
5614,-0.340161,-0.148656,0.602034,-0.272003,-0.341873,-0.120357,0.224259,0.206777,-0.050708,-0.413362,0.106826,-0.422804,-0.132413,0.146387,-0.105816,0.753637,0.414343,0.122540,0.179269,0.088668,0.793956,-0.148916,0.095241,0.706351,0.180122,0.449534,0.147201,-0.237596,-0.151617,-0.037453,0.579897,0.074145,0.068724,-0.265257,-0.118240,-0.280390,0.099969,-0.501312,0.068412,0.072087,...,-0.236783,-0.722360,-0.142746,-0.070288,-0.142651,-0.857589,-0.207489,0.172034,-0.133808,-0.344236,-0.061278,0.064547,0.122214,-0.097235,0.183206,-0.092817,-0.457735,0.098827,-0.223504,0.734053,-0.446106,0.604875,0.058997,-0.558629,-0.211270,-0.087967,-0.125183,-0.588484,-0.453498,-0.176318,-0.628320,0.045665,0.183669,-0.280129,-0.195073,-0.313454,-0.230915,-0.017511,0.452152,0.481604
2383,-0.391934,0.589659,0.081654,-0.364856,0.272116,-0.185212,-0.179081,0.248599,-0.116149,-0.085957,0.279195,-0.495050,-0.024453,0.470089,0.149684,1.075918,0.142027,-0.185910,0.287008,0.135054,0.798489,-0.099341,0.331279,0.542002,0.268069,0.279488,0.068858,-0.253427,-0.456926,0.170671,0.415524,-0.215536,-0.112373,-0.604337,0.137619,-0.606019,0.378203,-0.173980,0.747436,0.213383,...,-0.194351,-0.251167,-0.103786,-0.614198,-0.427270,-0.536838,-0.142121,0.175018,0.257659,0.063984,0.204867,0.264982,0.090459,-0.156936,0.328172,-0.267422,0.004332,0.437211,0.271880,0.837500,-0.222474,0.509900,-0.052957,-0.122060,0.306961,0.155461,0.102520,-0.191638,-0.222842,-0.347603,-0.569109,0.100797,-0.224095,0.094102,0.375749,0.166803,-0.509638,0.045654,0.267326,0.322551
5448,-0.311391,0.246223,0.322206,-0.626673,-0.120535,-0.328576,0.340233,0.150335,-0.228633,-0.119519,0.234018,-0.455078,0.094305,0.039104,-0.354940,1.053825,0.437839,0.352679,0.163303,-0.222175,0.575262,-0.151778,0.039601,0.237160,0.136593,0.441844,0.216453,-0.411612,-0.499689,-0.125523,0.927454,-0.064041,0.280551,0.071708,-0.018094,-0.288303,0.428391,-0.406295,0.066329,0.261219,...,-0.063150,-0.204630,-0.137450,-0.170344,-0.314548,-0.759352,-0.347153,-0.266005,-0.034147,0.089698,-0.136013,0.242091,0.530148,-0.121507,0.408584,-0.177793,-0.448586,-0.175234,0.130837,0.919979,-0.481680,0.259219,0.216071,-0.403066,-0.032478,0.047672,0.106018,-0.227465,-0.409935,-0.288559,-0.625146,0.146311,0.043609,0.026973,0.094333,-0.698623,-0.382027,-0.658689,0.190388,0.197619
8412,-0.157567,0.439942,0.159633,-0.716131,0.394165,-0.170059,0.290004,0.776767,-0.385152,-0.022196,0.137164,-0.146209,0.309307,-0.086701,-0.312970,0.535559,0.378450,0.155951,0.159643,-0.078866,0.229330,-0.042130,-0.046688,0.110598,-0.052326,0.334858,0.294528,0.001301,-0.752542,0.043778,0.562868,0.093217,-0.056954,-0.345783,-0.071456,-0.615525,0.299450,-0.361907,0.153372,0.102023,...,-0.160457,-0.595206,-0.231311,-0.040880,0.016651,-0.327501,-0.104689,-0.158801,-0.196048,0.014901,-0.336172,-0.002332,0.202591,-0.242723,0.059663,-0.198993,0.126893,0.143854,0.008069,0.378277,-0.390169,0.389232,0.494898,-0.393445,0.279485,0.022502,0.140657,0.021231,-0.466449,0.030109,-0.550439,0.104452,-0.086756,-0.266146,0.756947,-0.396499,-0.247555,-0.413154,0.299626,0.228981
1448,-0.166418,0.318898,0.330421,-0.459612,-0.162131,0.267957,0.236269,0.415457,-0.052450,-0.242916,0.365095,-0.367487,-0.144583,0.366747,-0.179655,0.657255,0.487739,0.288062,-0.025196,-0.025123,1.028283,0.087602,0.023006,0.461438,0.195829,0.166410,-0.029758,-0.221556,-0.417432,0.024154,0.216279,-0.060858,0.727997,-0.428053,0.269999,-0.264617,0.174458,-0.049764,0.479403,0.235735,...,0.167074,-0.798584,-0.346666,-0.463683,-0.125071,-0.276027,-0.208865,0.138834,-0.086230,-0.017895,-0.248692,0.168226,0.456018,-0.261686,0.062923,-0.336427,-0.276823,0.119404,0.571929,0.441417,-0.024522,0.412292,-0.085683,-0.287694,-0.030651,0.387611,0.239172,-0.372846,-0.460792,-0.505841,-0.717796,0.123918,0.077422,0.259605,0.282041,-0.028263,-0.214877,-0.100879,0.318270,-0.004839
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2512,-0.095390,0.449202,0.224358,-0.576839,0.287336,-0.409430,-0.004248,0.780078,-0.096692,-0.372438,-0.105739,-0.333202,0.027714,0.324660,-0.215557,0.960718,0.474020,0.017450,0.015001,-0.056878,0.432339,-0.218721,-0.218295,0.222339,0.328858,0.313048,0.265837,-0.011704,-0.654497,-0.086932,0.333828,0.368232,0.113449,-0.388320,-0.254458,-0.130834,0.342358,-0.447215,0.014278,-0.027058,...,-0.537184,-0.904238,-0.234918,-0.608328,0.132864,-0.577511,-0.558329,-0.132951,-0.165845,-0.032519,-0.064881,0.362410,-0.008946,-0.385552,0.028843,-0.635229,-0.100598,0.389611,0.144955,0.429051,-0.211813,0.540881,0.100470,-0.189920,0.252190,0.060027,0.083821,-0.448317,-0.622622,-0.158499,-0.493352,0.116451,-0.185150,-0.215702,0.311338,0.104676,-0.270898,-0.086899,0.236654,0.390968
8694,-0.145079,0.475734,0.365407,-0.620131,0.328844,0.151111,0.108706,0.852761,-0.457996,-0.247563,-0.029036,-0.127879,0.112376,0.168221,-0.282487,0.590009,0.499115,0.254812,0.427291,-0.123747,0.638318,-0.070385,0.052523,0.335780,0.097913,0.173838,0.614560,0.106443,-0.349500,-0.058919,0.505293,-0.064001,-0.172986,-0.075850,-0.132227,-0.384453,0.370201,-0.468317,-0.201656,0.048853,...,-0.265857,-0.223197,-0.058934,-0.323132,0.073406,-0.371595,-0.044837,-0.001413,0.018509,0.258279,-0.317874,0.332849,0.212638,0.056430,0.370784,-0.076809,0.180343,0.098970,0.058247,0.772120,-0.482748,0.695418,0.071432,-0.349727,0.203655,0.144502,0.156430,-0.250324,-0.374293,-0.250994,-0.330648,-0.096069,0.097628,-0.629242,0.276352,-0.184843,-0.126158,0.001324,0.549695,0.128526
8368,-0.090581,0.432085,0.374877,-0.471052,0.101587,0.137666,0.020140,0.438660,-0.227439,-0.421619,-0.015464,-0.266985,-0.196794,0.108737,-0.082186,0.811929,0.190508,0.046835,-0.133266,-0.312746,0.837726,-0.376351,0.239629,0.392617,0.293395,0.253092,0.570215,-0.252462,-0.398234,-0.221244,0.518364,0.045404,0.120460,-0.335099,-0.117464,-0.292256,0.200237,-0.173357,0.191883,0.055143,...,-0.203764,-0.207143,-0.417275,-0.421830,0.229466,-0.182841,0.209093,-0.340755,-0.520002,0.335023,0.064961,0.460892,0.196377,0.173352,0.518383,0.143448,0.168927,-0.019226,-0.087348,0.389946,-0.294117,0.699093,0.011120,-0.384840,-0.081319,0.028739,0.000828,-0.147006,-0.560598,-0.503186,-0.468980,0.193140,0.157212,-0.501378,0.166521,-0.202339,-0.203073,0.063565,0.301909,0.507903
3886,-0.490045,0.403811,0.233433,-1.101597,0.230473,-0.140833,0.170800,0.423244,-0.396777,-0.049143,0.206543,-0.689884,0.505751,-0.306384,-0.464576,0.924472,0.342408,0.543019,0.400223,-0.173314,0.342301,-0.150971,-0.109750,-0.179589,0.230435,0.035653,0.412052,-0.130912,-0.566213,-0.381750,0.798290,0.468334,0.196522,0.037070,0.054352,-0.617227,0.368253,-0.336719,0.157408,0.318117,...,0.136145,-0.158449,0.145648,-0.330110,0.095605,-0.911029,-0.152521,-0.277227,-0.283700,0.005756,-0.165566,0.285174,0.500378,-0.387482,0.691950,-0.060325,-0.550582,0.222612,0.019106,1.069099,-0.530565,0.136196,0.366861,-0.426256,0.226655,0.038184,0.256335,0.157677,-0.489276,-0.687277,-0.630584,-0.016984,-0.350730,-0.172616,0.260950,-0.729053,-0.367372,-0.624519,0.301135,0.294430


In [0]:
# Feed-Forward Neural Nets
class FFNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        """
        Args:
            input_dim (int): the size of the input vectors
            hidden_dim (int): the output size of the first Linear layer
            output_dim (int): the output size of the second Linear layer
        """
        super(FFNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.bn3 = nn.BatchNorm1d(output_dim)

    def forward(self, x):
        """The forward pass of the FFNN
        
        Args:
            x (torch.Tensor): an input data tensor. 
                x_in.shape should be (batch, input_dim)
        Returns:
            the resulting tensor. tensor.shape should be (batch, output_dim)
        """
        c = self.fc1(x)
        x = self.bn1(c)
        x = F.relu(x)
        x = F.dropout(x, p=0.5)
        c = torch.cat((x, c), 1)
        x = self.fc2(c)
        x = self.bn2(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.5)
        c = torch.cat((x, c), 1)
        x = self.fc3(c)
        x = self.bn3(x)
        x = F.relu(x)
        output = F.dropout(x, p=0.5)
     
        return output

batch_size = 32 # number of samples input at once
input_dim = 768
hidden_dim = 128
output_dim = 3

# Initialize model
model = FFNN(input_dim, hidden_dim, output_dim)
print(model)

X = torch.tensor(np.array(X_train))
# y_output = model(X)
# describe(y_output)

FFNN(
  (fc1): Linear(in_features=768, out_features=128, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=128, out_features=3, bias=True)
  (bn3): BatchNorm1d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [10]:
from keras.models import Sequential, Model, load_model
from keras.layers import Embedding, LSTM, Dense, Input, Dropout, GRU, Conv1D, MaxPooling1D, BatchNormalization, Activation, concatenate
from keras.layers import Bidirectional, Flatten, RepeatVector, Permute, Multiply, Lambda, TimeDistributed
from keras import backend as K

from keras.regularizers import l2
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint

Using TensorFlow backend.


In [11]:
units = 1024
lr = 0.0005
patience = 5


inputs = Input(shape=(768,), dtype='float32')
c = Dense(units)(inputs)
x = BatchNormalization()(c)
x = Dropout(0.5)(x)
x = Activation('relu')(x)
c = concatenate([x, c])

def FFUnit(c):
  x = Dense(units)(c)
  x = BatchNormalization()(x)
  x = Dropout(0.5)(x)
  x = Activation('relu')(x)
  c = concatenate([x, c])
  return c

for i in range(10):
  c = FFUnit(c)

x = Dense(3)(c)
x = BatchNormalization()(x)
outputs = Activation('softmax')(x)

model = Model(inputs=inputs, outputs=outputs)

model.summary()

model.compile(optimizer=Adam(lr=lr),
              loss="categorical_crossentropy",
              metrics=["acc"])

model.fit(x=X_train,
          y=y_train,
          validation_data=[X_val, y_val],
          epochs=patience, 
          batch_size=32)

cb = EarlyStopping(monitor='val_loss', 
                   mode='min', 
                   verbose=0, 
                   patience=patience,
                   restore_best_weights=True)

model.compile(optimizer=Adam(lr=lr/3),
              loss="categorical_crossentropy",
              metrics=["acc"])

model.fit(x=X_train, 
          y=y_train,
          validation_data=[X_val, y_val],
          epochs=99, 
          batch_size=32,
          callbacks=[cb])

model.compile(optimizer=Adam(lr=lr/6),
              loss="categorical_crossentropy",
              metrics=["acc"])

model.fit(x=X_train, 
          y=y_train,
          validation_data=[X_val, y_val],
          epochs=99, 
          batch_size=32,
          callbacks=[cb])


print('===Evaluation===')
model.evaluate(X_test, y_test)





Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 768)          0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1024)         787456      input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 1024)         4096        dense_1[0][0]                    
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 1024)         0           batch_n

[0.06356117938955624, 0.9788888888888889]