In [1]:
import numpy as np
import pickle
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import TensorDataset, DataLoader
from torch.distributions.normal import Normal
import torch.nn.functional as F
import gc
import imutils
import math

import data
import models
from models import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda:0"
LATENT_DIM = 64
NUM_GENERATORS = 1

In [3]:
# We will work with the distilled data for the sake of simplicity of our work, it turns out that the distilled dataset contains only 0s and 1s

In [4]:
X_train, X_test, Y_train, Y_test = data.get_dataset_distilled()

train_dataset = TensorDataset(X_train,Y_train)
train_dataloader = DataLoader(train_dataset,
                              batch_size=512,
                              shuffle=True,
                              drop_last = True,
                              num_workers=8,
                              pin_memory = False)

test_dataset = TensorDataset(X_test,Y_test)
test_dataloader = DataLoader(test_dataset,
                              batch_size=512,
                              shuffle=True,
                              drop_last = True,
                              num_workers=8,
                              pin_memory = False)

model_VAE = torch.load("../symmetry_2/VAE.pt")

In [5]:
model_VAE.eval()
train_Z = model_VAE.fc2(model_VAE.fc_mu(model_VAE.encoder(X_train.to(device)))).cpu().detach()
test_Z = model_VAE.fc2(model_VAE.fc_mu(model_VAE.encoder(X_test.to(device)))).cpu().detach()

train_dataset_Z = TensorDataset(train_Z,Y_train)
train_dataloader_Z = DataLoader(train_dataset_Z,
                              batch_size=512,
                              shuffle=True,
                              drop_last = True,
                              num_workers=8,
                              pin_memory = False)

test_dataset_Z = TensorDataset(test_Z,Y_test)
test_dataloader_Z = DataLoader(test_dataset_Z,
                              batch_size=512,
                              shuffle=True,
                              drop_last = True,
                              num_workers=8,
                              pin_memory = False)

In [6]:
model_fe = models.MLP(feature_size=LATENT_DIM).to(device)
model_fd = models.MLP(feature_size=LATENT_DIM).to(device)
# model_fo = models.LatentOracle().to(device)
model_fo = models.LatentDescriminator().to(device) # Since we ony have two elements

model_symmetry = models.GroupLatent(num_features=LATENT_DIM,num_generators=NUM_GENERATORS).to(device)

optimiser_fe = torch.optim.Adam(model_fe.parameters(), lr = 1e-3)
optimiser_fd = torch.optim.Adam(model_fd.parameters(), lr = 1e-3)
optimiser_fo = torch.optim.Adam(model_fo.parameters(), lr = 1e-3)
optimiser_symmetry = torch.optim.Adam(model_symmetry.parameters(), lr = 1e-3)


criterion_mse = nn.MSELoss()
criterion_BCE = nn.BCEWithLogitsLoss()



In [7]:
loss_S_closure = []
loss_S_orth = []
loss_S_collapse = []

loss_space = []
loss_oracle = []
for i in range(300):
    
    loss_S_closure_ = 0
    loss_S_orth_ = 0
    loss_S_collapse_ = 0

    loss_space_ = 0
    loss_oracle_ = 0
    
    for Z,M in tqdm(train_dataloader):
        Z = Z.to(device)
        M = M.to(device)
        
        optimiser_fd.zero_grad()
        optimiser_fe.zero_grad()
        optimiser_fo.zero_grad()
        optimiser_symmetry.zero_grad()

        theta = [(2*torch.rand(Z.shape[0],device = device) - 1) for i in range(NUM_GENERATORS)]  #Sampling

        P = model_fe(Z)
        P_S = model_symmetry(theta = theta, x = P)
        Z_S = model_fd(P_S)
        m = model_fo(Z)
        m_S = model_fo(Z_S)
        Z_P = model_fd(P)

        loss1 = criterion_BCE(m_S,torch.sigmoid(m))
        loss2 = model_symmetry.orthogonal_loss()
        loss3 = model_symmetry.collapse_loss()
        loss5 = criterion_mse(Z,Z_P)
        loss6 = criterion_BCE(m.squeeze(),M)
        
        loss_S = loss1 + loss2 + loss3
        loss_Ae = loss5
        loss_O = loss6
        
        loss_S.backward(retain_graph=True)
        loss_Ae.backward()
        loss_O.backward()
        
        
        
        optimiser_fd.step()
        optimiser_fe.step()
        optimiser_fo.step()
        optimiser_symmetry.step()
    
        

  0%|          | 0/227 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (14336x28 and 64x64)

In [11]:
Z,M  =  next(iter(train_dataloader_Z))
Z = Z.to(device)
M = M.to(device)

optimiser_fd.zero_grad()
optimiser_fe.zero_grad()
optimiser_fo.zero_grad()
optimiser_symmetry.zero_grad()

theta = [(2*torch.rand(Z.shape[0],device = device) - 1) for i in range(NUM_GENERATORS)]  #Sampling

P = model_fe(Z)
P_S = model_symmetry(theta = theta, x = P)
Z_S = model_fd(P_S)
m = model_fo(Z)
m_S = model_fo(Z_S)
Z_P = model_fd(P)

loss1 = criterion_BCE(m_S,torch.sigmoid(m))
loss2 = model_symmetry.orthogonal_loss()
loss3 = model_symmetry.collapse_loss()
loss5 = criterion_mse(Z,Z_P)
loss6 = criterion_BCE(m.squeeze(),M)

loss_S = loss1 + loss2 + loss3
loss_Ae = loss5
loss_O = loss6

loss_S.backward(retain_graph=True)
loss_Ae.backward()
loss_O.backward()

optimiser_fd.step()
optimiser_fe.step()
optimiser_fo.step()
optimiser_symmetry.step()


In [12]:
loss1.to(loss1.device)

tensor(0.7070, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [None]:
Z.shape

In [None]:
torch.sigmoid(m)

In [None]:
model_symmetry.group[0].algebra.device

In [11]:
M.shape

torch.Size([512])