In [1]:
### Import libraries
%matplotlib notebook
# Load libraries
import numpy as np
import torch
import matplotlib.pyplot as plt
import sklearn.model_selection
from torch.utils.data import DataLoader, TensorDataset, Subset

# Get data

In [2]:
# get data
data = np.genfromtxt('spambase.data', delimiter=',')
X_numpy, Y_numpy = data[:, :-1], data[:, -1]

print(X_numpy.shape)
print(Y_numpy.shape)

(4601, 57)
(4601,)


In [3]:
# Conver to tensor
X = torch.from_numpy(X_numpy).float()
Y = torch.from_numpy(Y_numpy).float()

print("X shape: ", X.shape)
print("Y shape: ", Y.shape)

X shape:  torch.Size([4601, 57])
Y shape:  torch.Size([4601])


# Describe data

In [4]:
fig, ax = plt.subplots()
ax.hist(Y)
ax.set_title("Spam vs No Spam class distribution")

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'Spam vs No Spam class distribution')

# Define model arquitecture

In [5]:
# MLP
class MLP(torch.nn.Module):
    
    # Constructor
    def __init__(self, input_dim = 57, hidden_dim = 50, output_dim = 1):
    
        super(MLP, self).__init__()
        
        self.hidden1 = torch.nn.Linear(input_dim, hidden_dim, bias = True)
        self.output = torch.nn.Linear(hidden_dim, output_dim, bias = True)
        self.activation = torch.nn.Sigmoid()
        
    def forward(self, x):
#         print(x.shape)
        z = self.activation(self.hidden1(x))
#         print(z.shape)
        output = self.activation(self.output(z))
#         print(output.shape)
        return output

# Define and test model architecture

In [6]:
# model to cuda
model = MLP().to("cuda")

# Y_hat = model.forward(X)

# print("Y hat shape: ", Y_hat.shape)

# Define loss and optimizer

In [7]:
# loss function
criterion = torch.nn.BCELoss()

# optimizer algorithm
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

# Split training and test data

In [8]:
# generate train and test idx 
train_idx, valid_idx = next(sklearn.model_selection.ShuffleSplit(train_size=0.6).split(X, Y))

# Creamos un conjunto de datos en formato tensor
torch_set = TensorDataset(X, Y)

# Data loader de entrenamiento
torch_train_loader = DataLoader(Subset(torch_set, train_idx), shuffle=True, batch_size=32)
# Data loader de validación
torch_valid_loader = DataLoader(Subset(torch_set, valid_idx), shuffle=False, batch_size=256)

# Train NN

In [9]:
def train_one_epoch(k, X, Y, model, criterion, optimizer, train_loss_values, test_loss_values):
    
    global best_valid
    
    # store train loss
    train_loss = 0
    valid_loss = 0
    
    # minibatch training
    for sample_data, sample_labels in torch_train_loader:
        
        sample_data = sample_data.cuda()
        sample_labels = sample_labels.cuda()
        
         # get predictions
        Y_predict = model.forward(sample_data)

        # get loss
        loss = criterion(Y_predict, sample_labels)
        
        # update loss
        train_loss += loss.cpu().detach().numpy()
        
        # reset grads
        optimizer.zero_grad()

        # backpropagation
        loss.backward()

        # update parameters
        optimizer.step()
    
    # add train loss to array
    train_loss_values.append(train_loss)
    
    # minibatch test
    for sample_data, sample_labels in torch_valid_loader:
        
        sample_data = sample_data.cuda()
        sample_labels = sample_labels.cuda()
        
         # get predictions
        Y_predict = model.forward(sample_data)

        # get loss
        loss = criterion(Y_predict, sample_labels)
        
        # update loss
        valid_loss += loss.cpu().detach().numpy()
        
    # add train loss to array
    test_loss_values.append(valid_loss)
    
    # check if there is a better value
    if k%2 == 0:
        
        if valid_loss < best_valid:
            
            print("new best model with valid loss: ", valid_loss)
            
            best_valid = valid_loss
            torch.save(
                {
                    'epoch': k,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': valid_loss
                },
                '/home/leo/Desktop/master_UACH/2S-2019/AI/artificial_intelligence_master/supervised_learning/works/best_MLP_spam.pt'
            )

In [96]:
epochs = 10000

best_valid = np.inf

train_loss_values = []
test_loss_values = []

# model.state_dict()
# global best_valid, model

for epoch in range(0, epochs):
    
    print("epoch ", epoch)
    
    train_one_epoch(epoch, X, Y, model, criterion, optimizer, train_loss_values, test_loss_values)

epoch  0
new best model with valid loss:  2.3974872678518295
epoch  1
epoch  2
epoch  3
epoch  4
epoch  5
epoch  6
epoch  7
epoch  8
new best model with valid loss:  2.2538104206323624
epoch  9
epoch  10
epoch  11
epoch  12
epoch  13
epoch  14
epoch  15
epoch  16
epoch  17
epoch  18
epoch  19
epoch  20
new best model with valid loss:  2.1527238339185715
epoch  21
epoch  22
epoch  23
epoch  24
epoch  25
epoch  26
epoch  27
epoch  28
epoch  29
epoch  30
epoch  31
epoch  32
epoch  33
epoch  34
epoch  35
epoch  36
epoch  37
epoch  38
epoch  39
epoch  40
epoch  41
epoch  42
epoch  43
epoch  44
epoch  45
epoch  46
epoch  47
epoch  48
epoch  49
epoch  50
epoch  51
epoch  52
epoch  53
epoch  54
epoch  55
epoch  56
epoch  57
epoch  58
epoch  59
epoch  60
epoch  61
epoch  62
epoch  63
epoch  64
epoch  65
epoch  66
epoch  67
epoch  68
epoch  69
epoch  70
epoch  71
epoch  72
epoch  73
epoch  74
epoch  75
epoch  76
epoch  77
epoch  78
epoch  79
epoch  80
epoch  81
epoch  82
epoch  83
epoch  84
epoc

epoch  728
epoch  729
epoch  730
epoch  731
epoch  732
epoch  733
epoch  734
epoch  735
epoch  736
epoch  737
epoch  738
epoch  739
epoch  740
epoch  741
epoch  742
epoch  743
epoch  744
epoch  745
epoch  746
epoch  747
epoch  748
epoch  749
epoch  750
epoch  751
epoch  752
epoch  753
epoch  754
epoch  755
epoch  756
epoch  757
epoch  758
epoch  759
epoch  760
epoch  761
epoch  762
epoch  763
epoch  764
epoch  765
epoch  766
epoch  767
epoch  768
epoch  769
epoch  770
epoch  771
epoch  772
epoch  773
epoch  774
epoch  775
epoch  776
epoch  777
epoch  778
epoch  779
epoch  780
epoch  781
epoch  782
epoch  783
epoch  784
epoch  785
epoch  786
epoch  787
epoch  788
epoch  789
epoch  790
epoch  791
epoch  792
epoch  793
epoch  794
epoch  795
epoch  796
epoch  797
epoch  798
epoch  799
epoch  800
epoch  801
epoch  802
epoch  803
epoch  804
epoch  805
epoch  806
epoch  807
epoch  808
epoch  809
epoch  810
epoch  811
epoch  812
epoch  813
epoch  814
epoch  815
epoch  816
epoch  817
epoch  818

epoch  1434
epoch  1435
epoch  1436
epoch  1437
epoch  1438
epoch  1439
epoch  1440
epoch  1441
epoch  1442
epoch  1443
epoch  1444
epoch  1445
epoch  1446
epoch  1447
epoch  1448
epoch  1449
epoch  1450
epoch  1451
epoch  1452
epoch  1453
epoch  1454
epoch  1455
epoch  1456
epoch  1457
epoch  1458
epoch  1459
epoch  1460
epoch  1461
epoch  1462
epoch  1463
epoch  1464
epoch  1465
epoch  1466
epoch  1467
epoch  1468
epoch  1469
epoch  1470
epoch  1471
epoch  1472
epoch  1473
epoch  1474
epoch  1475
epoch  1476
epoch  1477
epoch  1478
epoch  1479
epoch  1480
epoch  1481
epoch  1482
epoch  1483
epoch  1484
epoch  1485
epoch  1486
epoch  1487
epoch  1488
epoch  1489
epoch  1490
epoch  1491
epoch  1492
epoch  1493
epoch  1494
epoch  1495
epoch  1496
epoch  1497
epoch  1498
epoch  1499
epoch  1500
epoch  1501
epoch  1502
epoch  1503
epoch  1504
epoch  1505
epoch  1506
epoch  1507
epoch  1508
epoch  1509
epoch  1510
epoch  1511
epoch  1512
epoch  1513
epoch  1514
epoch  1515
epoch  1516
epoc

epoch  2114
epoch  2115
epoch  2116
epoch  2117
epoch  2118
epoch  2119
epoch  2120
epoch  2121
epoch  2122
epoch  2123
epoch  2124
epoch  2125
epoch  2126
epoch  2127
epoch  2128
epoch  2129
epoch  2130
epoch  2131
epoch  2132
epoch  2133
epoch  2134
epoch  2135
epoch  2136
epoch  2137
epoch  2138
epoch  2139
epoch  2140
epoch  2141
epoch  2142
epoch  2143
epoch  2144
epoch  2145
epoch  2146
epoch  2147
epoch  2148
epoch  2149
epoch  2150
epoch  2151
epoch  2152
epoch  2153
epoch  2154
epoch  2155
epoch  2156
epoch  2157
epoch  2158
epoch  2159
epoch  2160
epoch  2161
epoch  2162
epoch  2163
epoch  2164
epoch  2165
epoch  2166
epoch  2167
epoch  2168
epoch  2169
epoch  2170
epoch  2171
epoch  2172
epoch  2173
epoch  2174
epoch  2175
epoch  2176
epoch  2177
epoch  2178
epoch  2179
epoch  2180
epoch  2181
epoch  2182
epoch  2183
epoch  2184
epoch  2185
epoch  2186
epoch  2187
epoch  2188
epoch  2189
epoch  2190
epoch  2191
epoch  2192
epoch  2193
epoch  2194
epoch  2195
epoch  2196
epoc

epoch  2794
epoch  2795
epoch  2796
epoch  2797
epoch  2798
epoch  2799
epoch  2800
epoch  2801
epoch  2802
epoch  2803
epoch  2804
epoch  2805
epoch  2806
epoch  2807
epoch  2808
epoch  2809
epoch  2810
epoch  2811
epoch  2812
epoch  2813
epoch  2814
epoch  2815
epoch  2816
epoch  2817
epoch  2818
epoch  2819
epoch  2820
epoch  2821
epoch  2822
epoch  2823
epoch  2824
epoch  2825
epoch  2826
epoch  2827
epoch  2828
epoch  2829
epoch  2830
epoch  2831
epoch  2832
epoch  2833
epoch  2834
epoch  2835
epoch  2836
epoch  2837
epoch  2838
epoch  2839
epoch  2840
epoch  2841
epoch  2842
epoch  2843
epoch  2844
epoch  2845
epoch  2846
epoch  2847
epoch  2848
epoch  2849
epoch  2850
epoch  2851
epoch  2852
epoch  2853
epoch  2854
epoch  2855
epoch  2856
epoch  2857
epoch  2858
epoch  2859
epoch  2860
epoch  2861
epoch  2862
epoch  2863
epoch  2864
epoch  2865
epoch  2866
epoch  2867
epoch  2868
epoch  2869
epoch  2870
epoch  2871
epoch  2872
epoch  2873
epoch  2874
epoch  2875
epoch  2876
epoc

epoch  3474
epoch  3475
epoch  3476
epoch  3477
epoch  3478
epoch  3479
epoch  3480
epoch  3481
epoch  3482
epoch  3483
epoch  3484
epoch  3485
epoch  3486
epoch  3487
epoch  3488
epoch  3489
epoch  3490
epoch  3491
epoch  3492
epoch  3493
epoch  3494
epoch  3495
epoch  3496
epoch  3497
epoch  3498
epoch  3499
epoch  3500
epoch  3501
epoch  3502
epoch  3503
epoch  3504
epoch  3505
epoch  3506
epoch  3507
epoch  3508
epoch  3509
epoch  3510
epoch  3511
epoch  3512
epoch  3513
epoch  3514
epoch  3515
epoch  3516
epoch  3517
epoch  3518
epoch  3519
epoch  3520
epoch  3521
epoch  3522
epoch  3523
epoch  3524
epoch  3525
epoch  3526
epoch  3527
epoch  3528
epoch  3529
epoch  3530
epoch  3531
epoch  3532
epoch  3533
epoch  3534
epoch  3535
epoch  3536
epoch  3537
epoch  3538
epoch  3539
epoch  3540
epoch  3541
epoch  3542
epoch  3543
epoch  3544
epoch  3545
epoch  3546
epoch  3547
epoch  3548
epoch  3549
epoch  3550
epoch  3551
epoch  3552
epoch  3553
epoch  3554
epoch  3555
epoch  3556
epoc

epoch  4158
epoch  4159
epoch  4160
epoch  4161
epoch  4162
epoch  4163
epoch  4164
epoch  4165
epoch  4166
epoch  4167
epoch  4168
epoch  4169
epoch  4170
epoch  4171
epoch  4172
epoch  4173
epoch  4174
epoch  4175
epoch  4176
epoch  4177
epoch  4178
epoch  4179
epoch  4180
epoch  4181
epoch  4182
epoch  4183
epoch  4184
epoch  4185
epoch  4186
epoch  4187
epoch  4188
epoch  4189
epoch  4190
epoch  4191
epoch  4192
epoch  4193
epoch  4194
epoch  4195
epoch  4196
epoch  4197
epoch  4198
epoch  4199
epoch  4200
epoch  4201
epoch  4202
epoch  4203
epoch  4204
epoch  4205
epoch  4206
epoch  4207
epoch  4208
epoch  4209
epoch  4210
epoch  4211
epoch  4212
epoch  4213
epoch  4214
epoch  4215
epoch  4216
epoch  4217
epoch  4218
epoch  4219
epoch  4220
epoch  4221
epoch  4222
epoch  4223
epoch  4224
epoch  4225
epoch  4226
epoch  4227
epoch  4228
epoch  4229
epoch  4230
epoch  4231
epoch  4232
epoch  4233
epoch  4234
epoch  4235
epoch  4236
epoch  4237
epoch  4238
epoch  4239
epoch  4240
epoc

epoch  4842
epoch  4843
epoch  4844
epoch  4845
epoch  4846
epoch  4847
epoch  4848
epoch  4849
epoch  4850
epoch  4851
epoch  4852
epoch  4853
epoch  4854
epoch  4855
epoch  4856
epoch  4857
epoch  4858
epoch  4859
epoch  4860
epoch  4861
epoch  4862
epoch  4863
epoch  4864
epoch  4865
epoch  4866
epoch  4867
epoch  4868
epoch  4869
epoch  4870
epoch  4871
epoch  4872
epoch  4873
epoch  4874
epoch  4875
epoch  4876
epoch  4877
epoch  4878
epoch  4879
epoch  4880
epoch  4881
epoch  4882
epoch  4883
epoch  4884
epoch  4885
epoch  4886
epoch  4887
epoch  4888
epoch  4889
epoch  4890
epoch  4891
epoch  4892
epoch  4893
epoch  4894
epoch  4895
epoch  4896
epoch  4897
epoch  4898
epoch  4899
epoch  4900
epoch  4901
epoch  4902
epoch  4903
epoch  4904
epoch  4905
epoch  4906
epoch  4907
epoch  4908
epoch  4909
epoch  4910
epoch  4911
epoch  4912
epoch  4913
epoch  4914
epoch  4915
epoch  4916
epoch  4917
epoch  4918
epoch  4919
epoch  4920
epoch  4921
epoch  4922
epoch  4923
epoch  4924
epoc

epoch  5526
epoch  5527
epoch  5528
epoch  5529
epoch  5530
epoch  5531
epoch  5532
epoch  5533
epoch  5534
epoch  5535
epoch  5536
epoch  5537
epoch  5538
epoch  5539
epoch  5540
epoch  5541
epoch  5542
epoch  5543
epoch  5544
epoch  5545
epoch  5546
epoch  5547
epoch  5548
epoch  5549
epoch  5550
epoch  5551
epoch  5552
epoch  5553
epoch  5554
epoch  5555
epoch  5556
epoch  5557
epoch  5558
epoch  5559
epoch  5560
epoch  5561
epoch  5562
epoch  5563
epoch  5564
epoch  5565
epoch  5566
epoch  5567
epoch  5568
epoch  5569
epoch  5570
epoch  5571
epoch  5572
epoch  5573
epoch  5574
epoch  5575
epoch  5576
epoch  5577
epoch  5578
epoch  5579
epoch  5580
epoch  5581
epoch  5582
epoch  5583
epoch  5584
epoch  5585
epoch  5586
epoch  5587
epoch  5588
epoch  5589
epoch  5590
epoch  5591
epoch  5592
epoch  5593
epoch  5594
epoch  5595
epoch  5596
epoch  5597
epoch  5598
epoch  5599
epoch  5600
epoch  5601
epoch  5602
epoch  5603
epoch  5604
epoch  5605
epoch  5606
epoch  5607
epoch  5608
epoc

KeyboardInterrupt: 

In [11]:
# plotting results
fig, ax = plt.subplots()

ax.plot(range(0, epoch), train_loss_values, label = "train")
ax.plot(range(0, epoch), test_loss_values, label = "test")
ax.set_title("Loss v/s epoch for MLP")

<IPython.core.display.Javascript object>

NameError: name 'epoch' is not defined

# Recovering best model

In [10]:
for params in model.state_dict():
    
    print(params, "\t", model.state_dict()[params][0])
    
model.load_state_dict(torch.load("/home/leo/Desktop/master_UACH/2S-2019/AI/artificial_intelligence_master/supervised_learning/works/best_MLP_spam.pt")['model_state_dict'])

print("Parameters of recovery model")

for params in model.state_dict():
    
    print(params, "\t", model.state_dict()[params])

hidden1.weight 	 tensor([ 0.0573, -0.0546,  0.0909, -0.1064, -0.1202, -0.0513,  0.1318,  0.1098,
        -0.0839,  0.0585, -0.0233, -0.0034,  0.0692, -0.1199, -0.1213,  0.1013,
         0.0060,  0.1225,  0.0132, -0.1082,  0.0819,  0.0055,  0.0220, -0.1120,
        -0.0895, -0.0408,  0.1153, -0.0814,  0.0023, -0.0514,  0.0038,  0.1307,
         0.0684,  0.0105, -0.1227, -0.1088, -0.0470, -0.0254,  0.0305,  0.0405,
         0.0865, -0.0245,  0.0244, -0.0673, -0.1169,  0.0268,  0.1319, -0.1052,
         0.1030,  0.0948,  0.0389,  0.1119,  0.0613, -0.0870, -0.0263,  0.0379,
         0.0377], device='cuda:0')
hidden1.bias 	 tensor(-0.1321, device='cuda:0')
output.weight 	 tensor([ 0.0676, -0.0713, -0.0802, -0.0825,  0.0356, -0.1076,  0.0640,  0.1326,
        -0.1077, -0.1022, -0.0616,  0.0367,  0.0332,  0.1353, -0.0071,  0.0340,
         0.1009, -0.0552,  0.1360, -0.0709, -0.0335,  0.1013, -0.0056, -0.0535,
        -0.0801, -0.0017,  0.0833,  0.0847, -0.0559, -0.0893,  0.1296,  0.1125,
    