<a href="https://colab.research.google.com/github/louisemoelgaard/Dataproject/blob/main/slot_virker_upconvolution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [95]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import torchvision
from torchvision import datasets, models, transforms
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchvision.io import read_image
import matplotlib.pyplot as plt
import time
import os
import copy
import math
from PIL import Image

In [96]:
class SlotAttention(nn.Module):
    def __init__(self, num_classes, slots_per_class, dim, iters=3, eps=1e-8, vis=True, vis_id=0, loss_status=1, power=1, to_k_layer=1):
        super().__init__()
        self.num_classes = num_classes
        self.slots_per_class = slots_per_class
        self.num_slots = num_classes * slots_per_class
        self.iters = iters
        self.eps = eps
        self.scale = dim ** -0.5
        self.loss_status = loss_status

        slots_mu = nn.Parameter(torch.randn(1, 1, dim))
        slots_sigma = abs(nn.Parameter(torch.randn(1, 1, dim)))

        mu = slots_mu.expand(1, self.num_slots, -1)
        sigma = slots_sigma.expand(1, self.num_slots, -1)
        self.initial_slots = nn.Parameter(torch.normal(mu, sigma))

        self.to_q = nn.Sequential(
            nn.Linear(dim, dim),
        )
        to_k_layer_list = [nn.Linear(dim, dim)]
        for to_k_layer_id in range(1, to_k_layer):
            to_k_layer_list.append(nn.ReLU(inplace=True))
            to_k_layer_list.append(nn.Linear(dim, dim))
        
        self.to_k = nn.Sequential(
            *to_k_layer_list
        )
        self.gru = nn.GRU(dim, dim)

        self.vis = vis
        self.vis_id = vis_id
        self.power = power

    def forward(self, inputs, inputs_x):
        b, n, d = inputs.shape
        slots = self.initial_slots.expand(b, -1, -1)
        k, v = self.to_k(inputs), inputs

        for _ in range(self.iters):
            slots_prev = slots

            # q = self.to_q(slots)
            q = slots

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            dots = torch.div(dots, dots.sum(2).expand_as(dots.permute([2,0,1])).permute([1,2,0])) * dots.sum(2).sum(1).expand_as(dots.permute([1,2,0])).permute([2,0,1])# * 10
            attn = torch.sigmoid(dots)
            updates = torch.einsum('bjd,bij->bid', inputs_x, attn)
            updates = updates / inputs_x.size(2)
            self.gru.flatten_parameters()
            slots, _ = self.gru(
                updates.reshape(1, -1, d),
                slots_prev.reshape(1, -1, d)
            )

            slots = slots.reshape(b, -1, d)

            if self.vis:
                slots_vis = attn.clone()

        if self.vis:
            if self.slots_per_class > 1:
                new_slots_vis = torch.zeros((slots_vis.size(0), self.num_classes, slots_vis.size(-1)))
                for slot_class in range(self.num_classes):
                    new_slots_vis[:, slot_class] = torch.sum(torch.cat([slots_vis[:, self.slots_per_class*slot_class: self.slots_per_class*(slot_class+1)]], dim=1), dim=1, keepdim=False)
                slots_vis = new_slots_vis.to(updates.device)

            slots_vis = slots_vis[self.vis_id]
            slots_vis = ((slots_vis - slots_vis.min()) / (slots_vis.max()-slots_vis.min()) * 255.).reshape(slots_vis.shape[:1]+(int(slots_vis.size(1)**0.5), int(slots_vis.size(1)**0.5)))
            slots_vis = (slots_vis.cpu().detach().numpy()).astype(np.uint8)
            for id, image in enumerate(slots_vis):
                image = Image.fromarray(image, mode='L')
                image.save(f'slot_{id:d}.png')
            #print(self.loss_status*torch.sum(attn.clone(), dim=2, keepdim=False))
            #print(self.loss_status*torch.sum(updates.clone(), dim=2, keepdim=False))

        if self.slots_per_class > 1:
            new_updates = torch.zeros((updates.size(0), self.num_classes, updates.size(-1)))
            for slot_class in range(self.num_classes):
                new_updates[:, slot_class] = torch.sum(updates[:, self.slots_per_class*slot_class: self.slots_per_class*(slot_class+1)], dim=1, keepdim=False)
            updates = new_updates.to(updates.device)

        attn_relu = torch.relu(attn)
        slot_loss = torch.sum(attn_relu) / attn.size(0) / attn.size(1) / attn.size(2)# * self.slots_per_class

        return self.loss_status*torch.sum(updates, dim=2, keepdim=False), torch.pow(slot_loss, self.power)

class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list):
        x = tensor_list
        b, c, h, w = x.shape
        mask = torch.zeros((b, h, w), dtype=torch.bool, device=x.device)
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos.to(x.dtype)

class Identical(nn.Module):
    def __init__(self):
        super(Identical, self).__init__()

    def forward(self, x):
        return x

class SlotModel(nn.Module):
    def __init__(self, num_classes, slots_per_class, hidden_dim):
        super().__init__()

        #self.model = pre_model

            
        self.slots_per_class = slots_per_class
        self.conv1x1 = nn.Conv2d(256, hidden_dim, kernel_size=(1, 1), stride=(1, 1))
           
        self.slot = SlotAttention(num_classes, self.slots_per_class, hidden_dim)
        
        
        N_steps = hidden_dim // 2
        self.position_emb = PositionEmbeddingSine(N_steps, normalize=True)

    def forward(self, x):
        #x = self.model(x)
        
        x = self.conv1x1(x.view(x.size(0), 256, 14, 14))
        x = torch.relu(x)
        pe = self.position_emb(x)
        x_pe = x + pe
        
        b, n, r, c = x.shape
        x = x.reshape((b, n, -1)).permute((0, 2, 1))
        x_pe = x_pe.reshape((b, n, -1)).permute((0, 2, 1))
        x, attn_loss = self.slot(x_pe, x)
        output = F.log_softmax(x, dim=1)
        
        return output, attn_loss
        

In [97]:
def train(model, dataloader, nEpochs,optimizer, device):
    acc_history=[]
    loss_history=[]
    lambda_value = 1.
    for epoch in range(nEpochs):
        print(epoch)
        running_loss = 0.0
        running_corrects = 0
        ## Train
        model.train()
        cur_loss = 0
       
        for images, targets in dataloader:
            images = images.to(device)
            targets = targets.to(device)
            output, attn_loss = model(images) 
            
            #print(output.size(), targets.size())
            
            loss = F.nll_loss(output, targets) + lambda_value * attn_loss
            
            logits = output
            loss_list =  [loss, F.nll_loss(output, targets), attn_loss]# Forward pass
            loss = loss_list[0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            topv, topi = torch.topk(output,1)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(topi.view(topi.size(0)) == targets)
            
        epoch_loss = running_loss / len(dataloader.dataset)
        epoch_acc = running_corrects.double() / len(dataloader.dataset)
        acc_history.append(epoch_acc)
        loss_history.append(epoch_loss)
    
    return [acc_history, loss_history]

In [98]:
torchvision.datasets.CIFAR10(root='./data', download=True)

Files already downloaded and verified


Dataset CIFAR10
    Number of datapoints: 50000
    Root location: ./data
    Split: Train

In [99]:
import torch
import torchvision
import torchvision.transforms as transforms

# Transforms are common image transformations, that can be stacked and used for preprocessing images. 
# Here, our preprocessing consists of converting the data to torch tensors and normalizing the data 
# with 0.5 mean and 0.5 std. diviation for all three channels
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize([224,224]),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 150


# Set up the training set. The data-set helper function already implemented a data-split
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

# Set up the test set. The data-set helper function already implemented a data-split
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

# The CIFAR-10 Classes
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [100]:
device = "cuda" if torch.cuda.is_available() else "cpu"
resnet = models.resnet18(pretrained=True)
resnet.avgpool = resnet.avgpool = nn.ConvTranspose2d(512, out_channels = 256, kernel_size = (10,10), stride = (1,1), padding=(1, 1), bias=False)
resnet.fc = SlotModel(num_classes=10, slots_per_class=1, hidden_dim=512)

resnet.to(device)


for name, param in resnet.named_parameters():
    if "layer4" not in name:  
        param.requires_grad = False

params_to_update = []
for name, param in resnet.named_parameters():
    if param.requires_grad:
        params_to_update.append(param)

optimizer = torch.optim.Adam(params_to_update, lr=1e-4)        

rs = train(model=resnet, dataloader=trainloader, nEpochs=5, optimizer=optimizer, device=device)

0




1


KeyboardInterrupt: ignored

In [None]:
MODEL_PATH = 'cifar_net.pth'
torch.save(resnet.state_dict(), MODEL_PATH)

In [None]:
rs

In [None]:
#device = "cuda" if torch.cuda.is_available() else "cpu"
#resnet = models.resnet18(pretrained=True)

#Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
#resnet.avgpool = nn.ConvTranspose2d(512, out_channels = 256, kernel_size = (6,6), stride = (1,1), padding=(1, 1), bias=False)

#resnet.fc = Identical()

#resnet.fc = SlotModel(num_classes=10, slots_per_class=1, hidden_dim=512)

#resnet.to(device)


#for name, param in resnet.named_parameters():
#    if "layer4" not in name:  
#        param.requires_grad = False

#params_to_update = []
#for name, param in resnet.named_parameters():
   # if param.requires_grad:
 #       params_to_update.append(param)
#input = torch.rand(1,3,224,224).to(device)

#resnet(input).size()

In [None]:
import matplotlib.image as mpimg
def eval_attn(model, images, targets):
  model.eval()
  for image, target in zip(images, targets):
    image = image.view(1,image.size(0),image.size(1),image.size(2))

    out, att = model(image)
    out = out.cpu()
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    print(f'True class: {classes[target.item()]}')
    print(f'Predicted class: {classes[np.argmax(out.detach())]}')
    image = torchvision.utils.make_grid(image.cpu())
    img = image / 2 + 0.5     # unnormalize
    npimg = np.transpose(img.numpy(), (1, 2, 0))

    fig, axs = plt.subplots(2, 5)
    axs = axs.ravel()
    extent = 0,224,0,224
    for i in range(10):
      img = mpimg.imread(f'slot_{i}.png')
      axs[i].imshow(npimg, extent=extent)
      axs[i].imshow(img, alpha=0.5, extent=extent)
      axs[i].set_title(f'Why {classes[i]}')
    break

dataiter = iter(testloader)



In [None]:
image, target = dataiter.next()
image = image.to(device)
target = target.to(device)


eval_attn(resnet, image, target)