# Implementation of ResNet to ISA 

Let's implement Sanzianas ResNet encoder!

Ok, it works kind of... but only because I've added a transverse CNN layer to upsample from a [bs, 33, 4, 4] to [bs, 16, 32, 32]. Maybe this is better solved with a Linear layers which is then reshaped into [bs, 16, 32, 32]?

In [1]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib as mlp
from mpl_toolkits.axes_grid1 import make_axes_locatable

import json, yaml, os
os.sys.path.append('./../../code')

from plotting import plot_kslots, plot_kslots_iters
from data_scclevr import makeRings 
from model import InvariantSlotAttention

from matplotlib.patches import Circle
import json

# Set numpy seed for test set sampling 
torch_seed = 24082023
torch.manual_seed( torch_seed )

import random
random.seed(torch_seed)

import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from plotting import plot_chosen_slots, plot_kslots, plot_kslots_iters, plot_kslots_grads

%load_ext autoreload
%autoreload 2

In [2]:
from torch.nn import init
import scclevr
import torch.nn as nn

In [3]:
import torchvision.models as models

In [4]:
device = 'cpu'
cID_prev = 'isa-alpha3_scclevr'
with open(f'./../../code/configs/{cID_prev}.yaml') as f:
    cd = yaml.safe_load(f)

hps = cd['hps']
hps['device'] = device

In [5]:
class Sampling(nn.Module):
    def forward(self, z_mean, z_log_var):
        # get the shape of the tensor for the mean and log variance
        batch, dim = z_mean.shape
        # generate a normal random tensor (epsilon) with the same shape as z_mean
        # this tensor will be used for reparameterization trick
        epsilon = Normal(0, 1).sample((batch, dim)).to(z_mean.device)
        # apply the reparameterization trick to generate the samples in the
        # latent space
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

In [6]:
class InvariantSlotAttention_ResNet(torch.nn.Module):
    def __init__(self, 
                 resolution=(32,32),
                 xlow=-0.5,
                 xhigh=0.5,
                 varlow=0.01,
                 varhigh=0.05,
                 k_slots=3, 
                 num_conv_layers=3,
                 which_encoder='ResNet',
                 hidden_dim=32, 
                 query_dim=32, 
                 n_iter=2,
                 pixel_mult=1,
                 device='cpu' ,
                 learn_slot_feat=True
                 ):
        '''
        Slot attention encoder block, block attention
        '''
        super().__init__()

        self.k_slots = k_slots
        self.hidden_dim = hidden_dim
        self.query_dim = query_dim
        self.n_iter = n_iter

        self.resolution = resolution
        self.xlow, self.xhigh = xlow, xhigh
        self.rlow, self.rhigh = np.sqrt(varlow), np.sqrt(varhigh)
        
        self.device=device
         
        self.softmax_T = 1/np.sqrt(query_dim)
        
        self.dataN = torch.nn.LayerNorm(self.hidden_dim)
        self.queryN = torch.nn.LayerNorm(self.query_dim)
        
        self.toK = torch.nn.Linear(self.hidden_dim, self.query_dim)
        self.toV = torch.nn.Linear(self.hidden_dim, self.query_dim)
        self.gru = torch.nn.GRUCell(self.query_dim, self.query_dim)
        
        '''
        CNN feature extractor
        '''
        kwargs = {'out_channels': hidden_dim,'kernel_size': 5, 'padding':2 }
        cnn_layers = [torch.nn.Conv2d(1,**kwargs)]
        for i in range(num_conv_layers-1):
            cnn_layers += [torch.nn.ReLU(), torch.nn.Conv2d(hidden_dim,**kwargs)] 
        cnn_layers.append(torch.nn.ReLU())
          
        self.CNN_encoder = torch.nn.Sequential(*cnn_layers) # 3 CNN layers by default
        if which_encoder=='ResNet':
            self.CNN_encoder = Encoder_resnet_1()
            

            
          
            
        # Grid + query init
        self.abs_grid = self.build_grid()
                   
        self.dense = torch.nn.Linear(2, query_dim) 
        self.pixel_mult = pixel_mult # LH's proposal... but almost same as 1/delta in ISA

        # Apply after the data normalization
        self.init_mlp = torch.nn.Sequential(
            torch.nn.Linear(query_dim,query_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(query_dim,query_dim)
        )
            
        self.slots_mu = torch.nn.Parameter(torch.randn(1, 1, self.query_dim))
        self.slots_logsigma = torch.nn.Parameter(torch.zeros(1, 1, self.query_dim))
        init.xavier_uniform_(self.slots_logsigma)

        self.init_slots = self.init_slots

        
        '''
        Option to add a final (x,y,r) prediction to each slot
        '''
        self.learn_slot_feat = learn_slot_feat
        if self.learn_slot_feat:
            self.final_mlp = torch.nn.Sequential(
                torch.nn.Linear(query_dim,hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dim, 3)
            )
        
    def build_grid(self):
        '''
        From google slot attention repo:
        https://github.com/nhartman94/google-research/blob/master/slot_attention/model.py#L357C1-L364C53
        '''
        resolution = self.resolution
        xlow, xhigh = self.xlow, self.xhigh
           
        ranges = [np.linspace(xlow, xhigh, num=res) for res in resolution]
        grid = np.meshgrid(*ranges, sparse=False, indexing="xy")
        grid = np.stack(grid, axis=-1)
        grid = np.reshape(grid, [resolution[0], resolution[1], -1])
        grid = np.expand_dims(grid, axis=0)
        
        grid = torch.FloatTensor( grid ).to(self.device)
        grid = torch.flatten(grid,1,2)
    
        return grid
                
    def init_slots(self,Nbatch):
        '''
        Slot init taken from
        https://github.com/lucidrains/slot-attention/blob/master/slot_attention/slot_attention.py
        '''
        
        stdhigh, stdlow = self.rlow, self.rhigh
        
        mu = self.slots_mu.expand(Nbatch, self.k_slots, -1)
        sigma = self.slots_logsigma.exp().expand(Nbatch, self.k_slots, -1)
    
        queries = mu + sigma * torch.randn(mu.shape,device=self.device)
    
        # Add the position and scale initialization for the local ref frame
        ref_frame_dim = 3
        pos_scale = torch.rand(Nbatch, self.k_slots, ref_frame_dim,device=self.device)

        pos_scale[:,:2] -= 0.5
        pos_scale[:,-1]  = (stdhigh - stdlow) * pos_scale[:,-1] + stdlow
        
        return queries, pos_scale
     
    def get_keys_vals(self, encoded_data, pos_scale):

        # Get the relative position embedding
        rel_grid = self.abs_grid.unsqueeze(1) - pos_scale[:,:,:2].unsqueeze(2)
        rel_grid /= pos_scale[:,:,-1].unsqueeze(2).unsqueeze(-1)
        
        # Embed it in the same space as the query dimension 
        embed_grid = self.pixel_mult * self.dense( rel_grid )
        
        # keys, vals: (bs, img_dim, query_dim)
        keys = self.toK(encoded_data).unsqueeze(1) + embed_grid
        vals = self.toV(encoded_data).unsqueeze(1) + embed_grid
        
        keys = self.init_mlp(self.queryN(keys))
        vals = self.init_mlp(self.queryN(vals))
        
        return keys, vals
                
    def attention_and_weights(self,queries,keys):
        
        logits = torch.einsum('bse,bsde->bsd',queries,keys) * self.softmax_T
        
        att = torch.nn.functional.softmax(logits, dim = 1)
        
        div = torch.sum(att, dim = -1, keepdims = True)
        wts = att/div + 1e-8
        return att,wts

    def update_frames(self,wts):
        '''
        Update the relative frame position
        '''
        
        # expand to include the batch dim
        grid_exp = self.abs_grid.expand(wts.shape[0],-1,2)
        
        new_pos = torch.einsum('bsd,bde->bse',wts,grid_exp)
        
        new_scale = torch.sum(torch.pow(grid_exp.unsqueeze(1) - new_pos.unsqueeze(2),2),dim=-1)
        
        new_scale = torch.einsum('bsd,bsd->bs', wts, new_scale)
        new_scale = torch.sqrt(new_scale)
        
        return torch.cat([new_pos,new_scale.unsqueeze(-1)],axis=-1)
        
    def iterate(self, queries, pos_scale, encoded_data):
        
        # Get the keys and values in the ref ref frame
        keys, vals = self.get_keys_vals(encoded_data,pos_scale)
        
        # att,wts: (bs, k_slots, img_dim)
        att,wts = self.attention_and_weights(self.queryN(queries),keys)   
        
        new_pos_scale = self.update_frames(wts)
        
        # Update the queries with the recurrent block
        updates = torch.einsum('bsd,bsde->bse',wts,vals) # bs, n_slots, query_dim
        
        updates = self.gru(
            updates.reshape(-1,self.query_dim),
            queries.reshape(-1,self.query_dim),
        )
        
        return updates.reshape(queries.shape), new_pos_scale
        
    def forward(self, data):
    
        '''
        Step 1: Extract the CNN features
        '''
        #print('data shape', data.shape)
        encoded_data = self.CNN_encoder(data) # Apply the CNN encoder
    
        #print(encoded_data.shape) 
        # for ModifiedResNet18: torch.Size([32, 16, 1, 1]) but should be torch.Size([32, 16, 32, 32]) with [bs, kernel, picture] - at least no error anymore!
        # CustomResNet retruns indead [32,1,32, 32]
        encoded_data = torch.permute(encoded_data,(0,2,3,1)) # Put channel dim at the end
        encoded_data = torch.flatten(encoded_data,1,2) # flatten pixel dims
        #print(encoded_data.shape)
        encoded_data = self.dataN(encoded_data)
        
        '''
        Step 2: Initialize the slots
        '''
        Nbatch = data.shape[0]
        queries, pos_scale = self.init_slots(Nbatch) # Shape (Nbatch, k_slots, query_dim)
                
        '''
        Step 3: Iterate through the reconstruction
        '''
        for i in range(self.n_iter):
            queries, pos_scale = self.iterate(queries, pos_scale, encoded_data)    
            
        # With the final query vector, calc the attn, weights, + rel ref frames
        keys, vals = self.get_keys_vals(encoded_data,pos_scale)
        att, wts = self.attention_and_weights(self.queryN(queries),keys)   
        new_pos_scale = self.update_frames(wts)
                
        if self.learn_slot_feat:
            slot_feat = self.final_mlp(queries)
            
            # Want to learn the delta from the previously estimated position
            slot_feat += new_pos_scale
            
            return queries, att, slot_feat 
        
        else:
            return queries, att, wts

In [7]:
def hungarian_matching(pairwise_cost):
    '''
    Input:
    - pairwise_cost


    Hungarian section Translated from the TensorFlow loss function (from 2006.15055 code):
    https://github.com/nhartman94/google-research/blob/master/slot_attention/utils.py#L26-L57
    '''
    
    indices = list(map(linear_sum_assignment, pairwise_cost.cpu()))
    indices = torch.LongTensor(np.array(indices))
    
    loss = 0
    for pi,(ri,ci) in zip(pairwise_cost,indices):
        loss += pi[ri,ci].sum()
    
    return indices 

In [8]:
def _convert_into_pytorch_tensors(event_images, object_images, n_objects, object_features, device):
    return torch.FloatTensor(event_images).to(device), \
               torch.FloatTensor(object_images).to(device), \
               torch.FloatTensor(n_objects).to(device), \
               torch.FloatTensor(object_features).to(device)

In [16]:
model = InvariantSlotAttention_ResNet(**hps).to(device)

In [17]:
Ntrain = 5000 
bs=32 
lr=3e-4
warmup_steps=5_000
alpha=1
losses = {'tot':[],'bce':[],'mse':[]}
kwargs={'isRing': True, 'N_clusters':2}
clip_val = 1
device='cpu'
plot_every=20 
save_every=1000
color='C0'
cmap='Blues'

In [18]:
# Learning rate schedule config
base_learning_rate = lr

opt = torch.optim.Adam(model.parameters(), base_learning_rate)
model.train()

k_slots = model.k_slots
max_n_rings = kwargs['N_clusters']
resolution = model.resolution
kwargs['device'] = device
N_obj = kwargs['N_clusters'] # pass to makeRing fct

In [19]:


start = len(losses)
for i in range(start,start+Ntrain):

    learning_rate = base_learning_rate * 0.5 * (1 + np.cos(np.pi * i / Ntrain))
    if i < warmup_steps:
        learning_rate *= (i / warmup_steps)
    
    opt.param_groups[0]['lr'] = learning_rate
    
    # make scclevr data
    rings = scclevr.RingsBinaryUniform(N_obj) # two rings per imagne
    event_images, object_images, n_objects, object_features =  rings.gen_events(bs)
    event_images, object_images, n_objects, object_features =  _convert_into_pytorch_tensors(event_images, object_images, n_objects, object_features, device)
    # assign shorter names
    X = event_images
    mask = object_images
    Y = object_features
    
    queries, att, Y_pred = model(X)
        
    # Reshape the target mask to be flat in the pixels (same shape as att)
    flat_mask = mask.reshape(-1,max_n_rings, np.prod(resolution))   
    with torch.no_grad():
        
        att_ext  = torch.tile(att.unsqueeze(2), dims=(1,1,max_n_rings,1)) 
        mask_ext = torch.tile(flat_mask.unsqueeze(1),dims=(1,k_slots,1,1)) 
        
        pairwise_cost = F.binary_cross_entropy(att_ext,mask_ext,reduction='none').mean(axis=-1)
        
        # pairwise_cost = comb_loss(att,flat_mask,Y,Y_pred,alpha)
        indices = hungarian_matching(pairwise_cost)

    # Apply the sorting to the predict
    bis=torch.arange(bs).to(device)
    indices=indices.to(device)
    # Loss calc
    slots_sorted = torch.cat([att[bis,indices[:,0,ri]].unsqueeze(1) for ri in range(max_n_rings)],dim=1)
    rings_sorted = torch.cat([flat_mask[bis,indices[:,1,ri]].unsqueeze(1) for ri in range(max_n_rings)],dim=1)
    l_bce = F.binary_cross_entropy(slots_sorted,rings_sorted,reduction='none').sum(axis=1).mean()
    
    Y_pred_sorted = torch.cat([Y_pred[bis,indices[:,0,ri]].unsqueeze(1) for ri in range(max_n_rings)],dim=1)
    Y_true_sorted = torch.cat([Y[bis,indices[:,1,ri]].unsqueeze(1) for ri in range(max_n_rings)],dim=1)

    l_mse = torch.nn.MSELoss(reduction='none')(Y_pred_sorted,Y_true_sorted).sum(axis=1).mean()

    # Calculate the loss
    li = l_bce + alpha*l_mse
    
    li.backward()
    clip_val=1
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_val)
    
    opt.step()
    opt.zero_grad()

    losses['tot'].append(float(li))
    losses['bce'].append(float(l_bce))
    losses['mse'].append(float(l_mse))
    
    if i % plot_every == 0:
        print('iter',i,', loss',li.detach().cpu().numpy(),', lr',opt.param_groups[0]['lr'])  
        iEvt = 0

        # losses, mask, att_img, Y_true, Y_pred
        plot_chosen_slots(losses,
                            mask[iEvt].sum(axis=0), 
                            slots_sorted[iEvt].reshape(max_n_rings,*resolution),
                            Y_true_sorted[iEvt],
                            Y_pred_sorted[iEvt])
        




input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])
input shape:  torch.Size([32, 1, 32, 32])


KeyboardInterrupt: 

Ok. It runs! Now try to adjust with ResNet!!

In [15]:
class mish(nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation, pre_act):
        super(ResidualBlock, self).__init__()
        if activation == 'mish':
            activation = mish()
        self.activation = activation
        self.pre_act = pre_act
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv_res = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, inputs, inputs_scaled):
        x = inputs
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.activation(y)
        y = self.conv2(y)
        y = self.bn2(y)
        if self.pre_act:
            y = self.activation(y)
            
        x = self.conv_res(x)
        x = x + y
        if not self.pre_act:
            x = self.activation(x)
            
        x = torch.cat((x, inputs_scaled), dim=1)
        return x

class Encoder_resnet_1(nn.Module):
    def __init__(self, nPixels=32, latent_dim=128, nMaxClusters=2, activation=mish(), use_vae=False, pre_act=False, filters=[16,16,32]):
        super(Encoder_resnet_1, self).__init__()
        #self.initial_conv = nn.Conv2d(3, filters[0], kernel_size=3, stride=1, padding=1)
        self.res_blocks = nn.ModuleList()
        self.activation = activation
        self.use_vae = use_vae
        
        self.res_block11 = ResidualBlock(1,16, activation, pre_act)
        self.res_block12 = ResidualBlock(17,16, activation, pre_act)
        self.res_block13 = ResidualBlock(17,16, activation, pre_act)
        self.res_block14 = ResidualBlock(17,16, activation, pre_act)
        self.res_block15 = ResidualBlock(17,16, activation, pre_act)

        self.res_block21 = ResidualBlock(17,16, activation, pre_act)
        self.res_block22 = ResidualBlock(17,16, activation, pre_act)
        self.res_block23 = ResidualBlock(17,16, activation, pre_act)
        self.res_block24 = ResidualBlock(17,16, activation, pre_act)
        self.res_block25 = ResidualBlock(17,16, activation, pre_act)

        self.res_block31 = ResidualBlock(17,32, activation, pre_act)
        self.res_block32 = ResidualBlock(33,32, activation, pre_act)
        self.res_block33 = ResidualBlock(33,32, activation, pre_act)
        self.res_block34 = ResidualBlock(33,32, activation, pre_act)
        self.res_block35 = ResidualBlock(33,32, activation, pre_act)
        
        self.pooling = nn.MaxPool2d(kernel_size=2)
        self.flatten = nn.Flatten()
        self.dense = nn.Linear((filters[-1]+1) * nPixels * nPixels // (2 ** len(filters))**2, 256, bias=False)
       
        self.bn_dense = nn.BatchNorm1d(256)
        #self.sampling_layer = Sampling()

        if use_vae:
            self.z_mean_full = nn.Linear(256, nMaxClusters*latent_dim)
            self.z_log_var_full = nn.Linear(256,nMaxClusters*latent_dim)
        else:
            self.z_full = nn.Linear(256, nMaxClusters*latent_dim)
            
            
        # Sara's adjustment ideas
        self.upsample = nn.ConvTranspose2d(33, 16, 8, stride=8, padding=0) # is this a smart thing to do?
        self.lastdense = nn.Linear(256, 16*32*32)

    def forward(self, inputs):
        x = inputs
        inputs_scaled = inputs
        print("input shape: ", inputs.shape)
        #inputs_scaled = F.interpolate(x, size=(nPixels, nPixels))

        inputs_scaled = F.interpolate(inputs_scaled, size=(x.shape[2], x.shape[3]))
        x = self.res_block11(x,inputs_scaled)
        x = self.res_block12(x,inputs_scaled)
        x = self.res_block13(x,inputs_scaled)
        x = self.res_block14(x,inputs_scaled)
        x = self.res_block15(x,inputs_scaled)  
        x = self.pooling(x)

        inputs_scaled = F.interpolate(inputs_scaled, size=(x.shape[2], x.shape[3]))
        x = self.res_block21(x,inputs_scaled)
        x = self.res_block22(x,inputs_scaled)
        x = self.res_block23(x,inputs_scaled)
        x = self.res_block24(x,inputs_scaled)
        x = self.res_block25(x,inputs_scaled)
        x = self.pooling(x)

        inputs_scaled = F.interpolate(inputs_scaled, size=(x.shape[2], x.shape[3]))
        x = self.res_block31(x,inputs_scaled)
        x = self.res_block32(x,inputs_scaled)
        x = self.res_block33(x,inputs_scaled)
        x = self.res_block34(x,inputs_scaled)
        x = self.res_block35(x,inputs_scaled)
        x = self.pooling(x)
        #x = self.upsample(x) # option 1... upsampling with CNN from [bs, 33, 4,4] to [bs, 16, 32, 32]
        #return x
        
        x = self.flatten(x)
        #print(x.shape)
        x = self.dense(x)
        
        x = self.bn_dense(x)
        x = self.activation(x)
        
        x = self.lastdense(x)
        
        x = x.reshape([inputs.shape[0], 16, 32, 32]) # horribly hard coded I know
        
        return x
        """
        if self.use_vae:
            z_mean_full = self.z_mean_full(x)
            z_log_var_full = self.z_log_var_full(x)
            z_full = self.sampling_layer(z_mean_full, z_log_var_full)
            return z_mean_full, z_log_var_full, z_full
        else:
            z_full = self.z_full(x)
            return z_full, z_full, z_full
        """


Ok, let's back off a bit? Implement torch network?