In [1]:
import gc
import pickle
import torch
import torchvision
from torchvision import models
from torch import optim
from datetime import datetime
from tqdm.auto import tqdm
import pandas as pd
import itertools
from copy import deepcopy
from collections import defaultdict
from transformers import get_scheduler
from torch import nn
from torch.utils.data import DataLoader, Dataset
import wandb

## Prepare data
All four  "training_"  files contain the same data. They consist of sets of two 'MNIST' number pairs and correspoding sums of those two numbers.
The four files are saved with different file extensions (.pt vs .pkl) or in different data types (dict vs tuple).

In [2]:
with open('./data/MNIST_pair/training_tuple.pkl', 'rb') as f:
    data = pickle.load(f)

with open('./data/MNIST_pair/training_dict.pkl', 'rb') as f:
    data = pickle.load(f)
    
data = torch.load("./data/MNIST_pair/training_dict.pt")

data = torch.load("./data/MNIST_pair/training_tuple.pt")


In [3]:
with open('./data/MNIST_pair/test.pkl', 'rb') as f:
    test_data = pickle.load(f)

test_data = torch.load("./data/MNIST_pair/test.pt")

## Prepare Dataset & DataLoader

Split data into train data and validation data (if you need)

In [4]:
tr_len = int(data[0].size(0) * 0.75)

perm_idxs = torch.randperm(data[0].size(0))

train_idx = perm_idxs[:tr_len]
valid_idx = perm_idxs[tr_len: ]

train_data = data[0][train_idx]
train_label = data[1][train_idx]
assert len(train_data) == len(train_label)

valid_data = data[0][valid_idx]
valid_label = data[1][valid_idx]
assert len(valid_data) == len(valid_label)

In [5]:
mean, std = train_data.float().mean(), train_data.float().std()

train_data = ((train_data - mean) / std, train_label)
valid_data = ((valid_data - mean) / std , valid_label)

In [2]:
class MNISTDataset(Dataset):
    def __init__(self, dataset, test:bool = False):
        super().__init__()
        self.dataset = dataset
        
        
        if test:
            self.num1 = dataset[0][:, None, ]
            self.num2 = dataset[0][:, None, ]
            self.labels = dataset[1]
        
        else:
            num_pairs = dataset[0]
            self.labels = dataset[1]
            self.num1 = num_pairs[:, 0, None, ]
            self.num2 = num_pairs[:, 1, None, ]
            assert len(self.num1) == len(self.num2)
                     
        assert len(self.num1) == len(self.labels)
        
    def __len__(self,):
        return len(self.num1)
    
    def __getitem__(self, idx):
        return self.num1[idx].float(), self.num2[idx].float(), self.labels[idx]

In [3]:
train_dataset = MNISTDataset(train_data, test=False)
val_dataset = MNISTDataset(valid_data, test=False)
test_dataset = MNISTDataset(test_data, test=True)

NameError: name 'train_data' is not defined

In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

## Loss function

In [9]:
predict_method = "combination"
if predict_method == "combination":
    
    """ look up table for all permutations corresponding to summation"""
    import itertools
    from copy import deepcopy
    from collections import defaultdict

    num = (train_data[1].unique() / 2).int().unique().numpy().tolist()
    assert len(num) == 10
    num = num * 2
    
    per_list = sorted((set(itertools.permutations(num, 2, ))), key=lambda x: x)
   
    assert len(per_list) == 100

    label_dic = defaultdict(list)
    
    for i in per_list:
        label_dic[sum(i)].append(i)

    label_dic = pd.DataFrame(data=label_dic.values(), index=label_dic.keys()).T

    assert 100 == (~label_dic.isna()).sum().sum()
    
    def custom_loss(pred, label):
        
        return nn.NLLLoss(reduction="none")(nn.LogSoftmax(dim=-1)(pred), label)
    
    loss_function = custom_loss
    
else: 
    loss_function = nn.CrossEntropyLoss()
    

# ===========================================

# Main

# Prepare Data function

In [4]:
class MNISTDataset(Dataset):
    def __init__(self, dataset, test:bool = False):
        super().__init__()
        self.dataset = dataset
        
        
        if test:
            self.num1 = dataset[0][:, None, ]
            self.num2 = dataset[0][:, None, ]
            self.labels = dataset[1]
        
        else:
            num_pairs = dataset[0]
            self.labels = dataset[1]
            self.num1 = num_pairs[:, 0, None, ]
            self.num2 = num_pairs[:, 1, None, ]
            assert len(self.num1) == len(self.num2)
                     
        assert len(self.num1) == len(self.labels)
        
    def __len__(self,):
        return len(self.num1)
    
    def __getitem__(self, idx):
        return self.num1[idx].float(), self.num2[idx].float(), self.labels[idx]

In [5]:
def get_loader(batch_size):
    """Data For Train and Validation"""
    with open('./data/MNIST_pair/training_tuple.pkl', 'rb') as f:
        data = pickle.load(f)
    
    """Split Train and Valdiation set"""
    tr_len = int(data[0].size(0) * 0.75)

    perm_idxs = torch.randperm(data[0].size(0))

    train_idx = perm_idxs[:tr_len]
    valid_idx = perm_idxs[tr_len: ]

    train_data = data[0][train_idx]
    train_label = data[1][train_idx]
    assert len(train_data) == len(train_label)

    valid_data = data[0][valid_idx]
    valid_label = data[1][valid_idx]
    assert len(valid_data) == len(valid_label)
        
    """Normalize"""
    mean, std = train_data.float().mean(), train_data.float().std()
    train_data = ((train_data - mean) / std, train_label)
    valid_data = ((valid_data - mean) / std , valid_label)

#     train_data = (train_data, train_label)
#     valid_data = (valid_data, valid_label)
    

    """Data For Test"""
    with open('./data/MNIST_pair/test.pkl', 'rb') as f:
        test_data = pickle.load(f)
    test_label = test_data[1]
    
    """Normalize"""
    mean, std = test_data[0].float().mean(), test_data[0].float().std()
    test_data = ((test_data[0] - mean) / std, test_label)

#     test_data = (test_data[0], test_label)

    
    """Make Dataset"""
    train_dataset = MNISTDataset(train_data, test=False)
    val_dataset = MNISTDataset(valid_data, test=False)
    test_dataset = MNISTDataset(test_data, test=True)
    
    """Make DataLoader"""
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    return train_loader, val_loader, test_loader

## Metric

In [6]:
def accuracy(pred, target):
    pred_num = torch.argmax(pred, dim=-1)
    
    wandb.log({"acc/pred": pred_num, 
               "acc/label": target,
               "acc/acc": (pred_num == target).float().mean()})
    
    return (pred_num == target).float().mean()

def accuracy_sum(pred1,pred2, target):
    pred_num1 = torch.argmax(pred1, dim=-1)
    pred_num2 = torch.argmax(pred2,dim=-1)
    pred_sum = pred_num1 + pred_num2
    return (pred_sum == target).float().mean()   

# Prepare loss function

In [7]:
def get_loss_function(predict_method):
    if predict_method == "inverse_augment":
        """  Make label dictionary storing all permutations corresponding to summation"""

        num = list(range(10))   
        num = num * 2
        per_list = sorted((set(itertools.permutations(num, 2, ))), key=lambda x: x)

        assert len(per_list) == 100

        label_dic = defaultdict(list)
        for i in per_list:
            label_dic[sum(i)].append(i)

        label_dic = pd.DataFrame(data=label_dic.values(), index=label_dic.keys()).T

        assert 100 == (~label_dic.isna()).sum().sum()

        def custom_loss(pred, label):

            return nn.NLLLoss(reduction="none")(nn.LogSoftmax(dim=-1)(pred), label)

        loss_function = custom_loss
        
        return loss_function, label_dic
    else: 
        loss_function = nn.CrossEntropyLoss()
    
        return loss_function, None

## Model

In [14]:
class MNIST_Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.BatchNorm2d(6),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2), # 6 24 24 -> 6 12 12
            nn.Conv2d(6, 16, 5), # 6 12 12 -> 16 8 8
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2), # 16 8 8 -> 16 4 4
        )
        
        for m in self.encoder:
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)

        return x        

In [15]:
class Classifier(nn.Module):
    def __init__(self, N=10):
        super().__init__()
        self.classifier =  nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, N)
        )
        
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
        
    def forward(self, x):
        return self.classifier(x)

In [16]:
def get_models(predict_method):
    encoder = MNIST_Encoder()
    
    if predict_method in ["inverse_augment", "combination"]:
        classifier = Classifier(N=10)
    
    elif predict_method in ['add_hs', 'add_logits']:
        classifier = Classifier(N=19)
    else:
        raise ValueError (f"Method only have 3 options : ['add_hs', 'add_logits', 'combination', 'inverse_augment'], but {predict_method} is given")
    return encoder, classifier

## Train loop

In [17]:
def calculate_loss_acc(h1, h2, label, classifier, loss_function, label_dic, predict_method: str = "add_hs", test=False):
    if predict_method == "combination":
        if test:
            logit = classifier(h1)

            loss = loss_function(logit, label).mean()

            acc = accuracy(logit, label)

            return loss, acc    
            
        
        logit1 = classifier(h1)
        logit2 = classifier(h2)
        
        prob1 = nn.Softmax(dim=-1)(logit1).unsqueeze(-1)
        prob2 = nn.Softmax(dim=-1)(logit2).unsqueeze(1)
        
        batch_matrix = prob1 @ prob2
    
        anti_diag = [
            torch.stack(
                        [torch.sum(torch.diag(torch.fliplr(mat), diag)) 
                         for diag in range(len(mat)-1, -len(mat), -1)
                        ]
                        ) 
                for mat in batch_matrix ]
        
        batch_preds = torch.stack(anti_diag)
        
        loss = loss_function(batch_preds, label)
        
        acc = accuracy(batch_preds, label)
        
        return loss, acc
        
    elif predict_method == "add_hs":
        if test:
            label = label *2
        
        hidden = (h1 + h2) / 2
        
        logit = classifier(hidden)
        
        loss = loss_function(logit, label)

        acc = accuracy(logit, label)        
    
        return loss, acc
    
    elif predict_method == "add_logits":
        if test:
            label = label *2
        
        logit1 = classifier(h1)
        logit2 = classifier(h2)
        
        logit = (logit1 + logit2) / 2
        
        loss = loss_function(logit, label)

        acc = accuracy(logit, label)        

        return loss, acc
        
    elif predict_method == "inverse_augment":
        
        if test:
            logit = classifier(h1)
            
            loss = loss_function(logit, label).mean()
            
            acc = accuracy(logit, label)
            
            return loss, acc
        
        logit1 = classifier(h1)
        logit2 = classifier(h2)
        
        loss, acc = 0, 0
        
        assert label_dic is not None
        
        for lo1, lo2, (i, inverse_augments) in zip(logit1, logit2, label_dic[label].items()):

            label1 = torch.tensor(list(zip(*inverse_augments.dropna().values))[0])
            label2 = torch.tensor(list(zip(*inverse_augments.dropna().values))[1])
            
            lo1 = lo1.expand(label1.shape[0], -1)
            lo2 = lo2.expand(label2.shape[0], -1)
            
            loss += (loss_function(lo1, label1) + loss_function(lo2, label2)).mean()
            
            acc += (accuracy(lo1, label1) + accuracy(lo2, label2)).mean()
            
        loss /= len(label)
        acc /= len(label)
        
        return loss, acc
        
    else:
        raise ValueError (f"Method only have 3 options : ['add_hs', 'add_logits', 'inverse_augment'], but {predict_method} is given")

In [18]:
def train_epoch(epoch: int, 
                encoder: nn.Module,
                classifier:nn.Module,
                dataloader:torch.utils.data.DataLoader, 
                loss_function: nn.Module, 
                optimizer: torch.optim.Optimizer, 
                lr_scheduler,
                predict_method: str,
                label_dic = None,
               ):
    
    wandb.define_metric("Train/step")
    wandb.define_metric("Train/*", step_metric="Train/step")
    
    total_loss, total_acc = 0.0, 0.0
    
    with tqdm(enumerate(dataloader), desc=f"Training Epoch {epoch}", total=len(dataloader)) as train_bar:
        for tri, batch  in train_bar:
            
            encoder.train()
            classifier.train()
            optimizer.zero_grad()

            num1, num2, label = batch

            h1 = encoder(num1)
            h2 = encoder(num2)

            tr_loss, tr_acc = calculate_loss_acc(h1, h2, label, 
                                                classifier, 
                                                loss_function, 
                                                label_dic, 
                                                predict_method)
            
            tr_loss.backward()
            optimizer.step()
            lr_scheduler.step()

            total_loss += tr_loss.item()
            total_acc += tr_acc.item()
            
            train_bar.set_description(f"Train Step {tri} || Train ACC {tr_acc: .4f} | Train Loss {tr_loss.item(): .4f}")
            
            log_dict = {"Train/step": tri + epoch*len(dataloader),
                        "Train/Accuracy": tr_acc,
                        "Train/Loss": tr_loss}
                
            wandb.log(log_dict)
            
    return total_loss/len(dataloader), total_acc / len(dataloader)

In [19]:
def valid_epoch(epoch: int, 
                encoder: nn.Module,
                classifier:nn.Module,
                dataloader:torch.utils.data.DataLoader, 
                loss_function: nn.Module, 
                predict_method: str,
                label_dic = None                
               ):
    
    wandb.define_metric("Valid/step")
    wandb.define_metric("Valid/*", step_metric="Valid/step")
        
    total_loss, total_acc = 0.0, 0.0
    
    with torch.no_grad():
        with tqdm(enumerate(dataloader), desc=f"Val Epoch {epoch}", total=len(dataloader)) as val_bar:
            for vli, batch  in val_bar:

                num1, num2, label = batch

                h1 = encoder(num1)
                h2 = encoder(num2)

                vl_loss, vl_acc = calculate_loss_acc(h1, h2, label,
                                                    classifier, 
                                                    loss_function, 
                                                    label_dic, 
                                                    predict_method)

                total_loss += vl_loss.item()
                total_acc += vl_acc.item()

                val_bar.set_description(f"Val Step {vli} || Val ACC {vl_acc: .4f} | Val Loss {vl_loss: .4f}")
                
                log_dict = {"Valid/step": vli + epoch*len(dataloader),
                            "Valid/Accuracy": vl_acc,
                            "Valid/Loss": vl_loss}
                
                wandb.log(log_dict)
                
    return total_loss/len(dataloader), total_acc / len(dataloader)               


In [20]:
def test(encoder: nn.Module,
        classifier:nn.Module,
        dataloader:torch.utils.data.DataLoader, 
        loss_function: nn.Module,
        predict_method: str,
        label_dic = None                
    ):
    
    wandb.define_metric("Test/step")
    wandb.define_metric("Test/*", step_metric="Test/step")
        
    total_loss, total_acc = 0.0, 0.0

    with torch.no_grad():
        with tqdm(enumerate(dataloader), desc="Test", total=len(dataloader)) as test_bar:
            for tti, batch  in test_bar:

                num, _, label = batch

                hs = encoder(num)
                
                tt_loss, tt_acc = calculate_loss_acc(hs, hs, label,
                                                    classifier, 
                                                    loss_function, 
                                                    label_dic, 
                                                    predict_method,
                                                    True)

                total_loss += tt_loss.item()
                total_acc += tt_acc.item()
                
                test_bar.set_description(f"Test Step {tti} || Test ACC {tt_acc: .4f} | Test Loss {tt_loss: .4f}")
                
                log_dict = {"Test/step": tti,
                            "Test/Accuracy": tt_acc,
                            "Test/Loss": tt_loss}
                
                wandb.log(log_dict)

    return total_loss/len(dataloader), total_acc / len(dataloader) 

In [34]:
def main(num_epoch, batch_size, predict_method, lr, weight_decay):

    
    train_loader, valid_loader, test_loader = get_loader(batch_size)
    
    loss_function, label_dic = get_loss_function(predict_method)
    
    encoder, classifier = get_models(predict_method)

    wandb.watch((encoder, classifier))
    
    optimizer = torch.optim.AdamW(params=[{"params":encoder.parameters(), "params":classifier.parameters()}],
                                           lr=lr, 
                                           weight_decay=weight_decay
                                 )
    
    lr_scheduler = get_scheduler("cosine", optimizer=optimizer, 
                                 num_warmup_steps=int(len(train_loader)*num_epoch*0.1),
                                 num_training_steps=len(train_loader)*num_epoch
                                )
                                           
    with tqdm(range(num_epoch), desc="Total Epoch", total=num_epoch) as total_bar:
    
        for epoch in total_bar:
            
            train_loss, train_acc = train_epoch(epoch, 
                                                encoder, 
                                                classifier, 
                                                train_loader, 
                                                loss_function, 
                                                optimizer, 
                                                lr_scheduler, 
                                                predict_method,
                                                label_dic)
            
            valid_loss, valid_acc = valid_epoch(epoch, 
                                                encoder,
                                                classifier,
                                                valid_loader,
                                                loss_function,
                                                predict_method,
                                                label_dic)
            
                            
            total_bar.set_description(f"Epoch {epoch} |||| Train ACC {train_acc:.4f} \
                                        Train Epoch Loss {train_loss:.4f} || \
                                        Valid Epoch ACC {valid_acc:.4f} \
                                        Valid Epoch Loss {valid_loss:.4f}")
            
            wandb.log({"Epoch/Epoch": epoch,
                       "Total_ACC/Train Epoch ACC": train_acc,
                       "Total_Loss/Train Epoch Loss": train_loss,
                       "Total_ACC/Valid Epoch ACC ": valid_acc,
                       "Total_Loss/Valid Epoch Loss": valid_loss,
                        })

        test_loss, test_acc = test(encoder,
                                    classifier,
                                    test_loader,
                                    loss_function,
                                    predict_method,
                                    label_dic)
        
        wandb.log({"Total_ACC/Test Accuracy": test_acc,
                    "Total_Loss/Test Loss": test_loss})
    
    torch.save(encoder.state_dict(), f'./result/{wandb.config.id}/')
    torch.save(classifier.state_dict(), f'./result/{wandb.config.id}/')
    

In [35]:
num_epoch = 2
batch_size = 64
# ['add_hs' or 'add_logits' or 'combination','inverse_augment']
predict_method = ['add_hs', 'add_logits', 'combination', 'inverse_augment'][2]    
lr = 0.0001
weight_decay = 0.00001

In [37]:
if __name__ == "__main__":
    import json
    import os
    from random import random
    
    with open('./mnist/config.json') as f:
        config = json.load(f)
    
    num_epoch = config.get('num_epoch')
    batch_size = config.get('batch_size')
    predict_method = config.get('predict_method')    # ['add_hs' or 'add_logits' or 'combination', 'inverse_augment']
    lr =config.get('lr')
    weight_decay = config.get('weight_decay')
    
    id = predict_method+str(random())
    config['id'] = id
    print(config)

    wandb.init(project="MNIST addition", config=config, id=id)

    os.makedirs(f"./result/{id}/")

    with open(f'./result/{id}/config.json', 'w') as f:
        json.dump(config, f)


    main(num_epoch=num_epoch, batch_size=batch_size, predict_method=predict_method, lr=lr, weight_decay=weight_decay)

{'num_epoch': 100, 'batch_size': 256, 'predict_method': 'combination', 'lr': 0.0005, 'weight_decay': 1e-05, 'id': 'combination0.19143462883965856'}


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

Total Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

Training Epoch 0:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 0:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 1:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 1:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 2:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 2:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 3:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 3:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 4:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 4:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 5:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 5:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 6:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 6:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 7:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 7:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 8:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 8:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 9:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 9:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 10:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 10:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 11:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 11:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 12:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 12:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 13:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 13:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 14:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 14:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 15:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 15:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 16:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 16:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 17:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 17:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 18:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 18:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 19:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 19:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 20:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 20:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 21:   0%|          | 0/176 [00:00<?, ?it/s]

Val Epoch 21:   0%|          | 0/59 [00:00<?, ?it/s]

Training Epoch 22:   0%|          | 0/176 [00:00<?, ?it/s]

KeyboardInterrupt: 