In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ff import FF
from data import MNIST

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size_train = 32
batch_size_test = 32

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

In [4]:
test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [5]:
config = {
    "in_dims": [784, 500, 500],
    "out_dims": [500, 500, 500],
    "epochs": 50,
    "threshold": 1.5,
}
num_layers = 3

In [6]:
def overlay_y_on_x(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_ = x_.view(-1, 784)
    #x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = 1
    #print(x.max())
    return x_

In [7]:
model = FF(num_layers, config, torch.optim.Adam)

In [8]:
model.train()
for i, (x, y) in enumerate(train_loader):
    x_pos, _ = MNIST.overlay_y_on_x(x, y)
    rnd = torch.randperm(x.size(0))
    x_neg, _ = MNIST.overlay_y_on_x(x, y[rnd])
    #print(x_pos.requires_grad, x_neg.requires_grad)
    losses = model(x_pos, x_neg)
    print(losses)

0.8904549996058146
0.7119278844197591
0.6968376847108205
0.6945142245292665
0.6943762091795603
0.6944306111335754
0.6936002266407013
0.6934893743197122
0.6935726590951283
0.6935866951942443
0.693378122250239
0.693519124587377
0.6930455636978149
0.6933305589358012
0.6929041004180908
0.6933121689160665
0.6935051433245341
0.6932327262560527
0.6926033023993176
0.6925544202327728
0.6914953148365021
0.6903269414107006
0.6906576192378998
0.6894031997521718
0.687938203016917
0.6886534182230631
0.6881546227137249
0.6873592774073284
0.6856378269195557
0.6850971853733062
0.6850525907675425
0.6903017214934031
0.6832916494210561
0.6836666981379191
0.6782299606005351
0.6793849643071491
0.6884123639265697
0.6772475266456603
0.6765178402264912
0.6751298622290293
0.6768537716070812
0.6851858341693878
0.6735738968849182
0.6821114412943521
0.6769849793116252
0.6706606034437815
0.674903666973114
0.6665087751547495
0.6698934630552927
0.6700105202198028
0.6736619436740875
0.6550097382068635
0.67373182257016

KeyboardInterrupt: 

In [11]:
predictions, real = MNIST.predict(test_loader, model)
print("Accuracy ", np.sum(predictions == real)/len(real))

Accuracy  0.7441


In [None]:
class Net(nn.Module):
    
    def __init__(self):
        
        super(Net, self).__init__()
        self.l = nn.Linear(784, 10)
        
    def forward(self, x):
        #print(self.l(x))
        return F.relu(self.l(x))

In [None]:
net = Net()

In [None]:
net.train()
criterion = torch.nn.MSELoss()
optim = torch.optim.Adam(net.parameters())
for i, (x, y) in enumerate(train_loader):
    optim.zero_grad()
    out = net(x.view(-1, 784))
    temp = np.zeros((32, 10))
    for i in range(32):
        temp[i][y[i].item()] = 1
    loss = criterion(out, torch.Tensor(temp))
    print(loss.item())
    loss.backward()
    optim.step()

0.30330196022987366
0.13947108387947083
0.0982712134718895
0.10159628093242645
0.09316181391477585
0.09114620834589005
0.09905010461807251
0.09172849357128143
0.09627287089824677
0.094866082072258
0.08204140514135361
0.09555324167013168
0.08723263442516327
0.0862930566072464
0.09281317889690399
0.09035802632570267
0.0991055816411972
0.09251095354557037
0.09062061458826065
0.08190781623125076
0.07937588542699814
0.09207557141780853
0.08093785494565964
0.08611945807933807
0.09239598363637924
0.0877300351858139
0.06865394115447998
0.08277469128370285
0.08663459867238998
0.08925193548202515
0.08823592215776443
0.08161928504705429
0.08451317250728607
0.080644890666008
0.09183598309755325
0.09225058555603027
0.0951344221830368
0.08700670301914215
0.08770539611577988
0.08389990031719208
0.0863223671913147
0.0883345976471901
0.07750582695007324
0.08303310722112656
0.09172053635120392
0.0851646214723587
0.08881989866495132
0.0941840410232544
0.08158032596111298
0.09009160101413727
0.08472682535

0.07039080560207367
0.05649804323911667
0.05219923332333565
0.04913550615310669
0.05891081690788269
0.06446491181850433
0.053518205881118774
0.039981089532375336
0.03846997022628784
0.060955893248319626
0.039018552750349045
0.06326793879270554
0.056445688009262085
0.04904559627175331
0.059680670499801636
0.04813117906451225
0.055394165217876434
0.05441557615995407
0.05352281406521797
0.04692145064473152
0.06806307286024094
0.06118291616439819
0.05945088714361191
0.05875366926193237
0.058744482696056366
0.05375584214925766
0.06880637258291245
0.05015195161104202
0.05277208238840103
0.06057973951101303
0.060159653425216675
0.06403863430023193
0.05468534305691719
0.047256868332624435
0.06395430862903595
0.04324820637702942
0.06248891353607178
0.04762891307473183
0.05617370083928108
0.05073431134223938
0.07015683501958847
0.05004090815782547
0.04927970468997955
0.06450856477022171
0.05919228121638298
0.048128463327884674
0.04626912623643875
0.06578798592090607
0.05395808815956116
0.0591269

0.03822817653417587
0.03386981785297394
0.05038536339998245
0.05341368168592453
0.054455745965242386
0.032495271414518356
0.0688137412071228
0.05332159996032715
0.06361020356416702
0.05073153227567673
0.038579147309064865
0.04416673257946968
0.045398786664009094
0.04097389057278633
0.05229438096284866
0.053118497133255005
0.04473122954368591
0.05296897888183594
0.05997862294316292
0.05484083294868469
0.05718078464269638
0.04981305077672005
0.04240817949175835
0.04630979895591736
0.03380831331014633
0.06155339628458023
0.06777091324329376
0.06797096878290176
0.041003577411174774
0.06439359486103058
0.06596965342760086
0.04902214929461479
0.06011213734745979
0.04836622253060341
0.054075490683317184
0.042250849306583405
0.0490957610309124
0.06024710088968277
0.061517514288425446
0.062831811606884
0.05648305267095566
0.04603654518723488
0.06623373925685883
0.04364997148513794
0.05199909210205078
0.038251377642154694
0.03685372322797775
0.03149805963039398
0.048222918063402176
0.05150713771

0.06313473731279373
0.04355808347463608
0.042215295135974884
0.043717749416828156
0.05765565484762192
0.06268251687288284
0.0572952926158905
0.05100027844309807
0.041073743253946304
0.06152771785855293
0.05511915683746338
0.03446272388100624
0.04778880253434181
0.05339698866009712
0.05997694656252861
0.05820993706583977
0.06341379880905151
0.05091303586959839
0.04804682359099388
0.0474696159362793
0.055748455226421356
0.05329084396362305
0.04469794034957886
0.05776378512382507
0.04493885859847069
0.06050761789083481
0.050685781985521317
0.049076467752456665
0.05224036052823067
0.0439777597784996
0.03288992866873741
0.05500759929418564
0.04345678165555
0.05988488718867302
0.05799297243356705
0.04315267130732536
0.0533309280872345
0.06862841546535492
0.06476818025112152
0.05709416791796684
0.05872873216867447
0.062210965901613235
0.04881645366549492
0.05253224819898605
0.04764803126454353
0.0471847839653492
0.05634375289082527
0.053544364869594574
0.058374740183353424
0.04967742413282394