In [1]:
import torch
import torch.nn.functional
import pandas as pd
import numpy as np

In [2]:
train_data = pd.read_csv("labeled_file.txt", sep="\t", header=None)
test_data = pd.read_csv("unlabeled_file.txt", sep="\t", header=None)
test_data_bk = test_data

In [3]:
train_data = train_data.sample(frac=1).reset_index(drop=True)

In [4]:
layers_dim = [21, 128, 128, 128]

In [5]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.fc1 = torch.nn.Linear(layers_dim[0], layers_dim[1])
        self.relu1 = torch.nn.ReLU()
        self.drop = torch.nn.Dropout(0.4)
        self.fc2 = torch.nn.Linear(layers_dim[1], layers_dim[2])
        self.fc3 = torch.nn.Linear(layers_dim[2], layers_dim[3])
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.relu1(x)
        x = self.drop(x)
        x = self.fc3(x)
        
        return x
    
net = Net()

In [6]:
#USE THIS ONLY FOR LOADING THE MODEL

netLOADER = torch.nn.Sequential(
    torch.nn.Linear(layers_dim[0], layers_dim[1]),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.4),
    torch.nn.Linear(layers_dim[1], layers_dim[2]),
)

In [7]:
train_x = train_data.iloc[:, 2:-1]
test_x = test_data.iloc[:, 2:-1]
train_y = train_data.iloc[:, -1]

train_x = np.array(train_x.values, dtype=np.float32)
train_y = np.array(train_y.values, dtype=np.long)
test_x = np.array(test_x.values, dtype=np.float32)

In [8]:
split_ratio = 0.8   #train_size/(train_size+val_size)
assert(split_ratio >= 0.0 and split_ratio <= 1.0)
index = (int)(split_ratio * train_x.shape[0])

In [9]:
train_val_x = train_x[index:, :]
train_val_y = train_y[index:]
train_x = train_x[:index, :]
train_y = train_y[:index]
print(train_x.shape, train_y.shape, train_val_x.shape, train_val_y.shape)

(3708, 21) (3708,) (928, 21) (928,)


In [10]:
print(net.parameters)

<bound method Module.parameters of Net(
  (fc1): Linear(in_features=21, out_features=128, bias=True)
  (relu1): ReLU()
  (drop): Dropout(p=0.4)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
)>


In [11]:
train_x = torch.from_numpy(train_x)
test_x = torch.from_numpy(test_x)
train_y = torch.from_numpy(train_y)
train_val_x = torch.from_numpy(train_val_x)
train_val_y = torch.from_numpy(train_val_y)

In [12]:
print(train_y.shape)
print(train_val_y.shape)

torch.Size([3708])
torch.Size([928])


In [13]:
train_y = torch.tensor(train_y, dtype=torch.long)
train_val_y = torch.tensor(train_val_y, dtype=torch.long)

  """Entry point for launching an IPython kernel.
  


In [14]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
num_epochs = 5000

In [15]:
for epoch in range(num_epochs):  # number of epochs
    outputs = net(train_x)
    loss = criterion(outputs, train_y)
    
    #backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

Epoch [1/5000], Loss: 133.3223
Epoch [2/5000], Loss: 75.7861
Epoch [3/5000], Loss: 56.9665
Epoch [4/5000], Loss: 41.5990
Epoch [5/5000], Loss: 28.7331
Epoch [6/5000], Loss: 17.9247
Epoch [7/5000], Loss: 12.3046
Epoch [8/5000], Loss: 8.5382
Epoch [9/5000], Loss: 5.9507
Epoch [10/5000], Loss: 5.3762
Epoch [11/5000], Loss: 4.8733
Epoch [12/5000], Loss: 4.6441
Epoch [13/5000], Loss: 4.7675
Epoch [14/5000], Loss: 4.5131
Epoch [15/5000], Loss: 4.4777
Epoch [16/5000], Loss: 4.4918
Epoch [17/5000], Loss: 4.4506
Epoch [18/5000], Loss: 4.4032
Epoch [19/5000], Loss: 4.4010
Epoch [20/5000], Loss: 4.3471
Epoch [21/5000], Loss: 4.3199
Epoch [22/5000], Loss: 4.2795
Epoch [23/5000], Loss: 4.2190
Epoch [24/5000], Loss: 4.2103
Epoch [25/5000], Loss: 4.2066
Epoch [26/5000], Loss: 4.1366
Epoch [27/5000], Loss: 4.1048
Epoch [28/5000], Loss: 4.0903
Epoch [29/5000], Loss: 4.0392
Epoch [30/5000], Loss: 4.0363
Epoch [31/5000], Loss: 4.1542
Epoch [32/5000], Loss: 4.0564
Epoch [33/5000], Loss: 3.9602
Epoch [34/5

Epoch [275/5000], Loss: 2.9133
Epoch [276/5000], Loss: 2.9223
Epoch [277/5000], Loss: 2.9210
Epoch [278/5000], Loss: 2.9171
Epoch [279/5000], Loss: 2.9312
Epoch [280/5000], Loss: 2.9357
Epoch [281/5000], Loss: 2.9142
Epoch [282/5000], Loss: 2.9089
Epoch [283/5000], Loss: 2.8847
Epoch [284/5000], Loss: 2.9126
Epoch [285/5000], Loss: 2.8728
Epoch [286/5000], Loss: 2.8719
Epoch [287/5000], Loss: 2.8768
Epoch [288/5000], Loss: 2.8834
Epoch [289/5000], Loss: 2.8491
Epoch [290/5000], Loss: 2.8535
Epoch [291/5000], Loss: 2.8805
Epoch [292/5000], Loss: 2.8476
Epoch [293/5000], Loss: 2.8453
Epoch [294/5000], Loss: 2.8485
Epoch [295/5000], Loss: 2.8340
Epoch [296/5000], Loss: 2.8271
Epoch [297/5000], Loss: 2.8149
Epoch [298/5000], Loss: 2.8431
Epoch [299/5000], Loss: 2.8251
Epoch [300/5000], Loss: 2.8541
Epoch [301/5000], Loss: 2.8473
Epoch [302/5000], Loss: 2.8656
Epoch [303/5000], Loss: 2.8470
Epoch [304/5000], Loss: 2.8420
Epoch [305/5000], Loss: 2.8182
Epoch [306/5000], Loss: 2.8430
Epoch [3

Epoch [543/5000], Loss: 2.8216
Epoch [544/5000], Loss: 2.8182
Epoch [545/5000], Loss: 2.8035
Epoch [546/5000], Loss: 2.8086
Epoch [547/5000], Loss: 2.8059
Epoch [548/5000], Loss: 2.8112
Epoch [549/5000], Loss: 2.8160
Epoch [550/5000], Loss: 2.7930
Epoch [551/5000], Loss: 2.7980
Epoch [552/5000], Loss: 2.7933
Epoch [553/5000], Loss: 2.7818
Epoch [554/5000], Loss: 2.8112
Epoch [555/5000], Loss: 2.7972
Epoch [556/5000], Loss: 2.8000
Epoch [557/5000], Loss: 2.8065
Epoch [558/5000], Loss: 2.7881
Epoch [559/5000], Loss: 2.7775
Epoch [560/5000], Loss: 2.7724
Epoch [561/5000], Loss: 2.7812
Epoch [562/5000], Loss: 2.7777
Epoch [563/5000], Loss: 2.7719
Epoch [564/5000], Loss: 2.7652
Epoch [565/5000], Loss: 2.7719
Epoch [566/5000], Loss: 2.7470
Epoch [567/5000], Loss: 2.7646
Epoch [568/5000], Loss: 2.7519
Epoch [569/5000], Loss: 2.7035
Epoch [570/5000], Loss: 2.7380
Epoch [571/5000], Loss: 2.7316
Epoch [572/5000], Loss: 2.7126
Epoch [573/5000], Loss: 2.7070
Epoch [574/5000], Loss: 2.7121
Epoch [5

Epoch [811/5000], Loss: 2.8418
Epoch [812/5000], Loss: 2.8723
Epoch [813/5000], Loss: 2.8500
Epoch [814/5000], Loss: 2.8629
Epoch [815/5000], Loss: 2.8919
Epoch [816/5000], Loss: 2.8729
Epoch [817/5000], Loss: 2.8716
Epoch [818/5000], Loss: 2.8573
Epoch [819/5000], Loss: 2.8619
Epoch [820/5000], Loss: 2.8475
Epoch [821/5000], Loss: 2.8689
Epoch [822/5000], Loss: 2.8495
Epoch [823/5000], Loss: 2.8352
Epoch [824/5000], Loss: 2.8285
Epoch [825/5000], Loss: 2.8158
Epoch [826/5000], Loss: 2.7930
Epoch [827/5000], Loss: 2.7873
Epoch [828/5000], Loss: 2.7949
Epoch [829/5000], Loss: 2.7283
Epoch [830/5000], Loss: 2.7362
Epoch [831/5000], Loss: 2.7199
Epoch [832/5000], Loss: 2.7133
Epoch [833/5000], Loss: 2.6947
Epoch [834/5000], Loss: 2.7205
Epoch [835/5000], Loss: 2.7077
Epoch [836/5000], Loss: 2.7070
Epoch [837/5000], Loss: 2.6980
Epoch [838/5000], Loss: 2.6963
Epoch [839/5000], Loss: 2.6887
Epoch [840/5000], Loss: 2.6993
Epoch [841/5000], Loss: 2.7093
Epoch [842/5000], Loss: 2.6681
Epoch [8

Epoch [1079/5000], Loss: 2.5546
Epoch [1080/5000], Loss: 2.5526
Epoch [1081/5000], Loss: 2.5527
Epoch [1082/5000], Loss: 2.5414
Epoch [1083/5000], Loss: 2.5477
Epoch [1084/5000], Loss: 2.5438
Epoch [1085/5000], Loss: 2.5441
Epoch [1086/5000], Loss: 2.5589
Epoch [1087/5000], Loss: 2.5448
Epoch [1088/5000], Loss: 2.5499
Epoch [1089/5000], Loss: 2.5546
Epoch [1090/5000], Loss: 2.5305
Epoch [1091/5000], Loss: 2.5327
Epoch [1092/5000], Loss: 2.5306
Epoch [1093/5000], Loss: 2.5423
Epoch [1094/5000], Loss: 2.5392
Epoch [1095/5000], Loss: 2.5443
Epoch [1096/5000], Loss: 2.5373
Epoch [1097/5000], Loss: 2.5326
Epoch [1098/5000], Loss: 2.5275
Epoch [1099/5000], Loss: 2.5479
Epoch [1100/5000], Loss: 2.5387
Epoch [1101/5000], Loss: 2.5296
Epoch [1102/5000], Loss: 2.5568
Epoch [1103/5000], Loss: 2.5453
Epoch [1104/5000], Loss: 2.5439
Epoch [1105/5000], Loss: 2.5426
Epoch [1106/5000], Loss: 2.5602
Epoch [1107/5000], Loss: 2.5641
Epoch [1108/5000], Loss: 2.5534
Epoch [1109/5000], Loss: 2.5294
Epoch [1

Epoch [1340/5000], Loss: 2.8326
Epoch [1341/5000], Loss: 2.8340
Epoch [1342/5000], Loss: 2.8385
Epoch [1343/5000], Loss: 2.8148
Epoch [1344/5000], Loss: 2.8016
Epoch [1345/5000], Loss: 2.7820
Epoch [1346/5000], Loss: 2.7718
Epoch [1347/5000], Loss: 2.8129
Epoch [1348/5000], Loss: 2.8002
Epoch [1349/5000], Loss: 2.7791
Epoch [1350/5000], Loss: 2.7701
Epoch [1351/5000], Loss: 2.7921
Epoch [1352/5000], Loss: 2.8181
Epoch [1353/5000], Loss: 2.7863
Epoch [1354/5000], Loss: 2.7918
Epoch [1355/5000], Loss: 2.7772
Epoch [1356/5000], Loss: 2.7647
Epoch [1357/5000], Loss: 2.7651
Epoch [1358/5000], Loss: 2.7682
Epoch [1359/5000], Loss: 2.7643
Epoch [1360/5000], Loss: 2.7505
Epoch [1361/5000], Loss: 2.7316
Epoch [1362/5000], Loss: 2.7506
Epoch [1363/5000], Loss: 2.7490
Epoch [1364/5000], Loss: 2.7712
Epoch [1365/5000], Loss: 2.7511
Epoch [1366/5000], Loss: 2.7498
Epoch [1367/5000], Loss: 2.7037
Epoch [1368/5000], Loss: 2.7686
Epoch [1369/5000], Loss: 2.7507
Epoch [1370/5000], Loss: 2.7854
Epoch [1

Epoch [1599/5000], Loss: 2.4957
Epoch [1600/5000], Loss: 2.4831
Epoch [1601/5000], Loss: 2.5596
Epoch [1602/5000], Loss: 2.5625
Epoch [1603/5000], Loss: 2.5330
Epoch [1604/5000], Loss: 2.5329
Epoch [1605/5000], Loss: 2.5391
Epoch [1606/5000], Loss: 2.5211
Epoch [1607/5000], Loss: 2.5516
Epoch [1608/5000], Loss: 2.5330
Epoch [1609/5000], Loss: 2.5956
Epoch [1610/5000], Loss: 2.5693
Epoch [1611/5000], Loss: 2.6342
Epoch [1612/5000], Loss: 2.6415
Epoch [1613/5000], Loss: 2.6801
Epoch [1614/5000], Loss: 2.6848
Epoch [1615/5000], Loss: 2.6465
Epoch [1616/5000], Loss: 2.5911
Epoch [1617/5000], Loss: 2.7748
Epoch [1618/5000], Loss: 2.7056
Epoch [1619/5000], Loss: 2.7409
Epoch [1620/5000], Loss: 2.7828
Epoch [1621/5000], Loss: 2.7649
Epoch [1622/5000], Loss: 2.7618
Epoch [1623/5000], Loss: 2.7645
Epoch [1624/5000], Loss: 2.7577
Epoch [1625/5000], Loss: 2.7701
Epoch [1626/5000], Loss: 2.7519
Epoch [1627/5000], Loss: 2.7354
Epoch [1628/5000], Loss: 2.7505
Epoch [1629/5000], Loss: 2.7273
Epoch [1

Epoch [1857/5000], Loss: 2.6490
Epoch [1858/5000], Loss: 2.6072
Epoch [1859/5000], Loss: 2.6355
Epoch [1860/5000], Loss: 2.6604
Epoch [1861/5000], Loss: 2.6889
Epoch [1862/5000], Loss: 2.6371
Epoch [1863/5000], Loss: 2.6100
Epoch [1864/5000], Loss: 2.6332
Epoch [1865/5000], Loss: 2.6355
Epoch [1866/5000], Loss: 2.6640
Epoch [1867/5000], Loss: 2.6704
Epoch [1868/5000], Loss: 2.6166
Epoch [1869/5000], Loss: 2.6422
Epoch [1870/5000], Loss: 2.6056
Epoch [1871/5000], Loss: 2.6125
Epoch [1872/5000], Loss: 2.6001
Epoch [1873/5000], Loss: 2.5967
Epoch [1874/5000], Loss: 2.5877
Epoch [1875/5000], Loss: 2.5772
Epoch [1876/5000], Loss: 2.5803
Epoch [1877/5000], Loss: 2.5691
Epoch [1878/5000], Loss: 2.5950
Epoch [1879/5000], Loss: 2.5879
Epoch [1880/5000], Loss: 2.6060
Epoch [1881/5000], Loss: 2.5793
Epoch [1882/5000], Loss: 2.6230
Epoch [1883/5000], Loss: 2.6162
Epoch [1884/5000], Loss: 2.6586
Epoch [1885/5000], Loss: 2.6590
Epoch [1886/5000], Loss: 3.0476
Epoch [1887/5000], Loss: 2.7134
Epoch [1

Epoch [2114/5000], Loss: 2.5455
Epoch [2115/5000], Loss: 2.5264
Epoch [2116/5000], Loss: 2.5266
Epoch [2117/5000], Loss: 2.5231
Epoch [2118/5000], Loss: 2.5104
Epoch [2119/5000], Loss: 2.5387
Epoch [2120/5000], Loss: 2.5322
Epoch [2121/5000], Loss: 2.5236
Epoch [2122/5000], Loss: 2.5304
Epoch [2123/5000], Loss: 2.5033
Epoch [2124/5000], Loss: 2.5086
Epoch [2125/5000], Loss: 2.5201
Epoch [2126/5000], Loss: 2.5198
Epoch [2127/5000], Loss: 2.5189
Epoch [2128/5000], Loss: 2.5272
Epoch [2129/5000], Loss: 2.5222
Epoch [2130/5000], Loss: 2.5181
Epoch [2131/5000], Loss: 2.5306
Epoch [2132/5000], Loss: 2.5334
Epoch [2133/5000], Loss: 2.5230
Epoch [2134/5000], Loss: 2.5117
Epoch [2135/5000], Loss: 2.5606
Epoch [2136/5000], Loss: 2.5266
Epoch [2137/5000], Loss: 2.5318
Epoch [2138/5000], Loss: 2.5230
Epoch [2139/5000], Loss: 2.5242
Epoch [2140/5000], Loss: 2.5190
Epoch [2141/5000], Loss: 2.5352
Epoch [2142/5000], Loss: 2.5171
Epoch [2143/5000], Loss: 2.5211
Epoch [2144/5000], Loss: 2.5390
Epoch [2

Epoch [2374/5000], Loss: 2.7597
Epoch [2375/5000], Loss: 2.8006
Epoch [2376/5000], Loss: 2.8641
Epoch [2377/5000], Loss: 2.8818
Epoch [2378/5000], Loss: 2.9264
Epoch [2379/5000], Loss: 2.9419
Epoch [2380/5000], Loss: 2.8846
Epoch [2381/5000], Loss: 2.9858
Epoch [2382/5000], Loss: 2.8915
Epoch [2383/5000], Loss: 2.9164
Epoch [2384/5000], Loss: 2.8607
Epoch [2385/5000], Loss: 2.8438
Epoch [2386/5000], Loss: 3.0157
Epoch [2387/5000], Loss: 2.9393
Epoch [2388/5000], Loss: 3.0252
Epoch [2389/5000], Loss: 3.0017
Epoch [2390/5000], Loss: 3.0891
Epoch [2391/5000], Loss: 3.0769
Epoch [2392/5000], Loss: 3.0904
Epoch [2393/5000], Loss: 3.0537
Epoch [2394/5000], Loss: 3.0764
Epoch [2395/5000], Loss: 3.0323
Epoch [2396/5000], Loss: 3.0935
Epoch [2397/5000], Loss: 3.0598
Epoch [2398/5000], Loss: 3.0660
Epoch [2399/5000], Loss: 3.0433
Epoch [2400/5000], Loss: 3.0259
Epoch [2401/5000], Loss: 3.0469
Epoch [2402/5000], Loss: 2.9847
Epoch [2403/5000], Loss: 3.0059
Epoch [2404/5000], Loss: 2.9206
Epoch [2

Epoch [2632/5000], Loss: 2.7034
Epoch [2633/5000], Loss: 2.7068
Epoch [2634/5000], Loss: 2.7010
Epoch [2635/5000], Loss: 2.6822
Epoch [2636/5000], Loss: 2.6768
Epoch [2637/5000], Loss: 2.6432
Epoch [2638/5000], Loss: 2.6442
Epoch [2639/5000], Loss: 2.6377
Epoch [2640/5000], Loss: 2.6636
Epoch [2641/5000], Loss: 2.6536
Epoch [2642/5000], Loss: 2.6435
Epoch [2643/5000], Loss: 2.6476
Epoch [2644/5000], Loss: 2.6457
Epoch [2645/5000], Loss: 2.6460
Epoch [2646/5000], Loss: 2.6287
Epoch [2647/5000], Loss: 2.6328
Epoch [2648/5000], Loss: 2.6196
Epoch [2649/5000], Loss: 2.6410
Epoch [2650/5000], Loss: 2.6382
Epoch [2651/5000], Loss: 2.6370
Epoch [2652/5000], Loss: 2.6195
Epoch [2653/5000], Loss: 2.6079
Epoch [2654/5000], Loss: 2.6136
Epoch [2655/5000], Loss: 2.6232
Epoch [2656/5000], Loss: 2.6151
Epoch [2657/5000], Loss: 2.6150
Epoch [2658/5000], Loss: 2.6048
Epoch [2659/5000], Loss: 2.6106
Epoch [2660/5000], Loss: 2.6239
Epoch [2661/5000], Loss: 2.6082
Epoch [2662/5000], Loss: 2.5965
Epoch [2

Epoch [2890/5000], Loss: 2.6166
Epoch [2891/5000], Loss: 2.5829
Epoch [2892/5000], Loss: 2.6050
Epoch [2893/5000], Loss: 2.5961
Epoch [2894/5000], Loss: 2.5959
Epoch [2895/5000], Loss: 2.5945
Epoch [2896/5000], Loss: 2.5991
Epoch [2897/5000], Loss: 2.6016
Epoch [2898/5000], Loss: 2.5866
Epoch [2899/5000], Loss: 2.5871
Epoch [2900/5000], Loss: 2.5888
Epoch [2901/5000], Loss: 2.5777
Epoch [2902/5000], Loss: 2.5855
Epoch [2903/5000], Loss: 2.6055
Epoch [2904/5000], Loss: 2.5930
Epoch [2905/5000], Loss: 2.5891
Epoch [2906/5000], Loss: 2.5779
Epoch [2907/5000], Loss: 2.5783
Epoch [2908/5000], Loss: 2.5757
Epoch [2909/5000], Loss: 2.5833
Epoch [2910/5000], Loss: 2.5844
Epoch [2911/5000], Loss: 2.5830
Epoch [2912/5000], Loss: 2.5758
Epoch [2913/5000], Loss: 2.5710
Epoch [2914/5000], Loss: 2.5688
Epoch [2915/5000], Loss: 2.5594
Epoch [2916/5000], Loss: 2.5494
Epoch [2917/5000], Loss: 2.5490
Epoch [2918/5000], Loss: 2.5603
Epoch [2919/5000], Loss: 2.5548
Epoch [2920/5000], Loss: 2.5508
Epoch [2

Epoch [3149/5000], Loss: 2.4726
Epoch [3150/5000], Loss: 2.4534
Epoch [3151/5000], Loss: 2.4529
Epoch [3152/5000], Loss: 2.4463
Epoch [3153/5000], Loss: 2.4616
Epoch [3154/5000], Loss: 2.4558
Epoch [3155/5000], Loss: 2.4522
Epoch [3156/5000], Loss: 2.4629
Epoch [3157/5000], Loss: 2.4612
Epoch [3158/5000], Loss: 2.4515
Epoch [3159/5000], Loss: 2.4631
Epoch [3160/5000], Loss: 2.4548
Epoch [3161/5000], Loss: 2.4540
Epoch [3162/5000], Loss: 2.4720
Epoch [3163/5000], Loss: 2.4645
Epoch [3164/5000], Loss: 2.4732
Epoch [3165/5000], Loss: 2.4552
Epoch [3166/5000], Loss: 2.4722
Epoch [3167/5000], Loss: 2.4722
Epoch [3168/5000], Loss: 2.4453
Epoch [3169/5000], Loss: 2.4591
Epoch [3170/5000], Loss: 2.4590
Epoch [3171/5000], Loss: 2.4595
Epoch [3172/5000], Loss: 2.4483
Epoch [3173/5000], Loss: 2.4652
Epoch [3174/5000], Loss: 2.4734
Epoch [3175/5000], Loss: 2.4632
Epoch [3176/5000], Loss: 2.4567
Epoch [3177/5000], Loss: 2.4640
Epoch [3178/5000], Loss: 2.4522
Epoch [3179/5000], Loss: 2.4455
Epoch [3

Epoch [3406/5000], Loss: 2.5874
Epoch [3407/5000], Loss: 2.6089
Epoch [3408/5000], Loss: 2.5889
Epoch [3409/5000], Loss: 2.5844
Epoch [3410/5000], Loss: 2.5762
Epoch [3411/5000], Loss: 2.6148
Epoch [3412/5000], Loss: 2.5866
Epoch [3413/5000], Loss: 2.5781
Epoch [3414/5000], Loss: 2.6004
Epoch [3415/5000], Loss: 2.6024
Epoch [3416/5000], Loss: 2.5757
Epoch [3417/5000], Loss: 2.6000
Epoch [3418/5000], Loss: 2.5831
Epoch [3419/5000], Loss: 2.5754
Epoch [3420/5000], Loss: 2.5998
Epoch [3421/5000], Loss: 2.5771
Epoch [3422/5000], Loss: 2.5888
Epoch [3423/5000], Loss: 2.5857
Epoch [3424/5000], Loss: 2.5668
Epoch [3425/5000], Loss: 2.5714
Epoch [3426/5000], Loss: 2.5785
Epoch [3427/5000], Loss: 2.5740
Epoch [3428/5000], Loss: 2.5787
Epoch [3429/5000], Loss: 2.5799
Epoch [3430/5000], Loss: 2.5496
Epoch [3431/5000], Loss: 2.5785
Epoch [3432/5000], Loss: 2.5619
Epoch [3433/5000], Loss: 2.5499
Epoch [3434/5000], Loss: 2.5690
Epoch [3435/5000], Loss: 2.5678
Epoch [3436/5000], Loss: 2.5590
Epoch [3

Epoch [3664/5000], Loss: 2.5338
Epoch [3665/5000], Loss: 2.5307
Epoch [3666/5000], Loss: 2.5268
Epoch [3667/5000], Loss: 2.5306
Epoch [3668/5000], Loss: 2.5089
Epoch [3669/5000], Loss: 2.5292
Epoch [3670/5000], Loss: 2.5005
Epoch [3671/5000], Loss: 2.4873
Epoch [3672/5000], Loss: 2.4998
Epoch [3673/5000], Loss: 2.4978
Epoch [3674/5000], Loss: 2.4979
Epoch [3675/5000], Loss: 2.4969
Epoch [3676/5000], Loss: 2.5238
Epoch [3677/5000], Loss: 2.5027
Epoch [3678/5000], Loss: 2.5058
Epoch [3679/5000], Loss: 2.5032
Epoch [3680/5000], Loss: 2.4850
Epoch [3681/5000], Loss: 2.4760
Epoch [3682/5000], Loss: 2.5012
Epoch [3683/5000], Loss: 2.5067
Epoch [3684/5000], Loss: 2.5022
Epoch [3685/5000], Loss: 2.5089
Epoch [3686/5000], Loss: 2.4924
Epoch [3687/5000], Loss: 2.4880
Epoch [3688/5000], Loss: 2.4770
Epoch [3689/5000], Loss: 2.4805
Epoch [3690/5000], Loss: 2.4991
Epoch [3691/5000], Loss: 2.4873
Epoch [3692/5000], Loss: 2.4774
Epoch [3693/5000], Loss: 2.4607
Epoch [3694/5000], Loss: 2.4675
Epoch [3

Epoch [3924/5000], Loss: 2.5039
Epoch [3925/5000], Loss: 2.5049
Epoch [3926/5000], Loss: 2.5175
Epoch [3927/5000], Loss: 2.4961
Epoch [3928/5000], Loss: 2.5039
Epoch [3929/5000], Loss: 2.4998
Epoch [3930/5000], Loss: 2.5099
Epoch [3931/5000], Loss: 2.4956
Epoch [3932/5000], Loss: 2.4800
Epoch [3933/5000], Loss: 2.4930
Epoch [3934/5000], Loss: 2.5050
Epoch [3935/5000], Loss: 2.4973
Epoch [3936/5000], Loss: 2.5045
Epoch [3937/5000], Loss: 2.4917
Epoch [3938/5000], Loss: 2.4801
Epoch [3939/5000], Loss: 2.4912
Epoch [3940/5000], Loss: 2.5142
Epoch [3941/5000], Loss: 2.5023
Epoch [3942/5000], Loss: 2.5259
Epoch [3943/5000], Loss: 2.5009
Epoch [3944/5000], Loss: 2.5024
Epoch [3945/5000], Loss: 2.4933
Epoch [3946/5000], Loss: 2.4974
Epoch [3947/5000], Loss: 2.4831
Epoch [3948/5000], Loss: 2.5180
Epoch [3949/5000], Loss: 2.4823
Epoch [3950/5000], Loss: 2.5047
Epoch [3951/5000], Loss: 2.5174
Epoch [3952/5000], Loss: 2.4974
Epoch [3953/5000], Loss: 2.5057
Epoch [3954/5000], Loss: 2.4817
Epoch [3

Epoch [4184/5000], Loss: 2.6130
Epoch [4185/5000], Loss: 2.6027
Epoch [4186/5000], Loss: 2.6145
Epoch [4187/5000], Loss: 2.5787
Epoch [4188/5000], Loss: 2.5690
Epoch [4189/5000], Loss: 2.5972
Epoch [4190/5000], Loss: 2.6240
Epoch [4191/5000], Loss: 2.5905
Epoch [4192/5000], Loss: 2.6253
Epoch [4193/5000], Loss: 2.6083
Epoch [4194/5000], Loss: 2.6231
Epoch [4195/5000], Loss: 2.6167
Epoch [4196/5000], Loss: 2.5731
Epoch [4197/5000], Loss: 2.6660
Epoch [4198/5000], Loss: 2.5951
Epoch [4199/5000], Loss: 2.6543
Epoch [4200/5000], Loss: 2.6695
Epoch [4201/5000], Loss: 2.6707
Epoch [4202/5000], Loss: 2.6371
Epoch [4203/5000], Loss: 2.6293
Epoch [4204/5000], Loss: 2.5897
Epoch [4205/5000], Loss: 2.6696
Epoch [4206/5000], Loss: 2.5867
Epoch [4207/5000], Loss: 2.6531
Epoch [4208/5000], Loss: 2.6380
Epoch [4209/5000], Loss: 2.6529
Epoch [4210/5000], Loss: 2.6252
Epoch [4211/5000], Loss: 2.5636
Epoch [4212/5000], Loss: 2.6598
Epoch [4213/5000], Loss: 2.5682
Epoch [4214/5000], Loss: 2.6212
Epoch [4

Epoch [4442/5000], Loss: 2.3826
Epoch [4443/5000], Loss: 2.3787
Epoch [4444/5000], Loss: 2.4013
Epoch [4445/5000], Loss: 2.3652
Epoch [4446/5000], Loss: 2.3928
Epoch [4447/5000], Loss: 2.4343
Epoch [4448/5000], Loss: 2.4428
Epoch [4449/5000], Loss: 2.4980
Epoch [4450/5000], Loss: 2.5336
Epoch [4451/5000], Loss: 2.5251
Epoch [4452/5000], Loss: 2.4065
Epoch [4453/5000], Loss: 2.9046
Epoch [4454/5000], Loss: 2.4924
Epoch [4455/5000], Loss: 2.5219
Epoch [4456/5000], Loss: 2.5183
Epoch [4457/5000], Loss: 2.5075
Epoch [4458/5000], Loss: 2.5007
Epoch [4459/5000], Loss: 2.5121
Epoch [4460/5000], Loss: 2.4966
Epoch [4461/5000], Loss: 2.5090
Epoch [4462/5000], Loss: 2.5231
Epoch [4463/5000], Loss: 2.5054
Epoch [4464/5000], Loss: 2.5049
Epoch [4465/5000], Loss: 2.4967
Epoch [4466/5000], Loss: 2.4947
Epoch [4467/5000], Loss: 2.4815
Epoch [4468/5000], Loss: 2.4948
Epoch [4469/5000], Loss: 2.4749
Epoch [4470/5000], Loss: 2.5085
Epoch [4471/5000], Loss: 2.5219
Epoch [4472/5000], Loss: 2.5349
Epoch [4

Epoch [4704/5000], Loss: 2.7028
Epoch [4705/5000], Loss: 2.7299
Epoch [4706/5000], Loss: 2.6943
Epoch [4707/5000], Loss: 2.7085
Epoch [4708/5000], Loss: 2.7074
Epoch [4709/5000], Loss: 2.7283
Epoch [4710/5000], Loss: 2.7362
Epoch [4711/5000], Loss: 2.7444
Epoch [4712/5000], Loss: 2.7594
Epoch [4713/5000], Loss: 2.7578
Epoch [4714/5000], Loss: 2.7614
Epoch [4715/5000], Loss: 2.7899
Epoch [4716/5000], Loss: 2.7779
Epoch [4717/5000], Loss: 2.7845
Epoch [4718/5000], Loss: 2.7842
Epoch [4719/5000], Loss: 2.7707
Epoch [4720/5000], Loss: 2.7762
Epoch [4721/5000], Loss: 2.7740
Epoch [4722/5000], Loss: 2.7579
Epoch [4723/5000], Loss: 2.7391
Epoch [4724/5000], Loss: 2.8022
Epoch [4725/5000], Loss: 2.7152
Epoch [4726/5000], Loss: 2.7661
Epoch [4727/5000], Loss: 2.7673
Epoch [4728/5000], Loss: 2.7297
Epoch [4729/5000], Loss: 2.7365
Epoch [4730/5000], Loss: 2.7269
Epoch [4731/5000], Loss: 2.7747
Epoch [4732/5000], Loss: 2.7368
Epoch [4733/5000], Loss: 2.7302
Epoch [4734/5000], Loss: 2.7585
Epoch [4

Epoch [4961/5000], Loss: 2.5567
Epoch [4962/5000], Loss: 2.5560
Epoch [4963/5000], Loss: 2.5383
Epoch [4964/5000], Loss: 2.5319
Epoch [4965/5000], Loss: 2.5237
Epoch [4966/5000], Loss: 2.5431
Epoch [4967/5000], Loss: 2.5406
Epoch [4968/5000], Loss: 2.5477
Epoch [4969/5000], Loss: 2.5564
Epoch [4970/5000], Loss: 2.5469
Epoch [4971/5000], Loss: 2.5302
Epoch [4972/5000], Loss: 2.5376
Epoch [4973/5000], Loss: 2.5492
Epoch [4974/5000], Loss: 2.5274
Epoch [4975/5000], Loss: 2.5393
Epoch [4976/5000], Loss: 2.5480
Epoch [4977/5000], Loss: 2.5427
Epoch [4978/5000], Loss: 2.5362
Epoch [4979/5000], Loss: 2.5402
Epoch [4980/5000], Loss: 2.5342
Epoch [4981/5000], Loss: 2.5437
Epoch [4982/5000], Loss: 2.5403
Epoch [4983/5000], Loss: 2.5392
Epoch [4984/5000], Loss: 2.5563
Epoch [4985/5000], Loss: 2.5284
Epoch [4986/5000], Loss: 2.5353
Epoch [4987/5000], Loss: 2.5497
Epoch [4988/5000], Loss: 2.5356
Epoch [4989/5000], Loss: 2.5316
Epoch [4990/5000], Loss: 2.5399
Epoch [4991/5000], Loss: 2.5333
Epoch [4

In [16]:
def predict_accuracy(val_x, val_y):
    outputs = net(val_x)
    _, predicted = torch.max(outputs, 1)
    correct = (predicted == val_y).sum().item()
    print("Accuracy = {}%".format(100 *(float(correct/val_x.shape[0]))))

In [17]:
#outputs = net(train_x)

In [18]:
#_, predicted = torch.max(outputs, 1)

In [19]:
#print(predicted)

In [20]:
#correct = (predicted == train_y).sum().item()

In [21]:
#print("Accuracy = {}%".format(100 *(float(correct/train_data.shape[0]))))

In [22]:
predict_accuracy(train_x, train_y)
predict_accuracy(train_val_x, train_val_y)

Accuracy = 40.83063646170442%
Accuracy = 40.94827586206897%


In [None]:
def generate_labels(test_x):
    outputs = net(test_x)
    _, predictions = torch.max(outputs, 1)
    return predictions

In [None]:
torch.save(net, "trained_model_10000_epochs.pt")

In [None]:
netLOADER = torch.load("trained_model_10000_epochs.pt")

In [None]:
test_labels = generate_labels(test_x)
print(test_labels)

In [None]:
test_data[[23]] = test_labels
test_data

In [None]:
test_data.to_csv("Labelled_prev_unlabelBCLL.txt", sep=" ", header=False, index=False)