In [156]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
import torchvision.models as models

import math

import os

from torch.utils.data import random_split, DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

import pickle

In [157]:
dataset = pickle.load(open("../data/image/MNIST/MNIST_SLIC_graph_28_16_0p5.pkl", "rb"))

In [189]:
def resize_stack_slic_graph_patches(data, size):
    r = Resize(size)
    
    for g in data:
        g.imgs = [torch.Tensor(img).unsqueeze(0) for img in g.imgs]
        g.imgs = [r(img) for img in g.imgs]
        g.imgs = torch.cat(g.imgs, dim=0).unsqueeze(1)
    
    return data

dataset = resize_stack_slic_graph_patches(dataset, (7, 7))

In [175]:
from torch.nn.utils.rnn import pad_sequence

def collate_slic_graph_patches(batch):
    lengths = torch.tensor([len(g.imgs) for g in batch])
    max_len = torch.max(lengths)
    mask = torch.arange(max_len).expand(len(lengths), max_len) >= lengths.unsqueeze(1)
    
    imgs = pad_sequence([g.imgs for g in batch], batch_first=True)
    coords = pad_sequence([g.centroid for g in batch], batch_first=True)
    
    y = torch.tensor([g.y for g in batch], dtype=torch.long)

    return imgs, coords, mask, y

train_dataset, test_dataset = random_split(dataset, [.9, .1])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_slic_graph_patches)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_slic_graph_patches)

next(iter(train_loader))[2][:,-5:]

tensor([[False, False, False, False,  True],
        [False, False, False, False,  True],
        [False, False, False, False,  True],
        [False, False, False, False,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False,  True],
        [False, False, False,  True,  True],
        [False, False, False, False, False],
        [False, False, False, False,  True],
        [False, False,  True,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False,  True],
        [False, False, False, False,  True],
        [False, False, False, False,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [F

In [187]:
class PositionalEncoding2D(nn.Module):
    def __init__(self, dim, max_len=1000):
        super(PositionalEncoding2D, self).__init__()
        self.dim = dim
        
        self.pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)
    
        # self.register_buffer("pe_x", pe.clone()) # TODO do we need this?
        # self.register_buffer("pe_y", pe.clone())
    
    def forward(self, coords): # TODO forces ints
        pe_x = self.pe[coords[..., 0]]
        pe_y = self.pe[coords[..., 1]]
        
        return pe_x + pe_y
    
    def to(self, device):
        super(PositionalEncoding2D, self).to(device)
        self.pe = self.pe.to(device)
        return self


class VisionEncoder(nn.Module):
    r"""Vision Encoder Model

        An Encoder Layer with the added functionality to encode important local structures of a tokenized image

        Args:
            embed_size      (int): Embedding Size of Input
            num_heads       (int): Number of heads in multi-headed attention
            hidden_size     (int): Number of hidden layers
            dropout         (float, optional): A probability from 0 to 1 which determines the dropout rate

    """

    def __init__(self, embed_size: int, num_heads: int, hidden_size: int, dropout: float = 0.1):
        super(VisionEncoder, self).__init__()

        self.embed_size = embed_size
        self.num_heads = num_heads
        # self.hidden_size = hidden_size
        self.dropout = dropout

        self.norm1 = nn.LayerNorm(self.embed_size)
        self.norm2 = nn.LayerNorm(self.embed_size)
        
        self.attention = nn.MultiheadAttention(self.embed_size, self.num_heads, dropout=dropout)

        self.mlp = nn.Sequential(
            nn.Linear(self.embed_size, 4 * self.embed_size),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(4 * self.embed_size, self.embed_size),
            nn.Dropout(self.dropout)
        )

    def forward(self, x, mask=None):
        x = self.norm1(x)
        
        x = x.transpose(0, 1)
        attn, _ = self.attention(x, x, x, key_padding_mask=mask)
        x = x + attn
        x = x.transpose(0, 1)
        
        x = x + self.mlp(self.norm2(x))
        
        return x


class CoordViT(nn.Module):
    r"""Vision Transformer Model

        A transformer model to solve vision tasks by treating images as sequences of tokens.

        Args:
            image_size      (int): Size of input image
            channel_size    (int): Size of the channel
            patch_size      (int): Max patch size, determines number of split images/patches and token size
            embed_size      (int): Embedding size of input
            num_heads       (int): Number of heads in Multi-Headed Attention
            classes         (int): Number of classes for classification of data
            hidden_size     (int): Number of hidden layers
            dropout         (float, optional): A probability from 0 to 1 which determines the dropout rate

    """

    def __init__(self, image_size: int, channel_size: int, patch_size: int, embed_size: int, num_heads: int,
                 classes: int, num_layers: int, hidden_size: int, dropout: float = 0.1):
        super(CoordViT, self).__init__()

        self.p = patch_size
        self.image_size = image_size
        self.embed_size = embed_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_size = channel_size * (patch_size ** 2)
        self.num_heads = num_heads
        self.classes = classes
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.dropout_layer = nn.Dropout(dropout)

        self.embeddings = nn.Linear(self.patch_size, self.embed_size)
        self.class_token = nn.Parameter(torch.randn(1, 1, self.embed_size))
        self.positional_encoding = PositionalEncoding2D(self.embed_size, self.image_size)

        self.encoders = nn.ModuleList([])
        for layer in range(self.num_layers):
            self.encoders.append(VisionEncoder(self.embed_size, self.num_heads, self.hidden_size, self.dropout))

        self.norm = nn.LayerNorm(self.embed_size)

        self.classifier = nn.Sequential(
            nn.Linear(self.embed_size, self.classes)
        )

    def forward(self, x, coords, mask):
        b, n, c, h, w = x.size()

        x = x.reshape(b, n, h*w)
        x = self.embeddings(x)

        b, n, e = x.size()

        pe = self.positional_encoding(coords.int()).to(x.device) # TODO int
        x = x + pe
        
        class_token = self.class_token.expand(b, 1, e)
        x = torch.cat((x, class_token), dim=1)
        mask = torch.cat((mask, torch.tensor([False] * mask.shape[0], device=mask.device).unsqueeze(1)), dim=1)
        
        x = self.dropout_layer(x)
        
        for encoder in self.encoders:
            x = encoder(x, mask)

        x = x[:, -1, :]

        x = F.log_softmax(self.classifier(self.norm(x)), dim=-1)

        return x
    
    def to(self, device):
        super(CoordViT, self).to(device)
        self.positional_encoding = self.positional_encoding.to(device)
        return self


image_size = 28
channel_size = 1
patch_size = 7
embed_size = 512
num_heads = 8
classes = 10
num_layers = 3
hidden_size = 256
dropout = 0.2
model = CoordViT(image_size, channel_size, patch_size, embed_size, num_heads, classes, num_layers, hidden_size, dropout=dropout)

x, coords, mask, y = next(iter(train_loader))
model(x, coords, mask)

tensor([[-2.7354, -2.4129, -2.8052, -3.4367, -1.9890, -2.0890, -3.1326, -3.3764,
         -1.4232, -1.7512],
        [-3.1527, -2.7185, -2.8222, -2.8696, -1.9462, -2.6661, -3.0450, -3.2003,
         -1.5199, -1.3638],
        [-2.2215, -2.1765, -2.9552, -2.8865, -1.8127, -2.5994, -3.5094, -3.2663,
         -1.3715, -2.1985],
        [-2.7306, -2.4969, -3.0589, -3.0459, -1.8857, -2.2834, -2.9170, -3.2802,
         -1.2623, -2.0432],
        [-2.5614, -2.8845, -3.2692, -2.6177, -2.0206, -2.0739, -2.2596, -3.2860,
         -1.5621, -1.9231],
        [-2.5108, -2.4781, -2.9143, -3.1705, -2.0552, -2.3170, -3.2491, -3.8691,
         -1.0914, -2.1490],
        [-2.3300, -2.6724, -2.5239, -2.7192, -2.2130, -2.0612, -2.6448, -3.9853,
         -1.6115, -1.8222],
        [-2.6251, -2.9441, -2.7898, -2.7099, -2.1250, -2.3106, -2.7228, -3.5696,
         -1.0733, -2.3798],
        [-1.6949, -2.6911, -3.0347, -2.5203, -1.8715, -2.6245, -2.9079, -3.2275,
         -1.7247, -2.1101],
        [-2.9135, -

In [188]:
import os
import pickle

from time import time

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import os
from typing import Iterable

import matplotlib.pyplot as plt
from skimage import graph, io, color

device = "cuda" if torch.cuda.is_available() else "cpu"

def plot_train_test(x: Iterable, y_train: Iterable, y_test: Iterable,
                    title: str, save_path: str, figsize: tuple[int, int] = (12, 8)) -> None:
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    plt.clf()

    plt.figure(figsize=figsize)
    plt.plot(x, y_train)
    plt.plot(x, y_test)
    plt.legend(["Train", "Test"])
    plt.title(title)
    
    if os.path.isfile(save_path):
        os.remove(save_path)
    plt.savefig(save_path)
    
    plt.clf()

def train_epoch(model: nn.Module, optimizer: Optimizer, criterion: nn.Module,
                loader: DataLoader, normalise_loss=True) -> tuple[float, float]:
    model.train()
    correct, total_loss, total = 0, 0, 0
    for batch in loader:
        x, y = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()
        
        out = model(x)
        loss = criterion(out, y)
        total_loss += loss.item()
        correct += out.argmax(dim=-1).eq(y).sum().item()
        total += len(y)
        
        loss.backward()
        optimizer.step()
    
    if total == 0:
        return 0, 1
    
    if normalise_loss:
        total_loss /= total
    
    return total_loss, correct / total

def train_epoch_coordvit(model: nn.Module, optimizer: Optimizer, criterion: nn.Module,
                loader: DataLoader, normalise_loss=True) -> tuple[float, float]:
    model.train()
    correct, total_loss, total = 0, 0, 0
    for batch in loader:
        x, coords, mask, y = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)
        
        optimizer.zero_grad()
        
        out = model(x, coords, mask)
        loss = criterion(out, y)
        total_loss += loss.item()
        correct += out.argmax(dim=-1).eq(y).sum().item()
        total += len(y)
        
        loss.backward()
        optimizer.step()
    
    if total == 0:
        return 0, 1
    
    if normalise_loss:
        total_loss /= total
    
    return total_loss, correct / total

def eval(model: nn.Module, criterion: nn.Module, loader: DataLoader, normalise_loss=True) -> tuple[float, float]:
    model.eval()
    correct, total_loss, total = 0, 0, 0
    for batch in loader:
        x, y = batch[0].to(device), batch[1].to(device)
        
        out = model(x)
        total_loss += criterion(out, y).item()
        correct += out.argmax(dim=-1).eq(y).sum().item()
        total += len(y)
        
    if total == 0:
        return 0, 1
    if normalise_loss:
        total_loss /= total
    return total_loss, correct / total

def eval_coordvit(model: nn.Module, criterion: nn.Module, loader: DataLoader, normalise_loss=True) -> tuple[float, float]:
    model.eval()
    correct, total_loss, total = 0, 0, 0
    for batch in loader:
        x, coords, mask, y = batch[0].to(device), batch[1].to(device), batch[2].to(device), batch[3].to(device)
        
        out = model(x, coords, mask)
        total_loss += criterion(out, y).item()
        correct += out.argmax(dim=-1).eq(y).sum().item()
        total += len(y)
        
    if total == 0:
        return 0, 1
    if normalise_loss:
        total_loss /= total
    return total_loss, correct / total

def train_test_loop(model: nn.Module, optimizer: Optimizer, criterion: nn.Module,
                    train_loader: DataLoader, test_loader: DataLoader,
                    num_epochs: int, lr_scheduler: LRScheduler = None,
                    save_path: str = None, plot=False,
                    train_epoch_fn=train_epoch, eval_fn=eval,
                    normalise_loss=True,
                    ) -> tuple[list[float]]:
    assert save_path is not None or plot == False
    
    train_accs, test_accs, train_losses, test_losses = [], [], [], []
    
    if save_path is not None:
        os.makedirs(save_path, exist_ok=True)
        
        if os.path.exists(save_path + "model.pt"):
            model.load_state_dict(torch.load(save_path + "model.pt"))
        if os.path.exists(save_path + "metrics.pkl"):
            metrics = pickle.load(open(save_path + "metrics.pkl", "rb"))
            train_accs, test_accs, train_losses, test_losses = metrics
            
    
    for i in range(1+len(train_accs), num_epochs+1):
        interval = time()

        train_loss, train_acc = train_epoch_fn(model, optimizer, criterion, train_loader, normalise_loss=normalise_loss)
        test_loss, test_acc = eval_fn(model, criterion, test_loader, normalise_loss=normalise_loss)
        if lr_scheduler is not None:
            lr_scheduler.step(epoch=i, metrics=train_loss)

        interval = time() - interval
        
        print(
            f"Epoch {i:03d}: train loss {train_loss:.4f},",
            f"train accuracy {train_acc:.3f},",
            f"test accuracy {test_acc:.3f}, "
            f"time {int(interval)}s"
        )
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
        if save_path is not None:
            torch.save(model.state_dict(), save_path + "model.pt")
            pickle.dump(
                (train_accs, test_accs, train_losses, test_losses),
                open(save_path + "metrics.pkl", "wb"),
            )
        
        if plot:
            plot_train_test(list(range(1, i+1)), train_losses, test_losses,
                            "Loss", save_path+"loss.png")
            plot_train_test(list(range(1, i+1)), train_accs, test_accs,
                            "Classification accuracy", save_path+"accuracy.png")

    return train_losses, train_accs, test_losses, test_accs


image_size = 28
channel_size = 1
patch_size = 7
embed_size = 512
num_heads = 8
classes = 10
num_layers = 3
hidden_size = 256
dropout = 0.2
model = CoordViT(image_size, channel_size, patch_size, embed_size, num_heads, classes, num_layers, hidden_size, dropout=dropout).to(device)


optimizer = Adam(model.parameters(), lr=5e-5)
criterion = CrossEntropyLoss()

num_epochs = 50

metrics = train_test_loop(
    model,
    optimizer,
    criterion,
    train_loader, 
    test_loader, 
    num_epochs,
    save_path=None,
    plot=False,
    train_epoch_fn=train_epoch_coordvit,
    eval_fn=eval_coordvit,
)

Epoch 001: train loss 0.0271, train accuracy 0.708, test accuracy 0.904, time 20s
Epoch 002: train loss 0.0105, train accuracy 0.894, test accuracy 0.931, time 19s
Epoch 003: train loss 0.0081, train accuracy 0.918, test accuracy 0.943, time 20s
Epoch 004: train loss 0.0067, train accuracy 0.931, test accuracy 0.949, time 21s
Epoch 005: train loss 0.0060, train accuracy 0.939, test accuracy 0.955, time 21s
Epoch 006: train loss 0.0053, train accuracy 0.945, test accuracy 0.955, time 22s
Epoch 007: train loss 0.0048, train accuracy 0.951, test accuracy 0.959, time 22s
Epoch 008: train loss 0.0044, train accuracy 0.955, test accuracy 0.960, time 23s


KeyboardInterrupt: 

In [190]:
dataset = pickle.load(open("../data/image/MNIST/MNIST_SLIC_graph_28_16_0p5.pkl", "rb"))
dataset = resize_stack_slic_graph_patches(dataset, (7, 7)) # this is part of the model
print("Finished preprocessing")

train_dataset, test_dataset = random_split(dataset, [.9, .1])
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_slic_graph_patches
)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=collate_slic_graph_patches
)

image_size = 28
channel_size = 1
patch_size = 7
embed_size = 512
num_heads = 8
classes = 10
num_layers = 3
hidden_size = 256
dropout = 0.2
model = CoordViT(
    image_size,
    channel_size,
    patch_size,
    embed_size,
    num_heads,
    classes,
    num_layers,
    hidden_size,
    dropout=dropout
).to(device)


optimizer = Adam(model.parameters(), lr=5e-5)
criterion = CrossEntropyLoss()

num_epochs = 50

metrics = train_test_loop(
    model,
    optimizer,
    criterion,
    train_loader, 
    test_loader, 
    num_epochs,
    save_path=None,
    plot=False,
    train_epoch_fn=train_epoch_coordvit,
    eval_fn=eval_coordvit,
)

Finished preprocessing
Epoch 001: train loss 0.0274, train accuracy 0.700, test accuracy 0.904, time 20s
Epoch 002: train loss 0.0107, train accuracy 0.892, test accuracy 0.935, time 19s


KeyboardInterrupt: 