In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from tqdm.auto import tqdm

import torch
import torch.nn as nn

from torch.utils.data import DataLoader, Dataset
import torchvision

from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam, SGD

import numpy as np

In [3]:
from model import NonLinearICA

# CelebA

In [4]:
class CelebA_dataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        self.u_star = data.attr[torch.randperm(data.attr.shape[0])]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image, (attr, identity)  = self.data[idx]
        image = torch.tensor(np.array(image), dtype=torch.float32).moveaxis(-1, 0)
        u = attr
        u_star = self.u_star[idx]
        return image, u, u_star

In [5]:
train = torchvision.datasets.CelebA("data/CelebA", split="train", target_type=["attr", "identity"])
valid = torchvision.datasets.CelebA("data/CelebA", split="valid", target_type=["attr", "identity"])
test = torchvision.datasets.CelebA("data/CelebA", split="test", target_type=["attr", "identity"])

In [6]:
batch_size=32

In [7]:
train_dataloader = DataLoader(CelebA_dataset(train), batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(CelebA_dataset(valid), batch_size=batch_size)
test_dataloader = DataLoader(CelebA_dataset(test), batch_size=batch_size)

In [8]:
hidden_dimension = 25
model = NonLinearICA(3, hidden_dimension, dropout=0.2, data_type='CelebA')

In [9]:
epochs=35

optimizer = SGD(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()
log_dir = "logs/CelebA/"
writer = SummaryWriter(log_dir=log_dir)

device="cuda:9"

model = model.to(device)

In [10]:
for epoch in tqdm(range(epochs), desc="Training on epoch"):
    model.train()
    for i, batch in enumerate(train_dataloader):
        x, u, u_star = batch
        labels = torch.randint(0, 2, size=(x.shape[0], 1), dtype=x.dtype) # choose random u or u_star -> labels
        u = torch.where(labels.bool(), u, u_star) # get u or u_star depending on label
        
        x = x.to(device)
        u = u.to(device)
        labels = labels.to(device)
        
        output = model(x, u)
        loss = criterion(output, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        pred = (torch.sigmoid(output) > 0.5).float()
        correct = (pred == labels).float().sum()
        writer.add_scalar("Train/loss", loss.cpu().item(), len(train_dataloader)*epoch + i)
        writer.add_scalar("Train/accuracy", correct/output.shape[0], len(train_dataloader)*epoch + i)
    model.eval()
    val_loss = 0.0
    val_correct = 0.0
    with torch.no_grad():
        for batch in test_dataloader:
            x, u, u_star = batch
            labels = torch.randint(0, 2, size=(x.shape[0], 1), dtype=x.dtype) # choose random u or u_star -> labels
            u = torch.where(labels.bool(), u, u_star) # get u or u_star depending on label

            x = x.to(device)
            u = u.to(device)
            labels = labels.to(device)
            
            output = model(x, u)
            loss = criterion(output, labels)
            val_loss += loss.cpu().item()
            
            pred = (torch.sigmoid(output) > 0.5).float()
            val_correct += (pred == labels).float().sum()
    writer.add_scalar("Test/loss", val_loss / len(test_dataloader), epoch)
    writer.add_scalar("Test/accuracy", val_correct / (len(test_dataloader)*batch_size), epoch)

Training on epoch:   0%|          | 0/35 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# save the model
torch.save(model, log_dir + f"/model_n{hidden_dimension}")

# MNIST

## Load data

In [5]:
class MNIST_dataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        self.u_star = data.targets[torch.randperm(len(data.targets))]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image, u  = self.data[idx]
        u_star = self.u_star[idx]
        image = torch.tensor(np.array(image), dtype=torch.float32).unsqueeze(0)
        u = nn.functional.one_hot(torch.tensor(u), len(self.data.classes))
        u_star = nn.functional.one_hot(u_star, len(self.data.classes))
        return image, u, u_star

In [6]:
train = torchvision.datasets.MNIST("data/MNIST", download=True, train=True)
test = torchvision.datasets.MNIST("data/MNIST", download=True, train=False)

In [7]:
batch_size=32
train_dataloader = DataLoader(MNIST_dataset(train), batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(MNIST_dataset(test), batch_size=batch_size)

## Train model

In [7]:
hidden_dimension = 12
model = NonLinearICA(1, hidden_dimension, dropout=0.2, data_type='MNIST')

In [8]:
epochs=25

optimizer = SGD(model.parameters(), lr=0.01)
criterion = nn.BCEWithLogitsLoss()
log_dir = "logs/MNIST/"
writer = SummaryWriter(log_dir=log_dir)

device="cuda:9"

model = model.to(device)

In [9]:
for epoch in tqdm(range(epochs), desc="Training on epoch"):
    model.train()
    for i, batch in enumerate(train_dataloader):
        x, u, u_star = batch
        labels = torch.randint(0, 2, size=(x.shape[0], 1), dtype=x.dtype) # choose random u or u_star -> labels
        u = torch.where(labels.bool(), u, u_star) # get u or u_star depending on label
        
        x = x.to(device)
        u = u.to(device)
        labels = labels.to(device)
        
        output = model(x, u)
        loss = criterion(output, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        pred = (torch.sigmoid(output) > 0.5).float()
        correct = (pred == labels).float().sum()
        writer.add_scalar("Train/loss", loss.cpu().item(), len(train_dataloader)*epoch + i)
        writer.add_scalar("Train/accuracy", correct/output.shape[0], len(train_dataloader)*epoch + i)
    model.eval()
    val_loss = 0.0
    val_correct = 0.0
    with torch.no_grad():
        for batch in test_dataloader:
            x, u, u_star = batch
            labels = torch.randint(0, 2, size=(x.shape[0], 1), dtype=x.dtype) # choose random u or u_star -> labels
            u = torch.where(labels.bool(), u, u_star) # get u or u_star depending on label

            x = x.to(device)
            u = u.to(device)
            labels = labels.to(device)
            
            output = model(x, u)
            loss = criterion(output, labels)
            val_loss += loss.cpu().item()
            
            pred = (torch.sigmoid(output) > 0.5).float()
            val_correct += (pred == labels).float().sum()
    writer.add_scalar("Test/loss", val_loss / len(test_dataloader), epoch)
    writer.add_scalar("Test/accuracy", val_correct / (len(test_dataloader)*batch_size), epoch)

Training on epoch:   0%|          | 0/25 [00:00<?, ?it/s]

In [10]:
# save the model
torch.save(model, log_dir + f"/model_n{hidden_dimension}")

## Investigate model features

In [4]:
log_dir = "logs/MNIST/"
hidden_dim = 12
model = torch.load("logs/MNIST/" + f"model_n{hidden_dim}")

In [8]:
model.forward_h()

NonLinearICA(
  (h): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Dropout(p=0.2, inplace=False)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): Dropout(p=0.2, inplace=False)
    (5): ReLU()
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): Dropout(p=0.2, inplace=False)
    (8): ReLU()
    (9): Conv2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Linear(in_features=1568, out_features=256, bias=True)
    (12): ReLU()
    (13): Dropout(p=0.2, inplace=False)
    (14): Linear(in_features=256, out_features=12, bias=True)
  )
  (psi): ModuleList(
    (0): PsiICA(
      (m): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.25, inplace=False)
        (2): Linear(in_features=11, out_features=128, bias=True)
        (3): ReLU()
        (4): Dropout(p=0.25, inplace=False)
        (5): Linear

$h_i$ represent a single dimension (unmixed source, disentangled feature).

The total model output:

$$r(x, u) = \sum\limits_{i=1}^n \psi_i(h_i(\mathbf{x}), \mathbf{u})$$

The $\mathbf{u}$ should be "one-hot" encoded.

To **train** the model, use binary cross-entropy, check the quality (accuracy, loss) on the validation dataset.

To **compare** the quality with InfoGAN, take correlations of values of $h_i$ with classes. Take the representation on the test dataset, look at the different values of $h_i$ within one class, try to find meaningful explanations.

Compare the correlation between sources ($h_i$, features).

So, we have the same weigths for $h_i$, and n outputs (Simple conv net).

Then, for $\phi_i$ we have $n$ different fully-connected networks, where we have the $\mathbb{u}$ vector dimension + 1.

In [48]:
criterion = nn.BCEWithLogitsLoss()

In [53]:
criterion(model(x, u), torch.randint(0, 2, size=(x.shape[0], 1), dtype=torch.float32))

tensor(0.7090, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)