In [1]:
import subprocess
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import torch.nn as nn
%matplotlib inline


class Dataset(torch.utils.data.Dataset):
    def __init__(self, predictors, labels):
        self.labels = labels
        self.predictors = predictors

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, index):
        X = self.predictors[index, :]
        y = self.labels[index, :]

        return X, y

In [2]:
class RBDN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 5, stride = 1, padding = 2)
        
        self.convB11 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB12 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB21 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB22 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB31 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.convB32 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)

        self.deconvB1 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.deconvB2 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.deconvB3 = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)

        self.pool1 = nn.MaxPool2d(2, 2)
        self.poolB1 = nn.MaxPool2d(2, 2)
        self.poolB2 = nn.MaxPool2d(2, 2)
        self.poolB3 = nn.MaxPool2d(2, 2)

        self.unpool1 = nn.MaxUnpool2d(2, 2)
        self.unpoolB1 = nn.MaxUnpool2d(2, 2)
        self.unpoolB2 = nn.MaxUnpool2d(2, 2)
        self.unpoolB3 = nn.MaxUnpool2d(2, 2)
        
        self.conv2 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv3 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv4 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv5 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv6 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv7 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv8 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)
        self.conv9 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)

        self.unpoolL = nn.MaxUnpool2d(2, 2)
        self.deconvL = nn.ConvTranspose2d(in_channels = 64, out_channels = 64, kernel_size = 3, stride = 1, padding = 1)


    def forward(self, x):
        # conv1
        x1 = F.relu(self.conv1(x))
        x1 = self.pool1(x1)

        # B11
        xB11 = F.relu(self.convB11(x1))
        xB11 = self.poolB1(xB11)

        # B21
        xB21 = F.relu(self.convB21(xB11))
        xB21 = self.poolB2(xB21)

        # B3
        xB3 = F.relu(self.convB31(xB21))
        xB3 = self.poolB3(xB3)
        xB3 = F.relu(self.convB32(xB3))
        xB3 = self.unpoolB3(xB3)
        xB3 = F.relu(self.deconvB3(xB3))

        # B22
        xB22 = torch.cat((xB21, xB3), axis = 1)
        xB22 = F.relu(self.convB22(xB22))
        xB22 = self.unpoolB2(xB22)
        xB22 = F.relu(self.deconvB2(xB22))

        # B12
        xB12 = torch.cat((xB11, xB22), axis = 1)
        xB12 = F.relu(self.convB12(xB12))
        xB12 = self.unpoolB1(xB12)
        xB12 = F.relu(self.deconvB1(xB12))

        # conv 2-9
        x = torch.cat((x1, xB12), axis = 1)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        
        x = self.unpoolL(x)
        x = self.deconvL(x)

        return x

In [49]:
class RBDN(nn.Module):
    def __init__(self, num_channels = 64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = num_channels, kernel_size = 5, stride = 1, padding = 2)
        
        self.convB11 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.convB12 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)

        self.deconvB1 = nn.ConvTranspose2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)

        self.pool1 = nn.MaxPool2d(2, 2, return_indices = True)
        self.poolB1 = nn.MaxPool2d(2, 2, return_indices = True)

        self.unpool1 = nn.MaxUnpool2d(2, 2)
        self.unpoolB1 = nn.MaxUnpool2d(2, 2)
        
        self.conv2 = nn.Conv2d(in_channels = num_channels * 2, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv3 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv4 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv5 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv6 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv7 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv8 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)
        self.conv9 = nn.Conv2d(in_channels = num_channels, out_channels = num_channels, kernel_size = 3, stride = 1, padding = 1)

        self.unpoolL = nn.MaxUnpool2d(2, 2)
        self.deconvL = nn.ConvTranspose2d(in_channels = num_channels, out_channels = 1, kernel_size = 3, stride = 1, padding = 1)

    def forward(self, x):
        # conv1
        x1 = F.relu(self.conv1(x))
        x1, idx1 = self.pool1(x1)

        # B1
        xB11 = F.relu(self.convB11(x1))
        xB11, idxB1 = self.poolB1(xB11)
        xB12 = F.relu(self.convB12(xB11))
        xB12 = self.unpoolB1(xB12, indices = idxB1)
        xB12 = F.relu(self.deconvB1(xB12))

        # conv 2-9
        x = torch.cat((x1, xB12), axis = 1)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        
        x = self.unpoolL(x, indices = idx1)
        x = self.deconvL(x)

        return x

In [50]:
num_channels = 64
rbdn = RBDN(num_channels)
sum([p.numel() for p in rbdn.parameters() if p.requires_grad])

445313

In [41]:
train_data = np.loadtxt("fashion-mnist_train.csv", skiprows = 1, delimiter = ",")
test_data = np.loadtxt("fashion-mnist_test.csv", skiprows = 1, delimiter = ",")

train_img = train_data[:, 1:]
test_img = test_data[:, 1:]
train_label = train_data[:, 0]
test_label = test_data[:, 0]
train_img = train_img / 255
test_img = test_img / 255

train_idx_all = np.arange(60000)
test_idx_all = np.arange(10000)

In [80]:
n = 2000
n_train = 1000
n_test = 1000

num_exp = 50
result_mse = np.zeros((num_exp, 2))
for exp in range(num_exp):
    random.seed(exp)
    torch.manual_seed(exp)
    np.random.seed(exp)
    #####################################################################
    # generate images
    train_idx = np.random.choice(train_idx_all, size = n_train, replace = False)
    test_idx = np.random.choice(test_idx_all, size = n_test, replace = False)

    train_imgs = train_img[train_idx, ]
    test_imgs = test_img[test_idx, ]

    train_quantiles = np.zeros((4, n_train))
    for i in range(n_train):
        train_img_i = train_imgs[i, :]
        train_img_i = train_img_i[train_img_i > 0]
        train_quantiles[:, i] = np.quantile(train_img_i, [0, 0.25, 0.5, 0.75])


    train_q0 = np.tile(train_quantiles[0, :].reshape((n_train, 1)), (1, 784))
    train_q1 = np.tile(train_quantiles[1, :].reshape((n_train, 1)), (1, 784))
    train_q2 = np.tile(train_quantiles[2, :].reshape((n_train, 1)), (1, 784))
    train_q3 = np.tile(train_quantiles[3, :].reshape((n_train, 1)), (1, 784))

    train_p0 = np.zeros((n_train, 784))
    train_p1 = np.zeros((n_train, 784))
    train_p2 = np.zeros((n_train, 784))
    train_p3 = np.zeros((n_train, 784))

    train_p3[train_imgs >= train_q3] = train_imgs[train_imgs >= train_q3]
    train_p2[(train_imgs >= train_q2) & (train_imgs < train_q3)] = train_imgs[(train_imgs >= train_q2) & (train_imgs < train_q3)]
    train_p1[(train_imgs >= train_q1) & (train_imgs < train_q2)] = train_imgs[(train_imgs >= train_q1) & (train_imgs < train_q2)]
    train_p0[(train_imgs >= train_q0) & (train_imgs < train_q1)] = train_imgs[(train_imgs >= train_q0) & (train_imgs < train_q1)]

    train_predictors = np.zeros((n_train, 28*28*4))
    train_outcomes = train_imgs
    train_outcomes_padding = np.zeros((n_train, 28*28*4))
    for i in range(n_train):
        train_p0_i = train_p0[i, :].reshape((28, 28))
        train_p1_i = train_p1[i, :].reshape((28, 28))
        train_p2_i = train_p2[i, :].reshape((28, 28))
        train_p3_i = train_p3[i, :].reshape((28, 28))
        train_predictor_i = np.hstack((train_p0_i, train_p1_i, train_p2_i, train_p3_i))
        train_predictors[i, :] = train_predictor_i.reshape(-1)
        train_outcome_i = train_outcomes[i, :].reshape((28, 28))
        train_outcomes_padding[i, :] = np.hstack((np.zeros((28, 42)), train_outcome_i, np.zeros((28, 42)))).reshape(-1)
        
        
    test_quantiles = np.zeros((4, n_test))
    for i in range(n_test):
        test_img_i = test_imgs[i, :]
        test_img_i = test_img_i[test_img_i > 0]
        test_quantiles[:, i] = np.quantile(test_img_i, [0, 0.25, 0.5, 0.75])


    test_q0 = np.tile(test_quantiles[0, :].reshape((n_test, 1)), (1, 784))
    test_q1 = np.tile(test_quantiles[1, :].reshape((n_test, 1)), (1, 784))
    test_q2 = np.tile(test_quantiles[2, :].reshape((n_test, 1)), (1, 784))
    test_q3 = np.tile(test_quantiles[3, :].reshape((n_test, 1)), (1, 784))

    test_p0 = np.zeros((n_test, 784))
    test_p1 = np.zeros((n_test, 784))
    test_p2 = np.zeros((n_test, 784))
    test_p3 = np.zeros((n_test, 784))

    test_p3[test_imgs >= test_q3] = test_imgs[test_imgs >= test_q3]
    test_p2[(test_imgs >= test_q2) & (test_imgs < test_q3)] = test_imgs[(test_imgs >= test_q2) & (test_imgs < test_q3)]
    test_p1[(test_imgs >= test_q1) & (test_imgs < test_q2)] = test_imgs[(test_imgs >= test_q1) & (test_imgs < test_q2)]
    test_p0[(test_imgs >= test_q0) & (test_imgs < test_q1)] = test_imgs[(test_imgs >= test_q0) & (test_imgs < test_q1)]

    test_predictors = np.zeros((n_test, 28*28*4))
    test_outcomes = test_imgs
    test_outcomes_padding = np.zeros((n_test, 28*28*4))
    for i in range(n_test):
        test_p0_i = test_p0[i, :].reshape((28, 28))
        test_p1_i = test_p1[i, :].reshape((28, 28))
        test_p2_i = test_p2[i, :].reshape((28, 28))
        test_p3_i = test_p3[i, :].reshape((28, 28))
        test_predictor_i = np.hstack((test_p0_i, test_p1_i, test_p2_i, test_p3_i))
        test_predictors[i, :] = test_predictor_i.reshape(-1)
        test_outcome_i = test_outcomes[i, :].reshape((28, 28))
        test_outcomes_padding[i, :] = np.hstack((np.zeros((28, 42)), test_outcome_i, np.zeros((28, 42)))).reshape(-1)

    np.savetxt("train_predictors.txt", train_predictors)
    np.savetxt("test_predictors.txt", test_predictors)
    np.savetxt("train_outcomes.txt", train_outcomes)
    np.savetxt("test_outcomes.txt", test_outcomes)

    rbdn = RBDN(num_channels).to("cuda")
    rbdn_train_predictors = torch.tensor(train_predictors, dtype = torch.float32).reshape((n_train, 1, 28, 112)).to("cuda")
    rbdn_train_outcomes = torch.tensor(train_outcomes_padding, dtype = torch.float32).reshape((n_train, 1, 28, 112)).to("cuda")
    rbdn_loss = nn.functional.mse_loss
    rbdn_optim = torch.optim.Adam(rbdn.parameters(), lr = 1e-3)

    rbdn_dataset = Dataset(rbdn_train_predictors, rbdn_train_outcomes)
    rbdn_dataloader = torch.utils.data.DataLoader(rbdn_dataset, batch_size = 64, shuffle = True)

    num_epochs = 50
    rbdn.train()
    for epoch in range(num_epochs):
        for (idx, (X_batch, y_batch)) in enumerate(rbdn_dataloader):
            X_batch = X_batch.to("cuda")
            y_batch = y_batch.to("cuda")
            output = rbdn(X_batch)             
            loss = rbdn_loss(output, y_batch)

            rbdn_optim.zero_grad()           
            loss.backward()            
            rbdn_optim.step()
    
    rbdn_train_predictors = rbdn_train_predictors.to("cpu")
    rbdn = rbdn.to("cpu")

    rbdn_train_pred = rbdn(rbdn_train_predictors)
    rbdn_test_pred = rbdn(torch.tensor(test_predictors, dtype = torch.float32).reshape((n_test, 1, 28, 112)))

    rbdn_train_pred = rbdn_train_pred[:, :, :, 42:70]
    rbdn_test_pred = rbdn_test_pred[:, :, :, 42:70]

    result_mse[exp, 0] = np.mean((rbdn_train_pred.detach().to("cpu").numpy().reshape((n_train, 784)) - train_outcomes)**2)
    result_mse[exp, 1] = np.mean((rbdn_test_pred.detach().to("cpu").numpy().reshape((n_test, 784)) - test_outcomes)**2)

    print(exp, result_mse[exp, 0], result_mse[exp, 1])
    torch.cuda.empty_cache()

    np.savetxt("result_mse_rbdn.txt", result_mse)

0 0.03183671604201324 0.05997440799546761
1 0.028112346753245348 0.03967142362843392
2 0.03247798115966877 0.05614195362114102
3 0.03405832968122399 0.045341083475639
4 0.030918063012268338 0.04698811771253044
5 0.026032466034815026 0.051097723735983584
6 0.029460508808956634 0.04446909990919047
7 0.025476917214574835 0.04611815398865871
8 0.03630754846466739 0.056051359671980876
9 0.029242342680338537 0.0463854644746687
10 0.02743156440980799 0.04556926659342202
11 0.028667399283391043 0.04355979551438229
12 0.03042843209822699 0.05227563496039377
13 0.03312195053248766 0.05117278825814095
14 0.034899879891482286 0.0426519609902003
15 0.025113375162007255 0.03985696482785238
16 0.031628826066188755 0.04266884652952742
17 0.027995357440242716 0.047439949356599995
18 0.033382047463223576 0.0421281882526685
19 0.03437315819703424 0.05961349356876301
20 0.031401038909475626 0.0523975706886006
21 0.02822860716068578 0.04594070237042509
22 0.03903883548866728 0.04732787610639706
23 0.028323