<a href="https://colab.research.google.com/github/daleas0120/Example_notebooks/blob/main/pytorch_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


How To Invert a Neural Network

---

Curious to see if I can take a neural network, train it, and then run it backwards.


In [1]:
#!pip install torchinfo



In [2]:
#!pip install torcheval



In [1]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
from torchinfo import summary
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torcheval.metrics.functional import multiclass_f1_score

from tqdm import tqdm, trange

In [2]:
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
batch_size_train = 8
batch_size_test = 8

# 0. Load Data


In [41]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=False)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=False)

## Create Custom, Random Dataset

In [4]:
rand_img = torch.rand(1, 28, 28, requires_grad=True)

# 1. Define Model

In [146]:
class CNN_classifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CNN_classifier, self).__init__()
        self.conv0 = nn.Conv2d(1, 4, (3, 3), stride=2)
        self.conv1 = nn.Conv2d(4, 8, (3,3), stride=2)
        self.conv2 = nn.Conv2d(8, 16, (3,3), stride=2)
        self.flat = nn.Flatten(start_dim=1, end_dim=-1)
        self.linear = nn.Linear(64, num_classes)
        self.LeakyReLU = nn.LeakyReLU(0.2)
        self.num_classes = num_classes

    def forward(self, input):
        c0 = F.relu(self.conv0(input))
        c1 = F.relu(self.conv1(c0))
        c2 = F.relu(self.conv2(c1))
        c2_flat = self.flat(c2)
        feat_vec = self.linear(c2_flat)
        logits = F.softmax(feat_vec, dim=1)

        return logits


In [147]:
classifier_model = CNN_classifier(input_dim=(1, 28, 28), num_classes=10).to(DEVICE)

In [149]:
summary(classifier_model, (1, 1, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
CNN_classifier                           [1, 10]                   --
├─Conv2d: 1-1                            [1, 4, 13, 13]            40
├─Conv2d: 1-2                            [1, 8, 6, 6]              296
├─Conv2d: 1-3                            [1, 16, 2, 2]             1,168
├─Flatten: 1-4                           [1, 64]                   --
├─Linear: 1-5                            [1, 10]                   650
Total params: 2,154
Trainable params: 2,154
Non-trainable params: 0
Total mult-adds (M): 0.02
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.01
Estimated Total Size (MB): 0.02

# 2. Prepare for Training

In [13]:
learning_rate = 1e-5
num_epochs = 1
MSE_loss = nn.MSELoss()

In [9]:
from torch.optim import LBFGS, Adam

BCE_loss = nn.BCELoss()
MSE_loss = nn.MSELoss()

def loss_function(x, x_hat, mean, log_var, recon_const=1.):
    #reconstruction_loss = recon_const*nn.functional.binary_cross_entropy(x_hat, x, reduction='sum',)
    reconstruction_loss = nn.functional.mse_loss(x_hat, x, reduction='sum')
    KLD      = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reconstruction_loss + KLD, reconstruction_loss, KLD


optimizer = Adam(classifier_model.parameters(), lr=learning_rate)
#optimizer = optim.LBFGS([rand_img])
loss = nn.MSELoss()

# 3. Training

In [13]:
for epoch in range(num_epochs):

    for batch_idx, (x, Y) in enumerate(tqdm(train_loader, desc="Training Epoch "+str(epoch))):

        x = x.to(DEVICE)

        one_hot = F.one_hot(Y, num_classes=10)

        label = torch.tensor(one_hot, dtype=torch.float).to(DEVICE)

        optimizer.zero_grad()

        logits = classifier_model(x)

        loss = MSE_loss(label, logits)

        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Loss: {:.6f}'.format(epoch+1, num_epochs, loss.item()))


  label = torch.tensor(one_hot, dtype=torch.float).to(DEVICE)
  logits = F.softmax(feat_vec)
Training Epoch 0: 100%|██████████| 7500/7500 [00:42<00:00, 178.37it/s]

Epoch [1/1], Loss: 0.089502





# 4. Testing

In [67]:
labels_true = []
labels_pred = []

for batch_idx, (x, Y) in enumerate(tqdm(test_loader, desc="Testing ")):

        x = x.to(DEVICE)

        one_hot = F.one_hot(Y, num_classes=10)

        #label = torch.tensor(Y, dtype=torch.float).to(DEVICE)
        labels_true.extend(Y.detach().numpy())

        logits = classifier_model(x)
        labels_pred.extend(torch.argmax(logits, axis=1).detach().numpy())
        break



  logits = F.softmax(feat_vec)
Testing :   0%|          | 0/1250 [00:00<?, ?it/s]


In [9]:
multiclass_f1_score(torch.tensor(labels_true), torch.tensor(labels_pred), num_classes=10, average=None)



tensor([0.0000, 0.2063, 0.0686, 0.0069, 0.0000, 0.0000, 0.0000, 0.0000, 0.0152,
        0.0000])

# 5. Loss Landscape

# 6. Hessian

In [10]:
from torch.func import functional_call, vmap, hessian

torch.autograd.functional.hessian(error_func, set of values)


In [110]:
my_params = dict(classifier_model.named_parameters())

In [151]:
net = classifier_model

batch_size=8

targets = torch.randn(batch_size)
inputs = torch.randn(batch_size, 1)
params = dict(net.named_parameters())

def fcall(params, inputs):
  outputs = functional_call(net, params, inputs)
  return outputs

def loss_fn(outputs, targets):
  return torch.mean((outputs - targets)**2, dim=0)

def compute_loss(params, inputs, targets):
  #outputs = vmap(fcall, in_dims=(None,0))(params, inputs) #vectorize over batch
  outputs = net(inputs)
  return loss_fn(outputs, targets)

def compute_hessian_loss(params, inputs, targets):
  return hessian(compute_loss, argnums=(0))(params, inputs, targets)

loss = compute_loss(my_params, x, one_hot)
print(loss)

hess = compute_hessian_loss(my_params, x, one_hot)
key=list(params.keys())[0] #take weight in first layer as example key
print(hess[key][key].shape) #Hessian of loss w.r.t first weight (shape [16, 1, 16, 1])

tensor([0.1101, 0.2043, 0.1092, 0.0088, 0.2086, 0.0090, 0.0097, 0.1095, 0.0074,
        0.1112], grad_fn=<MeanBackward1>)
torch.Size([10, 4, 1, 3, 3, 4, 1, 3, 3])


In [141]:
one_hot[0].shape

torch.Size([10])

In [137]:
import torch
from torch import nn
from torch.func import functional_call, vmap, hessian

class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.fc1=nn.Linear(1,16)
    self.fc2=nn.Linear(16,1)
    self.af=nn.Tanh()
  def forward(self, x):
    x=self.fc1(x)
    x=self.af(x)
    x=self.fc2(x)
    return x.squeeze(-1)

net = Model()

batch_size=1

targets = torch.randn(batch_size)
inputs = torch.randn(batch_size, 1)
params = dict(net.named_parameters())

def fcall(params, inputs):
  outputs = functional_call(net, params, inputs)
  return outputs

def loss_fn(outputs, targets):
  return torch.mean((outputs - targets)**2, dim=0)

def compute_loss(params, inputs, targets):
  outputs = vmap(fcall, in_dims=(None,0))(params, inputs) #vectorize over batch
  return loss_fn(outputs, targets)

def compute_hessian_loss(params, inputs, targets):
  return hessian(compute_loss, argnums=(0))(params, inputs, targets)

loss = compute_loss(params, inputs, targets)
print(loss)

hess = compute_hessian_loss(params, inputs, targets)
key=list(params.keys())[0] #take weight in first layer as example key
print(hess[key][key].shape) #Hessian of loss w.r.t first weight (shape [16, 1, 16, 1])

tensor(1.9320, grad_fn=<MeanBackward1>)
torch.Size([16, 1, 16, 1])
