In [13]:
import tqdm
import torch
import torchvision
import numpy as np
import pickle
import random
import cv2
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
setup_seed(56274635)

In [3]:
epoches = 20
batch_size = 64
lr = 0.00001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [68]:
class Binarify(torch.nn.Module):
    def __init__(self, threshold=0.5):
        super(Binarify, self).__init__()
        self.threshold = threshold
    
    def forward(self, img):
        return Image.fromarray(cv2.threshold(np.asarray(img)/255, self.threshold, 1, cv2.THRESH_BINARY)[1])

In [69]:
mnist_path = 'C:\MyWorks\Git Repository\Remote Repo\cnn\mnist.pkl'
dataset = pickle.load(open(mnist_path,'rb'),encoding='iso-8859-1')

class MNISTDataset(torch.utils.data.Dataset):
    def __init__(self, data: tuple):
        self.x_raw, self.y_raw = data
        self.x_raw = torch.from_numpy(self.x_raw).to(device)
        self.y_raw = torch.from_numpy(self.y_raw).to(device)
        self.transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToPILImage(),
            Binarify(threshold=0.5),
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.x_raw)

    def __getitem__(self, idx):
        im = self.transforms(self.x_raw[idx].reshape((1, 28, 28)))
        nonzero_idx_1 = torch.nonzero(im.squeeze().sum(dim=-2)>1)
        shift_1 = 16 - (nonzero_idx_1.max() + nonzero_idx_1.min()).item() // 2
        nonzero_idx_2 = torch.nonzero(im.squeeze().sum(dim=-1)>1)
        shift_2 = 16 - (nonzero_idx_2.max() + nonzero_idx_2.min()).item() // 2
        im = torch.roll(input=im, shifts=(shift_2, shift_1), dims=(-2, -1))
        im = im.reshape((1, 32, 32))
        return im.to(device), self.y_raw[idx]


train_dataset = MNISTDataset(data=dataset[0])
valid_dataset = MNISTDataset(data=dataset[1])
test_dataset = MNISTDataset(data=dataset[2])

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, num_workers=0)


In [5]:
class Transpose(torch.nn.Module):
    def __init__(self, dim0:int, dim1:int) -> None:
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x:torch.Tensor):
        return x.transpose(dim0=self.dim0, dim1=self.dim1)

class VitNet(torch.nn.Module):
    def __init__(self):
        super(VitNet, self).__init__()
        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, 32))
        self.positional_embedding = torch.nn.Parameter(torch.zeros(1, 65, 32))
        self.Embedding = torch.nn.Sequential(*[
            # ViT-like Embedding
            # (1, 32, 32)
            torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=4, stride=4), 
            # (32, 8, 8)
            torch.nn.Flatten(start_dim=-2, end_dim=-1),
            # (32, 64)
            Transpose(-1, -2),
            # (64, 32)
        ])

        self.TransformerEncoder = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(
                d_model=32, 
                nhead=4
            ), 
            num_layers=12, 
            norm=torch.nn.LayerNorm(normalized_shape=32)
        )
        self.Output = torch.nn.Sequential(*[
            torch.nn.Flatten(start_dim=1, end_dim=-1),
            torch.nn.Linear(in_features=32, out_features=64),
            torch.nn.GELU(),
            torch.nn.Linear(in_features=64, out_features=10),
            torch.nn.Softmax(dim=-1),
        ])
        

    def forward(self, x:torch.Tensor):
        batch_num = x.shape[0]
        # (1, 32, 32)
        x = self.Embedding(x)
        # (64, 32)
        cls_token = self.cls_token.expand(batch_num, 1, 32)
        x = torch.cat((cls_token,x), dim=1)

        # (65, 32)
        x += self.positional_embedding

        # (65, 32)
        x = self.TransformerEncoder(x)

        # (65, 32)
        cls_token_final = x[:, 0, :]

        x = self.Output(cls_token_final)
        return self.sequence(x)




vitNet = VitNet().to(device)
vitNet_optimiser = torch.optim.Adam(vitNet.parameters(), lr=lr)
vitNet_loss_func = torch.nn.CrossEntropyLoss()

train_loss_list = []
valid_loss_list = []

In [6]:
%matplotlib auto
plt.ion()

for epoch in tqdm.tqdm(range(epoches)):
    for x, y in train_loader:
        vitNet.train()
        vitNet_optimiser.zero_grad()
        y_ = vitNet(x)
        loss = vitNet_loss_func(y_, y)
        loss.backward()
        vitNet_optimiser.step()
        train_loss_list.append(loss.mean().item())
        
        plt.subplot(1, 2, 1)
        plt.cla()
        plt.plot(list(range(len(train_loss_list[:-100]), len(train_loss_list))), train_loss_list[-100:], color='black', linewidth=1)
        plt.title('train loss for last 100 cases')
        plt.pause(0.01)

    with torch.no_grad():
        vitNet.eval()
        Y, Y_ = torch.tensor([]).to(device), torch.tensor([]).to(device)
        for x, y in valid_loader:
            y_ = vitNet(x)
            
            Y = torch.cat((Y, y))
            Y_ = torch.cat((Y_, y_.argmax(dim=1)))

        valid_loss_list.append((Y == Y_).float().mean().item())
    
    

    plt.subplot(1, 2, 2)
    plt.cla()
    plt.plot(list(range(len(valid_loss_list[-100:]))), valid_loss_list[-100:], color='red', linewidth=1)
    plt.text(len(valid_loss_list[-100:]), valid_loss_list[-1], str(valid_loss_list[-1]*100))
    plt.title('valid accuracy')
    plt.pause(0.01)
    torch.save(vitNet.state_dict(), 'vitNet e%d at %.2f.pt' % (epoch, valid_loss_list[-1]*100))

plt.ioff()
plt.close()

Using matplotlib backend: TkAgg


  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
%matplotlib inline

In [None]:
plt.plot(list(range(len(train_loss_list))), train_loss_list)
plt.title('train loss')
plt.plot(list(range(len(valid_loss_list))), valid_loss_list)
plt.title('valid accuracy')

In [None]:
del train_dataset
del train_loader
del valid_dataset
del valid_loader
torch.cuda.empty_cache()

In [None]:
for x, y in test_loader:
    vitNet.zero_grad()
    vitNet.eval()
    y_ = vitNet(x)
    Y = torch.cat((Y, y))
    Y_ = torch.cat((Y_, y_.argmax(dim=1)))
print('test accuracy')
print((Y == Y_).float().mean().item() * 100, '%')
print('GroundTruth:', Y[-10:])
print('Net Outuput:', Y_[-10:])


In [None]:
with open('vit_log.pkl', 'wb') as f:
    pickle.dump((train_loss_list, valid_loss_list), f)