In [8]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb
from ff import FF, FFLayer
from data import MNIST
from tqdm import tqdm

In [2]:
# device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cpu"

In [3]:
batch_size_train = 32
batch_size_test = 32

In [4]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

In [5]:
test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [6]:
threshold = 1.5
epochs = 50
model = FF(logging=True)
model.add_layer(FFLayer(nn.Linear(784, 500), optimizer=torch.optim.Adam, epochs=epochs, threshold=threshold, activation=nn.ReLU(), lr=0.01, positive_lr=0.005, negative_lr=0.005, logging=True, name="layer 1"))
model.add_layer(FFLayer(nn.Linear(500, 500), optimizer=torch.optim.Adam, epochs=epochs, threshold=threshold, activation=nn.ReLU(), lr=0.01, positive_lr=0.005, negative_lr=0.005, logging=True, name="layer 2"))
model.add_layer(FFLayer(nn.Linear(500, 500), optimizer=torch.optim.Adam, epochs=epochs, threshold=threshold, activation=nn.ReLU(), lr=0.01, positive_lr=0.005, negative_lr=0.005, logging=True, name="layer 3"))

In [6]:
wandb.init(project="MNIST", entity="ffalgo")
wandb.config = {
  "learning_rate": 0.01,
  "epochs": 50,
  "batch_size": 32,
  "activation": "relu",
  "positive_lr": 0.005,
  "negative_lr": 0.005,
  "threshold": threshold,
  "optimizer": torch.optim.Adam
}

[34m[1mwandb[0m: Currently logged in as: [33mmirceatlx[0m ([33mffalgo[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
model.train()
for i, (x, y) in enumerate(train_loader):
    x_pos, _ = MNIST.overlay_y_on_x(x, y)
    rnd = torch.randperm(x.size(0))
    x_neg, _ = MNIST.overlay_y_on_x(x, y[rnd])
    #print(x_pos.requires_grad, x_neg.requires_grad)

    # Nu ii place deloc daca fac una dupa alta
    #losses = model.forward_positive(x_pos)
    #losses += model.forward_negative(x_neg)
    #print(losses/2)

    losses = model(x_pos, x_neg)
    print(losses)

0.7304187019666036
0.6940457832813264
0.6992494734128316
0.6901481437683105
0.6886713886260987
0.6688858354091645
0.6541979861259462
0.6494408269723256
0.6519698528448741
0.6525717842578888
0.6376169613997141
0.6445633188883463
0.6315015542507171
0.6230955962340038
0.6381931241353352
0.6115521041552227
0.5940692156553268
0.6070854097604751
0.6151680779457093
0.5911633735895157
0.608948582013448
0.6016064776976903
0.608482506275177
0.607167570590973
0.582430297334989
0.6064931412537893
0.6090555266539256
0.567383065422376
0.6119307045141856
0.5919301617145538
0.54632268846035
0.6015938305854798
0.6006661049524943
0.5480481632550558
0.5462324368953705
0.539910926024119
0.5650316305955251
0.5467461574077607
0.5336650808652242
0.5475887296597163
0.4983527217308679
0.5600311789909999
0.5431181307633718
0.5359541871150335
0.5638215911388397
0.5351374884446461
0.5256327968835831
0.5294865922133128
0.539250586827596
0.5284269851446152
0.5314625473817189
0.5427025264501572
0.5306513359149297
0.

0.2707367571194967
0.24032197554906207
0.20558441321055096
0.41084830343723294
0.2767601255575816
0.3579827763636907
0.3338602510094642
0.30746602644522986
0.3468237163623174
0.3030395436286926
0.4345294415950775
0.2727509571115176
0.33774443129698434
0.3599084283908209
0.28832097629706066
0.24152117093404132
0.25169296185175577
0.35533702601989114
0.21566522459189097
0.36833932191133495
0.34035232206185656
0.28866958429416023
0.3826875797907512
0.2769618015487989
0.223435157140096
0.3493169449766477
0.2668446271618207
0.30073517342408496
0.3094137870272
0.2860943063100179
0.3435290169715881
0.32162175516287483
0.33647395312786105
0.3516354580720266
0.26350540687640506
0.2490991481145223
0.2664443862438202
0.29837276607751845
0.28091862410306934
0.2324049100279808
0.35776911526918415
0.3224206591645877
0.34528807163238523
0.31784755369027456
0.21225944360097249
0.18592376112937928
0.34071621189514795
0.3396377943952878
0.3525979870557785
0.35201725582281745
0.282135773897171
0.29767884

0.2580850194891294
0.2781615868210792
0.25364479551712676
0.22049315671126046
0.16120778754353524
0.2374884088834127
0.2894110021988551
0.2814656906326612
0.17114199588696163
0.27777114262183505
0.31233548243840537
0.2420353971918424
0.3521361172199249
0.31446684449911116
0.3331568044424057
0.22094415446122487
0.27555215915044146
0.3138770009080569
0.3435916260878245
0.34092005074024195
0.304676159620285
0.2531906753778458
0.22215097784996032
0.2935682418942452
0.28590924590826033
0.304365802804629
0.2174062226215998
0.26904397885004677
0.1899471561113993
0.27904256065686545
0.27082634131113686
0.20389786779880525
0.25472469806671144
0.29450102160374325
0.23549788842598596
0.2593783202767372
0.21244408736626308
0.17187967191139855
0.2532156522075335
0.3213878484567006
0.2991406704982122
0.36806681215763093
0.29242913524309794
0.31865011960268025
0.31047735472520194
0.2878421090046565
0.2976577123006185
0.2771198117733002
0.35518109023571015
0.17094288468360905
0.19902783185243608
0.291

0.30123846183220543
0.30858389218648274
0.36379007836182914
0.27168508380651474
0.2438894369204839
0.2738784913221995
0.2660268380244573
0.38523237884044653
0.2579285290837288
0.3036803655823072
0.3023427355289459
0.24185718615849813
0.3605796531836192
0.34818686465422316
0.2906393891572952
0.2541265331705411
0.2800509669383367
0.2004938476284345
0.3594508024056753
0.23672378659248353
0.21588822702566782
0.22177280724048617
0.31678694238265354
0.3288858895500501
0.18014864673217137
0.25177544802427293
0.32258524696032204
0.29179624458154046
0.1897989681363106
0.3230688053369522
0.25999516556660335
0.2513187623023987
0.22033847719430924
0.21409197419881823
0.24312581350406012
0.17478157927592597
0.1770545811454455
0.30849453548590344
0.31531111379464466
0.25938349743684136
0.26751023451487227
0.2929982234040896
0.48676972428957627
0.2504561679561933
0.2187459241350492
0.4313229131698608
0.22963972906271615
0.23293085485696793
0.3155960129698117
0.3960075261195501
0.24516681492328643
0.2

0.30495974232753115
0.22934796681006753
0.2615286562840144
0.2378381148974101
0.20783035268386205
0.310898436208566
0.24232255091269814
0.2900091927250226
0.19147446542978286
0.19250524630149204
0.3074006222685178
0.35439479917287825
0.37269321183363596
0.33760950456062955
0.24580484052499138
0.306780454814434
0.26171837230523426
0.23252082129319507
0.30010572959979376
0.34676019191741947
0.28973075379927954
0.25621671438217164
0.19923926581939058
0.2321077111363411
0.4483871940771739
0.317124441464742
0.20095536470413208
0.23390143126249316
0.2886563742160797
0.1958679296573003
0.3189887009064356
0.28371377497911454
0.32083483040332794
0.2804253495732943
0.3748931175470352
0.1971286756793658
0.287148132622242
0.2478410383065542
0.28590396771828336
0.24068826854228975
0.32363226751486457
0.35787485619386034
0.3730248056848844
0.2893875866134961
0.2497946737209956
0.2542868413527806
0.32395124346017834
0.3548137980699539
0.3186525116364161
0.2997459827860196
0.2961047530174255
0.2627882

In [8]:
wandb.finish()

0,1
loss,█▇▆▅▃▃▃▄▃▄▄▄▃▃▃▄▂▄▂▃▃▃▃▄▁▃▃▄▂▃▃▃▂▃▅▃▂▄▃▅

0,1
loss,0.20355


In [None]:
predictions, real = MNIST.predict(test_loader, model)
print("Accuracy ", np.sum(predictions == real)/len(real))

# Training

In [7]:
wandb.init(project="MNIST", entity="ffalgo")
wandb.config = {
  "learning_rate": 0.01,
  "epochs": 50,
  "batch_size": 32,
  "activation": "relu",
  "positive_lr": 0.005,
  "negative_lr": 0.005,
  "threshold": threshold,
  "optimizer": torch.optim.Adam,
  "device": device
}

[34m[1mwandb[0m: Currently logged in as: [33mmirceatlx[0m ([33mffalgo[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
model = model.to(device)
epochs = 2
best_acc = 0.0
for i in tqdm(range(epochs)):
    if i % 10 == 0:
        predictions, real = MNIST.predict(test_loader, model, device)
        acc = np.sum(predictions == real)/len(real)
        wandb.log({"Accuracy on test data": acc})
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'best_mnist.ph')
        
    predictions, real = MNIST.predict(train_loader, model, device)
    acc = np.sum(predictions == real)/len(real)
    wandb.log({"Accuracy on train data": acc})
    model.train()
    for i, (x, y) in enumerate(train_loader):
        x_pos, _ = MNIST.overlay_y_on_x(x, y)
        rnd = torch.randperm(x.size(0))
        x_neg, _ = MNIST.overlay_y_on_x(x, y[rnd])
        x_pos, x_neg = x_pos.to(device), x_neg.to(device)
        losses = model(x_pos, x_neg)
        
wandb.finish()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [12:51<00:00, 385.64s/it]


0,1
Accuracy on test data,▁
Accuracy on train data,▁
loss on layer 1,█▅▄▅▄▅▂▆▅▃▄▄▄▆▃▅▁▇▅▂▃▁▄▆▆▂▂▄▄▄▂▃▃▃▂▄▃▃▁█
loss on layer 2,▇▅▃▃▃▃▁▄▄▂▃▂▃▃▂▄▁▄▃▃▄▃▃▄▃▂▂▃▃▃▃▂▂▃▄▄▄▃▁█
loss on layer 3,█▇▅▅▅▃▁▄▅▂▃▃▃▃▂▄▁▃▄▂▃▂▃▄▂▁▂▄▃▂▂▂▃▂▄▃▃▃▁▅
overall loss,█▆▄▅▄▄▁▅▅▂▃▃▃▄▂▄▁▅▄▂▃▂▄▅▃▁▂▄▃▃▂▂▃▂▄▄▄▃▁▇

0,1
Accuracy on test data,0.1329
Accuracy on train data,0.13075
loss on layer 1,0.28371
loss on layer 2,0.5098
loss on layer 3,0.49409
overall loss,0.4292
