In [1]:
import subprocess
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import mnist
import torch.nn as nn
%matplotlib inline

def generate_sign(n = 2000, size = 28, length_min = 13, length_max = 20, width_min = 3, width_max = 5):
    center = math.floor(size / 2)
    sign = np.random.choice((-1, 1), size = n)
    length = np.random.choice(np.arange(length_min, length_max + 1), size = n)
    width = np.random.choice(np.arange(width_min, width_max + 1), size = n)
    start = np.random.choice(np.arange(4, size - length_max - 1), size = n) - center
    x = np.tile(np.arange(size), size) - center
    y = np.repeat(np.arange(size), size) - center
    
    sign_img = np.zeros((n, size**2))
    for i in range(n):
        sign_i = sign[i]
        length_i = length[i]
        width_i = width[i]
        start_i = start[i]
        minus_i = ((x >= start_i) &
                   (x <= start_i + length_i) &
                   (y >= - math.floor(width_i / 2)) &
                   (y <= - math.floor(width_i / 2) + width_i))
        if sign_i == -1:
            sign_img[i, :] = minus_i * 1
        else:
            mid_i = (length_i + 2 * start_i) / 2
            minus_i_2 = ((x >= math.floor(mid_i - width_i / 2)) &
                         (x <= math.floor(mid_i + width_i / 2)) &
                         (y >= -math.floor(length_i / 2)) &
                         (y <= -math.floor(length_i / 2) + length_i))
            sign_img[i, :] = (minus_i | minus_i_2) * 1
    
    sign_img[sign_img > 0] = 0.9 * sign_img[sign_img > 0] + np.random.normal(scale = 0.05, size = np.sum(sign_img > 0))
    return sign_img, sign


class Dataset(torch.utils.data.Dataset):
    def __init__(self, predictors, labels):
        self.labels = labels
        self.predictors = predictors

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        X = self.predictors[index, :]
        y = self.labels[index, :]

        return X, y
    

import torch.nn as nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.linear = nn.Linear(32 * 7 * 7, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        x = self.linear(x)
        x = self.sigmoid(x)
        return x

In [2]:
class RBDN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 5, stride = 1, padding = 2)
        
        self.convB11 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB12 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB21 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB22 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB31 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB32 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)

        self.deconvB1 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.deconvB2 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.deconvB3 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)

        self.pool1 = nn.MaxPool2d(2, 2)
        self.poolB1 = nn.MaxPool2d(2, 2)
        self.poolB2 = nn.MaxPool2d(2, 2)
        self.poolB3 = nn.MaxPool2d(2, 2)

        self.unpool1 = nn.MaxUnpool2d(2, 2)
        self.unpoolB1 = nn.MaxUnpool2d(2, 2)
        self.unpoolB2 = nn.MaxUnpool2d(2, 2)
        self.unpoolB3 = nn.MaxUnpool2d(2, 2)
        
        self.conv2 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv3 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv4 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv5 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv6 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv7 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv8 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv9 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)

        self.unpoolL = nn.MaxUnpool2d(2, 2)
        self.deconvL = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)


    def forward(self, x):
        # conv1
        x1 = F.relu(self.conv1(x))
        x1 = self.pool1(x1)

        # B11
        xB11 = F.relu(self.convB11(x1))
        xB11 = self.poolB1(xB11)

        # B21
        xB21 = F.relu(self.convB21(xB11))
        xB21 = self.poolB2(xB21)

        # B3
        xB3 = F.relu(self.convB31(xB21))
        xB3 = self.poolB3(xB3)
        xB3 = F.relu(self.convB32(xB3))
        xB3 = self.unpoolB3(xB3)
        xB3 = F.relu(self.deconvB3(xB3))

        # B22
        xB22 = torch.cat((xB21, xB3), axis = 1)
        xB22 = F.relu(self.convB22(xB22))
        xB22 = self.unpoolB2(xB22)
        xB22 = F.relu(self.deconvB2(xB22))

        # B12
        xB12 = torch.cat((xB11, xB22), axis = 1)
        xB12 = F.relu(self.convB12(xB12))
        xB12 = self.unpoolB1(xB12)
        xB12 = F.relu(self.deconvB1(xB12))

        # conv 2-9
        x = torch.cat((x1, xB12), axis = 1)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        
        x = self.unpoolL(x)
        x = self.deconvL(x)

        return x

In [2]:
class RBDN(nn.Module):
    def __init__(self, num_channels = 64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = num_channels, kernel_size = 5, stride = 1, padding = 2)
        
        self.convB11 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.convB12 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)

        self.deconvB1 = nn.ConvTranspose2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)

        self.pool1 = nn.MaxPool2d(2, 2, return_indices = True)
        self.poolB1 = nn.MaxPool2d(2, 2, return_indices = True)

        self.unpool1 = nn.MaxUnpool2d(2, 2)
        self.unpoolB1 = nn.MaxUnpool2d(2, 2)
        
        self.conv2 = nn.Conv2d(in_channels = num_channels * 2, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv3 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv4 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv5 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv6 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv7 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv8 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv9 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)

        self.unpoolL = nn.MaxUnpool2d(2, 2)
        self.deconvL = nn.ConvTranspose2d(in_channels = num_channels, out_channels = 1, kernel_size = 3, stride = 1, padding = 1)


    def forward(self, x):
        # conv1
        x1 = F.relu(self.conv1(x))
        x1, idx1 = self.pool1(x1)

        # B1
        xB11 = F.relu(self.convB11(x1))
        xB11, idxB1 = self.poolB1(xB11)
        xB12 = F.relu(self.convB12(xB11))
        xB12 = self.unpoolB1(xB12, indices = idxB1)
        xB12 = F.relu(self.deconvB1(xB12))

        # conv 2-9
        x = torch.cat((x1, xB12), axis = 1)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        
        x = self.unpoolL(x, indices = idx1)
        x = self.deconvL(x)

        return x

In [4]:
num_channels = 64
rbdn = RBDN(num_channels)
sum([p.numel() for p in rbdn.parameters() if p.requires_grad])

445313

In [5]:
train_img, train_label, test_img, test_label = mnist.load()
train_img = train_img / 255
test_img = test_img / 255

In [6]:
train_label_1_idx = np.where(train_label == 1)[0]
train_label_2_idx = np.where(train_label == 2)[0]
train_label_3_idx = np.where(train_label == 3)[0]
test_label_1_idx = np.where(test_label == 1)[0]
test_label_2_idx = np.where(test_label == 2)[0]
test_label_3_idx = np.where(test_label == 3)[0]

n = 2000
n_train = 1000
n_test = 1000

num_exp = 50
result_mse = np.zeros((num_exp, 2))
result_acc = np.zeros((num_exp, 2))
for exp in range(num_exp):
    random.seed(exp)
    torch.manual_seed(exp)
    np.random.seed(exp)
    #####################################################################
    # generate images
    sign_img, sign = generate_sign(n = n)
    train_sign = sign[0:n_train]
    train_sign_img = sign_img[0:n_train, :]
    test_sign = sign[n_train:n]
    test_sign_img = sign_img[n_train:n, :]
    train_1s = train_img[np.random.choice(train_label_1_idx, size = n_train, replace = False), :]
    train_2s = train_img[np.random.choice(train_label_2_idx, size = n_train, replace = False), :]
    test_1s = test_img[np.random.choice(test_label_1_idx, size = n_train, replace = False), :]
    test_2s = test_img[np.random.choice(test_label_2_idx, size = n_train, replace = False), :]

    train_predictors = np.zeros((n_train, 28*28*3))
    train_outcomes = np.zeros((n_train, 28*28))
    train_outcomes_padding = np.zeros((n_train, 28*28*3))
    test_predictors = np.zeros((n_test, 28*28*3))
    test_outcomes = np.zeros((n_test, 28*28))
    test_outcomes_padding = np.zeros((n_test, 28*28*3))
    train_outcomes_label = np.zeros(n_train)
    test_outcomes_label = np.zeros(n_test)
    for i in range(n_train):
        train_1_i = train_1s[i, :].reshape((28, 28))
        train_2_i = train_2s[i, :].reshape((28, 28))
        train_sign_i = train_sign[i]
        train_sign_img_i = train_sign_img[i, :].reshape((28, 28))
        train_predictor_i = np.hstack((train_2_i, train_sign_img_i, train_1_i))
        if train_sign_i == -1:
            label_img_i = train_img[np.random.choice(train_label_1_idx, size = 1, replace = False), :]
            train_outcomes_label[i] = 1
        else:
            label_img_i = train_img[np.random.choice(train_label_3_idx, size = 1, replace = False), :]
            train_outcomes_label[i] = 3
        train_predictors[i, :] = train_predictor_i.reshape(-1)
        train_outcomes[i, :] = label_img_i
        label_img_i = label_img_i.reshape((28, 28))
        train_outcomes_padding[i, :] = np.hstack((np.zeros((28, 28)), label_img_i, np.zeros((28, 28)))).reshape(-1)

    for i in range(n_test):
        test_1_i = test_1s[i, :].reshape((28, 28))
        test_2_i = test_2s[i, :].reshape((28, 28))
        test_sign_i = test_sign[i]
        test_sign_img_i = test_sign_img[i, :].reshape((28, 28))
        test_predictor_i = np.hstack((test_2_i, test_sign_img_i, test_1_i))
        if test_sign_i == -1:
            label_img_i = test_img[np.random.choice(test_label_1_idx, size = 1, replace = False), :]
            test_outcomes_label[i] = 1
        else:
            label_img_i = test_img[np.random.choice(test_label_3_idx, size = 1, replace = False), :]
            test_outcomes_label[i] = 3
        test_predictors[i, :] = test_predictor_i.reshape(-1)
        test_outcomes[i, :] = label_img_i
        label_img_i = label_img_i.reshape((28, 28))
        test_outcomes_padding[i, :] = np.hstack((np.zeros((28, 28)), label_img_i, np.zeros((28, 28)))).reshape(-1)


    np.savetxt("train_predictors.txt", train_predictors)
    np.savetxt("test_predictors.txt", test_predictors)
    np.savetxt("train_outcomes.txt", train_outcomes)
    np.savetxt("test_outcomes.txt", test_outcomes)

    #####################################################################
    # train cnn mnist classifier
    cnn = CNN().to("cuda")
    cnn_optimizer = torch.optim.Adam(cnn.parameters(), lr = 0.001)
    cnn_loss = nn.functional.binary_cross_entropy

    cnnX = torch.tensor(train_outcomes, dtype = torch.float32).reshape((n_train, 1, 28, 28)).to("cuda")
    cnny = torch.tensor(train_outcomes_label, dtype = torch.float32).reshape((n_train, 1)).to("cuda")
    cnny[cnny == 1] = 0
    cnny[cnny == 3] = 1

    cnn_dataset = Dataset(cnnX, cnny)
    cnn_dataloader = torch.utils.data.DataLoader(cnn_dataset, batch_size = 64, shuffle = True)

    num_epochs = 100
    cnn.train()
    for epoch in range(num_epochs):
        for (idx, (X_batch, y_batch)) in enumerate(cnn_dataloader):
            X_batch = X_batch.to("cuda")
            y_batch = y_batch.to("cuda")
            output = cnn(X_batch)             
            loss = cnn_loss(output, y_batch)

            # clear gradients for this training step   
            cnn_optimizer.zero_grad()           

            # backpropagation, compute gradients 
            loss.backward()    
            # apply gradients             
            cnn_optimizer.step()

    rbdn = RBDN(num_channels).to("cuda")
    rbdn_train_predictors = torch.tensor(train_predictors, dtype = torch.float32).reshape((n_train, 1, 28, 84)).to("cuda")
    rbdn_train_outcomes = torch.tensor(train_outcomes_padding, dtype = torch.float32).reshape((n_train, 1, 28, 84)).to("cuda")
    rbdn_loss = nn.functional.mse_loss
    rbdn_optim = torch.optim.Adam(rbdn.parameters(), lr = 1e-3)

    rbdn_dataset = Dataset(rbdn_train_predictors, rbdn_train_outcomes)
    rbdn_dataloader = torch.utils.data.DataLoader(rbdn_dataset, batch_size = 64, shuffle = True)

    num_epochs = 50
    rbdn.train()
    for epoch in range(num_epochs):
        for (idx, (X_batch, y_batch)) in enumerate(rbdn_dataloader):
            X_batch = X_batch.to("cuda")
            y_batch = y_batch.to("cuda")
            output = rbdn(X_batch)             
            loss = rbdn_loss(output, y_batch)

            rbdn_optim.zero_grad()           
            loss.backward()            
            rbdn_optim.step()    
    
    rbdn_train_predictors = rbdn_train_predictors.to("cpu")
    rbdn = rbdn.to("cpu")

    rbdn_train_pred = rbdn(rbdn_train_predictors)
    rbdn_test_pred = rbdn(torch.tensor(test_predictors, dtype = torch.float32).reshape((n_test, 1, 28, 84)))

    rbdn_train_pred = rbdn_train_pred[:, :, :, 28:56]
    rbdn_test_pred = rbdn_test_pred[:, :, :, 28:56]

    cnn = cnn.to("cpu")
    result_acc[exp, 0] = sum( (cnn(rbdn_train_pred).detach().to("cpu").numpy().reshape(-1) > 0.5) == (train_outcomes_label == 3) ) / n_train
    result_acc[exp, 1] = sum( (cnn(rbdn_test_pred).detach().to("cpu").numpy().reshape(-1) > 0.5) == (test_outcomes_label == 3) ) / n_test
    result_mse[exp, 0] = np.mean((rbdn_train_pred.detach().to("cpu").numpy().reshape((n_train, 784)) - train_outcomes)**2)
    result_mse[exp, 1] = np.mean((rbdn_test_pred.detach().to("cpu").numpy().reshape((n_test, 784)) - test_outcomes)**2)

    print(exp, result_acc[exp, 0], result_acc[exp, 1])
    torch.cuda.empty_cache()

    np.savetxt("result_mse_rbdn.txt", result_mse)
    np.savetxt("result_acc_rbdn.txt", result_acc)



0 1.0 1.0
1 1.0 1.0
2 1.0 1.0
3 1.0 1.0
4 1.0 1.0
5 1.0 1.0
6 1.0 1.0
7 1.0 1.0
8 1.0 1.0
9 1.0 1.0
10 1.0 1.0
11 1.0 1.0
12 1.0 1.0
13 1.0 1.0
14 1.0 1.0
15 1.0 1.0
16 1.0 1.0
17 1.0 1.0
18 1.0 1.0
19 1.0 1.0
20 1.0 1.0
21 1.0 1.0
22 1.0 1.0
23 1.0 1.0
24 1.0 1.0
25 1.0 1.0
26 1.0 1.0
27 1.0 1.0
28 1.0 1.0
29 1.0 1.0
30 1.0 1.0
31 1.0 1.0
32 1.0 1.0
33 1.0 1.0
34 1.0 1.0
35 1.0 1.0
36 1.0 1.0
37 1.0 1.0
38 1.0 1.0
39 1.0 1.0
40 1.0 1.0
41 1.0 1.0
42 1.0 1.0
43 1.0 1.0
44 1.0 1.0
45 1.0 1.0
46 1.0 1.0
47 1.0 1.0
48 1.0 1.0
49 1.0 1.0
