In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ff import FF
from data import MNIST

In [2]:
batch_size_train = 32
batch_size_test = 32

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

In [4]:
test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [5]:
config = {
    "in_dims": [784, 500, 500],
    "out_dims": [500, 500, 500],
    "epochs": 50,
    "threshold": 1.5,
}
num_layers = 3

In [6]:
def overlay_y_on_x(x, y):
    """Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
    """
    x_ = x.clone()
    x_ = x_.view(-1, 784)
    #x_[:, :10] *= 0.0
    x_[range(x.shape[0]), y] = 1
    #print(x.max())
    return x_

In [7]:
model = FF(num_layers, config, torch.optim.Adam)

In [8]:
model.train()
for i, (x, y) in enumerate(train_loader):
    x_pos, _ = MNIST.overlay_y_on_x(x, y)
    rnd = torch.randperm(x.size(0))
    x_neg, _ = MNIST.overlay_y_on_x(x, y[rnd])
    #print(x_pos.requires_grad, x_neg.requires_grad)
    losses = model(x_pos, x_neg)
    print(losses)

0.8921538416544598
0.7131796967983246
0.6958352637290955
0.6943308190504709
0.6943092203140259
0.6937384490172068
0.693321602344513
0.6935725724697113
0.6939375817775727
0.6935251931349437
0.6933620802561441
0.6932956596215566
0.6934369130929312
0.6931616004308064
0.6932443551222484
0.6933431545893352
0.6934152948856354
0.6932981741428375
0.6930201737085978
0.6929927893479665
0.6932642845312754
0.6930877884229023
0.6930809338887532
0.6927498877048492
0.6926173837979634
0.6924555321534475
0.6928468728065491
0.6917631240685781
0.6925203748544057
0.6915167355537415
0.6916525836785633
0.6912747112909953
0.6912840521335601
0.6915015812714896
0.6912551414966583
0.6913789717356363
0.6915867257118226
0.6912017230192821
0.6898128441969554
0.6888645283381144
0.6915825486183166
0.6894165058930715
0.6897816030184427
0.6899740362167358
0.6871826887130736
0.6915605938434601
0.6889102395375569
0.686359928448995
0.6885792366663616
0.6893411632378896
0.6900021366278332
0.687642384370168
0.6853090528647

0.5657136384646099
0.6008066817124685
0.5838808568318685
0.6075978016853333
0.6119965843359629
0.5969610798358918
0.606229636669159
0.6036020942529042
0.5886315087477366
0.5899072229862213
0.6098426342010498
0.6055630457401275
0.5832257906595866
0.5698482251167297
0.5718832759062448
0.5606295971075693
0.597249253988266
0.5867380468050639
0.6310700444380442
0.6024830142656962
0.6214959335327149
0.6028196481863658
0.5680178209145864
0.5761476852496464
0.5864850755532584
0.6190977350870769
0.5742939221858978
0.5904055412610373
0.6012225612004598
0.5892963111400603
0.5881817619005839
0.5760610250631969
0.5757343530654908
0.5982614549001058
0.6205565901597341
0.5780161873499553
0.5868872114022573
0.6070559430122375
0.5762736781438191
0.6053559724489848
0.6050023448467253
0.589564197063446
0.5748450756072998
0.6024035402139027
0.6111250313123068
0.5840914956728617
0.5992016192277272
0.5763300653298695
0.6188488630453746
0.5685594765345255
0.5554799099763236
0.5897221926848094
0.5475947125752

KeyboardInterrupt: 

In [46]:
class Net(nn.Module):
    
    def __init__(self):
        
        super(Net, self).__init__()
        self.l = nn.Linear(784, 10)
        
    def forward(self, x):
        #print(self.l(x))
        return F.relu(self.l(x))

In [47]:
net = Net()

In [48]:
net.train()
criterion = torch.nn.MSELoss()
optim = torch.optim.Adam(net.parameters())
for i, (x, y) in enumerate(train_loader):
    optim.zero_grad()
    out = net(x.view(-1, 784))
    temp = np.zeros((32, 10))
    for i in range(32):
        temp[i][y[i].item()] = 1
    loss = criterion(out, torch.Tensor(temp))
    print(loss.item())
    loss.backward()
    optim.step()

0.30330196022987366
0.13947108387947083
0.0982712134718895
0.10159628093242645
0.09316181391477585
0.09114620834589005
0.09905010461807251
0.09172849357128143
0.09627287089824677
0.094866082072258
0.08204140514135361
0.09555324167013168
0.08723263442516327
0.0862930566072464
0.09281317889690399
0.09035802632570267
0.0991055816411972
0.09251095354557037
0.09062061458826065
0.08190781623125076
0.07937588542699814
0.09207557141780853
0.08093785494565964
0.08611945807933807
0.09239598363637924
0.0877300351858139
0.06865394115447998
0.08277469128370285
0.08663459867238998
0.08925193548202515
0.08823592215776443
0.08161928504705429
0.08451317250728607
0.080644890666008
0.09183598309755325
0.09225058555603027
0.0951344221830368
0.08700670301914215
0.08770539611577988
0.08389990031719208
0.0863223671913147
0.0883345976471901
0.07750582695007324
0.08303310722112656
0.09172053635120392
0.0851646214723587
0.08881989866495132
0.0941840410232544
0.08158032596111298
0.09009160101413727
0.08472682535

0.07039080560207367
0.05649804323911667
0.05219923332333565
0.04913550615310669
0.05891081690788269
0.06446491181850433
0.053518205881118774
0.039981089532375336
0.03846997022628784
0.060955893248319626
0.039018552750349045
0.06326793879270554
0.056445688009262085
0.04904559627175331
0.059680670499801636
0.04813117906451225
0.055394165217876434
0.05441557615995407
0.05352281406521797
0.04692145064473152
0.06806307286024094
0.06118291616439819
0.05945088714361191
0.05875366926193237
0.058744482696056366
0.05375584214925766
0.06880637258291245
0.05015195161104202
0.05277208238840103
0.06057973951101303
0.060159653425216675
0.06403863430023193
0.05468534305691719
0.047256868332624435
0.06395430862903595
0.04324820637702942
0.06248891353607178
0.04762891307473183
0.05617370083928108
0.05073431134223938
0.07015683501958847
0.05004090815782547
0.04927970468997955
0.06450856477022171
0.05919228121638298
0.048128463327884674
0.04626912623643875
0.06578798592090607
0.05395808815956116
0.0591269

0.03822817653417587
0.03386981785297394
0.05038536339998245
0.05341368168592453
0.054455745965242386
0.032495271414518356
0.0688137412071228
0.05332159996032715
0.06361020356416702
0.05073153227567673
0.038579147309064865
0.04416673257946968
0.045398786664009094
0.04097389057278633
0.05229438096284866
0.053118497133255005
0.04473122954368591
0.05296897888183594
0.05997862294316292
0.05484083294868469
0.05718078464269638
0.04981305077672005
0.04240817949175835
0.04630979895591736
0.03380831331014633
0.06155339628458023
0.06777091324329376
0.06797096878290176
0.041003577411174774
0.06439359486103058
0.06596965342760086
0.04902214929461479
0.06011213734745979
0.04836622253060341
0.054075490683317184
0.042250849306583405
0.0490957610309124
0.06024710088968277
0.061517514288425446
0.062831811606884
0.05648305267095566
0.04603654518723488
0.06623373925685883
0.04364997148513794
0.05199909210205078
0.038251377642154694
0.03685372322797775
0.03149805963039398
0.048222918063402176
0.05150713771

0.06313473731279373
0.04355808347463608
0.042215295135974884
0.043717749416828156
0.05765565484762192
0.06268251687288284
0.0572952926158905
0.05100027844309807
0.041073743253946304
0.06152771785855293
0.05511915683746338
0.03446272388100624
0.04778880253434181
0.05339698866009712
0.05997694656252861
0.05820993706583977
0.06341379880905151
0.05091303586959839
0.04804682359099388
0.0474696159362793
0.055748455226421356
0.05329084396362305
0.04469794034957886
0.05776378512382507
0.04493885859847069
0.06050761789083481
0.050685781985521317
0.049076467752456665
0.05224036052823067
0.0439777597784996
0.03288992866873741
0.05500759929418564
0.04345678165555
0.05988488718867302
0.05799297243356705
0.04315267130732536
0.0533309280872345
0.06862841546535492
0.06476818025112152
0.05709416791796684
0.05872873216867447
0.062210965901613235
0.04881645366549492
0.05253224819898605
0.04764803126454353
0.0471847839653492
0.05634375289082527
0.053544364869594574
0.058374740183353424
0.04967742413282394