In [68]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import distributions as dist
from torch.utils.data import DataLoader, TensorDataset
from torch import optim

from torchvision.datasets import MNIST, FashionMNIST
from torchvision import transforms as tr
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [217]:
iris = datasets.load_iris()
X = iris['data']
y = iris['target']
J = (np.eye(X.shape[1])[np.random.choice(X.shape[1], X.shape[0])] ==0).astype(int)

X = X * J # masked elements are zeros
X.shape, J.shape, y.shape, set(y)

((150, 4), (150, 4), (150,), {0, 1, 2})

In [42]:
X_train, X_test, J_train, J_test, y_train, y_test = train_test_split(X, J, y, test_size=0.33, random_state=42)

In [43]:
ds_train = TensorDataset(
    torch.Tensor(X_train), 
    torch.Tensor(J_train),
    torch.Tensor(y_train).long()
)

ds_test = TensorDataset(
    torch.Tensor(X_test), 
    torch.Tensor(J_test),
    torch.Tensor(y_test).long()
)

batch_size=16
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_test = DataLoader(ds_test, batch_size, shuffle=True)

In [214]:
class Reshaper(nn.Module):
    def __init__(self, out_size):
        super().__init__()

        self.out_size = out_size
        
    def forward(self, x):
        return x.view(*self.out_size)

In [280]:
class IrisInpainter(
    nn.Module
):
    def __init__(self, n_mixes: int = 3, in_size: int = 4):
        super().__init__()

        
        self.extractor = nn.Sequential(
            nn.Linear(in_size * 2, 10),
            nn.ReLU(),
            nn.Linear(10, 20),
            nn.ReLU(),
        )
        
        self.a_extractor = nn.Sequential(
            nn.Linear(20, in_size * n_mixes),
            Reshaper((-1, n_mixes, in_size))
        )
        self.m_extractor = nn.Sequential(
            nn.Linear(20, n_mixes * in_size),
            Reshaper((-1, n_mixes, in_size))

        )
        
        self.d_extractor = nn.Sequential(
            nn.Linear(20, n_mixes * in_size),
            Reshaper((-1, n_mixes, in_size))

        )
        
        self.p_extractor = nn.Sequential(
            nn.Linear(20, n_mixes),
            nn.Softmax()
        )

    def forward(self, X, J):
        X_J = torch.cat([X, J], dim=1)
        features = self.extractor(X_J)
        m = self.m_extractor(features)
        d = self.d_extractor(features)
        p = self.p_extractor(features)
        a = self.a_extractor(features)
        
        return  p, m, a, d

In [287]:

def nll_loss(X, J, P, M, A, D) -> torch.autograd.Variable:
    zipped = zip(X, J, P, M, A, D)

    for i, (x, j, p, m, a, d) in enumerate(zipped):
        print(i)
        print("x", x.shape)
        print("j", j)
        print("p", p.shape)
        print("m", m.shape)
        print("a", a.shape)
        print("d", d.shape, )
        for (p_i, m_i, d_i, a_i) in zip(p, m, d, a):
            a_i_u = a_i.unsqueeze(0)
            cov = (a_i_u.T @ a_i_u) + torch.diag(d_i ** 2)
            mvn_d = dist.MultivariateNormal(m_i, cov)
            l = mvn_d.log_prob(x) 
            # add log p instead of multiplying prob by p
            # but since i'm summing all of this
            
        
        break

In [288]:
inpainter = IrisInpainter()
for x, j, y in dl_train:
    p, m, a, d = inpainter(x, j)
    nll_loss(x, j, p, m, a, d)
    
    break

0
x torch.Size([4])
j tensor([0., 1., 1., 1.])
p torch.Size([3])
m torch.Size([3, 4])
a torch.Size([3, 4])
d torch.Size([3, 4])
tensor(-332.8036, grad_fn=<SubBackward0>)
tensor(-6484.9761, grad_fn=<SubBackward0>)
tensor(-204.3279, grad_fn=<SubBackward0>)
