# unzip dataset

In [1]:
# import zipfile

# if __name__ == '__main__':
#     folder = "/home/gyuseonglee/workspace/2day"
#     unzip = zipfile.ZipFile(f"{folder}/Paired_MNIST.zip")
#     unzip.extractall("/home/gyuseonglee/workspace/2day/data")
#     unzip.close()

### import libraries

In [1]:
import os
import pickle as pkl
import torch
import pandas as pd
import numpy as np
import cv2
from tqdm.auto import tqdm as tq
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt


In [2]:
folder = './data/Paired_MNIST'
with open(f'{folder}/training_tuple.pkl', 'rb') as f:
    training_tuple = pkl.load(f)
with open(f'{folder}/training_dict.pkl', 'rb') as f:
    training_dict  = pkl.load(f)
with open(f'{folder}/test.pkl', 'rb') as f:
    test = pkl.load(f)

In [3]:
# print shape
X_train = training_tuple[0]
Y_train = training_tuple[1]
X_train, X_valid, Y_train, Y_valid = train_test_split(X_train, Y_train, test_size=0.16666666, random_state=1203)
X_test  = test[0]
Y_test  = test[1]

print(f"train img   : {X_train.shape}")
print(f"train label : {Y_train.shape}")
print(f"valid img   : {X_valid.shape}")
print(f"valid label : {Y_valid.shape}")

print(f"test img    : {test[0].shape}")
print(f"test label  : {test[1].shape}")


train img   : torch.Size([50000, 2, 28, 28])
train label : torch.Size([50000])
valid img   : torch.Size([10000, 2, 28, 28])
valid label : torch.Size([10000])
test img    : torch.Size([10000, 28, 28])
test label  : torch.Size([10000])


In [4]:
# https://github.com/ansh941/MnistSimpleCNN

''' SOTA model (homogeneous ensenble -m3) 
        super(ModelM3, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, bias=False)       # output becomes 26x26
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 48, 3, bias=False)      # output becomes 24x24
        self.conv2_bn = nn.BatchNorm2d(48)
        self.conv3 = nn.Conv2d(48, 64, 3, bias=False)      # output becomes 22x22
        self.conv3_bn = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 80, 3, bias=False)      # output becomes 20x20
        self.conv4_bn = nn.BatchNorm2d(80)
        self.conv5 = nn.Conv2d(80, 96, 3, bias=False)      # output becomes 18x18
        self.conv5_bn = nn.BatchNorm2d(96)
        self.conv6 = nn.Conv2d(96, 112, 3, bias=False)     # output becomes 16x16
        self.conv6_bn = nn.BatchNorm2d(112)
        self.conv7 = nn.Conv2d(112, 128, 3, bias=False)    # output becomes 14x14
        self.conv7_bn = nn.BatchNorm2d(128)
        self.conv8 = nn.Conv2d(128, 144, 3, bias=False)    # output becomes 12x12
        self.conv8_bn = nn.BatchNorm2d(144)
        self.conv9 = nn.Conv2d(144, 160, 3, bias=False)    # output becomes 10x10
        self.conv9_bn = nn.BatchNorm2d(160)
        self.conv10 = nn.Conv2d(160, 176, 3, bias=False)   # output becomes 8x8
        self.conv10_bn = nn.BatchNorm2d(176)
        self.fc1 = nn.Linear(11264, 10, bias=False)
        self.fc1_bn = nn.BatchNorm1d(10)
'''
print()




In [16]:
# single conv layer (conv -> batch norm -> activation)
class ConvBnAct(torch.nn.Module):
    def __init__(self, in_channels, out_channels, padding=0, activation='relu'):
        super().__init__()
        
        self.conv = torch.nn.Conv2d(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=3,
            stride=1,
            padding=padding,
            dilation=1,
            bias=False,
        )
        self.bn = torch.nn.BatchNorm2d(
            num_features=out_channels, eps=1e-05, momentum=0.1)
        
        if activation == 'gelu':
            self.activation = torch.nn.GELU(approximate='tanh') 
        else:
            self.activation = torch.nn.ReLU(inplace=True)
        
    def forward(self, x):
        return self.activation(self.bn(self.conv(x)))


    
# single digit recognizer
class DigitRecognizer(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.maxpool = torch.nn.MaxPool2d(2, 2) # 안쓸듯..?
        
        self.conv1 = ConvBnAct(in_channels= 1, out_channels=32, activation='relu')
        self.conv2 = ConvBnAct(in_channels=32, out_channels=48, activation='relu')
        self.conv3 = ConvBnAct(in_channels=48, out_channels=64, activation='relu')

        self.conv4 = ConvBnAct(in_channels=64, out_channels=80, activation='relu')
        self.conv5 = ConvBnAct(in_channels=80, out_channels=96, activation='relu')
        self.conv6 = ConvBnAct(in_channels=96, out_channels=112, activation='relu')

        self.conv7 = ConvBnAct(in_channels=112, out_channels=118, activation='relu')
        self.conv8 = ConvBnAct(in_channels=118, out_channels=144, activation='relu')
        self.conv9 = ConvBnAct(in_channels=144, out_channels=160, activation='relu')
        self.conv10 = ConvBnAct(in_channels=160, out_channels=176, activation='relu')

        self.convs = [
            self.conv1, self.conv2, self.conv3,
            self.conv4, self.conv5, self.conv6,
            self.conv7, self.conv8, self.conv9, self.conv10,
        ]
        self.linear = torch.nn.Linear(11264, 10)
        
    def forward(self, x):
        for c in self.convs:
            x = c(x)   # [batch_size, channels, height, weight]
        x = torch.flatten(x.permute(0, 2, 3, 1), 1) # embedding
        x = self.linear(x)
        return x
            
    
    
    
# model for addition 
class DigitAdder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.digit_recognizer = DigitRecognizer()
        self.add_layer = torch.nn.Linear(100, 19)

        
    def forward(self, x1, x2):
        B = x1.shape[0] # batch size
        
        x1 = self.digit_recognizer(x1)
        x2 = self.digit_recognizer(x2)
        comb = x1.unsqueeze(2) + x2.unsqueeze(1)
        summation = self.add_layer(comb.view(B, -1))
        return summation


In [31]:
class cfg:
    def __init__(self):
        return
    
configs = cfg()
configs.batch_size    = 512
configs.learning_rate = 0.0001
configs.device = 'cuda:0'
configs.epochs = 100
configs.num_gpus = torch.cuda.device_count()
configs.tqdm = True

In [32]:
class BaseDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y, mode ='train'):
        self.X = X
        self.Y = Y
        self.mode = mode
                
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        if self.mode in ['train', 'valid']:
            x1 = self.X[idx][0].reshape(1, 28, 28).float()
            x2 = self.X[idx][1].reshape(1, 28, 28).float() 
            y = self.Y[idx] 
            return x1, x2, y
        else:
            x = self.X[idx].reshape(1, 28, 28).float()
            y = self.Y[idx]
            return x, y



train_dataset = BaseDataset(X_train, Y_train, mode='train')
valid_dataset = BaseDataset(X_train, Y_train, mode='valid')
test_dataset  = BaseDataset(X_test, Y_test, mode='test')

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=configs.batch_size, shuffle=False)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=configs.batch_size, shuffle=False)

In [19]:
model = DigitAdder()
optimizer = torch.optim.Adam(params= model.parameters(), lr=configs.learning_rate)
scheduler = None
warm_up = None
criterion = torch.nn.CrossEntropyLoss()

In [20]:
def train_fn(configs, model, optimizer, criterion, scheduler, warm_up):
    global y, yhat, labels, preds
    def forward_step(batch):
        x1, x2, y = batch
        x1 = x1.to(configs.device)
        x2 = x2.to(configs.device)
        y  = y.to(configs.device)
        
        yhat = model(x1, x2)
        loss = criterion(yhat, y)
        
        return yhat, loss
    
    train_loss_tracker = []
    valid_loss_tracker = []
    valid_acc_tracker  = []
    valid_f1_tracker   = []

    best_loss = 999999
    best_acc  = 0.0
    best_f1   = 0.0
    best_model = None

    model = model.to(configs.device)
    if configs.num_gpus >= 1:
        print("--current device : CUDA")
    if configs.num_gpus > 1:
        model = torch.nn.DataParallel(model)
        print(f"--distributed training : {['cuda:'+str(i) for i in range(torch.cuda.device_count())]}")

    criterion = criterion.to(configs.device)
    
    for epoch in range(1, (configs.epochs + 1)):
        # train stage
        model.train()
        train_loss = []
        train_iterator = tq(train_loader) if configs.tqdm else train_loader
        
        for batch in train_iterator:
            optimizer.zero_grad()
            _, loss = forward_step(batch)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            
        if scheduler is not None:
            scheduler.step()
    
        # validation stage
        model.eval()
        valid_loss = []
        labels = []
        preds  = []
        
        valid_iterator = tq(valid_loader) if configs.tqdm else valid_loader
        with torch.no_grad():
            for batch in valid_iterator:
                yhat, loss = forward_step(batch)
                valid_loss.append(loss.item())

                # result
                y = batch[2].detach().cpu().numpy()
                yhat =  yhat.argmax(1).detach().cpu().numpy()
                
                labels.append(y)
                preds.append(yhat)
                
        labels = np.concatenate(labels, axis=0)
        preds  = np.concatenate(preds,  axis=0)
        # metric
        acc, f1 = accuracy_score(labels, preds), f1_score(labels, preds, average = 'macro')
        
        if f1 > best_f1:
            best_f1 = f1
            best_model = model
        
        train_loss = round(np.mean(train_loss), 4)
        valid_loss = round(np.mean(valid_loss)  , 4)
        valid_acc  = round(acc, 4)
        valid_f1   = round(f1, 4)
        
        print(f"-- EPOCH {epoch} --")
        print(f"training   loss : {train_loss}")
        print(f"validation loss : {valid_loss}")
        print(f"current val acc : {valid_acc}")
        print(f"current val f1  : {valid_f1}") 
        print(f"best val acc    : {round(best_acc, 4)}")
        print(f"best val f1     : {round(best_f1, 4)}")
        print(f"labels (first 5 items)  : {labels[:5]}")
        print(f"preds  (first 5 items)  : {preds[:5]}")
        train_loss_tracker.append(train_loss)
        valid_loss_tracker.append(valid_loss)
        valid_acc_tracker.append(valid_acc)
        valid_f1_tracker.append(valid_f1)
        
    return best_model, train_loss_tracker, valid_loss_tracker, valid_acc_tracker, valid_f1_tracker

In [21]:
outputs = train_fn(configs, model, optimizer, criterion, scheduler, warm_up)

--current device : CUDA


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 1 --
training   loss : 2.6051
validation loss : 2.1955
current val acc : 0.2117
current val f1  : 0.1884
best val acc    : 0.0
best val f1     : 0.1884
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  8 13 14 12]


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 2 --
training   loss : 1.8968
validation loss : 1.714
current val acc : 0.3534
current val f1  : 0.3399
best val acc    : 0.0
best val f1     : 0.3399
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  6 14  9 12]


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 3 --
training   loss : 1.5205
validation loss : 1.3614
current val acc : 0.4857
current val f1  : 0.4737
best val acc    : 0.0
best val f1     : 0.4737
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  5 15  9 12]


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 4 --
training   loss : 1.3004
validation loss : 1.1846
current val acc : 0.5773
current val f1  : 0.5862
best val acc    : 0.0
best val f1     : 0.5862
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  5 16  9 14]


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 5 --
training   loss : 1.1341
validation loss : 1.0596
current val acc : 0.6233
current val f1  : 0.6485
best val acc    : 0.0
best val f1     : 0.6485
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  6 16  9 13]


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 6 --
training   loss : 0.9877
validation loss : 0.9448
current val acc : 0.6809
current val f1  : 0.6716
best val acc    : 0.0
best val f1     : 0.6716
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  6 16  9 14]


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 7 --
training   loss : 0.8608
validation loss : 0.7592
current val acc : 0.7768
current val f1  : 0.7943
best val acc    : 0.0
best val f1     : 0.7943
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  6 16  9 13]


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 8 --
training   loss : 0.7427
validation loss : 0.694
current val acc : 0.8112
current val f1  : 0.8322
best val acc    : 0.0
best val f1     : 0.8322
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  6 16  9 13]


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 9 --
training   loss : 0.6318
validation loss : 0.5722
current val acc : 0.8604
current val f1  : 0.8598
best val acc    : 0.0
best val f1     : 0.8598
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  6 16  9 12]


  0%|          | 0/98 [00:00<?, ?it/s]

  0%|          | 0/98 [00:00<?, ?it/s]

-- EPOCH 10 --
training   loss : 0.5466
validation loss : 0.5653
current val acc : 0.8382
current val f1  : 0.841
best val acc    : 0.0
best val f1     : 0.8598
labels (first 5 items)  : [ 1  6 16  9 14]
preds  (first 5 items)  : [ 1  5 16  9 12]


In [22]:
best_model = outputs[0]

infer_model = best_model.digit_recognizer

In [25]:
infer_model

DigitRecognizer(
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv1): ConvBnAct(
    (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU(inplace=True)
  )
  (conv2): ConvBnAct(
    (conv): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU(inplace=True)
  )
  (conv3): ConvBnAct(
    (conv): Conv2d(48, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU(inplace=True)
  )
  (conv4): ConvBnAct(
    (conv): Conv2d(64, 80, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ReLU(inplac

In [26]:
def inference(configs, model, test_loader):
    global preds, labels
    def forward_step(batch):
        x, y = batch
        x = x.to(configs.device)
        yhat = model(x)
        return yhat
    
    model = model.to(configs.device)
    # test stage
    model.eval()
    test_loss = []
    labels = []
    preds  = []

    test_iterator = tq(test_loader) if configs.tqdm else test_loader
    with torch.no_grad():
        for batch in test_iterator:
            yhat = forward_step(batch)

            # result
            y = batch[1].detach().cpu().numpy()
            yhat =  yhat.argmax(1).detach().cpu().numpy()

            labels.append(y)
            preds.append(yhat)

    labels = np.concatenate(labels, axis=0)
    preds  = np.concatenate(preds,  axis=0)
    # metric
    acc, f1 = accuracy_score(labels, preds), f1_score(labels, preds, average = 'macro')

    test_acc  = round(acc, 4)
    test_f1   = round(f1, 4)
    
    print([test_acc, test_f1])
    
    return test_acc, test_f1, labels, preds



In [27]:
predict = inference(configs, infer_model, test_loader)

  0%|          | 0/20 [00:00<?, ?it/s]

[0.1142, 0.0396]


In [28]:
preds

array([2, 1, 1, ..., 2, 2, 2])

In [29]:
labels

array([7, 2, 1, ..., 4, 5, 6])

In [126]:
r = pd.DataFrame(preds==labels)
r[r[0] == True]

Unnamed: 0,0
64,True
111,True
144,True
175,True
411,True
...,...
9546,True
9618,True
9652,True
9746,True


In [127]:
r

Unnamed: 0,0
0,False
1,False
2,False
3,False
4,False
...,...
9995,False
9996,False
9997,False
9998,False
