In [18]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [75]:
dat = pd.read_csv("qualified_clean.csv")
dat['times'] = pd.to_datetime(dat['times'])
dat = dat.sort_values(by=['user', 'entry', 'times'])
dat = dat.reset_index(drop=True)
print(dat.columns)
dat = dat.drop(['times', 'next_char'], axis=1)
dat = dat.dropna()
dat.head()

Index(['user', 'entry', 'character', 'next_char', 'times', 'digraph',
       'phrasetime', 'del', 'err'],
      dtype='object')


Unnamed: 0,user,entry,character,digraph,phrasetime,del,err
0,1,1,8,0.0,0.303296,1.0,0.0
1,1,1,10,0.013945,0.303296,1.0,0.0
2,1,1,1,0.001499,0.303296,1.0,0.0
3,1,1,6,0.001758,0.303296,1.0,0.0
4,1,1,23,0.015259,0.303296,1.0,0.0


In [76]:
MAX_LEN = 100


x = []
y = []
for user in np.unique(dat['user']):
    u_x = dat[dat['user'] == user]
    u_e = []
    for entry in np.unique(u_x['entry']):
        new_entry = u_x[u_x['entry'] == entry]
        new_entry = np.pad(new_entry, [(MAX_LEN - new_entry.shape[0], 0), (0, 0)], mode='constant')
        u_e.append(new_entry)
    x.append(np.stack(u_e))
    num_entries = len(u_e)
    y.append(np.ones([num_entries]) * user)
    print(user)

x = np.concatenate(x)
y = np.concatenate(y)

print(x.shape, y.shape)
    

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
(14120, 100, 7) (14120,)


In [77]:
print(x.dtype)

float64


In [88]:
dataset = torch.utils.data.TensorDataset(torch.Tensor(x), torch.Tensor(y))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=500, shuffle=True)

In [89]:
class LogisticRegression(torch.nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(7, 26)

    def forward(self, x):
        x_simple = x[:, -1, :]
        outputs = self.linear(x_simple)
        
        return outputs

In [136]:
class GRUModel(torch.nn.Module):
    def __init__(self, hidden_size=100):
        super(GRUModel, self).__init__()
        self.gru = torch.nn.GRU(7, hidden_size, num_layers=2)
        self.linear = torch.nn.Linear(10 * hidden_size, 26)

    def forward(self, x):
        x = x[:, -10:, :]
        x = torch.transpose(x, 0, 1) # now (seq_len, batch, dim)
#         print(x.shape)
        gruOut, gruHN = self.gru(x)
#         print(gruHN)
        gruOut = torch.transpose(gruOut, 0, 1) # now (batch, seq_len, dim)
        gruOut = gruOut.reshape(gruOut.size(0), -1)
#         print(last_layer.shape)
        out = self.linear(gruOut)
        
        
        return out

In [138]:
model = GRUModel()

criterion = torch.nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) 

PRINT_RATE = 20

## Optimization Loop

for epoch in range(1000): 
    loss_sum = 0
    total, correct = 0, 0
    for idx, (x_batch, y_batch) in enumerate(dataloader):
        print(idx)
        y_pred = model(x_batch)
#         print(y_pred[:5])
#         print("y_pred shape is", y_pred.shape)
#         print("y_batch shape is", y_batch.shape)
        loss = criterion(y_pred, y_batch.long())
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step()
        loss_sum += loss
        y_pred_i = torch.argmax(y_pred, dim=-1)
        correct += torch.sum(torch.eq(y_pred_i, y_batch)).item()
        total += len(y_batch)
    print('epoch {}, loss {}'.format(epoch, loss_sum.item())) 
#     print(y_pred_i[:1])
    print(total, correct, correct/total)

        

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
epoch 0, loss 93.57019805908203
14120 786 0.0556657223796034
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
epoch 1, loss 92.33280181884766
14120 923 0.06536827195467422
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
epoch 2, loss 92.00379180908203
14120 917 0.06494334277620396
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
epoch 3, loss 91.70684051513672
14120 923 0.06536827195467422
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
epoch 4, loss 91.4397964477539
14120 935 0.06621813031161473
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
epoch 5, loss 91.27296447753906
14120 1040 0.07365439093484419
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
epoch 6, loss 90.79420471191406
14120 1042 0.07379603399433428
0
1
2
3
4
5
6
7
8
9
10
11
1

KeyboardInterrupt: 

In [2]:
dat = pd.read_csv("clean_wide.csv")
dat.head()

Unnamed: 0,user,entry,character_1,character_2,character_3,character_4,character_5,character_6,character_7,character_8,...,digraph_35,digraph_36,digraph_37,digraph_38,digraph_39,digraph_40,digraph_41,del,err,phrasetime
0,1,1,8,10,1,6,23,21,13,10,...,0.00477,0.01747,0.016675,0.001636,0.017868,,,1,0,0.303296
1,1,2,8,10,1,6,23,21,13,10,...,0.016757,0.00835,0.009592,0.018797,0.000709,0.00169,0.006765,2,1,0.305861
2,1,3,8,10,1,6,23,21,13,13,...,0.001077,0.001366,0.005116,0.002441,0.001095,0.008573,,2,1,0.30562
3,1,4,8,10,1,6,23,21,13,13,...,0.01478,0.001881,0.001936,0.008713,0.007559,0.00924,0.011532,2,0,0.272646
4,1,5,8,10,1,6,23,21,13,10,...,0.004859,0.016099,0.008431,0.004978,0.007762,,,1,0,0.336338


In [33]:
df = dat.to_numpy()
print(df.shape)
df = np.nan_to_num(df)

(16815, 128)


In [34]:
x = df[:,2:]
y = df[:,0]

print(x.shape)
print(y.shape)
print(x)
print(y)

(16815, 126)
(16815,)
[[ 8.         10.          1.         ...  1.          0.
   0.30329609]
 [ 8.         10.          1.         ...  2.          1.
   0.305861  ]
 [ 8.         10.          1.         ...  2.          1.
   0.30562019]
 ...
 [ 8.         10.          1.         ...  2.          1.
   1.5771718 ]
 [ 8.         10.          1.         ...  2.          1.
   1.55091119]
 [ 8.         10.          1.         ...  1.          0.
   1.48837495]]
[ 1.  1.  1. ... 25. 25. 25.]


In [35]:
dataset = torch.utils.data.TensorDataset(torch.Tensor(x), torch.Tensor(y))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=500, shuffle=True)

In [54]:
class LogisticRegression(torch.nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(126, 400),
            nn.ReLU(True),
            nn.Linear(400, 400),
            nn.ReLU(True),
             nn.Linear(400, 400),
             nn.ReLU(True),
             nn.Linear(400, 400),
             nn.ReLU(True),
            nn.Linear(400, 400),
            nn.ReLU(True),
            nn.Linear(400, 26)
        )

    def forward(self, x):
        return self.net(x)
        


In [55]:
model = LogisticRegression()

criterion = torch.nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) 

PRINT_RATE = 20

## Optimization Loop

for epoch in range(1000): 
    loss_sum = 0
    total, correct = 0, 0
    for idx, (x_batch, y_batch) in enumerate(dataloader):
        y_pred = model(x_batch)
        loss = criterion(y_pred, y_batch.long())
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step()
        loss_sum += loss
        y_pred_i = torch.argmax(y_pred, dim=-1)
        correct += torch.sum(torch.eq(y_pred_i, y_batch)).item()
        total += len(y_batch)
    print('epoch {}, loss {}'.format(epoch, loss_sum.item())) 
    print(total, correct, correct/total)

epoch 0, loss 110.45457458496094
16815 791 0.04704133214391912
epoch 1, loss 109.63056945800781
16815 1008 0.059946476360392506
epoch 2, loss 108.92058563232422
16815 1250 0.07433838834374071
epoch 3, loss 108.32846069335938
16815 1435 0.08534046981861433
epoch 4, loss 107.88009643554688
16815 1425 0.0847457627118644
epoch 5, loss 107.47245788574219
16815 1419 0.08438893844781445
epoch 6, loss 107.0788345336914
16815 1510 0.08980077311923877
epoch 7, loss 106.6823959350586
16815 1571 0.09342848647041332
epoch 8, loss 106.26136016845703
16815 1589 0.09449895926256319
epoch 9, loss 105.83038330078125
16815 1669 0.0992566161165626
epoch 10, loss 105.41015625
16815 1736 0.10324115373178709
epoch 11, loss 105.01292419433594
16815 1717 0.10211121022896223
epoch 12, loss 104.69383239746094
16815 1845 0.10972346119536129
epoch 13, loss 104.37446594238281
16815 1849 0.10996134403806125
epoch 14, loss 104.0894775390625
16815 1860 0.11061552185548618
epoch 15, loss 103.81623840332031
16815 1853 0

epoch 129, loss 97.54557800292969
16815 2342 0.13928040440083259
epoch 130, loss 97.68087005615234
16815 2260 0.1344038061254832
epoch 131, loss 97.68743133544922
16815 2326 0.13832887303003272
epoch 132, loss 97.57160949707031
16815 2408 0.1432054713053821
epoch 133, loss 97.56864166259766
16815 2406 0.14308652988403212
epoch 134, loss 97.57547760009766
16815 2338 0.13904252155813263
epoch 135, loss 97.49664306640625
16815 2425 0.14421647338685697
epoch 136, loss 97.4308853149414
16815 2319 0.13791257805530777
epoch 137, loss 97.47732543945312
16815 2423 0.14409753196550698
epoch 138, loss 97.46180725097656
16815 2371 0.1410050550104074
epoch 139, loss 97.33209991455078
16815 2384 0.14177817424918226
epoch 140, loss 97.62010955810547
16815 2348 0.13963722866488254
epoch 141, loss 97.33133697509766
16815 2402 0.14284864704133216
epoch 142, loss 97.51136779785156
16815 2406 0.14308652988403212
epoch 143, loss 97.4665298461914
16815 2362 0.14046981861433244
epoch 144, loss 97.52411651611

epoch 257, loss 95.77066040039062
16815 2585 0.15373178709485577
epoch 258, loss 96.31267547607422
16815 2546 0.15141242937853108
epoch 259, loss 95.83330535888672
16815 2607 0.15504014272970562
epoch 260, loss 95.93801879882812
16815 2616 0.15557537912578057
epoch 261, loss 96.24617004394531
16815 2508 0.14915254237288136
epoch 262, loss 95.79678344726562
16815 2559 0.152185548617306
epoch 263, loss 96.299560546875
16815 2517 0.14968777876895628
epoch 264, loss 96.74430847167969
16815 2465 0.1465953018138567
epoch 265, loss 95.98558807373047
16815 2606 0.15498067201903062
epoch 266, loss 95.90461730957031
16815 2561 0.15230449003865595
epoch 267, loss 96.30904388427734
16815 2505 0.14897413024085637
epoch 268, loss 96.0811538696289
16815 2561 0.15230449003865595
epoch 269, loss 95.99610137939453
16815 2545 0.15135295866785609
epoch 270, loss 96.09835052490234
16815 2561 0.15230449003865595
epoch 271, loss 95.98102569580078
16815 2530 0.1504608980077312
epoch 272, loss 96.4323272705078

epoch 385, loss 96.43907928466797
16815 2538 0.15093666369313113
epoch 386, loss 95.2856674194336
16815 2842 0.16901575973832889
epoch 387, loss 96.6275634765625
16815 2451 0.14576271186440679
epoch 388, loss 95.37239074707031
16815 2616 0.15557537912578057
epoch 389, loss 96.47876739501953
16815 2449 0.1456437704430568
epoch 390, loss 95.36764526367188
16815 2661 0.1582515611061552
epoch 391, loss 96.3208236694336
16815 2511 0.14933095450490633
epoch 392, loss 96.65312957763672
16815 2540 0.15105560511448113
epoch 393, loss 96.18526458740234
16815 2583 0.1536128456735058
epoch 394, loss 95.85448455810547
16815 2612 0.15533749628308058
epoch 395, loss 95.42078399658203
16815 2551 0.15170978293190604
epoch 396, loss 96.26543426513672
16815 2520 0.14986619090098127
epoch 397, loss 95.96440124511719
16815 2547 0.15147190008920608
epoch 398, loss 95.7623291015625
16815 2616 0.15557537912578057
epoch 399, loss 96.8863525390625
16815 2432 0.14463276836158193
epoch 400, loss 96.2594985961914


epoch 513, loss 95.84346008300781
16815 2822 0.16782634552482903
epoch 514, loss 94.84681701660156
16815 2772 0.1648528099910794
epoch 515, loss 98.59579467773438
16815 2313 0.1375557537912578
epoch 516, loss 96.31900787353516
16815 2472 0.1470115967885816
epoch 517, loss 95.54879760742188
16815 2711 0.16122509663990484
epoch 518, loss 94.1701889038086
16815 3041 0.1808504311626524
epoch 519, loss 94.90072631835938
16815 2623 0.1559916741005055
epoch 520, loss 97.30561065673828
16815 2452 0.14582218257508178
epoch 521, loss 94.75459289550781
16815 2835 0.16859946476360393
epoch 522, loss 96.66429901123047
16815 2476 0.1472494796312816
epoch 523, loss 96.02755737304688
16815 2803 0.16669640202200417
epoch 524, loss 95.6021957397461
16815 2626 0.15617008623253048
epoch 525, loss 95.93180847167969
16815 2558 0.152126077906631
epoch 526, loss 95.06046295166016
16815 2739 0.16289027653880464
epoch 527, loss 98.16374206542969
16815 2324 0.13820993160868272
epoch 528, loss 96.01087951660156
1

epoch 641, loss 97.55810546875
16815 2335 0.13886410942610763
epoch 642, loss 97.0419692993164
16815 2363 0.14052928932500744
epoch 643, loss 96.782958984375
16815 2404 0.14296758846268212
epoch 644, loss 96.49201202392578
16815 2475 0.1471900089206066
epoch 645, loss 96.21121215820312
16815 2444 0.14534641688968183
epoch 646, loss 95.83984375
16815 2708 0.16104668450787987
epoch 647, loss 95.38361358642578
16815 2770 0.1647338685697294
epoch 648, loss 94.74862670898438
16815 3084 0.18340767172167707
epoch 649, loss 93.88282775878906
16815 3027 0.1800178412132025
epoch 650, loss 94.15995025634766
16815 2787 0.16574487065120427
epoch 651, loss 96.79724884033203
16815 2580 0.1534344335414808
epoch 652, loss 94.58895874023438
16815 2655 0.15789473684210525
epoch 653, loss 98.17227935791016
16815 2399 0.14267023490930716
epoch 654, loss 96.64491271972656
16815 2536 0.15081772227178114
epoch 655, loss 95.73114013671875
16815 2544 0.1512934879571811
epoch 656, loss 94.67378234863281
16815 29

epoch 769, loss 96.28457641601562
16815 2674 0.1590246803449301
epoch 770, loss 93.53534698486328
16815 2803 0.16669640202200417
epoch 771, loss 95.9163589477539
16815 2616 0.15557537912578057
epoch 772, loss 97.96330261230469
16815 2430 0.14451382694023193
epoch 773, loss 94.9875259399414
16815 2790 0.16592328278322926
epoch 774, loss 94.14447021484375
16815 3138 0.18661909009812666
epoch 775, loss 95.05043029785156
16815 2706 0.16092774308652988
epoch 776, loss 96.19292449951172
16815 2630 0.15640796907523044
epoch 777, loss 97.7852554321289
16815 2719 0.16170086232530478
epoch 778, loss 97.65542602539062
16815 2322 0.13809099018733273
epoch 779, loss 96.71284484863281
16815 2519 0.14980672019030628
epoch 780, loss 95.98308563232422
16815 2488 0.1479631281593815
epoch 781, loss 95.13438415527344
16815 2671 0.15884626821290515
epoch 782, loss 93.45994567871094
16815 3185 0.18941421349985132
epoch 783, loss 94.6024169921875
16815 2782 0.16544751709782932
epoch 784, loss 99.210205078125

epoch 897, loss 93.90155029296875
16815 2735 0.16265239369610468
epoch 898, loss 96.34032440185547
16815 2619 0.15575379125780553
epoch 899, loss 94.6413803100586
16815 2684 0.15961938745168006
epoch 900, loss 94.35651397705078
16815 2664 0.1584299732381802
epoch 901, loss 95.19133758544922
16815 2692 0.16009515313707998
epoch 902, loss 94.82492065429688
16815 2737 0.16277133511745465
epoch 903, loss 94.5623550415039
16815 2662 0.1583110318168302
epoch 904, loss 95.70182037353516
16815 2693 0.16015462384775497
epoch 905, loss 98.69187927246094
16815 2272 0.1351174546535831
epoch 906, loss 97.21540069580078
16815 2492 0.14820101100208147
epoch 907, loss 96.53838348388672
16815 2524 0.15010407374368123
epoch 908, loss 95.99845123291016
16815 2514 0.14950936663693132
epoch 909, loss 95.30433654785156
16815 2643 0.15718108831400535
epoch 910, loss 94.17428588867188
16815 2971 0.1766874814154029
epoch 911, loss 92.49412536621094
16815 3204 0.19054415700267618
epoch 912, loss 94.505622863769