In [None]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
torch.manual_seed(1234)

In [8]:
import matplotlib.pyplot as plt
import numpy as np


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [14]:
dic = torch.load("./MNIST_pair/training_dict.pt")
tup = torch.load("./MNIST_pair/training.pt")
with open('training_dict.pkl', 'rb') as f:
    pdic = pickle.load(f)
with open('training.pkl', 'rb') as f:
    ptup = pickle.load(f)

In [None]:
torch.save('./MNIST_pair/training_tuple.pt')

In [10]:
import pickle

In [36]:
with open('training_tuple.pkl', 'wb') as f:
    pickle.dump(tup, f, protocol=pickle.HIGHEST_PROTOCOL)

In [34]:
for (pk, pv), (ok, ov) in zip(pdic.items(), dic.items()):
    assert pv.numel() == ov.numel()
    assert (pv == ov).sum() / pv.numel() == 1.0

In [9]:
dic['num_pair'].shape

torch.Size([60000, 2, 28, 28])

In [9]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist_train_data = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True,transform=transform)
mnist_test_data = torchvision.datasets.MNIST(root='./MNIST', train=False, download=True,transform=transform)

In [10]:
import warnings
from PIL import Image
import os
import os.path
import numpy as np
import torch
import codecs
import string
import gzip
import lzma
from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union

def get_int(b: bytes) -> int:
    return int(codecs.encode(b, 'hex'), 16)


def open_maybe_compressed_file(path: Union[str, IO]) -> Union[IO, gzip.GzipFile]:
    """Return a file object that possibly decompresses 'path' on the fly.
       Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
    """
    if not isinstance(path, torch._six.string_classes):
        return path
    if path.endswith('.gz'):
        return gzip.open(path, 'rb')
    if path.endswith('.xz'):
        return lzma.open(path, 'rb')
    return open(path, 'rb')


SN3_PASCALVINCENT_TYPEMAP = {
    8: (torch.uint8, np.uint8, np.uint8),
    9: (torch.int8, np.int8, np.int8),
    11: (torch.int16, np.dtype('>i2'), 'i2'),
    12: (torch.int32, np.dtype('>i4'), 'i4'),
    13: (torch.float32, np.dtype('>f4'), 'f4'),
    14: (torch.float64, np.dtype('>f8'), 'f8')
}


def read_sn3_pascalvincent_tensor(path: Union[str, IO], strict: bool = True) -> torch.Tensor:
    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
       Argument may be a filename, compressed filename, or file object.
    """
    # read
    with open_maybe_compressed_file(path) as f:
        data = f.read()
    # parse
    magic = get_int(data[0:4])
    nd = magic % 256
    ty = magic // 256
    assert nd >= 1 and nd <= 3
    assert ty >= 8 and ty <= 14
    m = SN3_PASCALVINCENT_TYPEMAP[ty]
    s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
    parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
    assert parsed.shape[0] == np.prod(s) or not strict
    return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


def read_label_file(path: str) -> torch.Tensor:
    with open(path, 'rb') as f:
        x = read_sn3_pascalvincent_tensor(f, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 1)
    return x.long()


def read_image_file(path: str) -> torch.Tensor:
    with open(path, 'rb') as f:
        x = read_sn3_pascalvincent_tensor(f, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 3)
    return x

In [11]:
 training_set = (
            read_image_file(os.path.join("./MNIST/MNIST/raw/", 'train-images-idx3-ubyte')),
            read_label_file(os.path.join("./MNIST/MNIST/raw/", 'train-labels-idx1-ubyte'))
        )

In [13]:
imgs = training_set[0]
targets = training_set[1]
data = {'num_tuple_pair':[], 'num_concat_pair':[], 'summation': [],'summation_concat':[], 'formula':[]}
zero_idx = []
num_data = 60000

In [14]:
while num_data > 0:
    i, j = torch.randint(len(targets), (1,)), torch.randint(len(targets), (1,))
    tar1, tar2 = targets[i], targets[j]
    
    img1, img2 = imgs[i], imgs[j]
    data['num_tuple_pair'].append((img1, img2))
    img_pair = torch.cat((img1, img2))
    data['num_concat_pair'].append(img_pair)
    data['formula'].append(f"{tar1.item()}+{tar2.item()}")
    data['summation'].append((tar1.item() + tar2.item()))
    data['summation_concat'].append(torch.tensor(tar1.item() + tar2.item()))
    num_data -= 1        

In [19]:
data['num_concat_pair'] = torch.stack(tuple(data['num_concat_pair']))
data['summation_concat'] = torch.stack(tuple(data['summation_concat']))

In [31]:
concat_data = {'num_pair':[],'summation':[], 'formula':[]}
concat_data['num_pair'] = data['num_concat_pair']
concat_data['summation'] = data['summation_concat']
concat_data['formula'] = data['formula']

In [34]:
tup = (concat_data['num_pair'], concat_data['summation'])

In [None]:
os.mkdir('./MNIST_pair')

In [39]:
with open("./MNIST_pair/training_tuple.pt", 'wb') as f:
        torch.save(tup, f)

In [19]:
with open("./MNIST_pair/formula.pt", 'wb') as f:
        torch.save(tup[1], f)

In [48]:
with open("./MNIST_pair/training_dict.pt", 'wb') as f:
        torch.save(concat_data, f)

In [9]:
while num_data > 0:
    i, j = torch.randint(len(targets), (1,)), torch.randint(len(targets), (1,))
    tar1, tar2 = targets[i], targets[j]
    if tar1 != 0 and tar2 != 0:
        img1, img2 = imgs[i], imgs[j]
        data['num_tuple_pair'].append((img1, img2))
        img_pair = torch.cat((img1, img2))
        data['num_concat_pair'].append(img_pair)
        data['formula'].append(f"{tar1.item()}+{tar2.item()}")
        data['summation'].append((tar1.item() +tar2.item()))
        data['summation_concat'].append(torch.tensor(tar1.item()+tar2.item()))
        num_data -= 1
    
    elif  targets[i] == 0 and  targets[j] == 0:
        targets = torch.cat((targets[0:i], targets[i+1:]))
        imgs = torch.cat((imgs[:i], imgs[i+1:]))
        
        targets = torch.cat((targets[0:j], targets[j+1:]))
        imgs = torch.cat((imgs[:j], imgs[j+1:]))
        
    elif  targets[i] == 0 and  targets[j] != 0 :
        targets = torch.cat((targets[0:i], targets[i+1:]))
        imgs = torch.cat((imgs[:i], imgs[i+1:]))
        
    else:
        targets = torch.cat((targets[0:j], targets[j+1:]))
        imgs = torch.cat((imgs[:j], imgs[j+1:]))
        