In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torchvision
from PIL import Image
from torch import Tensor
import torch.nn.functional as F
from torch.optim import AdamW, Optimizer
from torch.utils.data import DataLoader, Dataset
from typing import Callable, Optional, Tuple

from torchvision import transforms
from tqdm import tqdm

from sklearn import cluster, datasets
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal

import matplotlib
# matplotlib.rcParams.update({'font.size': 1})
from matplotlib import rcParams, rcParamsDefault
rcParams.update(rcParamsDefault)
import sys

sys.path.append('../eval')

from metrics import slot_mean_corr_coef, r2_score

In [None]:
n_dp = 1000
nsamples = 128
ncenters = 3
nfeatures = 2
locs = [[0, 0], [10, 10], [-10, -10], [-10, 10], [10, -10]] #, [-10, 0], [0, 10], [10, 0], [0, -10]]
colors = ['green', 'blue', 'orange', 'purple', 'cyan', 'brown', 'pink', 'gray', 'olive', 'red']

locs = np.array(locs)
nsamples = [nsamples//ncenters, nsamples//ncenters, nsamples - 2*(nsamples//ncenters)]

varied = []; cluster_idxs = []
idxs = np.arange(len(locs))
for _ in range(n_dp):
    np.random.shuffle(idxs)
    cidx = idxs[:ncenters]

    data = []; cluster_idx = []
    for ei, i in enumerate(cidx):
        samples = np.random.multivariate_normal(mean = locs[i],
                                                cov = 0.5*np.eye(nfeatures),
                                                size = nsamples[ei])
        data.extend(samples)
        cluster_idx.extend([i]*nsamples[ei])
        
    varied.append(data)
    cluster_idxs.append(cluster_idx)
    
    # varied.append(datasets.make_blobs(
    #                     n_samples=n_samples,
    #                     centers=locs[cidx],
    #                     n_features=2)[0])


varied = np.array(varied)
cluster_idxs = np.array(cluster_idxs)

In [None]:
plt.figure(figsize=(50,50))
for i in range(100):
    plt.subplot(10, 10, i+1)
    for cidx in np.unique(cluster_idxs[i]):
        idx = np.where(cluster_idxs[i] == cidx)[0]
        plt.scatter(varied[i][idx, 0], varied[i][idx, 1], color = colors[cidx])
        plt.xlim(-15, 15)
        plt.ylim(-15, 15)

plt.savefig("data_samples.pdf", bbox_inches='tight')
plt.show()

In [None]:
class SlotAttentionFixedEM(nn.Module):
    def __init__(
        self,
        input_dim: int = 2,
        num_slots: int = 7,
        slot_dim: int = 2,
        routing_iters: int = 3,
        hidden_dim: int = 2,
    ):
        super().__init__()
        self.num_slots = num_slots
        self.slot_dim = slot_dim
        self.routing_iters = routing_iters

        self.loc = nn.Parameter(torch.zeros(1, self.slot_dim))
        self.logscale = nn.Parameter(torch.zeros(1, self.slot_dim))
        self.eps = 1e-6

    def forward(self, x: Tensor, num_slots: Optional[int] = None):
        # b: batch_size, n: num_inputs, c: input_dim, K: num_slots, d: slot_dim
        b, N, d = x.shape
        # (b, n, c)
        # x = self.ln_inputs(x)

        # (b, k, d)
        K = num_slots if num_slots is not None else self.num_slots
        slots = self.loc + self.logscale.exp() * torch.randn(
            b, K, self.slot_dim, device=x.device
        )

        pi = torch.ones(b, K, 1, device = x.device, dtype = x.dtype)/K
        sigma = self.logscale.exp().unsqueeze(1).repeat(1, K, 1)
        

        for _ in range(self.routing_iters):
            slots_prev = slots

            # E-step
            log_pi    = - 0.5 * torch.tensor(2 * torch.pi, device=x.device).log()
            log_scale = - torch.log(torch.clamp(sigma.unsqueeze(2), min = self.eps)) # (B, K, N, d)
            exponent  = - 0.5 * (x.unsqueeze(1) - slots_prev.unsqueeze(2)) ** 2 / (sigma.unsqueeze(2)) ** 2 # (B, K, N, d)
            log_probs = torch.log(torch.clamp(pi, min = self.eps)) + (exponent + log_pi + log_scale).sum(dim=-1) # (B, K, N)
                            
            attn = log_probs.softmax(dim=1) + self.eps # (B, K, N)


            # M-step
            Nk = torch.sum(attn, dim=2, keepdim=True) # (B, K, 1)
            pi = Nk / N
            
            slots = (1 / Nk) * torch.sum(attn.unsqueeze(-1) * x.unsqueeze(1), dim=2) # (B, K, D)

            sigma = (1 / Nk) * torch.sum(attn.unsqueeze(-1) * (x.unsqueeze(1) - slots.unsqueeze(2))**2, dim=2) # (B, K, D)
            sigma = torch.sqrt(sigma) + self.eps
            
        return slots, sigma, attn, pi


class PositionEmbed(nn.Module):
    def __init__(self, out_channels: int = 2, resolution: int = 1500):
        super().__init__()
        # (1, N, 2)
        lp = torch.linspace(0.0, 1.0, steps=resolution).unsqueeze(-1)
        self.grid = torch.cat([lp, 1.0 - lp], dim=-1).unsqueeze(0)
        self.mlp = nn.Linear(2, out_channels)  # 4 for (x, y, 1-x, 1-y)

    def forward(self, x: Tensor):
        # (1, N, out_channels)
        grid = self.mlp(self.grid)
        # (batch_size, out_channels, height, width)
        return x + grid

    def build_grid(self, resolution: Tuple[int, int]) -> Tensor:
        xy = [torch.linspace(0.0, 1.0, steps=r) for r in resolution]
        xx, yy = torch.meshgrid(xy, indexing="ij")
        grid = torch.stack([xx, yy], dim=-1)
        grid = grid.unsqueeze(0)
        return torch.cat([grid, 1.0 - grid], dim=-1)

class SlotAutoencoder(nn.Module):
    def __init__(
        self,
        num_inputs: int = 1500,
        emb_dim: int = 2,
        num_slots: int = 10,
        slot_dim: int = 2,
        routing_iters: int = 1,
        additive_decoder: bool = False,
        slot_type: Optional[str] = "fixedEM",
    ):
        super().__init__()

        enc_act = nn.ReLU()
        self.encoder = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            # enc_act,
            # PositionEmbed(emb_dim, num_inputs),
            # nn.Linear(emb_dim, emb_dim)
        )

        
        self.slot_attention = SlotAttentionFixedEM(
            input_dim=emb_dim,
            num_slots=num_slots,
            slot_dim=slot_dim,
            routing_iters=routing_iters,
            hidden_dim=emb_dim,
        )


        dec_act = nn.LeakyReLU()
        
        self.additive_decoder = additive_decoder
        
        if self.additive_decoder:
            self.decoder = nn.Sequential(
                                    # PositionEmbed(slot_dim, num_inputs),
                                    # nn.Linear(slot_dim, emb_dim),
                                    # dec_act,
                                    nn.Linear(emb_dim, emb_dim + 1),
                                )

        else:
            self.decoder = nn.Sequential(
                                    # PositionEmbed(slot_dim, num_inputs),
                                    # nn.Linear(slot_dim, emb_dim),
                                    # dec_act,
                                    nn.Linear(emb_dim, emb_dim),
                                )
            

        
    def approximate_posterior(self, x: Tensor):
        prev_routing_iters = self.slot_attention.routing_iters

        self.slot_attention.routing_iters = 10
        _, _, _, slots, sigma, _, pi, encodings = self.forward(x)

        joint_pi = pi.flatten(0, 1)/slots.shape[0]
        joint_sigma = sigma.flatten(0, 1)
        joint_slots = slots.flatten(0, 1) 
        
        self.slot_attention.routing_iters = prev_routing_iters
        return joint_pi, joint_slots, joint_sigma, encodings.flatten(0, 1)
        

    def forward(self, x: Tensor, use_encoder: bool = True):
        # b: batch_size, c: channels, h: height, w: width, d: out_channels
        b, n, d = x.shape
        # (b, d, h, w)

        if use_encoder:
            encodings = self.encoder(x)
        else:
            encodings = x

            
        # (b, num_slots, slot_dim)
        slots, sigma, attn, pi = self.slot_attention(encodings)

        x = slots.unsqueeze(1) + (sigma.unsqueeze(1)**0.5)*torch.randn(b, n, slots.shape[1], slots.shape[2])
        x = x * pi.unsqueeze(1)
        
        if self.additive_decoder:
            # (b*num_slots, slot_dim, init_h, init_w)
            x = x.flatten(0, 1)
            
            # (b*num_slots, c + 1, h, w)
            x = self.decoder(x)
    
            # (b, num_slots, c + 1, h, w)
            x = x.view(b, -1, n, d + 1)
            
            # (b, num_slots, n, d), (b, num_slots, n, 1)
            recons, masks = torch.split(x, [d, 1], dim=3)
            masks = masks.softmax(dim=1)
            
            # (b, c, h, w)
            recon_combined = torch.sum(recons * masks, dim=1)
        else:
            x = x.mean(2)
            recon_combined = self.decoder(x)
            masks = None; recons = None
            
        return recon_combined, recons, masks, slots, sigma, attn, pi, encodings

In [None]:
def run(runid = 1, additive_decoder = True, use_encoder = True):
    model = SlotAutoencoder(num_slots = num_slots, 
                            num_inputs = num_inputs, 
                            routing_iters = num_inputs,
                            additive_decoder = additive_decoder)
    optimizer = AdamW(
            model.parameters(), 
            lr=lr, 
            weight_decay=wd
        )
    
    loss_np = 0
    pbar = tqdm(range(nepochs))
    data_idx = np.arange(data.shape[0])
    
    for _  in pbar:
        np.random.shuffle(data_idx)    
        optimizer.zero_grad()
        
        x, _, _, slots, sigma, attn, pi, encodings = model(data[data_idx], use_encoder = use_encoder)
        loss = F.mse_loss(x, data[data_idx])
        
        # x = F.sigmoid(x)
        # loss = F.binary_cross_entropy(x, data) 
        
        loss.backward()
        optimizer.step()
    
        loss_np = loss.item()
        pbar.set_description("Training Loss:{:.4f}".format(loss_np))
    
    plt.clf()
    plt.figure(figsize=(50,50))
    for i in range(100):    
        slots_ = slots[i].detach().cpu(); sigma_ = sigma[i].detach().cpu(); pi_ = pi[i].detach().cpu()
        samples = encodings[i].detach().cpu().view(-1, 2)
        
        plt.subplot(10, 10, i+1)
        for cidx in np.unique(cluster_idxs[i]):
            idx = np.where(cluster_idxs[i] == cidx)[0]
            plt.scatter(samples[idx, 0], samples[idx, 1], color = colors[cidx])
        
        def gaussian_log_prob(x, loc, scale):
            return (
                -0.5 * torch.tensor(2 * torch.pi, device=x.device).log()
                - torch.log(scale)
                - 0.5 * (x - loc) ** 2 / scale ** 2
            )
        
        datanp = samples.view(-1, 2)
        x, y = np.meshgrid(np.linspace(datanp[:, 0].min(), datanp[:, 0].max(), 100),
                           np.linspace(datanp[:, 1].min(), datanp[:, 1].max(), 100))
        
        coord = np.array([x.ravel(), y.ravel()]).T
        coord = torch.from_numpy(coord).unsqueeze(1)
        log_probs = torch.log(pi_) + gaussian_log_prob(coord, slots_, sigma_).sum(dim=-1, keepdim=True)
        likelihood = log_probs.exp().sum(dim=1).reshape(x.shape)
        plt.contour(x, y, likelihood, levels=10)
        
        plt.grid()
        plt.xlabel('Feature-1')
        plt.ylabel('Feature-2')
        plt.title(f'Datapoint-{i}')
    
    plt.tight_layout()
    plt.savefig(f"encoded_GMM_fit_run{runid}.pdf", bbox_inches='tight')
    plt.show()
    
    
    # ================================================================================================
    approx_pi, approx_slots, approx_sigma, encodings = model.approximate_posterior(data[:50])
    approx_slots = approx_slots.detach().cpu(); approx_sigma = approx_sigma.detach().cpu()
    approx_pi = approx_pi.detach().cpu(); approx_samples = encodings.detach().cpu()
    
    random_samples = np.arange(approx_samples.shape[0])
    np.random.shuffle(random_samples)
    
    plt.clf()
    plt.figure(figsize=(10,10))
    nsamples = -1
    
    cluster_idx_flat = cluster_idxs.flatten()[random_samples[:nsamples]]
    approx_samples = approx_samples[random_samples[:nsamples]]
    
    for cidx in np.unique(cluster_idx_flat):
        idx = np.where(cluster_idx_flat == cidx)[0]
        plt.scatter(approx_samples[idx, 0], approx_samples[idx, 1], color = colors[cidx])
    
    
    def gaussian_log_prob(x, loc, scale):
        return (
            -0.5 * torch.tensor(2 * torch.pi, device=x.device).log()
            - torch.log(scale)
            - 0.5 * (x - loc) ** 2 / scale ** 2
        )
    
    datanp = approx_samples.view(-1, 2)
    x, y = np.meshgrid(np.linspace(datanp[:, 0].min(), datanp[:, 0].max(), 100),
                       np.linspace(datanp[:, 1].min(), datanp[:, 1].max(), 100))
    
    coord = np.array([x.ravel(), y.ravel()]).T
    coord = torch.from_numpy(coord).unsqueeze(1)
    log_probs = torch.log(approx_pi) + gaussian_log_prob(coord, approx_slots, approx_sigma).sum(dim=-1, keepdim=True)
    likelihood = log_probs.exp().sum(dim=1).reshape(x.shape)
    plt.contour(x, y, likelihood, levels=10)
    plt.tight_layout()
    
    plt.grid()
    plt.xlabel('Feature-1')
    plt.ylabel('Feature-2')
    # plt.axis('equal')
    plt.title('Aggregate posterior')
    
    plt.savefig(f"Aggregate_posterior_run_{runid}.pdf", bbox_inches='tight')
    plt.show()
    return approx_slots

In [None]:
data = torch.from_numpy(varied).type(torch.float32)
data = (data - data.mean(1, keepdim=True))/data.std(1, keepdim=True)
# data = (data - data.mean(2, keepdim=True))/data.std(2, keepdim=True)

# data = data/data.max(1, keepdim=True)[0]
# data = (data - data.min(1, keepdim=True)[0])/(data.max(1, keepdim=True)[0] - data.min(1, keepdim=True)[0] + 1e-5)

v_numpy = data[0].detach().cpu().numpy()
plt.clf()
plt.scatter(v_numpy[:,0], v_numpy[:, 1])
# plt.xlim(-1.25, 1.25)
# plt.ylim(-1.25, 1.25)
plt.show()
        
nepochs = 15
lr = 0.10
wd = 1e-6
num_slots = 3
num_inputs = v_numpy.shape[0] 
routing_iters = 3

# No encoder, No additive decoder

In [None]:
logs = []
for id in range(15):
    logs.append(run(id, additive_decoder = False, use_encoder = False))

In [None]:
mccs = []; r2s = []
for i in range(len(logs) -1):
    xdata = logs[i].view(-1, num_slots, nfeatures) # (logs[i] - logs[i].mean(1, keepdim=True))/logs[i].std(1, keepdim=True)
    ydata = logs[i+1].view(-1, num_slots, nfeatures)  #(logs[i+1] - logs[i+1].mean(1, keepdim=True))/logs[i].std(1, keepdim=True)
    
    mcc_score, ordered_logs = slot_mean_corr_coef(xdata, 
                                                    ydata,
                                                    return_ordered = True,
                                                    affine_transformation = True)

    ordered_logs = ordered_logs
    r2scc = r2_score(xdata, ordered_logs)
    
    mccs.append(mcc_score.item())
    r2s.append(r2scc)



idxs = np.argsort(mccs)[::-1]
mccs = np.array(mccs)[idxs[:5]]
r2s = np.array(r2s)[idxs[:5]]
print(f'SMCC: ({np.mean(mccs)} +- {np.std(mccs)}), R2: ({np.mean(r2s)} +- {np.std(r2s)})')

# No encoder with additive decoder

In [None]:
logs = []
for id in range(15 ):
    logs.append(run(id, additive_decoder = True, use_encoder = False))

In [None]:
mccs = []; r2s = []
for i in range(len(logs) -1):
    xdata = logs[i].view(-1, num_slots, nfeatures) # (logs[i] - logs[i].mean(1, keepdim=True))/logs[i].std(1, keepdim=True)
    ydata = logs[i+1].view(-1, num_slots, nfeatures)  #(logs[i+1] - logs[i+1].mean(1, keepdim=True))/logs[i].std(1, keepdim=True)
    
    mcc_score, ordered_logs = slot_mean_corr_coef(xdata, 
                                                    ydata,
                                                    return_ordered = True,
                                                    affine_transformation = True)

    ordered_logs = ordered_logs
    r2scc = r2_score(xdata, ordered_logs)
    
    mccs.append(mcc_score.item())
    r2s.append(r2scc)



idxs = np.argsort(mccs)[::-1]
mccs = np.array(mccs)[idxs[:5]]
r2s = np.array(r2s)[idxs[:5]]
print(f'SMCC: ({np.mean(mccs)} +- {np.std(mccs)}), R2: ({np.mean(r2s)} +- {np.std(r2s)})')

# encoder, No additive decoder

In [None]:
logs = []
for id in range(15 ):
    logs.append(run(id, additive_decoder = False, use_encoder = True))

In [None]:
mccs = []; r2s = []
for i in range(len(logs) -1):
    xdata = logs[i].view(-1, num_slots, nfeatures) # (logs[i] - logs[i].mean(1, keepdim=True))/logs[i].std(1, keepdim=True)
    ydata = logs[i+1].view(-1, num_slots, nfeatures)  #(logs[i+1] - logs[i+1].mean(1, keepdim=True))/logs[i].std(1, keepdim=True)
    
    mcc_score, ordered_logs = slot_mean_corr_coef(xdata, 
                                                    ydata,
                                                    return_ordered = True,
                                                    affine_transformation = True)

    ordered_logs = ordered_logs
    r2scc = r2_score(xdata, ordered_logs)
    
    mccs.append(mcc_score.item())
    r2s.append(r2scc)



idxs = np.argsort(mccs)[::-1]
mccs = np.array(mccs)[idxs[:5]]
r2s = np.array(r2s)[idxs[:5]]
print(f'SMCC: ({np.mean(mccs)} +- {np.std(mccs)}), R2: ({np.mean(r2s)} +- {np.std(r2s)})')

# encoder with additive decoder

In [None]:
logs = []
for id in range(15 ):
    logs.append(run(id, additive_decoder = True, use_encoder = True))

In [None]:
mccs = []; r2s = []
for i in range(len(logs) -1):
    xdata = logs[i].view(-1, num_slots, nfeatures) # (logs[i] - logs[i].mean(1, keepdim=True))/logs[i].std(1, keepdim=True)
    ydata = logs[i+1].view(-1, num_slots, nfeatures)  #(logs[i+1] - logs[i+1].mean(1, keepdim=True))/logs[i].std(1, keepdim=True)
    
    mcc_score, ordered_logs = slot_mean_corr_coef(xdata, 
                                                    ydata,
                                                    return_ordered = True,
                                                    affine_transformation = True)

    ordered_logs = ordered_logs
    r2scc = r2_score(xdata, ordered_logs)
    
    mccs.append(mcc_score.item())
    r2s.append(r2scc)



idxs = np.argsort(mccs)[::-1]
mccs = np.array(mccs)[idxs[:5]]
r2s = np.array(r2s)[idxs[:5]]
print(f'SMCC: ({np.mean(mccs)} +- {np.std(mccs)}), R2: ({np.mean(r2s)} +- {np.std(r2s)})')