In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available. Using device:", device)
    print("GPU Name:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU is not available. Using CPU.")

GPU is not available. Using CPU.


 Data Loading

In [6]:
import numpy as np
import struct
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

def read_idx(filename):
    """Reads an IDX file and returns a NumPy array."""
    with open(filename, 'rb') as f:
        magic, size = struct.unpack(">II", f.read(8))

        if magic == 2051:
            rows, cols = struct.unpack(">II", f.read(8))
            data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>')).reshape(size, rows, cols)
        elif magic == 2049:
            data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
        else:
            raise ValueError("Invalid magic number: {}".format(magic))
    return data

class MNISTUByteDataset(Dataset):
    def __init__(self, images_file, labels_file, transform=None):
        self.images = read_idx(images_file)
        self.labels = read_idx(labels_file)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]


        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, label



transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


train_images_file = 'train-images.idx3-ubyte'  # If the file is in the root
train_labels_file = 'train-labels.idx1-ubyte'
test_images_file = 't10k-images.idx3-ubyte'
test_labels_file = 't10k-labels.idx1-ubyte'


train_dataset = MNISTUByteDataset(train_images_file, train_labels_file, transform=transform)
test_dataset = MNISTUByteDataset(test_images_file, test_labels_file, transform=transform)


train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

print("Data loaded from UByte files.")

Data loaded from UByte files.


In [7]:
import numpy as np
from tqdm import tqdm, trange
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image


def read_idx(filename):
    """Reads an IDX file and returns a NumPy array."""
    with open(filename, 'rb') as f:
        magic, size = struct.unpack(">II", f.read(8))

        if magic == 2051:
            rows, cols = struct.unpack(">II", f.read(8))
            data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>')).reshape(size, rows, cols)
        elif magic == 2049:
            data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
        else:
            raise ValueError("Invalid magic number: {}".format(magic))
    return data

class MNISTUByteDataset(Dataset):
    def __init__(self, images_file, labels_file, transform=None):
        self.images = read_idx(images_file)
        self.labels = read_idx(labels_file)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]


        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, label

np.random.seed(0)
torch.manual_seed(0)


def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result


class MyMSA(nn.Module):
    def __init__(self, d, n_heads=2):
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads)
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):

        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out


class MyViT(nn.Module):
    def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):

        super(MyViT, self).__init__()


        self.chw = chw
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d


        assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)


        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)


        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))


        self.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d),
                             persistent=False)


        self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])


        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, out_d),
            nn.Softmax(dim=-1)
        )

    def patchify(self, images, n_patches):
        n, c, h, w = images.shape

        assert h == w, "Patchify method is implemented for square images only"

        patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
        patch_size = h // n_patches

        for idx, image in enumerate(images):
            for i in range(n_patches):
                for j in range(n_patches):
                    patch = image[:, i * patch_size: (i + 1) * patch_size,
                            j * patch_size: (j + 1) * patch_size]
                    patches[idx, i * n_patches + j] = patch.flatten()
        return patches

    def forward(self, images):

        n, c, h, w = images.shape
        patches = self.patchify(images, self.n_patches).to(self.positional_embeddings.device)

        tokens = self.linear_mapper(patches)


        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)


        out = tokens + self.positional_embeddings.repeat(n, 1, 1)


        for block in self.blocks:
            out = block(out)


        out = out[:, 0]

        return self.mlp(out)

def main():


    train_images_file = 'train-images.idx3-ubyte'  # If the file is in the root
    train_labels_file = 'train-labels.idx1-ubyte'
    test_images_file = 't10k-images.idx3-ubyte'
    test_labels_file = 't10k-labels.idx1-ubyte'

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_set = MNISTUByteDataset(train_images_file, train_labels_file, transform=transform)
    test_set = MNISTUByteDataset(test_images_file, test_labels_file, transform=transform)


    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
    model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)
    N_EPOCHS = 5
    LR = 0.005


    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")


    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

if __name__ == '__main__':
    main()

Using device:  cpu 


Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 1/469 [00:00<05:22,  1.45it/s][A
Epoch 1 in training:   0%|          | 2/469 [00:01<04:32,  1.72it/s][A
Epoch 1 in training:   1%|          | 3/469 [00:01<04:15,  1.83it/s][A
Epoch 1 in training:   1%|          | 4/469 [00:02<04:03,  1.91it/s][A
Epoch 1 in training:   1%|          | 5/469 [00:02<03:54,  1.98it/s][A
Epoch 1 in training:   1%|▏         | 6/469 [00:03<04:37,  1.67it/s][A
Epoch 1 in training:   1%|▏         | 7/469 [00:04<04:49,  1.60it/s][A
Epoch 1 in training:   2%|▏         | 8/469 [00:04<04:59,  1.54it/s][A
Epoch 1 in training:   2%|▏         | 9/469 [00:05<05:24,  1.42it/s][A
Epoch 1 in training:   2%|▏         | 10/469 [00:06<04:57,  1.55it/s][A
Epoch 1 in training:   2%|▏         | 11/469 [00:06<04:36,  1.66it/s][A
Epoch 1 in training:   3%|▎         | 12/469 [00:07<04:21,  1.75it/s][A
Epoch 1 in training: 

Epoch 1/5 loss: 2.11



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 1/469 [00:00<05:35,  1.40it/s][A
Epoch 2 in training:   0%|          | 2/469 [00:01<05:25,  1.44it/s][A
Epoch 2 in training:   1%|          | 3/469 [00:02<05:36,  1.39it/s][A
Epoch 2 in training:   1%|          | 4/469 [00:02<05:37,  1.38it/s][A
Epoch 2 in training:   1%|          | 5/469 [00:03<05:03,  1.53it/s][A
Epoch 2 in training:   1%|▏         | 6/469 [00:03<04:40,  1.65it/s][A
Epoch 2 in training:   1%|▏         | 7/469 [00:04<04:21,  1.76it/s][A
Epoch 2 in training:   2%|▏         | 8/469 [00:04<04:07,  1.86it/s][A
Epoch 2 in training:   2%|▏         | 9/469 [00:05<04:01,  1.91it/s][A
Epoch 2 in training:   2%|▏         | 10/469 [00:05<03:54,  1.96it/s][A
Epoch 2 in training:   2%|▏         | 11/469 [00:06<03:54,  1.95it/s][A
Epoch 2 in training:   3%|▎         | 12/469 [00:06<03:50,  1.98it/s][A
Epoch 2 in training:   3%|▎         | 13/469 [00:07<03:52,  1.96it/s

Epoch 2/5 loss: 1.94



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 1/469 [00:00<05:41,  1.37it/s][A
Epoch 3 in training:   0%|          | 2/469 [00:01<04:37,  1.68it/s][A
Epoch 3 in training:   1%|          | 3/469 [00:01<04:18,  1.80it/s][A
Epoch 3 in training:   1%|          | 4/469 [00:02<04:08,  1.87it/s][A
Epoch 3 in training:   1%|          | 5/469 [00:02<04:04,  1.90it/s][A
Epoch 3 in training:   1%|▏         | 6/469 [00:03<03:54,  1.97it/s][A
Epoch 3 in training:   1%|▏         | 7/469 [00:03<03:49,  2.01it/s][A
Epoch 3 in training:   2%|▏         | 8/469 [00:04<03:49,  2.01it/s][A
Epoch 3 in training:   2%|▏         | 9/469 [00:04<03:49,  2.01it/s][A
Epoch 3 in training:   2%|▏         | 10/469 [00:05<03:45,  2.03it/s][A
Epoch 3 in training:   2%|▏         | 11/469 [00:05<03:51,  1.98it/s][A
Epoch 3 in training:   3%|▎         | 12/469 [00:06<03:46,  2.02it/s][A
Epoch 3 in training:   3%|▎         | 13/469 [00:06<03:41,  2.06it/s

Epoch 3/5 loss: 1.84



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   0%|          | 1/469 [00:00<04:13,  1.85it/s][A
Epoch 4 in training:   0%|          | 2/469 [00:01<04:05,  1.91it/s][A
Epoch 4 in training:   1%|          | 3/469 [00:01<04:01,  1.93it/s][A
Epoch 4 in training:   1%|          | 4/469 [00:02<03:53,  2.00it/s][A
Epoch 4 in training:   1%|          | 5/469 [00:02<03:59,  1.94it/s][A
Epoch 4 in training:   1%|▏         | 6/469 [00:03<04:17,  1.80it/s][A
Epoch 4 in training:   1%|▏         | 7/469 [00:03<04:38,  1.66it/s][A
Epoch 4 in training:   2%|▏         | 8/469 [00:04<04:53,  1.57it/s][A
Epoch 4 in training:   2%|▏         | 9/469 [00:05<05:16,  1.45it/s][A
Epoch 4 in training:   2%|▏         | 10/469 [00:05<04:54,  1.56it/s][A
Epoch 4 in training:   2%|▏         | 11/469 [00:06<04:37,  1.65it/s][A
Epoch 4 in training:   3%|▎         | 12/469 [00:06<04:15,  1.79it/s][A
Epoch 4 in training:   3%|▎         | 13/469 [00:07<04:05,  1.86it/s

Epoch 4/5 loss: 1.79



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   0%|          | 1/469 [00:00<04:15,  1.83it/s][A
Epoch 5 in training:   0%|          | 2/469 [00:01<04:03,  1.92it/s][A
Epoch 5 in training:   1%|          | 3/469 [00:01<04:13,  1.84it/s][A
Epoch 5 in training:   1%|          | 4/469 [00:02<04:06,  1.89it/s][A
Epoch 5 in training:   1%|          | 5/469 [00:02<04:04,  1.90it/s][A
Epoch 5 in training:   1%|▏         | 6/469 [00:03<03:59,  1.93it/s][A
Epoch 5 in training:   1%|▏         | 7/469 [00:03<03:59,  1.93it/s][A
Epoch 5 in training:   2%|▏         | 8/469 [00:04<03:56,  1.95it/s][A
Epoch 5 in training:   2%|▏         | 9/469 [00:04<03:56,  1.95it/s][A
Epoch 5 in training:   2%|▏         | 10/469 [00:05<03:56,  1.94it/s][A
Epoch 5 in training:   2%|▏         | 11/469 [00:05<03:50,  1.98it/s][A
Epoch 5 in training:   3%|▎         | 12/469 [00:06<03:51,  1.98it/s][A
Epoch 5 in training:   3%|▎         | 13/469 [00:06<03:47,  2.01it/s

Epoch 5/5 loss: 1.76


Testing: 100%|██████████| 79/79 [00:22<00:00,  3.56it/s]

Test loss: 1.74
Test accuracy: 71.58%



