In [1]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from code import helper_functions as hf

In [2]:
data_dir = '/Users/jqin/Documents/Adulting/Studying/mnist_practice/data'
train_data_file = 'train-images.idx3-ubyte'
train_labels_file = 'train-labels.idx1-ubyte'
test_data_file = 't10k-images.idx3-ubyte'
test_labels_file = 't10k-labels.idx1-ubyte'

train_data, train_labels = hf.load_data(data_dir, train_data_file, train_labels_file)
test_data, test_labels = hf.load_data(data_dir, test_data_file, test_labels_file)

In [3]:
class Transformer(nn.Module):
    def __init__(self, n_feat, n_hidden, n_out, n_heads, n_layer):
        super().__init__()
        self.embeddings = nn.Embedding(n_feat, n_hidden) 
        self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(n_hidden, nhead=n_heads), num_layers=n_layer)
        self.ffn = nn.Linear(n_hidden*n_feat, n_out)
        self.output_fn = nn.LogSoftmax(dim=0)
    
    def forward(self, data):
        batch_size = data.shape[0]
        embedded = self.embeddings(data)
        encoded = torch.reshape(self.encoder(embedded), (batch_size, -1))
        output = self.ffn(encoded)
        preds = self.output_fn(output)
        
        return preds
        

In [4]:
n_feat = 28*28
n_hidden = 30
n_out = 10
n_heads = 5
n_layer = 3
learning_rate = 0.001

model = Transformer(n_feat, n_hidden, n_out, n_heads, n_layer)
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [5]:
n_train = train_labels.shape[0]
batch_size = 10
loss_fn = nn.CrossEntropyLoss()

for epoch in range(401):
    optim.zero_grad()
    
    start_idx = (epoch * batch_size) % n_train
    end_idx = min(n_train, start_idx + batch_size)
    batch_data = torch.tensor(train_data[start_idx:end_idx]).to(torch.int64)
    batch_labels = torch.tensor(train_labels[start_idx:end_idx]).to(torch.int64)

    preds = model(batch_data)
    loss = loss_fn(preds, batch_labels)
    
    if (epoch % 10 == 0):
        print('Loss on epoch {} is {}'.format(epoch, torch.sum(loss)))
    
    loss.backward()
    optim.step()

Loss on epoch 0 is 2.327805995941162
Loss on epoch 10 is 1.2722746133804321
Loss on epoch 20 is 0.996909499168396
Loss on epoch 30 is 0.9750736951828003
Loss on epoch 40 is 0.6945288777351379
Loss on epoch 50 is 1.5724852085113525
Loss on epoch 60 is 0.7617841362953186
Loss on epoch 70 is 1.0861659049987793
Loss on epoch 80 is 0.7380486726760864
Loss on epoch 90 is 1.2996108531951904
Loss on epoch 100 is 0.8445766568183899
Loss on epoch 110 is 0.7441269159317017
Loss on epoch 120 is 0.8465034365653992
Loss on epoch 130 is 1.1179803609848022
Loss on epoch 140 is 0.8090758323669434
Loss on epoch 150 is 0.7708579301834106
Loss on epoch 160 is 1.0725456476211548
Loss on epoch 170 is 1.0307934284210205
Loss on epoch 180 is 0.6791755557060242
Loss on epoch 190 is 0.7750086784362793
Loss on epoch 200 is 0.38142916560173035
Loss on epoch 210 is 0.8576698303222656
Loss on epoch 220 is 1.5651919841766357
Loss on epoch 230 is 1.218719244003296
Loss on epoch 240 is 1.3919713497161865
Loss on epoch

In [6]:
test_data = torch.tensor(test_data).to(torch.int64)
test_labels = torch.tensor(train_labels).to(torch.int64)

test_preds = model(test_data[0:20])