In [1]:
#%pip install sklearn
#%pip install torch

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

## project structure
DATA_DIR = "/data/projects/capturingBias/research/framing/data/" 
DATA_NPZ = DATA_DIR + "data2021.npz"
DATA_NEW_NPZ = DATA_DIR + "dataReduced2021.npz"


## load files
data = np.load(DATA_NPZ)

X_2D = data['X_2D']
X_2D = torch.from_numpy(X_2D)

In [2]:
def set_seed(seed=-1):
    if seed < 0:
        seed = np.random.randint(0, 2**32-1)

    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    return seed
    
print(set_seed())  # make reproducable

1703112340


In [3]:
class autoencoder(nn.Module):                    
    def __init__(self, input_dim, embedding_dim, num_layers=3, dropout=0.1):
        super().__init__()
        
        steps = (input_dim + embedding_dim) // (num_layers)
        nodes_per_layer = [input_dim-i*steps for i in range(num_layers)]
        nodes_per_layer.append(embedding_dim)
                
        encoder_layers = []
        for i in range(len(nodes_per_layer)-1):
            encoder_layers.extend([nn.Linear(nodes_per_layer[i], nodes_per_layer[i+1]),
                                   nn.BatchNorm1d(nodes_per_layer[i+1]),
                                   nn.ReLU(),
                                   nn.Dropout(dropout)])

            
        decoder_layers = []
        for i in range(len(nodes_per_layer)-1):
            decoder_layers.extend([nn.Linear(nodes_per_layer[-(i+1)], nodes_per_layer[-(i+2)]),
                                   nn.BatchNorm1d(nodes_per_layer[-(i+2)]),
                                   nn.ReLU(),
                                   nn.Dropout(dropout)])
        
        self.encoder = nn.Sequential(*encoder_layers) 
        self.decoder = nn.Sequential(*decoder_layers)

    def forward(self, X): 
        embeddings = self.encoder(X)
        reconstruction = self.decoder(embeddings)
        
        return (embeddings, reconstruction)

In [4]:
def fit(model, data, lr=0.001, l2norm=0.001, n_epoch=250):
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2norm)
    criterion = nn.MSELoss() 
    
    num_samples = data.shape[0]
    idx = np.arange(num_samples)
    
    best_model = (None, 99)
    
    model.train()
    for epoch in range(n_epoch):
        np.random.shuffle(idx)
        X = data[idx]  # feed data in random order
        
        _, X_reconstructed = model(X)
        loss = criterion(X_reconstructed, X)
        
        if loss.item() < best_model[1]:
            best_model = (model.state_dict(), float(loss.item()))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss = float(loss)
        print("epoch %d \t loss: %f" % (epoch, loss))
        
    model.train(False)
    return best_model

In [5]:
num_features = X_2D.shape[1]
ae = autoencoder(input_dim=num_features, embedding_dim=10, dropout=0.0, num_layers=3)
state, loss = fit(ae, X_2D.float(), n_epoch=3000)

print('loss: %f' % loss)
ae.load_state_dict(state)
with torch.no_grad():
    X_2D_reduced, _ = ae(X_2D.float())

epoch 0 	 loss: 2.500864
epoch 1 	 loss: 2.371450
epoch 2 	 loss: 2.292778
epoch 3 	 loss: 2.225749
epoch 4 	 loss: 2.165858
epoch 5 	 loss: 2.114548
epoch 6 	 loss: 2.070101
epoch 7 	 loss: 2.031875
epoch 8 	 loss: 1.998143
epoch 9 	 loss: 1.969283
epoch 10 	 loss: 1.944115
epoch 11 	 loss: 1.921941
epoch 12 	 loss: 1.902422
epoch 13 	 loss: 1.883923
epoch 14 	 loss: 1.866719
epoch 15 	 loss: 1.851180
epoch 16 	 loss: 1.836712
epoch 17 	 loss: 1.822829
epoch 18 	 loss: 1.809759
epoch 19 	 loss: 1.797550
epoch 20 	 loss: 1.785917
epoch 21 	 loss: 1.774446
epoch 22 	 loss: 1.763604
epoch 23 	 loss: 1.753079
epoch 24 	 loss: 1.742840
epoch 25 	 loss: 1.732898
epoch 26 	 loss: 1.723591
epoch 27 	 loss: 1.714260
epoch 28 	 loss: 1.705227
epoch 29 	 loss: 1.696725
epoch 30 	 loss: 1.688191
epoch 31 	 loss: 1.679738
epoch 32 	 loss: 1.671092
epoch 33 	 loss: 1.662647
epoch 34 	 loss: 1.654664
epoch 35 	 loss: 1.646685
epoch 36 	 loss: 1.638702
epoch 37 	 loss: 1.630872
epoch 38 	 loss: 1.623

epoch 309 	 loss: 1.204053
epoch 310 	 loss: 1.203879
epoch 311 	 loss: 1.203719
epoch 312 	 loss: 1.203589
epoch 313 	 loss: 1.203464
epoch 314 	 loss: 1.203340
epoch 315 	 loss: 1.203231
epoch 316 	 loss: 1.203099
epoch 317 	 loss: 1.202972
epoch 318 	 loss: 1.202939
epoch 319 	 loss: 1.202827
epoch 320 	 loss: 1.202737
epoch 321 	 loss: 1.202655
epoch 322 	 loss: 1.202563
epoch 323 	 loss: 1.202501
epoch 324 	 loss: 1.202427
epoch 325 	 loss: 1.202361
epoch 326 	 loss: 1.202285
epoch 327 	 loss: 1.202244
epoch 328 	 loss: 1.202166
epoch 329 	 loss: 1.202123
epoch 330 	 loss: 1.202084
epoch 331 	 loss: 1.202057
epoch 332 	 loss: 1.201997
epoch 333 	 loss: 1.201960
epoch 334 	 loss: 1.201907
epoch 335 	 loss: 1.201881
epoch 336 	 loss: 1.201831
epoch 337 	 loss: 1.201802
epoch 338 	 loss: 1.201773
epoch 339 	 loss: 1.201755
epoch 340 	 loss: 1.201722
epoch 341 	 loss: 1.201703
epoch 342 	 loss: 1.201695
epoch 343 	 loss: 1.201651
epoch 344 	 loss: 1.201635
epoch 345 	 loss: 1.201614
e

epoch 615 	 loss: 1.088161
epoch 616 	 loss: 1.087956
epoch 617 	 loss: 1.087755
epoch 618 	 loss: 1.087569
epoch 619 	 loss: 1.087389
epoch 620 	 loss: 1.087217
epoch 621 	 loss: 1.087043
epoch 622 	 loss: 1.086876
epoch 623 	 loss: 1.086714
epoch 624 	 loss: 1.086558
epoch 625 	 loss: 1.086405
epoch 626 	 loss: 1.086260
epoch 627 	 loss: 1.086124
epoch 628 	 loss: 1.085989
epoch 629 	 loss: 1.085860
epoch 630 	 loss: 1.085734
epoch 631 	 loss: 1.085613
epoch 632 	 loss: 1.085495
epoch 633 	 loss: 1.085382
epoch 634 	 loss: 1.085273
epoch 635 	 loss: 1.085170
epoch 636 	 loss: 1.085074
epoch 637 	 loss: 1.084982
epoch 638 	 loss: 1.084887
epoch 639 	 loss: 1.084797
epoch 640 	 loss: 1.084705
epoch 641 	 loss: 1.084615
epoch 642 	 loss: 1.084528
epoch 643 	 loss: 1.084445
epoch 644 	 loss: 1.084363
epoch 645 	 loss: 1.084283
epoch 646 	 loss: 1.084207
epoch 647 	 loss: 1.084133
epoch 648 	 loss: 1.084060
epoch 649 	 loss: 1.083990
epoch 650 	 loss: 1.083923
epoch 651 	 loss: 1.083857
e

epoch 920 	 loss: 1.049963
epoch 921 	 loss: 1.049937
epoch 922 	 loss: 1.049913
epoch 923 	 loss: 1.049889
epoch 924 	 loss: 1.049866
epoch 925 	 loss: 1.049844
epoch 926 	 loss: 1.049822
epoch 927 	 loss: 1.049801
epoch 928 	 loss: 1.049781
epoch 929 	 loss: 1.049761
epoch 930 	 loss: 1.049742
epoch 931 	 loss: 1.049723
epoch 932 	 loss: 1.049705
epoch 933 	 loss: 1.049687
epoch 934 	 loss: 1.049670
epoch 935 	 loss: 1.049653
epoch 936 	 loss: 1.049637
epoch 937 	 loss: 1.049622
epoch 938 	 loss: 1.049606
epoch 939 	 loss: 1.049591
epoch 940 	 loss: 1.049577
epoch 941 	 loss: 1.049563
epoch 942 	 loss: 1.049549
epoch 943 	 loss: 1.049536
epoch 944 	 loss: 1.049523
epoch 945 	 loss: 1.049511
epoch 946 	 loss: 1.049499
epoch 947 	 loss: 1.049487
epoch 948 	 loss: 1.049475
epoch 949 	 loss: 1.049464
epoch 950 	 loss: 1.049453
epoch 951 	 loss: 1.049443
epoch 952 	 loss: 1.049433
epoch 953 	 loss: 1.049422
epoch 954 	 loss: 1.049413
epoch 955 	 loss: 1.049403
epoch 956 	 loss: 1.049394
e

epoch 1216 	 loss: 1.039265
epoch 1217 	 loss: 1.039249
epoch 1218 	 loss: 1.039235
epoch 1219 	 loss: 1.039220
epoch 1220 	 loss: 1.039206
epoch 1221 	 loss: 1.039193
epoch 1222 	 loss: 1.039180
epoch 1223 	 loss: 1.039168
epoch 1224 	 loss: 1.039156
epoch 1225 	 loss: 1.039144
epoch 1226 	 loss: 1.039133
epoch 1227 	 loss: 1.039122
epoch 1228 	 loss: 1.039112
epoch 1229 	 loss: 1.039102
epoch 1230 	 loss: 1.039092
epoch 1231 	 loss: 1.039083
epoch 1232 	 loss: 1.039074
epoch 1233 	 loss: 1.039065
epoch 1234 	 loss: 1.039056
epoch 1235 	 loss: 1.039048
epoch 1236 	 loss: 1.039040
epoch 1237 	 loss: 1.039032
epoch 1238 	 loss: 1.039025
epoch 1239 	 loss: 1.039018
epoch 1240 	 loss: 1.039011
epoch 1241 	 loss: 1.039004
epoch 1242 	 loss: 1.038997
epoch 1243 	 loss: 1.038991
epoch 1244 	 loss: 1.038985
epoch 1245 	 loss: 1.038978
epoch 1246 	 loss: 1.038972
epoch 1247 	 loss: 1.038967
epoch 1248 	 loss: 1.038961
epoch 1249 	 loss: 1.038956
epoch 1250 	 loss: 1.038951
epoch 1251 	 loss: 1

epoch 1510 	 loss: 1.042047
epoch 1511 	 loss: 1.042174
epoch 1512 	 loss: 1.042319
epoch 1513 	 loss: 1.042537
epoch 1514 	 loss: 1.042733
epoch 1515 	 loss: 1.042723
epoch 1516 	 loss: 1.042759
epoch 1517 	 loss: 1.042945
epoch 1518 	 loss: 1.042783
epoch 1519 	 loss: 1.042474
epoch 1520 	 loss: 1.042557
epoch 1521 	 loss: 1.042435
epoch 1522 	 loss: 1.042315
epoch 1523 	 loss: 1.042239
epoch 1524 	 loss: 1.042192
epoch 1525 	 loss: 1.042184
epoch 1526 	 loss: 1.042043
epoch 1527 	 loss: 1.041930
epoch 1528 	 loss: 1.041859
epoch 1529 	 loss: 1.041797
epoch 1530 	 loss: 1.041891
epoch 1531 	 loss: 1.041443
epoch 1532 	 loss: 1.041404
epoch 1533 	 loss: 1.041352
epoch 1534 	 loss: 1.041231
epoch 1535 	 loss: 1.041213
epoch 1536 	 loss: 1.040990
epoch 1537 	 loss: 1.040935
epoch 1538 	 loss: 1.040827
epoch 1539 	 loss: 1.040795
epoch 1540 	 loss: 1.040782
epoch 1541 	 loss: 1.040755
epoch 1542 	 loss: 1.040776
epoch 1543 	 loss: 1.040791
epoch 1544 	 loss: 1.040627
epoch 1545 	 loss: 1

epoch 1806 	 loss: 1.038912
epoch 1807 	 loss: 1.038754
epoch 1808 	 loss: 1.038769
epoch 1809 	 loss: 1.038825
epoch 1810 	 loss: 1.038776
epoch 1811 	 loss: 1.038781
epoch 1812 	 loss: 1.038722
epoch 1813 	 loss: 1.038699
epoch 1814 	 loss: 1.038738
epoch 1815 	 loss: 1.038715
epoch 1816 	 loss: 1.038819
epoch 1817 	 loss: 1.038649
epoch 1818 	 loss: 1.038710
epoch 1819 	 loss: 1.038655
epoch 1820 	 loss: 1.038687
epoch 1821 	 loss: 1.038655
epoch 1822 	 loss: 1.038712
epoch 1823 	 loss: 1.038713
epoch 1824 	 loss: 1.038762
epoch 1825 	 loss: 1.038639
epoch 1826 	 loss: 1.038647
epoch 1827 	 loss: 1.038488
epoch 1828 	 loss: 1.038543
epoch 1829 	 loss: 1.038637
epoch 1830 	 loss: 1.038531
epoch 1831 	 loss: 1.038514
epoch 1832 	 loss: 1.038498
epoch 1833 	 loss: 1.038500
epoch 1834 	 loss: 1.038521
epoch 1835 	 loss: 1.038722
epoch 1836 	 loss: 1.038781
epoch 1837 	 loss: 1.038652
epoch 1838 	 loss: 1.038763
epoch 1839 	 loss: 1.038659
epoch 1840 	 loss: 1.039024
epoch 1841 	 loss: 1

epoch 2101 	 loss: 1.039309
epoch 2102 	 loss: 1.039373
epoch 2103 	 loss: 1.039352
epoch 2104 	 loss: 1.039319
epoch 2105 	 loss: 1.039381
epoch 2106 	 loss: 1.039442
epoch 2107 	 loss: 1.039400
epoch 2108 	 loss: 1.039381
epoch 2109 	 loss: 1.039376
epoch 2110 	 loss: 1.039157
epoch 2111 	 loss: 1.039149
epoch 2112 	 loss: 1.039218
epoch 2113 	 loss: 1.039231
epoch 2114 	 loss: 1.039146
epoch 2115 	 loss: 1.039088
epoch 2116 	 loss: 1.039075
epoch 2117 	 loss: 1.038910
epoch 2118 	 loss: 1.038931
epoch 2119 	 loss: 1.038881
epoch 2120 	 loss: 1.038797
epoch 2121 	 loss: 1.038775
epoch 2122 	 loss: 1.038617
epoch 2123 	 loss: 1.038708
epoch 2124 	 loss: 1.038647
epoch 2125 	 loss: 1.038514
epoch 2126 	 loss: 1.038549
epoch 2127 	 loss: 1.038316
epoch 2128 	 loss: 1.038481
epoch 2129 	 loss: 1.038320
epoch 2130 	 loss: 1.038371
epoch 2131 	 loss: 1.038230
epoch 2132 	 loss: 1.038227
epoch 2133 	 loss: 1.038135
epoch 2134 	 loss: 1.038195
epoch 2135 	 loss: 1.038097
epoch 2136 	 loss: 1

epoch 2396 	 loss: 1.037745
epoch 2397 	 loss: 1.037828
epoch 2398 	 loss: 1.037834
epoch 2399 	 loss: 1.037774
epoch 2400 	 loss: 1.037692
epoch 2401 	 loss: 1.037686
epoch 2402 	 loss: 1.037721
epoch 2403 	 loss: 1.037777
epoch 2404 	 loss: 1.037745
epoch 2405 	 loss: 1.037740
epoch 2406 	 loss: 1.037807
epoch 2407 	 loss: 1.037928
epoch 2408 	 loss: 1.038063
epoch 2409 	 loss: 1.038029
epoch 2410 	 loss: 1.038103
epoch 2411 	 loss: 1.038159
epoch 2412 	 loss: 1.038041
epoch 2413 	 loss: 1.038124
epoch 2414 	 loss: 1.038208
epoch 2415 	 loss: 1.038198
epoch 2416 	 loss: 1.038141
epoch 2417 	 loss: 1.038223
epoch 2418 	 loss: 1.038081
epoch 2419 	 loss: 1.038167
epoch 2420 	 loss: 1.038110
epoch 2421 	 loss: 1.038246
epoch 2422 	 loss: 1.038095
epoch 2423 	 loss: 1.038204
epoch 2424 	 loss: 1.038194
epoch 2425 	 loss: 1.038279
epoch 2426 	 loss: 1.038170
epoch 2427 	 loss: 1.038255
epoch 2428 	 loss: 1.038262
epoch 2429 	 loss: 1.038346
epoch 2430 	 loss: 1.038226
epoch 2431 	 loss: 1

epoch 2690 	 loss: 1.038005
epoch 2691 	 loss: 1.037970
epoch 2692 	 loss: 1.037804
epoch 2693 	 loss: 1.037982
epoch 2694 	 loss: 1.037862
epoch 2695 	 loss: 1.037872
epoch 2696 	 loss: 1.037870
epoch 2697 	 loss: 1.038043
epoch 2698 	 loss: 1.038411
epoch 2699 	 loss: 1.038924
epoch 2700 	 loss: 1.040194
epoch 2701 	 loss: 1.041528
epoch 2702 	 loss: 1.041860
epoch 2703 	 loss: 1.042477
epoch 2704 	 loss: 1.042915
epoch 2705 	 loss: 1.042484
epoch 2706 	 loss: 1.043469
epoch 2707 	 loss: 1.043468
epoch 2708 	 loss: 1.044073
epoch 2709 	 loss: 1.043613
epoch 2710 	 loss: 1.045096
epoch 2711 	 loss: 1.045214
epoch 2712 	 loss: 1.045928
epoch 2713 	 loss: 1.046395
epoch 2714 	 loss: 1.046375
epoch 2715 	 loss: 1.045910
epoch 2716 	 loss: 1.045554
epoch 2717 	 loss: 1.044534
epoch 2718 	 loss: 1.044151
epoch 2719 	 loss: 1.043480
epoch 2720 	 loss: 1.043399
epoch 2721 	 loss: 1.042803
epoch 2722 	 loss: 1.042243
epoch 2723 	 loss: 1.041769
epoch 2724 	 loss: 1.041334
epoch 2725 	 loss: 1

epoch 2986 	 loss: 1.036018
epoch 2987 	 loss: 1.036016
epoch 2988 	 loss: 1.036010
epoch 2989 	 loss: 1.036009
epoch 2990 	 loss: 1.036010
epoch 2991 	 loss: 1.036016
epoch 2992 	 loss: 1.036017
epoch 2993 	 loss: 1.036012
epoch 2994 	 loss: 1.036013
epoch 2995 	 loss: 1.036020
epoch 2996 	 loss: 1.036025
epoch 2997 	 loss: 1.036033
epoch 2998 	 loss: 1.036040
epoch 2999 	 loss: 1.036034
loss: 1.035702


In [6]:
np.savez_compressed(DATA_NEW_NPZ,
                    X_2D_reduced = X_2D_reduced.numpy())

In [11]:
# the worst possible MSE
f_loss = nn.MSELoss()
loss = 0.0
for _ in range(10):
    noise = torch.rand(X_2D.shape)
    loss += f_loss(noise, X_2D).item()
print(loss/10)

2.3424200378500233
