In [None]:
# For colab

#!pip install dgl-cu100
#!pip install scipy --upgrade

In [None]:
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import dgl
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

### Homework 4: Attention Mechanisms

This is a key exercise for learning transformers, but in this case we will do it with graph neural networks.

The goals are:

1. Learn about heterogeneous graphs in DGL (graphs with multiple types of nodes and edges)
2. Implement key - query attention
3. Learn about slot attention and permutation invariant loss

<b> The task is object detection around a cloud of points. </b>

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
if not os.path.exists('Dataset.zip'):
    !wget https://www.dropbox.com/s/qrivkcb50yliez9/Dataset.zip

In [None]:
if not os.path.exists('Dataset'):
    !unzip Dataset.zip

In [None]:
# Already implemented

from dataloader import RandomShapeDataset, collate_graphs,plot_graph

In [None]:
dataset = RandomShapeDataset('Dataset/training.bin')
validation_ds = RandomShapeDataset('Dataset/validation.bin')

Your input information is a set of points and their positions. 

You want to identify how many clusters there are and for each cluster you want to draw a box around it!

In [None]:
# Target for training
fig,ax = plt.subplots(4,4,figsize=(8,8),dpi=100)

for i in range(4):
    for j in range(4):
        ax_i = ax[i][j]
        g = validation_ds[np.random.randint(len(validation_ds))]
        plot_graph(g,ax_i,size=0.2)


plt.tight_layout()
plt.show()

In [None]:
dataset[0]

#### How is this represented on our graph?

1. Each node store a dictionary with (objects, points and predicted objects)
2. You have different edges (the points to predicted object, the predicted objects to target)

<img src="structure.jpeg" width="800" height="400">

The points are the blue cloud, the objects are your target. Each bounding box is represented by four numbers (2 coordinates for the center, height and width of the box). 

The predicted objects are element of the graph where we will store our prediction and afterwards compare it to the target object.

In [None]:
from torch.utils.data import Dataset, DataLoader

data_loader = DataLoader(dataset, batch_size=300, shuffle=True,
                         collate_fn=collate_graphs)

valid_data_loader = DataLoader(validation_ds, batch_size=300, shuffle=False,
                         collate_fn=collate_graphs)

In [None]:
for batched_g in data_loader:
    break

In [None]:
batched_g

In [None]:
#Number of clusters (you need to specify what tiy want to see)
batched_g.batch_num_nodes('objects')

In [None]:
#Since it's an heterogeneous graph we need to specify which nodes are we dealing with 
#Before we were using batched_g.ndata[...], now this is not possible, because we have an hetereogenous graph!
batched_g.nodes['points'].data

## The model

It is based on Object-Centric Learning with Slot Attention, https://arxiv.org/abs/2006.15055.

We start with an array of points. First, we want to pass these through a DeepSet (like exercise 3, part 1). This will produce a global representation for the graph and hidden representation for the nodes.

On one side, we use this global representation to make a prediction on how many objects (clusters) there are in the cloud of points (classification problem, 2/3/4 objects). On the other side, we ignore the set size prediction and we CHEAT during the training. We initialize the predicted objects to be the same amount of the real objects. 

Then we have a slot attention part (figure below, with key, value and query). The key and value come from the points, while the query come from the objects. We do the dot product of the key and the query. We do a weighted sum of the values for each one of the objects. We put all of this through a GRU cell (recurrent network). 

This creates an updated hidden representation of the predicted objects that captures more features about our data.

The last part is a simple FC network to predict the box boundaries (center, width and height).

<img src="model_1.jpeg" width="800" height="400">

<img src="model_2.jpeg" width="800" height="400">

After running the DeepSet you have to:

1. Create the size prediction
2. Create the object prediction (center, width and height) in case of training

In [None]:
from model import Net

In [None]:
net = Net()

In [None]:
for batched_g in data_loader:
    break

In [None]:
net.train();

In [None]:
# Predicted graph and prediction of how many clusters there are..
predicted_g, size_pred = net(batched_g)

In [None]:
predicted_g, size_pred.shape

In [None]:
predicted_g.nodes['objects'].datata['']

In [None]:
predicted_g.nodes['predicted objects'].data['properties'].shape

## Permutation invariant loss

We need to compute two different losses in order to train our network:

1. The loss for the object boundaries

    https://en.wikipedia.org/wiki/Hungarian_algorithm

    The loss computation has to take into account the fact that there is no order to the output. I can predict the objects boundaries in whatever order I want, and the loss should not be affected by this.
    

2. The loss for the size prediction (a simple CrossEntropyLoss)

In [None]:
# Already implemented.. have a look!

from loss import Set2SetLoss

In [None]:
loss_func = Set2SetLoss()

In [None]:
loss_func(batched_g)

In [None]:
# Loss for the size prediction, a classical classification task

size_loss_func = nn.CrossEntropyLoss()

In [None]:
size_loss_func(size_pred, batched_g.batch_num_nodes('objects')-2)

### Training the objects prediction

The idea is to first train the bounding boxes, since they do not care about the size prediction (we give it to the network). 

Aftwerwards, we will freeze all these weights and train only the size prediction!

In [None]:
net = Net()

In [None]:
dataset = RandomShapeDataset('Dataset/training.bin')
validation_ds = RandomShapeDataset('Dataset/validation.bin')

data_loader = DataLoader(dataset, batch_size=300, shuffle=True,
                         collate_fn=collate_graphs)

valid_data_loader = DataLoader(validation_ds, batch_size=300, shuffle=False,
                         collate_fn=collate_graphs)

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.0005) 

In [None]:
if torch.cuda.is_available():
    net.cuda()

In [None]:
training_loss_vs_epoch = []
validation_loss_vs_epoch = []

In [None]:
# I run it on colab

if torch.cuda.is_available():
    
    n_epochs = 400 #it takes a while.. like 2 hours!

    for epoch in range(n_epochs): 

        if len(validation_loss_vs_epoch) > 0:

            print(epoch, 'train loss',training_loss_vs_epoch[-1],'validation loss',validation_loss_vs_epoch[-1])

        net.train() # put the net into "training mode"

        epoch_loss = 0
        n_batches = 0
        for batched_g in tqdm(data_loader):
            n_batches+=1

            if torch.cuda.is_available():
                batched_g = batched_g.to(torch.device('cuda'))

            optimizer.zero_grad()

            predicted_g,size_pred = net(batched_g)

            loss = loss_func(batched_g) 

            epoch_loss+=loss.item()

            loss.backward()
            optimizer.step()

        epoch_loss = epoch_loss/n_batches
        training_loss_vs_epoch.append(epoch_loss)

        net.eval()
        with torch.no_grad():
            epoch_loss = 0
            n_batches = 0
            for batched_g in tqdm(valid_data_loader):
                n_batches+=1

                if torch.cuda.is_available():
                    batched_g = batched_g.to(torch.device('cuda'))

                predicted_g,size_pred = net(batched_g,use_target_size=True)

                loss = loss_func(batched_g) 

                epoch_loss+=loss.item()

            epoch_loss = epoch_loss/n_batches
            validation_loss_vs_epoch.append(epoch_loss)

        if len(validation_loss_vs_epoch)==1 or np.amin(validation_loss_vs_epoch[:-1]) > validation_loss_vs_epoch[-1]:
            torch.save(net.state_dict(), 'trained_model.pt')

In [None]:
if torch.cuda.is_available():

    plt.plot(training_loss_vs_epoch)
    plt.plot(validation_loss_vs_epoch)

In [None]:
#!cp trained_model.pt trained_model_objects.pt #making a copy in case something goes wrong

In [None]:
net.cpu()
net.load_state_dict(torch.load('trained_model.pt',map_location='cpu'))

### Results of first training 

I check the results without training the size prediction.

In [None]:
net.eval()
net.cpu()
predicted_sizes = []
for batched_g in valid_data_loader:
    predicted_g,size_pred = net(batched_g)
    
    predicted_sizes+=list(torch.argmax(size_pred,dim=1).cpu().data.numpy())
    
predicted_sizes = np.array(predicted_sizes)+2

In [None]:
target_sizes = np.array([validation_ds[i].num_nodes('objects') for i in range(len(validation_ds))])

In [None]:
# We are not training the size prediction yet
# In this plot you will see how the predicted labels differ from the true labels

cm = confusion_matrix(target_sizes, predicted_sizes)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['2','3','4'])
disp.plot()

In [None]:
# I randomly select a validation graph to check (you can change this)

idxValidation = 90

In [None]:
g = validation_ds[idxValidation].cpu()

net.eval()
predicted_g,size_pred = net(g) 

In [None]:
predicted_g.num_nodes('predicted objects')

In [None]:
x = g.nodes['points'].data['xy'][:,0].data.numpy()
y = g.nodes['points'].data['xy'][:,1].data.numpy()
object_centers = g.nodes['objects'].data['centers'].data.numpy()

object_width = g.nodes['objects'].data['width'].data.numpy()
object_height = g.nodes['objects'].data['height'].data.numpy()

predicted_heights = predicted_g.nodes['predicted objects'].data['properties'][:,0].data.numpy()
predicted_widths = predicted_g.nodes['predicted objects'].data['properties'][:,1].data.numpy()
predicted_centers = predicted_g.nodes['predicted objects'].data['properties'][:,[2,3]].data.numpy()

attn_weights = predicted_g.edges['points_to_object'].data['attention weights'].cpu().data.numpy()

In [None]:
n_points = predicted_g.num_nodes('points')
n_objects = predicted_g.num_nodes('predicted objects')
n_objects

In [None]:
estart = predicted_g.edges(etype='points_to_object')[0].cpu().data.numpy()
eend =  predicted_g.edges(etype='points_to_object')[1].cpu().data.numpy()

weight_dict = {i:{} for i in range(n_objects)}
for e_i,(es,ee) in enumerate(zip(estart,eend)):
    weight_dict[ee][es] = attn_weights[e_i]

In [None]:
fig,ax = plt.subplots(2,n_objects,figsize=(3*n_objects,6),dpi=100)

for i in range(2):
    ax[0][i].scatter(x,y,c='cornflowerblue',cmap='tab10',s=3)

ax[0][1].scatter(predicted_centers[:,0],predicted_centers[:,1],c='r',cmap='tab10',s=30,ec='k')
ax[0][0].scatter(object_centers[:,0],object_centers[:,1],c='r',marker='o',s=30,ec='k')

for i in range(len(object_height)):
    
    bounding_box = patches.Rectangle((object_centers[i][0]-object_width[i]/2, object_centers[i][1]-object_height[i]/2), 
                             object_width[i], object_height[i], linewidth=1, edgecolor='r', facecolor='none')

    ax[0][0].add_patch(bounding_box) 
    
for i in range(len(predicted_centers)):
    
    bounding_box = patches.Rectangle((predicted_centers[i][0]-predicted_widths[i]/2, 
                                          predicted_centers[i][1]-predicted_heights[i]/2), 
                             predicted_widths[i], predicted_heights[i], linewidth=1, 
                                         edgecolor='darkgreen', facecolor='none')

    ax[0][1].add_patch(bounding_box)

for i in range(2):
    ax[0][i].set_xlim(-1,1)
    ax[0][i].set_ylim(-1,1)

for object_idx in range(n_objects):
    object_attn_weights = []

    for point_i in range(n_points):
        object_attn_weights.append(weight_dict[object_idx][point_i])


    object_attn_weights = torch.softmax(torch.tensor(object_attn_weights),dim=0).data.numpy()

    ax[1][object_idx].scatter(x,y,s=0.2,alpha=0.2)
    ax[1][object_idx].scatter(x,y,s=300.0*object_attn_weights,alpha=0.8,c=object_attn_weights,cmap='Reds')

    ax[1][object_idx].set_xlim(-1,1)
    ax[1][object_idx].set_ylim(-1,1)
    
    bounding_box = patches.Rectangle((predicted_centers[object_idx][0]-predicted_widths[object_idx]/2, 
                                          predicted_centers[object_idx][1]-predicted_heights[object_idx]/2), 
                             predicted_widths[object_idx], predicted_heights[object_idx], linewidth=1, 
                                         edgecolor='r', facecolor='none')


    ax[1][object_idx].scatter(predicted_centers[:,0][object_idx],
                           predicted_centers[:,1][object_idx],c='r',marker='o',s=30,ec='k')
    ax[1][object_idx].add_patch(bounding_box)
    
plt.tight_layout()
plt.show()

If you did everything correct, the model will be able to create correctly boxes around the different clusters. We can notice that, since the size prediction has not been trained, the number of clusters found is wrong.

The top left plot corresponds to our target, while the top right plot to our prediction.
The bottom plots reflect the slot attention mechanism. Each of the predicted objects should pay attention to the part of points within the box. 

### Training the size prediction

Now I can freeze everything and only focus on training the size prediction.

In [None]:
training_loss_vs_epoch = []
validation_loss_vs_epoch = []

In [None]:
dataset = RandomShapeDataset('Dataset/training.bin')
validation_ds = RandomShapeDataset('Dataset/validation.bin')

data_loader = DataLoader(dataset, batch_size=300, shuffle=True,
                         collate_fn=collate_graphs)

valid_data_loader = DataLoader(validation_ds, batch_size=300, shuffle=False,
                         collate_fn=collate_graphs)

In [None]:
# I loop over the network, and unless is 'size_predictor', I freeze the weights

for p_name, p in net.named_parameters():
    if 'size_predictor' not in p_name:
        p.requires_grad = False

In [None]:
optimizer = optim.Adam(net.size_predictor.parameters(), lr=0.001) 

In [None]:
if torch.cuda.is_available():
    net.cuda()

In [None]:
if torch.cuda.is_available():
    
    n_epochs = 180

    for epoch in range(n_epochs): 

        if len(validation_loss_vs_epoch) > 0:

            print(epoch, 'train loss',training_loss_vs_epoch[-1],'validation loss',validation_loss_vs_epoch[-1])

        net.train() # put the net into "training mode"

        epoch_loss = 0
        n_batches = 0
        for batched_g in tqdm(data_loader):
            n_batches+=1

            if torch.cuda.is_available():
                batched_g = batched_g.to(torch.device('cuda'))

            optimizer.zero_grad()

            predicted_g,size_pred = net(batched_g)


            loss = size_loss_func(size_pred, batched_g.batch_num_nodes('objects')-2 )

            epoch_loss+=loss.item()

            loss.backward()
            optimizer.step()

        epoch_loss = epoch_loss/n_batches
        training_loss_vs_epoch.append(epoch_loss)

        net.eval()
        with torch.no_grad():
            epoch_loss = 0
            n_batches = 0
            for batched_g in tqdm(valid_data_loader):
                n_batches+=1

                if torch.cuda.is_available():
                    batched_g = batched_g.to(torch.device('cuda'))

                predicted_g,size_pred = net(batched_g,use_target_size=True)

                loss = size_loss_func(size_pred, batched_g.batch_num_nodes('objects')-2 )

                epoch_loss+=loss.item()

            epoch_loss = epoch_loss/n_batches
            validation_loss_vs_epoch.append(epoch_loss)

        if len(validation_loss_vs_epoch)==1 or np.amin(validation_loss_vs_epoch[:-1]) > validation_loss_vs_epoch[-1]:
            torch.save(net.state_dict(), 'trained_model.pt')

In [None]:
if torch.cuda.is_available():
    
    plt.plot(training_loss_vs_epoch)
    plt.plot(validation_loss_vs_epoch)

In [None]:
net.cpu()
net.load_state_dict(torch.load('trained_model.pt',map_location='cpu'))

### Results with everything trained

In [None]:
net.eval()
net.cpu()
predicted_sizes = []
for batched_g in valid_data_loader:
    predicted_g,size_pred = net(batched_g)
    
    predicted_sizes+=list(torch.argmax(size_pred,dim=1).cpu().data.numpy())
    
predicted_sizes = np.array(predicted_sizes)+2

In [None]:
target_sizes = np.array([validation_ds[i].num_nodes('objects') for i in range(len(validation_ds))])

In [None]:
cm = confusion_matrix(target_sizes, predicted_sizes)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['2','3','4'])
disp.plot()

In [None]:
g = validation_ds[idxValidation].cpu()

net.eval()
predicted_g, size_pred = net(g) 

In [None]:
predicted_g.num_nodes('predicted objects')

In [None]:
x = g.nodes['points'].data['xy'][:,0].data.numpy()
y = g.nodes['points'].data['xy'][:,1].data.numpy()
object_centers = g.nodes['objects'].data['centers'].data.numpy()

object_width = g.nodes['objects'].data['width'].data.numpy()
object_height = g.nodes['objects'].data['height'].data.numpy()

predicted_heights = predicted_g.nodes['predicted objects'].data['properties'][:,0].data.numpy()
predicted_widths = predicted_g.nodes['predicted objects'].data['properties'][:,1].data.numpy()
predicted_centers = predicted_g.nodes['predicted objects'].data['properties'][:,[2,3]].data.numpy()

attn_weights = predicted_g.edges['points_to_object'].data['attention weights'].cpu().data.numpy()

In [None]:
n_points = predicted_g.num_nodes('points')
n_objects = predicted_g.num_nodes('predicted objects')
n_objects

In [None]:
estart = predicted_g.edges(etype='points_to_object')[0].cpu().data.numpy()
eend =  predicted_g.edges(etype='points_to_object')[1].cpu().data.numpy()

weight_dict = {i:{} for i in range(n_objects)}
for e_i,(es,ee) in enumerate(zip(estart,eend)):
    weight_dict[ee][es] = attn_weights[e_i]

In [None]:
fig,ax = plt.subplots(2,n_objects,figsize=(3*n_objects,6),dpi=100)

for i in range(2):
    ax[0][i].scatter(x,y,c='cornflowerblue',cmap='tab10',s=3)

ax[0][1].scatter(predicted_centers[:,0],predicted_centers[:,1],c='r',cmap='tab10',s=30,ec='k')
ax[0][0].scatter(object_centers[:,0],object_centers[:,1],c='r',marker='o',s=30,ec='k')

for i in range(len(object_height)):
    
    bounding_box = patches.Rectangle((object_centers[i][0]-object_width[i]/2, object_centers[i][1]-object_height[i]/2), 
                             object_width[i], object_height[i], linewidth=1, edgecolor='r', facecolor='none')


    ax[0][0].add_patch(bounding_box) 
    
for i in range(len(predicted_centers)):
    
    bounding_box = patches.Rectangle((predicted_centers[i][0]-predicted_widths[i]/2, 
                                          predicted_centers[i][1]-predicted_heights[i]/2), 
                             predicted_widths[i], predicted_heights[i], linewidth=1, 
                                         edgecolor='darkgreen', facecolor='none')

    ax[0][1].add_patch(bounding_box)

for i in range(2):
    ax[0][i].set_xlim(-1,1)
    ax[0][i].set_ylim(-1,1)

for object_idx in range(n_objects):
    object_attn_weights = []

    for point_i in range(n_points):
        object_attn_weights.append(weight_dict[object_idx][point_i])

    object_attn_weights = torch.softmax(torch.tensor(object_attn_weights),dim=0).data.numpy()

    ax[1][object_idx].scatter(x,y,s=0.2,alpha=0.2)
    ax[1][object_idx].scatter(x,y,s=300.0*object_attn_weights,alpha=0.8,c=object_attn_weights,cmap='Reds')

    ax[1][object_idx].set_xlim(-1,1)
    ax[1][object_idx].set_ylim(-1,1)
    
    bounding_box = patches.Rectangle((predicted_centers[object_idx][0]-predicted_widths[object_idx]/2, 
                                          predicted_centers[object_idx][1]-predicted_heights[object_idx]/2), 
                             predicted_widths[object_idx], predicted_heights[object_idx], linewidth=1, 
                                         edgecolor='r', facecolor='none')


    ax[1][object_idx].scatter(predicted_centers[:,0][object_idx],
                           predicted_centers[:,1][object_idx],c='r',marker='o',s=30,ec='k')
    ax[1][object_idx].add_patch(bounding_box)
    
plt.tight_layout()
plt.show()

Now the model should correctly perform the task!!!! :)

In [None]:
from test_homework import *

In [None]:
test_homework()