<a href="https://colab.research.google.com/github/davidraamirez/GradientWithoutBackpropagation/blob/main/CNN_fwd_gradient.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Gradient Without Backpropagation

In [22]:
import torch
import tensorflow_datasets as tfds
from sklearn.model_selection import train_test_split
import tqdm
import torch.distributions as distr

In [23]:
%pip install torchmetrics --quiet

In [24]:
import torchmetrics
import torchvision
from torchvision import transforms as T
from matplotlib import pyplot as plt

Loading and preprocessing the dataset

In [25]:
#Load the dataset
train_data = torchvision.datasets.KMNIST('./data', train=True, download=True)

In [26]:
# This loads data with both data conversion.
train_data = torchvision.datasets.KMNIST('./data', train=True, transform=T.ToTensor())

In [27]:
# Loaders are used to shuffle, batch, and possibly sample the elements of the dataset
train_loader = torch.utils.data.DataLoader(train_data, shuffle=True)

In [90]:
Xtrain = torch.randn(500, 1, 1, 28, 28)
ytrain = torch.randn(500, 1)
for i in range(500):
  xb, yb = next(iter(train_loader))
  Xtrain[i] = xb
  ytrain[i] = yb

print(Xtrain.shape)
print(ytrain.shape)

torch.Size([500, 1, 1, 28, 28])
torch.Size([500, 1])


In [82]:
# Loading the test data is similar, but (a) we do not apply data augmentation,
# and (b) we do not shuffle when building the mini-batches.
test_data = torchvision.datasets.KMNIST('./data', train=False, transform=T.ToTensor())
test_loader = torch.utils.data.DataLoader(test_data, shuffle=False)

In [83]:
Xtest = torch.randn(50, 1, 1, 28, 28)
ytest = torch.randn(50, 1)
for i in range(50):
  xb, yb = next(iter(test_loader))
  Xtest[i] = xb
  ytest[i] = yb

print(Xtest.shape)
print(ytest.shape)

torch.Size([50, 1, 1, 28, 28])
torch.Size([50, 1])


Define Convolutional Neural Network Class

In [43]:
from torch import nn
from torch.nn import functional as F

In [84]:
class SimpleCNN(nn.Module):
    def __init__(self, input_size, conv1w, conv1b, conv2w, conv2b, fc1w, fc1b, fc2w, fc2b,):
        super().__init__()
        self.conv1 = nn.Conv2d(input_size, 2, 3, padding=1)
        self.conv1.weight = torch.nn.Parameter(conv1w)
        self.conv1.bias = torch.nn.Parameter(conv1b)

        self.conv2 = nn.Conv2d(2, 4, 3, padding=1)
        self.conv2.weight = torch.nn.Parameter(conv2w)
        self.conv2.bias = torch.nn.Parameter(conv2b)

        self.max_pool = nn.MaxPool2d(2)

        self.w1 = torch.nn.Parameter(fc1w)
        self.b1 = torch.nn.Parameter(fc1b)

        self.w2 = torch.nn.Parameter(fc2w)
        self.b2 = torch.nn.Parameter(fc2b)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.max_pool(x)
        x = x.reshape((-1, 4*14*14))
        x = F.relu(x@self.w1 + self.b1)
        x = x@self.w2 + self.b2
        return torch.softmax(x,1)

In [85]:
# We check if CUDA is available. If you do not see it,
# activate a GPU from Runtime >> Change runtime type and 
# restart the notebook.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


Initialize the parameters

In [86]:
# We initialize the parameters randomly and the model with an input size
conv1w = torch.FloatTensor(2, 1, 3, 3).uniform_(-1, 1)
conv1b = torch.FloatTensor(2).uniform_(-1, 1)
conv2w = torch.FloatTensor(4, 2, 3, 3).uniform_(-1, 1)
conv2b = torch.FloatTensor(4).uniform_(-1, 1)

fc1w = torch.FloatTensor(4*14*14,8).uniform_(-1,1)
fc1b = torch.FloatTensor(8).uniform_(-1, 1)
fc2w = torch.FloatTensor(8,10).uniform_(-1, 1)
fc2b = torch.FloatTensor(10).uniform_(-1, 1)
cnn = SimpleCNN(1, conv1w, conv1b, conv2w, conv2b, fc1w, fc1b, fc2w, fc2b).to(device)

In [87]:
# Note: we also need to move data when asking for a prediction
print(cnn(xb.to(device)).shape)

torch.Size([1, 10])


In [88]:
def cross_entropycw1(conv1w, conv1b, conv2w, conv2b, fc1w, fc1b, fc2w, fc2b, ytrue, x):
  """ Cross-entropy loss.
  Inputs:
  - ytrue (n,): vector of indices for the correct class.
  - ypred (n, 3): predictions of the model.
  Returns the average cross-entropy.
  """
  # This is called integer array indexing in NumPy:
  # https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
  ypred=torch.randn((ytrue.shape[0],10))
  for j in range (ytrue.shape[0]):
    xj =x[j]
    conv1 = nn.Conv2d(1, 2, 3, padding=1)
    conv1.weight = torch.nn.Parameter(conv1w)
    conv1.bias = torch.nn.Parameter(conv1b)
    conv2 = nn.Conv2d(2, 4, 3, padding=1)
    conv2.weight = torch.nn.Parameter(conv2w)
    conv2.bias = torch.nn.Parameter(conv2b)

    xj = F.relu(conv1(xj))
    xj = F.relu(conv2(xj))
    xj = nn.MaxPool2d(2)(xj)
    xj = xj.reshape((-1, 4*14*14))
    xj = F.relu(xj@fc1w + fc1b)
    xj = xj@fc2w + fc2b
    ypred[j]=torch.softmax(xj,1)
  print(ypred.shape)
  print(ytrue.shape)
  return - ypred[torch.arange(0, ypred.shape[0]), ytrue].log().mean()


In [49]:
def cross_entropycb1(conv1b, conv1w, conv2w, conv2b, fc1w, fc1b, fc2w, fc2b, ytrue, x):
  """ Cross-entropy loss.
  Inputs:
  - ytrue (n,): vector of indices for the correct class.
  - ypred (n, 3): predictions of the model.
  Returns the average cross-entropy.
  """
  # This is called integer array indexing in NumPy:
  # https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
  ypred=torch.randn((ytrue.shape[0],10))
  for j in range (ytrue.shape[0]):
    xj =x[j]
    conv1 = nn.Conv2d(1, 2, 3, padding=1)
    conv1.weight = torch.nn.Parameter(conv1w)
    conv1.bias = torch.nn.Parameter(conv1b)
    conv2 = nn.Conv2d(2, 4, 3, padding=1)
    conv2.weight = torch.nn.Parameter(conv2w)
    conv2.bias = torch.nn.Parameter(conv2b)

    xj = F.relu(conv1(xj))
    xj = F.relu(conv2(xj))
    xj = nn.MaxPool2d(2)(xj)
    xj = xj.reshape((-1, 4*14*14))
    xj = F.relu(xj@fc1w + fc1b)
    xj = xj@fc2w + fc2b
    ypred[j]=torch.softmax(xj,1)
  return - ypred[torch.arange(0, ypred.shape[0]), ytrue].log().mean()

In [50]:
def cross_entropycw2(conv2w, conv1b, conv1w, conv2b, fc1w, fc1b, fc2w, fc2b, ytrue, x):
  """ Cross-entropy loss.
  Inputs:
  - ytrue (n,): vector of indices for the correct class.
  - ypred (n, 3): predictions of the model.
  Returns the average cross-entropy.
  """
  # This is called integer array indexing in NumPy:
  # https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
  ypred=torch.randn((ytrue.shape[0],10))
  for j in range (ytrue.shape[0]):
    xj =x[j]
    conv1 = nn.Conv2d(1, 2, 3, padding=1)
    conv1.weight = torch.nn.Parameter(conv1w)
    conv1.bias = torch.nn.Parameter(conv1b)
    conv2 = nn.Conv2d(2, 4, 3, padding=1)
    conv2.weight = torch.nn.Parameter(conv2w)
    conv2.bias = torch.nn.Parameter(conv2b)

    xj = F.relu(conv1(xj))
    xj = F.relu(conv2(xj))
    xj = nn.MaxPool2d(2)(xj)
    xj = xj.reshape((-1, 4*14*14))
    xj = F.relu(xj@fc1w + fc1b)
    xj = xj@fc2w + fc2b
    ypred[j]=torch.softmax(xj,1)
  return - ypred[torch.arange(0, ypred.shape[0]), ytrue].log().mean()

In [51]:
def cross_entropycb2(conv2b, conv1b, conv2w, conv1w, fc1w, fc1b, fc2w, fc2b, ytrue, x):
  """ Cross-entropy loss.
  Inputs:
  - ytrue (n,): vector of indices for the correct class.
  - ypred (n, 3): predictions of the model.
  Returns the average cross-entropy.
  """
  # This is called integer array indexing in NumPy:
  # https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
  ypred=torch.randn((ytrue.shape[0],10))
  for j in range (ytrue.shape[0]):
    xj =x[j]
    conv1 = nn.Conv2d(1, 2, 3, padding=1)
    conv1.weight = torch.nn.Parameter(conv1w)
    conv1.bias = torch.nn.Parameter(conv1b)
    conv2 = nn.Conv2d(2, 4, 3, padding=1)
    conv2.weight = torch.nn.Parameter(conv2w)
    conv2.bias = torch.nn.Parameter(conv2b)

    xj = F.relu(conv1(xj))
    xj = F.relu(conv2(xj))
    xj = nn.MaxPool2d(2)(xj)
    xj = xj.reshape((-1, 4*14*14))
    xj = F.relu(xj@fc1w + fc1b)
    xj = xj@fc2w + fc2b
    ypred[j]=torch.softmax(xj,1)
  return - ypred[torch.arange(0, ypred.shape[0]), ytrue].log().mean()

In [52]:
def cross_entropyfcw1(fc1w, conv1b, conv2w, conv2b, conv1w, fc1b, fc2w, fc2b, ytrue, x):
  """ Cross-entropy loss.
  Inputs:
  - ytrue (n,): vector of indices for the correct class.
  - ypred (n, 3): predictions of the model.
  Returns the average cross-entropy.
  """
  # This is called integer array indexing in NumPy:
  # https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
  ypred=torch.randn((ytrue.shape[0],10))
  for j in range (ytrue.shape[0]):
    xj =x[j]
    conv1 = nn.Conv2d(1, 2, 3, padding=1)
    conv1.weight = torch.nn.Parameter(conv1w)
    conv1.bias = torch.nn.Parameter(conv1b)
    conv2 = nn.Conv2d(2, 4, 3, padding=1)
    conv2.weight = torch.nn.Parameter(conv2w)
    conv2.bias = torch.nn.Parameter(conv2b)

    xj = F.relu(conv1(xj))
    xj = F.relu(conv2(xj))
    xj = nn.MaxPool2d(2)(xj)
    xj = xj.reshape((-1, 4*14*14))
    xj = F.relu(xj@fc1w + fc1b)
    xj = xj@fc2w + fc2b
    ypred[j]=torch.softmax(xj,1)
  return - ypred[torch.arange(0, ypred.shape[0]), ytrue].log().mean()

In [53]:
def cross_entropyfcb1(fc1b, conv1b, conv2w, conv2b, fc1w, conv1w, fc2w, fc2b, ytrue, x):
  """ Cross-entropy loss.
  Inputs:
  - ytrue (n,): vector of indices for the correct class.
  - ypred (n, 3): predictions of the model.
  Returns the average cross-entropy.
  """
  # This is called integer array indexing in NumPy:
  # https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
  ypred=torch.randn((ytrue.shape[0],10))
  for j in range (ytrue.shape[0]):
    xj =x[j]
    conv1 = nn.Conv2d(1, 2, 3, padding=1)
    conv1.weight = torch.nn.Parameter(conv1w)
    conv1.bias = torch.nn.Parameter(conv1b)
    conv2 = nn.Conv2d(2, 4, 3, padding=1)
    conv2.weight = torch.nn.Parameter(conv2w)
    conv2.bias = torch.nn.Parameter(conv2b)

    xj = F.relu(conv1(xj))
    xj = F.relu(conv2(xj))
    xj = nn.MaxPool2d(2)(xj)
    xj = xj.reshape((-1, 4*14*14))
    xj = F.relu(xj@fc1w + fc1b)
    xj = xj@fc2w + fc2b
    ypred[j]=torch.softmax(xj,1)
  return - ypred[torch.arange(0, ypred.shape[0]), ytrue].log().mean()

In [54]:
def cross_entropyfcw2(fc2w, conv1b, conv2w, conv2b, fc1w, fc1b, conv1w, fc2b, ytrue, x):
  """ Cross-entropy loss.
  Inputs:
  - ytrue (n,): vector of indices for the correct class.
  - ypred (n, 3): predictions of the model.
  Returns the average cross-entropy.
  """
  # This is called integer array indexing in NumPy:
  # https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
  ypred=torch.randn((ytrue.shape[0],10))
  for j in range (ytrue.shape[0]):
    xj =x[j]
    conv1 = nn.Conv2d(1, 2, 3, padding=1)
    conv1.weight = torch.nn.Parameter(conv1w)
    conv1.bias = torch.nn.Parameter(conv1b)
    conv2 = nn.Conv2d(2, 4, 3, padding=1)
    conv2.weight = torch.nn.Parameter(conv2w)
    conv2.bias = torch.nn.Parameter(conv2b)

    xj = F.relu(conv1(xj))
    xj = F.relu(conv2(xj))
    xj = nn.MaxPool2d(2)(xj)
    xj = xj.reshape((-1, 4*14*14))
    xj = F.relu(xj@fc1w + fc1b)
    xj = xj@fc2w + fc2b
    ypred[j]=torch.softmax(xj,1)
  return - ypred[torch.arange(0, ypred.shape[0]), ytrue].log().mean()

In [55]:
def cross_entropyfcb2(fc2b, conv1b, conv2w, conv2b, fc1w, fc1b, fc2w, conv1w, ytrue, x):
  """ Cross-entropy loss.
  Inputs:
  - ytrue (n,): vector of indices for the correct class.
  - ypred (n, 3): predictions of the model.
  Returns the average cross-entropy.
  """
  # This is called integer array indexing in NumPy:
  # https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
  ypred=torch.randn((ytrue.shape[0],10))
  for j in range (ytrue.shape[0]):
    xj =x[j]
    conv1 = nn.Conv2d(1, 2, 3, padding=1)
    conv1.weight = torch.nn.Parameter(conv1w)
    conv1.bias = torch.nn.Parameter(conv1b)
    conv2 = nn.Conv2d(2, 4, 3, padding=1)
    conv2.weight = torch.nn.Parameter(conv2w)
    conv2.bias = torch.nn.Parameter(conv2b)

    xj = F.relu(conv1(xj))
    xj = F.relu(conv2(xj))
    xj = nn.MaxPool2d(2)(xj)
    xj = xj.reshape((-1, 4*14*14))
    xj = F.relu(xj@fc1w + fc1b)
    xj = xj@fc2w + fc2b
    ypred[j]=torch.softmax(xj,1)
  return - ypred[torch.arange(0, ypred.shape[0]), ytrue].log().mean()

In [89]:
print(cross_entropycw1(conv1w, conv1b, conv2w, conv2b, fc1w, fc1b, fc2w, fc2b,ytrain, Xtrain))
print(cross_entropycb1(conv1b, conv1w, conv2w, conv2b, fc1w, fc1b, fc2w, fc2b,yb,xb))
print(cross_entropycw2(conv2w, conv1b, conv1w, conv2b, fc1w, fc1b, fc2w, fc2b,yb,xb))
print(cross_entropycb2(conv2b, conv1b, conv2w, conv1w, fc1w, fc1b, fc2w, fc2b,yb,xb))
print(cross_entropyfcw1(fc1w, conv1b, conv2w, conv2b, conv1w, fc1b, fc2w, fc2b,yb,xb))
print(cross_entropyfcb1(fc1b, conv1b, conv2w, conv2b, fc1w, conv1w, fc2w, fc2b,yb,xb))
print(cross_entropyfcw2(fc2w, conv1b, conv2w, conv2b, fc1w, fc1b, conv1w, fc2b,yb,xb))
print(cross_entropyfcb2(fc2b, conv1b, conv2w, conv2b, fc1w, fc1b, fc2w, conv1w,yb,xb))

torch.Size([500, 10])
torch.Size([500, 1])


IndexError: ignored

Train and evaluate the network with forward gradient

In [141]:
def accuracy(ytrue, ypred):
  return (ypred.argmax(1) == ytrue).float().mean()

In [140]:
# Average accuracy at initialization is 10% (random guessing).
print(yb)
print(cnn(xb.to(device)))
accuracy(yb.to(device), cnn(xb.to(device)))

tensor([0])
tensor([[4.9879e-01, 8.5248e-07, 9.2412e-06, 2.2820e-08, 3.3297e-05, 4.4109e-03,
         4.8422e-01, 1.2538e-02, 5.7235e-07, 1.4758e-06]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)


tensor(1., device='cuda:0')

DEFINE CROSS_ENTROPY

In [None]:
# Note: it is important to move the CNN to the device before initializing the optimizer,
# since the optimizer also has a state that must be moved to the GPU.
loss = nn.CrossEntropyLoss()


In [None]:
def beale_function(x):
  return (torch.pow(torch.tensor([1.5])-x[0]+x[0]*x[1],2) + torch.pow(torch.tensor([2.25])-x[0]+x[0]*torch.pow(x[1],2),2)+torch.pow(torch.tensor([2.625])-x[0]+x[0]*torch.pow(x[1],3),2))

In [None]:
def rosenbrock_function(x):
  sum=0
  for p in x.size():
    for i in range (x.size(1)-1):
      sum += (100*torch.pow(x[i+1] - torch.pow(x[i], 2), 2) + torch.pow(x[i]-1, 2))
  return sum

In [None]:
from functorch import jvp

In [None]:
def train_fwd_gradient(x, y):
  x, y = x.to(device), y.to(device)
  x = x / 255

  l_rate0 = 0.025
  f = rosenbrock_function

  #Parameters
  conv1w = torch.randn((8, 1, 3, 3), requires_grad=False)
  conv1b = torch.randn(8, requires_grad=False)
  conv2w = torch.randn((16, 8, 3, 3), requires_grad=False)
  conv2b = torch.randn(16, requires_grad=False)
  conv3w = torch.randn((32, 16, 3, 3), requires_grad=False)
  conv3b = torch.randn(32, requires_grad=False)
  conv4w = torch.randn((64, 32, 3, 3), requires_grad=False)
  conv4b = torch.randn(64, requires_grad=False)
  fc1w = torch.randn((1024, 3136), requires_grad=False)
  fc1b = torch.randn(1024, requires_grad=False)
  fc2w = torch.randn((1, 1024), requires_grad=False)
  fc2b = torch.randn(1, requires_grad=False) 

  conv1w1 = conv1w.reshape(-1)
  conv2w1 = conv2w.reshape(-1)
  conv3w1 = conv3w.reshape(-1)
  conv4w1 = conv4w.reshape(-1)
  fc1w1 = fc1w.reshape(-1)
  fc2w1 = fc2w.reshape(-1)

  cnn = SimpleCNN(1, conv1w, conv1b, conv2w, conv2b, conv3w, conv3b, conv4w, conv4b, fc1w, fc1b, fc2w, fc2b).to(device)
  error = torch.norm(cnn(x)-y, 2)

  t=torch.tensor([0])

  while (error>1e-3) :

    t=t+1

    vconv1w1=torch.diagonal(torch.normal(torch.zeros_like(conv1w1),torch.eye(conv1w1.shape[0])))
    vconv1b=torch.diagonal(torch.normal(torch.zeros_like(conv1b),torch.eye(conv1b.shape[0])))
    vconv2w1=torch.diagonal(torch.normal(torch.zeros_like(conv2w1),torch.eye(conv2w1.shape[0])))
    vconv2b=torch.diagonal(torch.normal(torch.zeros_like(conv2b),torch.eye(conv2b.shape[0])))
    vconv3w1=torch.diagonal(torch.normal(torch.zeros_like(conv3w1),torch.eye(conv3w1.shape[0])))
    vconv3b=torch.diagonal(torch.normal(torch.zeros_like(conv3b),torch.eye(conv3b.shape[0])))
    vconv4w1=torch.diagonal(torch.normal(torch.zeros_like(conv4w1),torch.eye(conv4w1.shape[0])))
    vconv4b=torch.diagonal(torch.normal(torch.zeros_like(conv4b),torch.eye(conv4b.shape[0])))
    vfc1w1=torch.diagonal(torch.normal(torch.zeros_like(fc1w1),torch.eye(fc1w1.shape[0])))
    vfc1b=torch.diagonal(torch.normal(torch.zeros_like(fc1b),torch.eye(fc1b.shape[0])))
    vfc2w1=torch.diagonal(torch.normal(torch.zeros_like(fc2w1),torch.eye(fc2w1.shape[0])))
    vfc2b=torch.diagonal(torch.normal(torch.zeros_like(fc2b),torch.eye(fc2b.shape[0])))

    ftconv1w1=f(conv1w1)
    ftconv1b=f(conv1b)
    ftconv2w1=f(conv2w1)
    ftconv2b=f(conv2b)
    ftconv3w1=f(conv3w1)
    ftconv3b=f(conv3b)
    ftconv4w1=f(conv4w1)
    ftconv4b=f(conv4b)
    ftfc1w1=f(fc1w1)
    ftfc1b=f(fc1b)
    ftfc2w1=f(fc2w1)
    ftfc2b=f(fc2b)

    dtconv1w1=torch.tensor(jvp(f,(conv1w1, ), (vconv1w1, ))[1])
    dtconv1b=torch.tensor(jvp(f,(conv1b, ), (vconv1b, ))[1])
    dtconv2w1=torch.tensor(jvp(f,(conv2w1, ), (vconv2w1, ))[1])
    dtconv2b=torch.tensor(jvp(f,(conv2b, ), (vconv2b, ))[1])
    dtconv3w1=torch.tensor(jvp(f,(conv3w1, ), (vconv3w1, ))[1])
    dtconv3b=torch.tensor(jvp(f,(conv3b, ), (vconv3b, ))[1])
    dtconv4w1=torch.tensor(jvp(f,(conv4w1, ), (vconv4w1, ))[1])
    dtconv4b=torch.tensor(jvp(f,(conv4b, ), (vconv4b, ))[1])
    dtfc1w1=torch.tensor(jvp(f,(fc1w1, ), (vfc1w1, ))[1])
    dtfc1b=torch.tensor(jvp(f,(fc1b, ), (vfc1b, ))[1])
    dtfc2w1=torch.tensor(jvp(f,(fc2w1, ), (vfc2w1, ))[1])
    dtfc2b=torch.tensor(jvp(f,(fc2b, ), (vfc2b, ))[1])

    gtconv1w1 = vconv1w1*dtconv1w1
    gtconv1b = vconv1b*dtconv1b
    gtconv2w1 = vconv2w1*dtconv2w1
    gtconv2b = vconv2b*dtconv2b
    gtconv3w1 = vconv3w1*dtconv3w1
    gtconv3b = vconv3b*dtconv3b
    gtconv4w1 = vconv4w1*dtconv4w1
    gtconv4b = vconv4b*dtconv4b
    gtfc1w1 = vfc1w1*dtfc1w1
    gtfc1b = vfc1b*dtfc1b
    gtfc2w1 = vfc2w1*dtfc2w1
    gtfc2b = vfc2b*dtfc2b

    conv1w1 -= l_rate0*gtconv1w1
    conv1b -= l_rate0*gtconv1b
    conv2w1 -= l_rate0*gtconv2w1
    conv2b -= l_rate0*gtconv2b
    conv3w1 -= l_rate0*gtconv3w1
    conv3b -= l_rate0*gtconv3w1
    conv4w1 -= l_rate0*gtconv4w1
    conv4b -= l_rate0*gtconv4b
    fc1w1 -= l_rate0*gtfc1w1
    fc1b -= l_rate0*gtfc1b
    fc2w1 -= l_rate0*gtfc2w1
    fc2b -= l_rate0*gtfc2b

    conv1w = conv1w1.reshape(-1, 1, 3, 3)
    conv2w = conv2w1.reshape(-1, 8, 3, 3)
    conv3w = conv3w1.reshape(-1, 16, 3, 3)
    conv4w = conv4w1.reshape(-1, 32, 3, 3)
    fc1w = fc1w1.reshape(-1, 3136)
    fc2w = fc2w1.reshape(-1, 1024)

    cnn = SimpleCNN(1, conv1w, conv1b, conv2w, conv2b, conv3w, conv3b, conv4w, conv4b, fc1w, fc1b, fc2w, fc2b).to(device)
    error = torch.norm(cnn(x)-y, 2)

  return conv1w, conv1b, conv2w, conv2b, conv3w, conv3b, conv4w, conv4b, fc1w, fc1b, fc2w, fc2b 

In [None]:
for epoch in range(1):

  cnn.train()
  for i in range(1):
    xb, yb = next(iter(train_loader))
    xb = xb.to(device)
    yb = yb.to(device)

    conv1w, conv1b, conv2w, conv2b, conv3w, conv3b, conv4w, conv4b, fc1w, fc1b, fc2w, fc2b = train_fwd_gradient(xb, yb)
    cnn = SimpleCNN(1, conv1w, conv1b, conv2w, conv2b, conv3w, conv3b, conv4w, conv4b, fc1w, fc1b, fc2w, fc2b).to(device)

    #Update cnn parameters
    #Recalculate ypred and loss
    #MIRAR NN_LAB_LOGISITC_REGRESSION
    #CALCULAR G(THETA) QUE ES EL GRADIENTE Y APLICARLO A LOS PARAMETROS DEL CNN, LOS WEIGHTS