In [5]:
import yaml
#mport omegaconf
#from hydra.utils import instantiate

with open('conf/config.yaml') as f:
    cfg = yaml.safe_load(f)


In [3]:
import pytorch_lightning as pl
from torch import nn
import torch
import torch.nn.functional as F
import numpy as np

In [18]:
class LitSlotAttention(pl.LightningModule):
    def __init__(self, num_slots = 11, num_iterations = 3, resolution = (128,128), hid_dim=64, lang_dim=4, lr=0.0004,
                encoder = None, decoder = None, slot_attention = None, graph_module = None):

        super().__init__()
        self.num_slots = num_slots
        self.num_iterations = num_iterations
        self.hid_dim = hid_dim
        self.lang_dim = lang_dim
        self.lr = lr
        self.resolution = resolution

        self.encoder = encoder if encoder else Encoder(self.resolution, self.hid_dim)
        self.decoder = decoder if decoder else Decoder(self.hid_dim, self.resolution)
        self.SA = slot_attention if slot_attention else SlotAttention(
                                                        num_slots = self.num_slots,
                                                        dim = self.hid_dim,
                                                        iters = self.num_iterations,
                                                        eps = 1e-8, 
                                                        hidden_dim = 128,
                                                        language_dim = self.hid_dim)

        
        self.fc1 = nn.Linear(hid_dim, hid_dim)
        self.fc2 = nn.Linear(hid_dim, hid_dim)
        self.layer_norm = nn.LayerNorm((self.resolution[0]*self.resolution[1], self.hid_dim))


    def training_step(self, data, batch_idx):
        recon_combined, recons, masks, slots = self.forward(data)
        loss = nn.MSELoss(recon_combined, data['image'])
        self.log("train_loss", loss)
        return loss

    def validation_step(self, data, batch_idx):
        recon_combined, recons, masks, slots = self.forward(data)
        loss = nn.MSELoss(recon_combined, data['image'])
        self.log("val_loss", loss)
        return loss

   # def on_validation_epoch_end()     # log 6 example images
        # # or generated text... or whatever
        # sample_imgs = x[:6]
        # grid = torchvision.utils.make_grid(sample_imgs)
        # self.logger.experiment.add_image('example_images', grid, 0)

    def forward(self, data):
        # `image` has shape: [batch_size, num_channels, width, height].
        image, embeddings, mask = data['image'], data['objects'], data['mask']
        x = self.encoder(image)
        x = self.layer_norm(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        # `x` has shape: [batch_size, width*height, input_size].

        # Slot Attention module.
        slots = self.SA(x, embeddings, mask)
        # `slots` has shape: [batch_size, num_slots, slot_size].

        # """Broadcast slot features to a 2D grid and collapse slot dimension.""".
        slots = slots.reshape((-1, slots.shape[-1])).unsqueeze(1).unsqueeze(2)
        slots = slots.repeat((1, 8, 8, 1))
        
        # `slots` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
        x = self.decoder(slots)
        # `x` has shape: [batch_size*num_slots, width, height, num_channels+1].

        # Undo combination of slot and batch dimension; split alpha masks.
        recons, masks = x.reshape(image.shape[0], -1, x.shape[1], x.shape[2], x.shape[3]).split([3,1], dim=-1)
        # `recons` has shape: [batch_size, num_slots, width, height, num_channels].
        # `masks` has shape: [batch_size, num_slots, width, height, 1].

        # Normalize alpha masks over slots.
        masks = nn.Softmax(dim=1)(masks)
        recon_combined = torch.sum(recons * masks, dim=1)  # Recombine image.
        recon_combined = recon_combined.permute(0,3,1,2)
        # `recon_combined` has shape: [batch_size, width, height, num_channels].

        return recon_combined, recons, masks, slots
        
    
    def configure_optimizers(self):
        # Add Step LR
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    

In [13]:
class SlotAttention(nn.Module):
    def __init__(self, num_slots, dim, language_dim = 1024,iters = 3, eps = 1e-8, hidden_dim = 128):
        super().__init__()
        self.num_slots = num_slots
        self.iters = iters
        self.eps = eps
        self.scale = dim ** -0.5

        #self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
        #self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim))

        self.to_q = nn.Linear(dim, dim)
        self.to_k = nn.Linear(dim, dim)
        self.to_v = nn.Linear(dim, dim)

        self.gru = nn.GRUCell(dim, dim)

        hidden_dim = max(dim, hidden_dim)

        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)

        self.fc1_embed = nn.Linear(language_dim, hidden_dim)
        self.fc2_embed = nn.Linear(hidden_dim, dim)
        

        self.norm_input  = nn.LayerNorm(dim)
        self.norm_slots  = nn.LayerNorm(dim)
        self.norm_pre_ff = nn.LayerNorm(dim)

    def forward(self, inputs, embeddings, masks, num_slots = None):
        b, n, d = inputs.shape
        _, n_e, _ = embeddings.shape

        # check our objects array (including empty ones) passed in 
        # is equal to slots we are training with
        assert n_e == self.num_slots

        #n_s = num_slots if num_slots is not None else self.num_slots

        # mu = self.slots_mu.expand(b, n_s, -1)
        # sigma = self.slots_sigma.expand(b, n_s, -1)
        # slots = torch.normal(mu, torch.exp(sigma))


        slots = self.fc2_embed(F.relu(self.fc1_embed(embeddings)))
        inputs = self.norm_input(inputs)        
        k, v = self.to_k(inputs), self.to_v(inputs)
        for _ in range(self.iters):
            slots_prev = slots

            slots = self.norm_slots(slots)
            q = self.to_q(slots)

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            masked_dots = dots.masked_fill(masks.unsqueeze(-1), 0)
            attn = masked_dots.softmax(dim=1) + self.eps
            attn = attn / attn.sum(dim=-1, keepdim=True)

            updates = torch.einsum('bjd,bij->bid', v, attn)

            slots = self.gru(
                updates.reshape(-1, d),
                slots_prev.reshape(-1, d)
            )

            slots = slots.reshape(b, -1, d)
            slots = slots + self.fc2(F.relu(self.fc1(self.norm_pre_ff(slots))))

        return slots

In [14]:
"""
Utils
"""
import numpy as np
from torch import nn
import torch

def build_grid(resolution):
    ranges = [np.linspace(0., 1., num=res) for res in resolution]
    grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
    grid = np.stack(grid, axis=-1)
    grid = np.reshape(grid, [resolution[0], resolution[1], -1])
    grid = np.expand_dims(grid, axis=0)
    grid = grid.astype(np.float32)
    return torch.from_numpy(np.concatenate([grid, 1.0 - grid], axis=-1))

"""Adds soft positional embedding with learnable projection."""
class SoftPositionEmbed(nn.Module):
    def __init__(self, hidden_size, resolution):
        """Builds the soft position embedding layer.
        Args:
        hidden_size: Size of input feature dimension.
        resolution: Tuple of integers specifying width and height of grid.
        """
        super().__init__()
        self.embedding = nn.Linear(4, hidden_size, bias=True)
        self.register_buffer("grid", build_grid(resolution))
    def forward(self, inputs):
        grid = self.embedding(self.grid)
        return inputs + grid

In [15]:
"""
Encoder + Decoder Parts
"""

class Decoder(nn.Module):
    def __init__(self, hid_dim, resolution):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, stride=(2, 2), padding=3)
        self.conv2 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, stride=(2, 2), padding=2)
        self.conv3 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, stride=(2, 2), padding=2)
        self.conv4 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, stride=(2, 2), padding=3)
        self.conv5 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, stride=(1, 1), padding=1)
        self.conv6 = nn.ConvTranspose2d(hid_dim, 4, 3, stride=(1, 1), padding=1)
        self.decoder_initial_size = (8, 8)
        self.decoder_pos = SoftPositionEmbed(hid_dim, self.decoder_initial_size)
        self.resolution = resolution

    def forward(self, x):
        x = self.decoder_pos(x)
        x = x.permute(0,3,1,2)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.pad(x, (4,4,4,4))
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = self.conv5(x)
        x = F.relu(x)
        x = self.conv6(x)
        x = x[:,:,:self.resolution[0], :self.resolution[1]]
        x = x.permute(0,2,3,1)
        return x

class Encoder(nn.Module):
    def __init__(self, resolution, hid_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(3, hid_dim, 5, padding = 2)
        self.conv2 = nn.Conv2d(hid_dim, hid_dim, 5, padding = 2)
        self.conv3 = nn.Conv2d(hid_dim, hid_dim, 5, padding = 2)
        self.conv4 = nn.Conv2d(hid_dim, hid_dim, 5, padding = 2)
        self.encoder_pos = SoftPositionEmbed(hid_dim, resolution)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = F.relu(x)
        x = x.permute(0,2,3,1)
        x = self.encoder_pos(x)
        x = torch.flatten(x, 1, 2)
        return x

In [16]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
from dataset import CLEVR

train_set = CLEVR('train',priors='rel_words')
test_set = CLEVR('val',priors='rel_words') #consider this pseudo test set since labels for test not provided

train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=cfg["batch_size"],
                        shuffle=True, num_workers=cfg["num_workers"])

test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=cfg["batch_size"],
                        shuffle=True, num_workers=cfg["num_workers"])


In [20]:
batch = iter(train_dataloader).next()
batch['objects'].shape

torch.Size([128, 11, 5, 3, 50])

In [17]:
model = LitSlotAttention()
trainer = pl.Trainer(limit_train_batches=128, max_epochs=200, devices=4, accelerator='gpu', strategy='dp')
trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/felix/vrgym_SA/lightning/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name       | Type          | Params
---------------------------------------------
0 | encoder    | Encoder       | 312 K 
1 | decoder    | Decoder       | 514 K 
2 | SA         | SlotAttention | 71.0 K
3 | fc1        | Linear        | 4.2 K 
4 | fc2        | Linear        | 4.2 K 
5 | layer_norm | LayerNorm     | 2.1 M 
---------------------------------------------
3.0 M     Trainable params
0         Non-trainable params
3.0 M     Total params
12.016    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:01<?, ?it/s]

ValueError: Caught ValueError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/felix/SA-geo/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/felix/SA-geo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/felix/SA-geo/lib/python3.8/site-packages/pytorch_lightning/overrides/data_parallel.py", line 64, in forward
    output = super().forward(*inputs, **kwargs)
  File "/home/felix/SA-geo/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward
    return self.module.validation_step(*inputs, **kwargs)
  File "/tmp/ipykernel_448395/3839951270.py", line 36, in validation_step
    recon_combined, recons, masks, slots = self.forward(data)
  File "/tmp/ipykernel_448395/3839951270.py", line 58, in forward
    slots = self.SA(x, embeddings, mask)
  File "/home/felix/SA-geo/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/ipykernel_448395/1963388713.py", line 33, in forward
    _, n_e, _ = embeddings.shape
ValueError: too many values to unpack (expected 3)
