In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import torch.optim as optim
import os
import h5py
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_epochs = 1
from torch.utils.data import Dataset, DataLoader
dir = 'A2_dataset/'
batch_size = 64

In [2]:
class HDF5Dataset(Dataset):
    def __init__(self, hdf5_file):
        self.hdf5_file = hdf5_file
        
        # Open HDF5 file
        self.hdf5_handle = h5py.File(hdf5_file, 'r')
        self.images = self.hdf5_handle['images']
        self.masks = self.hdf5_handle['masks']
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        return (torch.tensor(self.images[idx], dtype=torch.float32), torch.tensor(self.masks[idx], dtype=torch.float32))

In [3]:
train_dataset = HDF5Dataset(dir + 'train_dataset.h5')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = HDF5Dataset(dir + 'val_dataset.h5')
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [4]:
class SlotAttention(nn.Module):
    def __init__(self, k, d_common=64, n_iter_train=3,n_iter_test=5, d_slot=64, d_inputs=64, hid_dim=128):
        super(SlotAttention, self).__init__()
        self.k = k
        self.d_common = d_common
        self.n_iter_train = n_iter_train
        self.n_iter_test = n_iter_test
        self.d_slot = d_slot
        self.d_inputs = d_inputs

        self.fc_q = nn.Linear(d_slot, d_common)
        self.fc_k = nn.Linear(d_inputs, d_common)
        self.fc_v = nn.Linear(d_inputs, d_common)

        self.gru = nn.GRUCell(d_common, d_slot)
        self.mlp = nn.Sequential(
            nn.Linear(d_slot, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, d_slot)
        )

        self.softmax = nn.Softmax(dim=2)
        self.mu = nn.Parameter(torch.randn(1, 1,d_common))
        self.sigma = nn.Parameter(torch.rand(1,1, d_common))

    
    def forward(self, inputs):
        # inputs: (batch_size, n_inputs, d_inputs)
        # slots: (batch_size, n_slots, d_slot)
        if self.training:
            n_iter = self.n_iter_train
        else:
            n_iter = self.n_iter_test
        batch_size, n_inputs, d_inputs = inputs.size()
        mu = self.mu.expand(batch_size, self.k, -1)
        sigma = self.sigma.expand(batch_size, self.k, -1)
        slots = torch.normal(mu, sigma).to(device)
        inputs = nn.LayerNorm(d_inputs).to(device)(inputs)
        k = self.fc_k(inputs)               # (batch_size, n_inputs, d_common)
        v = self.fc_v(inputs)               # (batch_size, n_inputs, d_common)
        for i in range(n_iter):
            q = self.fc_q(nn.LayerNorm(self.d_slot).to(device)(slots))                # (batch_size, n_slots, d_common)

            attn = torch.bmm(k, q.permute(0, 2, 1)) / np.sqrt(self.d_common)            # (batch_size, n_inputs, n_slots)
            attn = self.softmax(attn) +  1e-8                                           # (batch_size, n_inputs, n_slots)
            attn = attn / attn.sum(dim=1, keepdim=True)                                 # (batch_size, n_inputs, n_slots)
            attn = attn.permute(0,2,1)
            updates = torch.einsum('bjd,bij->bid', v, attn)                             # (batch_size, n_slots, d_common)


            slots = self.gru(updates.reshape(-1,self.d_common), slots.reshape(-1, self.d_slot)).reshape(batch_size, self.k, self.d_slot)
            slots = nn.LayerNorm(self.d_slot).to(device)(slots)
            slots = slots + self.mlp(slots)
        
        return slots



In [5]:
class PositionalEmbeddings(nn.Module):
    def __init__(self, H, W, hid_dim=64):
        super(PositionalEmbeddings, self).__init__()
        self.H = H
        self.W = W
        self.hid_dim = hid_dim
        self.project = nn.Linear(4, hid_dim)
    
    def construct_grid(self, H, W):
        x = torch.linspace(0, 1, W).unsqueeze(0).repeat(H, 1)
        y = torch.linspace(0, 1, H).unsqueeze(1).repeat(1, W)
        return torch.stack([x, 1-x, y, 1-y], dim=2)    # (H, W, 4)


    def forward(self, inputs):
        grid = self.construct_grid(self.H, self.W).to(device)  # (H, W, 4)
        grid = self.project(grid)
        return inputs + grid.unsqueeze(0).expand(inputs.size(0), self.H, self.W, self.hid_dim)

In [6]:
class CNNEncoder(nn.Module):
    def __init__(self, hid_dim=64):
        super(CNNEncoder, self).__init__()
        self.conv1 = nn.Conv2d(3, hid_dim, 5, padding=2)                    
        self.conv2 = nn.Conv2d(hid_dim, hid_dim, 5, padding=2)
        self.conv3 = nn.Conv2d(hid_dim, hid_dim, 5, padding=2)
        self.conv4 = nn.Conv2d(hid_dim, hid_dim, 5, padding=2)

        self.positionalEmb = PositionalEmbeddings(128, 128, hid_dim)
        self.relu = nn.ReLU()

        self.fc1 = nn.Linear(hid_dim, hid_dim)
        self.fc2 = nn.Linear(hid_dim, hid_dim)  


    def forward(self, inputs):
        inputs = self.conv1(inputs)
        inputs = self.relu(inputs)
        inputs = self.conv2(inputs)
        inputs = self.relu(inputs)
        inputs = self.conv3(inputs)
        inputs = self.relu(inputs)
        inputs = self.conv4(inputs)
        inputs = self.relu(inputs)
        
        inputs = self.positionalEmb(inputs.permute(0, 2, 3, 1))
        inputs = inputs.flatten(1, 2)
        inputs = nn.LayerNorm(inputs.size()[1:])(inputs) 
        inputs = self.fc1(inputs)
        inputs = self.relu(inputs)
        inputs = self.fc2(inputs)
        return inputs
           

In [7]:
class deconvDecoder(nn.Module):
    def __init__(self, hid_dim=64):
        super(deconvDecoder, self).__init__()
        self.hid_dim = hid_dim
        self.deconv1 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, padding=2, output_padding=1, stride=2)
        self.deconv2 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, padding=2, output_padding=1, stride=2)
        self.deconv3 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, padding=2, output_padding=1, stride=2)
        self.deconv4 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, padding=2, output_padding=1, stride=2)
        self.deconv5 = nn.ConvTranspose2d(hid_dim, hid_dim, 5, padding=2, output_padding=0, stride=1)
        self.deconv6 = nn.ConvTranspose2d(hid_dim, 4, 3, padding=1, output_padding=0, stride=1)

        self.relu = nn.ReLU()
        self.positionalEmb = PositionalEmbeddings(8, 8, hid_dim)

    def forward(self, slots):
        # slots: (batch_size, n_slots, d_slot)
        b, k, d = slots.size()
        slots = slots.unsqueeze(2).unsqueeze(3).expand(b, k, 8, 8, d)
        slots = slots.reshape(b*k, 8, 8, d)
        slots = self.positionalEmb(slots)
        slots = slots.permute(0, 3, 1, 2)
        slots = self.deconv1(slots)
        slots = self.relu(slots)
        slots = self.deconv2(slots)
        slots = self.relu(slots)
        slots = self.deconv3(slots)
        slots = self.relu(slots)
        slots = self.deconv4(slots)
        slots = self.relu(slots)
        slots = self.deconv5(slots)
        slots = self.relu(slots)
        slots = self.deconv6(slots)                 # (batch_size * n_slots, 4, 128, 128)

        slots = slots.reshape(b, k, 4, 128, 128)
        slots = slots.permute(0, 1, 3, 4, 2)
        contents, masks = slots.split([3, 1], dim=4)        # (batch_size, n_slots, 128, 128, 3)
        masks = nn.Softmax(dim=1)(masks)                    # (batch_size, n_slots, 128, 128, 1)
        img = (contents * masks).sum(dim=1)
        img = img.permute(0, 3, 1, 2)

        return img

        

In [8]:
class SlotAttentionModel(nn.Module):
    def __init__(self, k, d_common=64, n_iter_train=3, n_iter_test=5, d_slot=64, d_inputs=64, hid_dim=64):
        super(SlotAttentionModel, self).__init__()
        self.encoder = CNNEncoder(hid_dim)
        self.slotAttention = SlotAttention(k, d_common, n_iter_train, n_iter_test, d_slot, d_inputs, hid_dim)
        self.decoder = deconvDecoder(hid_dim)
    
    def forward(self, inputs):
        features = self.encoder(inputs)
        slots = self.slotAttention(features)
        img = self.decoder(slots)
        return img

In [9]:
SAM = SlotAttentionModel(4, 64, 3, 5, 64, 64, 64).to(device)
criterion = nn.MSELoss()
init_lr = 0.0004
optimizer = optim.Adam(SAM.parameters(), lr=init_lr)


warmup_iters = 10000
decay_steps = 10000
decay_rate = 0.5


  from .autonotebook import tqdm as notebook_tqdm


In [10]:
torch.autograd.set_detect_anomaly(True)
def train(model, criterion, optimizer, train_loader, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (images, masks) in enumerate(train_loader):
            images = images.to(device)
            masks = masks.to(device)
            if i < warmup_iters:
                learning_rate = init_lr*(i/warmup_iters)
            else:
                learning_rate = init_lr
            learning_rate = learning_rate * (decay_rate ** (i/decay_steps))
            optimizer.param_groups[0]['lr'] = learning_rate
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, images)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 10 == 0:
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
                running_loss = 0.0
    print('Finished Training')

In [11]:
train(SAM, criterion, optimizer, train_loader, n_epochs)

[1,     1] loss: 2480.615
Finished Training
