In [None]:
import torch
import math
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.nn.parameter import Parameter
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load the dataset and make data loader
transform = transforms.Compose(
    [transforms.RandomCrop(32, padding=4),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2, drop_last = True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2, drop_last = True)

In [None]:
def scaled_dot_product(q, k, v):
    d_k = q.size()[-1]

    attn_logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # [64, 8, 5, 5]
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)

    return values, attention

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, nhead, dropout, batch_size = batch_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.nhead = nhead
        self.dropout = dropout
        self.head_dim = embed_dim // nhead
        self.batch_size = batch_size

        assert self.head_dim * nhead == self.embed_dim, "embed_dim must be divisible by num_heads"

        self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim)))
        nn.init.xavier_uniform_(self.in_proj_weight)

        self.out_proj_weight =  Parameter(torch.empty((embed_dim, embed_dim)))
        nn.init.xavier_uniform_(self.out_proj_weight)

    def forward(self, query, key, value, attn_mask = None):
        seq_length, batch_size, embed_dim = query.size()                                            # query size (5, 64, 256)

        qkv = torch._C._nn.linear(query, self.in_proj_weight)

        qkv = qkv.unflatten(-1, (3, self.embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()  # qkv (5, 64, 768)
        q, k, v = qkv[0], qkv[1], qkv[2]                                                            # (5, 64, 256)

        q = q.view(seq_length, batch_size * self.nhead, self.head_dim).transpose(0, 1)              # [512, 5, 32]
        k = k.view(seq_length, batch_size * self.nhead, self.head_dim).transpose(0, 1)
        v = v.view(seq_length, batch_size * self.nhead, self.head_dim).transpose(0, 1)

        q = q.view(batch_size, self.nhead, seq_length, self.head_dim)                               # [64, 8, 5, 32]
        k = k.view(batch_size, self.nhead, seq_length, self.head_dim)
        v = v.view(batch_size, self.nhead, seq_length, self.head_dim)

        values, _ = scaled_dot_product(q, k, v)                                                     # [64, 8, 5, 32]
        values = values.permute(2, 0, 1, 3).contiguous().view(batch_size * seq_length, embed_dim)   # [320, 256]

        o = torch._C._nn.linear(values, self.out_proj_weight)
        o = o.view(seq_length, batch_size, embed_dim)

        return o


In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout):
        super().__init__()
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-5)

        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = nn.ReLU()

        self.attention = MultiheadAttention(d_model, nhead, dropout = dropout)

    def forward(self, x, src_mask = None, src_key_padding_mask = None, is_causal = False):
        # sa
        sa = self.dropout1(self.attention(x, x, x, attn_mask = src_mask))
        x = self.norm1(x + sa)

        # ff
        ff = self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(x)))))
        x = self.norm2(x + ff)

        return x

import copy

def _get_clones(module, N):
    # FIXME: copy.deepcopy() is not defined on nn.module
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

class Encoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None): #, ** block_args):
        super().__init__()

        self.num_layers = num_layers
        self.norm = norm
        self.layers = _get_clones(encoder_layer, num_layers)
        # self.layers = nn.ModuleList([EncoderLayer(**block_args) for _ in range(num_layers)])

    def forward(self, x, src_mask = None):
        for mod in self.layers:
            output = mod(x, src_mask)

        # output = torch.nested.to_padded_tensor(output, 0.)
        if self.norm is not None:
            output = self.norm(output)

        return output

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, emb_size, patch_size):
        super(PatchEmbedding, self).__init__()

        self.conv = nn.Conv2d(3, emb_size, kernel_size=patch_size, stride=patch_size)
        self.flatten = nn.Flatten(start_dim=2)
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x: torch.Tensor):
        x = self.conv(x)                    # (64, 256, 2, 2)
        batch_size, c, h, w = x.shape
        x = x.permute(2, 3, 0, 1)           # (2, 2, 64, 256)
        x = x.view(h * w, batch_size, c)    # (4, 64, 256)

        return x

class SecondEmbedding(nn.Module):
    def __init__(self, in_channels, emb_size, patch_size):
        super(SecondEmbedding, self).__init__()

        self.conv_1 = nn.Conv2d(3, 64, kernel_size=2, stride=2)
        self.flatten = nn.Flatten(start_dim=2)

    def forward(self, x: torch.Tensor):
        x = self.conv_1(x)
        x = self.flatten(x)

        x = x.view(64, 64, 256)    # (4, 64, 256)
        return x


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()
        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)

    def forward(self, x: torch.Tensor):
        pe = self.positional_encodings[:x.shape[0]]
        return x + pe

class Classification(nn.Module):
    def __init__(self, d_model, hidden):
        super(Classification, self).__init__()

        self.linear1 = nn.Linear(d_model, hidden)
        self.act = nn.ReLU()
        self.linear2 = nn.Linear(hidden, 10)

    def forward(self, x: torch.Tensor):
        x = self.act(self.linear1(x))
        x = self.linear2(x)
        return x

class enhance_classifier(nn.Module):
    def __init__(self, d_model = 512, nhead = 8, num_layers = 6, dim_feedforward = 512, hidden = 2048, num_classes = 10, patch_size = 16):
        super(enhance_classifier, self).__init__()

        self.d_model = d_model

        block_args = {'d_model': d_model, 'nhead': nhead, 'dim_feedforward': dim_feedforward, 'dropout': 0.1}

        self.patch_embedding = PatchEmbedding(in_channels = 3, emb_size = d_model, patch_size = patch_size)
        self.second_embedding = SecondEmbedding(in_channels = 3, emb_size = d_model, patch_size = patch_size)
        self.positional_encoding = PositionalEncoding(d_model = d_model)

        encoder_layers = EncoderLayer(d_model, nhead, dim_feedforward, 0.1)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)

        # self.transformer_encoder = Encoder(encoder_layers, num_layers = num_layers)

        self.classifaction = Classification(d_model, hidden)
        self.cls_token_emb = nn.Parameter(torch.zeros(1, 1, d_model), requires_grad=True)
        self.ln = nn.LayerNorm([d_model])

    def forward(self, x):                                               # (64, 32, 32)
        # Add embedding
        x = self.patch_embedding(x)                                     # torch.Size([4, 64, 256]) (h * w, batch_size, embedding dim)
                                                                        # 4 = (32 // 16) ** 2       h = img_h // patch size

        # Add positional information
        x = self.positional_encoding(x)                                 # torch.Size([4, 64, 256])

        # Add CLS token
        cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)   # torch.Size([1, 64, 256])
        x = torch.cat([cls_token_emb, x])                               # torch.Size([5, 64, 256])

        # Feed into network
        x = self.transformer_encoder(x)                                 # torch.Size([5, 64, 256])

        # Get Result
        x = x[0]                                                        # torch.Size([64, 256])

        x = self.ln(x)
        x = self.classifaction(x)
        return x

In [None]:
model = enhance_classifier(d_model = 256,
                           dim_feedforward = 512, hidden = 512,
                           num_layers = 6, nhead = 8,
                           patch_size = 4)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

In [None]:
# Train the model
model.train()
for epoch in range(30):
    loss_sum = 0
    correct = 0
    total = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("epoch", epoch, 'acc:', correct / total)


In [None]:
# Eval the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, emb_size, patch_size):
        super(PatchEmbedding, self).__init__()

        self.conv = nn.Conv2d(3, emb_size, kernel_size=patch_size, stride=patch_size)
        self.flatten = nn.Flatten(start_dim=2)
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x: torch.Tensor):
        x = self.conv(x)                # 64, 256, 2, 2
        print(1, x.size())
        batch_size, c, h, w = x.shape
        x = x.permute(2, 3, 0, 1)       # 2, 2, 64, 256
        print(2, x.size())
        x = x.view(h * w, batch_size, c)
        print(3, x.size())              # 4, 64, 256

        return x

In [None]:
test = torch.rand((64, 3, 32, 32))

conv_1 = nn.Conv2d(3, 64, kernel_size=2, stride=2)

flatten = nn.Flatten(start_dim=2)

x = conv_1(test)

print(x.size())
x = flatten(x)
print(x.size())

x = x.view(64, 64, 256)    # (4, 64, 256)