In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from pff.utils.dataset import PreloadedDataset
from pff.nn.models import PFF
from pff.optim import train_pff


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
device

device(type='cuda', index=0)

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.Lambda(lambda x: torch.flatten(x))
])
raw_dataset = dsets.MNIST(root='../Datasets/', train=True, transform=transform, download=False)

# SUBSET_SIZE = 5000
# raw_dataset = torch.utils.data.Subset(raw_dataset, range(SUBSET_SIZE))
VAL_RATIO = 0.2
length = len(raw_dataset) * 0.2
n_val = int(length * VAL_RATIO)
n_train = len(raw_dataset) - n_val
train_dataset, val_dataset = torch.utils.data.random_split(raw_dataset, [n_train, n_val])


train_dset = PreloadedDataset.from_dataset(raw_dataset, None, device)
val_dset = PreloadedDataset.from_dataset(raw_dataset, None, device)

  0%|          | 0/60000 [00:00<?, ?it/s]

                                                        

In [4]:
sizes = [784, 2000, 2000, 2000, 10]
model = PFF(sizes).to(device)
stats = None

In [7]:
# BATCH_SIZE = 2048
BATCH_SIZE = 512
EPOCHS = 80
LR = 0.001
optimiser = torch.optim.AdamW(model.parameters(), lr=LR)

stats = train_pff(
    model,
    train_dset,
    BATCH_SIZE,
    optimiser,
    EPOCHS,
    stats
)

                                                                                                                                                       

In [6]:
from pff.utils.functions import my_relu

In [41]:
G = nn.Linear(20, 100).to(device)
def step_gen_layer(z_above):
    z_above = F.normalize(z_above, dim=1)
    return my_relu(G(z_above))

def step_gen_model(z, z_g):
    errors = []
    z_g_bar = my_relu(z_g)
    
    z_above = z_g_bar
    z_pred = step_gen_layer(z_above)
    errors.append((z_pred - z[-1].detach()).square().sum(dim=1).mean())
    
    z_g.grad = None
    errors[-1].backward(retain_graph=True)
    with torch.no_grad():
        print(z_g.grad)
        z_g = z_g - 0.1 * z_g.grad
    

def infer_and_generate(model, x, y):
    z = model.forward(x)
    z_g = torch.zeros((x.shape[0], model.g_units), requires_grad=True).to(model.device)
    z_g.retain_grad()

    for i in range(5):
        # print(f"step {i}")
        y_hat, z = model.step_rep(x, y, z)
        x_hat, E, z_g = step_gen_model(z, z_g)

dataloader = torch.utils.data.DataLoader(train_dset, batch_size=16, shuffle=True)
X, Y = next(iter(dataloader))
Y = torch.nn.functional.one_hot(Y, 10).to(device)
infer_and_generate(model, X, Y)

tensor([[-2.1343e+08, -2.4996e+09,  2.0125e+10, -1.0635e+10,  1.7235e+10,
         -1.1003e+10,  3.4524e+09,  2.9326e+10, -2.9295e+10, -2.0506e+10,
         -7.0953e+09,  7.0628e+09,  2.0914e+10,  2.7987e+10,  1.6781e+09,
         -3.9245e+09,  1.6851e+10,  7.0713e+09, -2.6026e+10, -1.0132e+10],
        [-4.0795e+09, -1.0186e+09,  5.2653e+09, -6.1755e+09,  2.0358e+10,
         -1.4578e+10,  1.5307e+08,  2.4518e+10, -2.3845e+10, -1.9032e+10,
         -3.3909e+08,  1.2529e+10,  2.0205e+10,  2.6395e+10, -6.6079e+09,
         -4.4533e+09,  8.9984e+09,  1.7496e+10, -1.9924e+10, -2.6854e+09],
        [-6.9687e+09,  1.1715e+09,  1.1888e+10,  2.1008e+09,  1.3598e+10,
         -1.1218e+10,  3.2638e+09,  1.6608e+10, -2.4195e+10, -1.4288e+10,
          4.0473e+08,  5.4200e+09,  1.5732e+10,  2.8169e+10,  4.5415e+09,
         -3.6378e+09,  8.3780e+09,  1.1503e+10, -1.5220e+10, -5.8759e+09],
        [ 3.5456e+08, -2.8794e+09,  1.9781e+10, -1.1755e+10,  1.7099e+10,
         -1.1573e+10,  4.3872e+09, 

TypeError: cannot unpack non-iterable NoneType object

In [13]:

x = torch.rand((1,), requires_grad=True)
x_bar = my_relu(x.detach())
z = x_bar.pow(2)
z.backward()
x.grad

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [8]:
for i, layer in enumerate(model.layers):
    print(f"layer {i}: {layer.W.weight.shape}, {layer.V.weight.shape}, {layer.L.weight.shape}, {layer.G.weight.shape}")

layer 0: torch.Size([200, 784]), torch.Size([200, 500]), torch.Size([200, 200]), torch.Size([200, 500])
layer 1: torch.Size([500, 200]), torch.Size([500, 100]), torch.Size([500, 500]), torch.Size([500, 100])
layer 2: torch.Size([100, 500]), torch.Size([100, 10]), torch.Size([100, 100]), torch.Size([100, 10])
