# Invariant Slot Attention

**Goal:** I sort of thad this intuition for awhile that I want to be able to encode "circleness" into the slot representations that we're learning.

This idea from the SA follow-up paper is not _exactly_ the same as this, but I think it's going in this direction!

**Other optimization tricks included in this paper:**
- Cosine decay (instead of exponential decay)
- Use ResNet-34 as the image feature extractor model for 
    * They did modify the base block of this model to have stride 1 instead of 3
- They do also add a $\delta$ division with the positional embedding (they set $\delta = 5$).
   

In [None]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mlp
from mpl_toolkits.axes_grid1 import make_axes_locatable

import json, yaml, os
os.sys.path.append('../code')

from plotting import plot_kslots, plot_kslots_iters
from data import make_batch
from model import SoftPositionalEmbed, build_grid
from torch.nn import init
from train import hungarian_matching

import torch
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

In [None]:
device = 'cuda:2'

In [None]:
hps = {
    'hidden_dim': 16,
    'k_slots':3,
    'query_dim':128,
    'pixel_mult':0.2,
    'device':device
}

**How was the data generator initialized?**
- $x,y \sim \text{Unif}(-0.5, 0.5)$
- $r \sim \text{Unif}(0.01, 0.05)$

In [None]:
stdlow,stdhigh = 0.01, 0.05
from copy import copy

In [None]:
class InvariantSlotAttention(torch.nn.Module):
    def __init__(self, 
                 resolution=(32,32),
                 xlow=-0.5,
                 xhigh=0.5,
                 k_slots=3, 
                 num_conv_layers=3,
                 hidden_dim=32, 
                 final_cnn_relu=False,
                 query_dim=32, 
                 n_iter=2,
                 pixel_mult=1,
                 device='cpu' 
                 ):
        '''
        Slot attention encoder block with positional embedding

        Inputs:
        - resolution 
        - k_slots (default 3): number of slots (note, can vary between training and test time)
        - num_conv_layers: # of convolutional layers to apply (google paper has 4)
        - hidden_dim (default 32): The hidden dimension for the CNN (currently single layer w/ no non-linearities)
        - final_cnn_relu: Whether to apply the final cnn relu for these experiments (use true to mimic google repo)
        - query_dim (default 32): The latent space dimension that the slots and the queries get computed in
        - n_iter (default  2): Number of slot attention steps to apply (defualt 2)
        - T (str): Softmax temperature for scaling the logits 
            * default: 1/sqrt(query_dim)
        - device (str): Which device to put the model on.
            Options: cpu (default), mps, cuda:{i}
            Also used when drawing random samples for the query points 
            and the grid generation for the positional encoding
        '''
        super().__init__()

        self.k_slots = k_slots
        self.hidden_dim = hidden_dim
        self.query_dim = query_dim
        self.n_iter = n_iter

        self.resolution = resolution
        self.xlow, self.xhigh = xlow, xhigh
        
        self.device=device
         
        self.softmax_T = 1/np.sqrt(query_dim)
        
        self.dataN = torch.nn.LayerNorm(self.hidden_dim)
        self.queryN = torch.nn.LayerNorm(self.query_dim)
        
        self.toK = torch.nn.Linear(self.hidden_dim, self.query_dim)
        self.toV = torch.nn.Linear(self.hidden_dim, self.query_dim)
        self.gru = torch.nn.GRUCell(self.query_dim, self.query_dim)

        kwargs = {'out_channels': hidden_dim,'kernel_size': 5, 'padding':2 }
        cnn_layers = [torch.nn.Conv2d(1,**kwargs)]
        for i in range(num_conv_layers-1):
            cnn_layers += [torch.nn.ReLU(), torch.nn.Conv2d(hidden_dim,**kwargs)] 
            
        if final_cnn_relu:
            cnn_layers.append(torch.nn.ReLU())

        self.CNN_encoder = torch.nn.Sequential(*cnn_layers)
            
        '''
        Positional embedding inputs
        '''
        self.abs_grid = self.build_grid()
        
        self.dense = torch.nn.Linear(2, query_dim) 
        self.pixel_mult = pixel_mult # LH's proposal... but almost same as 1/delta in ISA

        # Apply after the data normalization
        self.init_mlp = torch.nn.Sequential(
            torch.nn.Linear(query_dim,query_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(query_dim,query_dim)
        )
           
        '''
        Slot initialization setup
        '''
        self.slots_mu = torch.nn.Parameter(torch.randn(1, 1, self.query_dim,device=device))
        self.slots_logsigma = torch.nn.Parameter(torch.zeros(1, 1, self.query_dim,device=device))
        init.xavier_uniform_(self.slots_logsigma)

        self.init_slots = self.init_slots

    def build_grid(self):
        '''
        From google slot attention repo:
        https://github.com/nhartman94/google-research/blob/master/slot_attention/model.py#L357C1-L364C53
        '''
        resolution = self.resolution
        xlow, xhigh = self.xlow, self.xhigh
           
        ranges = [np.linspace(xlow, xhigh, num=res) for res in resolution]
        grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
        grid = np.stack(grid, axis=-1)
        grid = np.reshape(grid, [resolution[0], resolution[1], -1])
        grid = np.expand_dims(grid, axis=0)
        # grid = grid.astype(np.float32)
        
        # Now make it a pytorch tensor
        grid = torch.FloatTensor( grid ).to(device)
        grid = torch.flatten(grid,1,2)
    
        return grid
        
    def init_slots(self,Nbatch):
        '''
        Slot init taken from
        https://github.com/lucidrains/slot-attention/blob/master/slot_attention/slot_attention.py
        '''
        mu = self.slots_mu.expand(Nbatch, self.k_slots, -1)
        sigma = self.slots_logsigma.exp().expand(Nbatch, self.k_slots, -1)
    
        queries = mu + sigma * torch.randn(mu.shape,device=device)
    
        # Add the position and scale initialization for the local ref frame
        ref_frame_dim = 3
        pos_scale = torch.rand(Nbatch, self.k_slots, ref_frame_dim,device=device)

        pos_scale[:,:2] -= 0.5
        pos_scale[:,-1]  = (stdhigh - stdlow) * pos_scale[:,-1] + stdlow
        
        return queries, pos_scale
    
    def get_keys_vals(self, encoded_data, pos_scale):

        # Get the relative position embedding
        rel_grid = self.abs_grid.unsqueeze(1) - pos_scale[:,:,:2].unsqueeze(2)
        rel_grid /= pos_scale[:,:,-1].unsqueeze(2).unsqueeze(-1)
        
        # Embed it in the same space as the query dimension 
        embed_grid = self.pixel_mult * self.dense( rel_grid )
        
        # keys, vals: (bs, img_dim, query_dim)
        keys = m.toK(encoded_data).unsqueeze(1) + embed_grid
        vals = m.toV(encoded_data).unsqueeze(1) + embed_grid
        
        keys = self.init_mlp(self.queryN(keys))
        vals = self.init_mlp(self.queryN(vals))
        
        return keys, vals
                
    def attention_and_weights(self,queries,keys):
        
        logits = torch.einsum('bse,bsde->bsd',queries,keys) * self.softmax_T
        
        att = torch.nn.functional.softmax(logits, dim = 1)
        
        div = torch.sum(att, dim = -1, keepdims = True)
        wts = att/div + 1e-8
        return att,wts

    def update_frames(self,wts):
        '''
        Update the relative frame position
        '''
        
        # expand to include the batch dim
        grid_exp = self.abs_grid.expand(wts.shape[0],-1,2)
        
        new_pos = torch.einsum('bsd,bde->bse',wts,grid_exp)
        
        new_scale = torch.sum(torch.pow(grid_exp.unsqueeze(1) - new_pos.unsqueeze(2),2),dim=-1)
        
        new_scale = torch.einsum('bsd,bsd->bs', wts, new_scale)
        new_scale = torch.sqrt(new_scale)
        
        return torch.cat([new_pos,new_scale.unsqueeze(-1)],axis=-1)
        
    def iterate(self, queries, pos_scale, encoded_data):
        
        # Get the keys and values in the ref ref frame
        keys, vals = self.get_keys_vals(encoded_data,pos_scale)
        
        # att,wts: (bs, k_slots, img_dim)
        att,wts = self.attention_and_weights(self.queryN(queries),keys)   
        
        new_pos_scale = self.update_frames(wts)
        
        # Update the queries with the recurrent block
        updates = torch.einsum('bsd,bsde->bse',wts,vals) # bs, n_slots, query_dim
        
        updates = self.gru(
            updates.reshape(-1,self.query_dim),
            queries.reshape(-1,self.query_dim),
        )
        
        return updates.reshape(queries.shape), new_pos_scale
        
    def forward(self, data, return_init=False):
    
        '''
        Step 1: Extract the CNN features
        '''
        encoded_data = self.CNN_encoder(data) # Apply the CNN encoder
        encoded_data = torch.permute(encoded_data,(0,2,3,1)) # Put channel dim at the end
        encoded_data = torch.flatten(encoded_data,1,2) # flatten pixel dims
        encoded_data = self.dataN(encoded_data)
        
        '''
        Step 2: Initialize the slots
        '''
        Nbatch = data.shape[0]
        
        # Initialize the queries and pos_scale
        queries, pos_scale = self.init_slots(Nbatch) # Shape (Nbatch, k_slots, query_dim)
        
        init_queries = copy(queries)
        init_pos = copy(pos_scale)
        
        '''
        Step 3: Iterate through the reconstruction
        '''
        for i in range(self.n_iter):
            queries, pos_scale = self.iterate(queries, pos_scale, encoded_data)    
            
        # With the final query vector, calc the attn, weights, + rel ref frames
        keys, vals = self.get_keys_vals(encoded_data,pos_scale)
        att, wts = self.attention_and_weights(self.queryN(queries),keys)   
        new_pos_scale = self.update_frames(wts)
        
        if return_init:
            return queries, new_pos_scale, att, wts, init_queries, init_pos 
        else:
            return queries, new_pos_scale, att, wts

In [None]:
m = InvariantSlotAttention(**hps).to(device)

In [None]:
m.abs_grid.sha

In [None]:
# m.load_state_dict(torch.load('code/models/test-isa/m_161.pt'))

In [None]:
nPixels=32
from scipy.optimize import linear_sum_assignment

In [None]:
def train_ISA(model, 
          Ntrain = 5000, 
          bs=32, 
          lr=3e-4,
          warmup_steps=5_000,
          losses = [],
          kwargs={'isRing': True, 'N_clusters':2},
          device='cpu',
          plot_every=250, 
          save_every=1000,
          color='C0',cmap='Blues',
          modelDir='.',figDir='',showImg=True):
    '''
    Same arg as train, rn just modifying for more outputs
    '''

    loss_fct = torch.nn.BCELoss(reduction='none')
    
    # Learning rate schedule config
    base_learning_rate = lr
    opt = torch.optim.Adam(model.parameters(), lr)
    
    model.train()
    
    k_slots = model.k_slots
    resolution = model.resolution
    kwargs['device'] = device

    max_n_rings = kwargs['N_clusters']
    isRing = kwargs["isRing"]
    print(f'Training model with {k_slots} slots on {max_n_rings}'+ ("rings" if isRing else "blobs"))

    start = len(losses)
    for i in range(start,start+Ntrain):
           
        learning_rate = base_learning_rate * 0.5 * (1 + np.cos(np.pi * i / Ntrain))
        if i < warmup_steps:
            learning_rate *= (i / warmup_steps)
        
        opt.param_groups[0]['lr'] = learning_rate
            
        X, Y, mask = make_batch(N_events=bs, **kwargs)
        
        opt.zero_grad()
        out = model(X,return_init=True)
        queries, pos_scale, att, wts, init_q, init_pos = out
        if torch.isnan(init_q).sum() > 0 :
            print('init_q is nan')
        
        if torch.isnan(att).sum() > 0 :
        
            print('# nan',torch.isnan(att).sum())
            print('att',att)
            print('try 2',torch.isnan(model(X)[2]).sum())
            
            # DEBUG: Save all sources of randomness
            ks = ['queries', 'pos_scale', 'att', 'wts', 'init_q', 'init_pos']
            data = {k: v.tolist() for k,v in zip(ks,out)}
            with open() as f:
                json.dump(data, f)
            
            return model, X,Y,mask, init_q, init_pos
            
        with torch.no_grad():
            
            # Calculate the loss of _all_ possible combinations  
            flat_mask = mask.reshape(-1,max_n_rings, np.prod(resolution))[:,None,:,:]
        
            att_ext  = torch.tile(att.unsqueeze(2),  dims=(1,1,max_n_rings,1)) 
            mask_ext = torch.tile(flat_mask,dims=(1,k_slots,1,1)) 

            pairwise_cost = loss_fct(att_ext,mask_ext).mean(axis=-1)
            
            indices = hungarian_matching(pairwise_cost)
        
        # Apply the sorting to the predict
        bis=torch.arange(bs).to(device)
        indices=indices.to(device)

        slots_sorted = torch.cat([att[bis,indices[:,0,ri]].unsqueeze(1) for ri in range(max_n_rings)],dim=1)
        
        flat_mask = mask.reshape(-1,max_n_rings, np.prod(resolution))
        rings_sorted = torch.cat([flat_mask[bis,indices[:,1,ri]].unsqueeze(1) for ri in range(max_n_rings)],dim=1)

        # Calculate the loss
        loss = loss_fct(slots_sorted,rings_sorted).sum(axis=1).mean()
        
        # DEBUG: Save model before update
        torch.save(model.state_dict(), f'{modelDir}/m_-2.pt')
        
        loss.backward()
        opt.step()
        
        # DEBUG: Save model after update
        torch.save(model.state_dict(), f'{modelDir}/m_-1.pt')
        
        losses.append(float(loss))
            
        if i % plot_every == 0:
            
            print('iter',i,', loss',loss.detach().cpu().numpy(),', lr',opt.param_groups[0]['lr'])
            
            iEvt = 0
            att_img = att[iEvt].reshape(k_slots,*resolution)
            plot_kslots(losses, 
                        mask[iEvt].sum(axis=0).detach().cpu().numpy(), 
                        att_img.detach().cpu().numpy(),
                        k_slots, color=color,cmap=cmap,
                        figname=f'{figDir}/loss-slots-iter{i}-evt{iEvt}.jpg',showImg=showImg)
#             plot_kslots_iters(model, X, iEvt=0, color=color,cmap=cmap, 
#                               figname=f'{figDir}/slots-unroll-iter{i}-evt{iEvt}.jpg',showImg=showImg)
            
        if i % save_every == 0:
            torch.save(model.state_dict(), f'{modelDir}/m_{i}.pt')
            with open(f'{modelDir}/loss.json','w') as f:
                json.dump(losses, f)
                
    model.eval()
    return model,losses

In [None]:
cID = 'isa-cosine-decay'

modelDir = f'../code/models/{cID}'
figDir = f'../code/figures/{cID}'

# for d in [modelDir,figDir]:
#     os.mkdir(d)

In [None]:
bs=128
max_n_rings=2

In [None]:
losses=[]

out = train_ISA(m,Ntrain=40_000,bs=bs,device=device,
                modelDir=modelDir,figDir=figDir,plot_every=50)

In [None]:
model, X,Y,mask, init_q, init_pos = out

In [None]:
queries, pos_scale, att, wts = model(X)

In [None]:
att[4]

**Ideas:**
1. I could see if an L2 regularization would help
- Plot the min and max of the model parameters over time
2. Would gradient clipping help?
- I feel like it seems more sus smth else in the opt pipeline rn

In [None]:
encoded_data

In [None]:
i_fail = 0

In [None]:
encoded_data = model.CNN_encoder(X) # Apply the CNN encoder
encoded_data = torch.permute(encoded_data,(0,2,3,1)) # Put channel dim at the end
encoded_data = torch.flatten(encoded_data,1,2) # flatten pixel dims
encoded_data = model.dataN(encoded_data)

# Use the init Q from the failure mode
queries = copy(init_q)
pos_scale = copy(init_pos)

'''
Step 3: Iterate through the reconstruction
'''
for i in range(model.n_iter):
    # queries, pos_scale = model.iterate(queries, pos_scale, encoded_data)    

    print('t=',i)
    
    # Get the keys and values in the ref ref frame
    keys, vals = model.get_keys_vals(encoded_data,pos_scale)

    print('keys',keys[i_fail].isnan().sum().item(),f'max {keys[i_fail].max().item():.2f},min {keys[i_fail].min().item():.2f}')
    print('vals',vals[i_fail].isnan().sum().item(),f'max {vals[i_fail].max().item():.2f},min {vals[i_fail].min().item():.2f}')
    
    # att,wts: (bs, k_slots, img_dim)
    att,wts = model.attention_and_weights(model.queryN(queries),keys)   

    print('att',att[i_fail].isnan().sum().item(),f'max {att[i_fail].max().item():.2f},min {att[i_fail].min().item():.2f}')
    print('wts',wts[i_fail].isnan().sum().item(),f'max {wts[i_fail].max().item():.2f},min {wts[i_fail].min().item():.2f}')

    new_pos_scale = model.update_frames(wts)

    # Update the queries with the recurrent block
    updates = torch.einsum('bsd,bsde->bse',wts,vals) # bs, n_slots, query_dim

    updates = model.gru(
        updates.reshape(-1,model.query_dim),
        queries.reshape(-1,model.query_dim),
    )

    queries,pos_scale = updates.reshape(queries.shape), new_pos_scale
    
    
    print(i,queries[i_fail],pos_scale[i_fail])
    
    break
    
    
# # With the final query vector, calc the attn, weights, + rel ref frames
# keys, vals = model.get_keys_vals(encoded_data,pos_scale)
# att, wts = model.attention_and_weights(model.queryN(queries),keys)   
# new_pos_scale = model.update_frames(wts)

In [None]:
X2, Y2, mask2 = make_batch(N_events=2,device=device, **{'isRing': True, 'N_clusters':2})
        

In [None]:
# m_cpu = model.to(device)

In [None]:
model(X2)

In [None]:
plt.hist(encoded_data.flatten().detach().cpu().numpy())

In [None]:
# Get the relative position embedding
rel_grid = model.abs_grid.unsqueeze(1) - init_pos[:,:,:2].unsqueeze(2)
rel_grid /= init_pos[:,:,-1].unsqueeze(2).unsqueeze(-1)

# Embed it in the same space as the query dimension 
embed_grid = model.pixel_mult * model.dense( rel_grid )

# keys, vals: (bs, img_dim, query_dim)
k0 = model.toK(encoded_data).unsqueeze(1) + embed_grid
v0 = model.toV(encoded_data).unsqueeze(1) + embed_grid

k0 = model.init_mlp(k0)
v0 = model.init_mlp(v0)

In [None]:
nb=100
r=(-5,5)
plt.hist(k0.flatten().detach().cpu().numpy(),nb,r,label='keys',
         color='g',histtype='step',lw=2)
plt.hist(v0.flatten().detach().cpu().numpy(),nb,r,label='values',
         color='b',histtype='step',lw=2)

plt.xlabel('features')
plt.ylabel('entries')
plt.legend()
plt.show()

In [None]:
nb=100
r=(-5,5)
plt.hist(k0[i_fail].flatten().detach().cpu().numpy(),nb,r,label='keys',
         color='g',histtype='step',lw=2)
plt.hist(v0[i_fail].flatten().detach().cpu().numpy(),nb,r,label='values',
         color='b',histtype='step',lw=2)

plt.xlabel('features')
plt.ylabel('entries')
plt.title(f'Event {i_fail} (evt with nans)')
plt.legend()
plt.show()

In [None]:
k0[i_fail].shape

In [None]:
k0[i_fail][abs(k0[i_fail])>100]

In [None]:
v0[i_fail][abs(v0[i_fail])>100]

In [None]:
for p in model.parameters():
    print(p.shape,p)
    # plt.hist(p.detach().cpu().numpy(),histtype='step')

In [None]:
losses

In [None]:
att.isnan().sum()

In [None]:
for i, att_i in enumerate(att):
    if att_i.isnan().sum() > 0:
        print(i)

In [None]:
i_fail = 23

In [None]:
init_q[i_fail]

In [None]:
init_pos[i_fail]

In [None]:
np.prod(att.shape)

In [None]:
init_q

In [None]:
init_pos

In [None]:
copy(att)

In [None]:
att.shape

In [None]:
m,losses = out

In [None]:
model, X,Y,mask = out

In [None]:
queries, pos_scale, att, wts = m(X)
        
print( torch.isnan(att).sum() >0 )

In [None]:
losses

In [None]:
# att_img = att[iEvt].reshape(model.k_slots,*resolution)
# plot_kslots(losses, 
#             mask[iEvt].sum(axis=0).detach().cpu().numpy(), 
#             att_img.detach().cpu().numpy(),
#             k_slots, color=color,cmap=cmap,
#             figname=f'{figDir}/loss-slots-iter{i}-evt{iEvt}.jpg',showImg=showImg)
