In [0]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

In [0]:
replay_ratio = 0.1
batch_size = 64
num_task = 10
mode_DL = True
mode_IL = not mode_DL

### Define P-MNIST data loader

In [0]:
class PermutedMNISTDataLoader(torchvision.datasets.MNIST):
    def __init__(self, source='data/mnist_data', train = True, shuffle_seed = None):
        super(PermutedMNISTDataLoader, self).__init__(source, train, download=True)
        
        self.train = train
        self.num_data = 0
        
        if self.train:
            self.permuted_train_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.train_data])
            self.num_data = self.permuted_train_data.shape[0]
            
        else:
            self.permuted_test_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.test_data])
            self.num_data = self.permuted_test_data.shape[0]
            
            
    def __getitem__(self, index):
        
        if self.train:
            input, label = self.permuted_train_data[index], self.train_labels[index]
        else:
            input, label = self.permuted_test_data[index], self.test_labels[index]
        
        return input, label

    
    def getNumData(self):
        return self.num_data

In [0]:
def permute_mnist():
    TrainLoaderList = []
    TestLoaderList = []
    
    train_data_num = 0
    test_data_num = 0
    
    for i in range(num_task):
        shuffle_seed = np.arange(28*28)
        np.random.shuffle(shuffle_seed)
        
        train_PMNIST_DataLoader = PermutedMNISTDataLoader(train=True, shuffle_seed=shuffle_seed)
        test_PMNIST_DataLoader = PermutedMNISTDataLoader(train=False, shuffle_seed=shuffle_seed)
        
        train_data_num += train_PMNIST_DataLoader.getNumData()
        test_data_num += test_PMNIST_DataLoader.getNumData()
        
        TrainLoaderList.append(torch.utils.data.DataLoader(
                                train_PMNIST_DataLoader,
                                batch_size=batch_size)
                            )
        TestLoaderList.append(torch.utils.data.DataLoader(
                                test_PMNIST_DataLoader,
                                batch_size=batch_size)
                            )
    
    return train_loader, test_loader, int(train_data_num/num_task), int(test_data_num/num_task)

TrainLoaderList, TestLoaderList, train_data_num, test_data_num = permute_mnist()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/mnist_data/PermutedMNISTDataLoader/raw/train-images-idx3-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/mnist_data/PermutedMNISTDataLoader/raw/train-labels-idx1-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/mnist_data/PermutedMNISTDataLoader/raw/t10k-images-idx3-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/mnist_data/PermutedMNISTDataLoader/raw/t10k-labels-idx1-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw
Processing...
Done!




#### Generative Adversarial Nets Modules

In [0]:
class Generator(torch.nn.Module):
    """
    Generator Class for GAN
    """
    def __init__(self):
        super(Generator, self).__init__()
        conv2d_1 = torch.nn.ConvTranspose2d(in_channels=num_noise,
                                   out_channels=28*8, 
                                   kernel_size=7, 
                                   stride=1,
                                   padding=0,
                                   bias=False)
        conv2d_2 = torch.nn.ConvTranspose2d(in_channels=28*8, 
                                   out_channels=28*4, 
                                   kernel_size=4, 
                                   stride=2,
                                   padding=1,
                                   bias=False)
        conv2d_3 = torch.nn.ConvTranspose2d(in_channels=28*4, 
                                   out_channels=1, 
                                   kernel_size=4, 
                                   stride=2,
                                   padding=1,
                                   bias=False)

        self.network = torch.nn.Sequential(
            conv2d_1,
            torch.nn.BatchNorm2d(num_features = 28*8),
            torch.nn.ReLU(inplace=True),
            conv2d_2,
            torch.nn.BatchNorm2d(num_features = 28*4),
            torch.nn.ReLU(inplace=True),
            conv2d_3,
            torch.nn.Tanh()
        )

        if cuda_available:
            self.network = self.network.cuda()

    def forward(self, x):
        return self.network(x.view(-1, num_noise, 1, 1))

In [0]:
class Discriminator(torch.nn.Module):
    """
    Discriminator Class for GAN
    """
    def __init__(self):
        super(Discriminator, self).__init__()
        conv2d_1 = torch.nn.Conv2d(in_channels=1, 
                                   out_channels=28*4, 
                                   kernel_size=4, 
                                   stride=2,
                                   padding=1,
                                   bias=False)
        conv2d_2 = torch.nn.Conv2d(in_channels=28*4, 
                                   out_channels=28*8, 
                                   kernel_size=4, 
                                   stride=2,
                                   padding=1,
                                   bias=False)
        conv2d_3 = torch.nn.Conv2d(in_channels=28*8, 
                                   out_channels=1, 
                                   kernel_size=7, 
                                   stride=1,
                                   padding=0,
                                   bias=False)

        self.network = torch.nn.Sequential(
            conv2d_1,
            torch.nn.BatchNorm2d(num_features=28*4),
            torch.nn.LeakyReLU(inplace=True),
            conv2d_2,
            torch.nn.BatchNorm2d(num_features=28*8),
            torch.nn.LeakyReLU(inplace=True),
            conv2d_3,
            torch.nn.Sigmoid()
        )

        if cuda_available:
            self.network = self.network.cuda()

    def forward(self, x):
        return self.network(x).view(-1, 1)

In [0]:
class Solver(torch.nn.Module):
    """
    Solver Class for Deep Generative Replay
    """
    def __init__(self, T_n):
        super(Solver, self).__init__()
        fc1 = torch.nn.Linear(28*28, 100)
        fc2 = torch.nn.Linear(100, 100)
        fc3 = torch.nn.Linear(100, T_n * 10)
        self.network = torch.nn.Sequential(
            fc1,
            torch.nn.ReLU(),
            fc2,
            torch.nn.ReLU(),
            fc3
        )

        if cuda_available:
            self.network = self.network.cuda()

    def forward(self, x):
        return self.network(x)

In [0]:
def sample_noise(batch_size, N_noise):
    """
    Returns 
    """
    if torch.cuda.is_available():
        return torch.randn(batch_size, N_noise).cuda()
    else:
        return torch.randn(batch_size, N_noise)

### Learning Scheme

In [0]:
def learn(num_task, ):
    gen = [Generator() for _ in range(num_task)]    
    for i, trainloader in TrainLoaderList:
        # Needed for training current generator & solver        
        if i > 0:
            pre_generator = gen[i-1]
            pre_solver = solver
        
        generator = gen[i]
        discriminator = Discriminator()
        solver = Solver(i)

        gen_optim = torch.optim.Adam(generator.parameters(), lr=0.0002)
        disc_optim = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
        solver_optim = torch.optim.Adam(solver.parameters(), lr=0.0001)

        # train GAN
        for image, label in trainloader:
            if torch.cuda.is_available():
                image = image.cuda()
                label = label.cuda()

            ### Discriminator Training
            disc_optim.zero_grad()
            
            p_real = discriminator(img_batch.view(img_batch.shape[0], -1 , 28, 28))
            p_fake = discriminator(generator(sample_z(batch_size, num_noise)))

            ones = torch.ones_like(p_real)
            zeros = torch.zeros_like(p_real)
            if cuda_available:
                ones = ones.cuda()
                zeros = zeros.cuda()

            loss_d = criterion(p_real, ones) + criterion(p_fake, zeros)

            loss_d.backward()
            disc_optim.step()

            ### Generator Training
            gen_optim.zero_grad()
            p_fake = discriminator(generator(sample_z(batch_size, num_noise)))

            ones = torch.ones_like(p_fake)
            if torch.cuda.is_available():
                ones = ones.cuda()

            loss_g = criterion(p_fake, ones)
            loss_g.backward()

            gen_optim.step()

            # Train Solver
            #
