In [52]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.autograd as autograd
import numpy as np
import matplotlib.pyplot as plt

T = 8
batch_size = 100
A = 60
B = 60
z_size = 10
read_size = 12
write_size = 12
enc_size = 256
dec_size = 256
epoch_num = 50
learning_rate = 1e-3
beta1 = 0.5
USE_CUDA = True
clip = 5.0
attention=1


def Variable(data, *args, **kwargs):
    if USE_CUDA:
        data = data.cuda()
    return autograd.Variable(data,*args, **kwargs)


class DrawModel(nn.Module):
    def __init__(self,T,A,B,z_size,read_size,write_size,dec_size,enc_size,attention):
        super(DrawModel,self).__init__()
        self.T = T
        # self.batch_size = batch_size
        self.attention = attention
        self.A = A
        self.B = B
        self.z_size = z_size
        self.read_size = read_size
        self.write_size = write_size
        self.dec_size = dec_size
        self.enc_size = enc_size
        self.cs = [0] * T #canvas
        self.logsigmas,self.sigmas,self.mus = [0] * T,[0] * T,[0] * T

        self.encoder = nn.LSTMCell(2 * B * A + dec_size, enc_size)   #encoder for read without attention
        self.encoder_a = nn.LSTMCell(2 * read_size * read_size + dec_size, enc_size) #encoder for read with attention
        self.decoder = nn.LSTMCell(z_size, dec_size)                 #decoder

        self.mu_linear = nn.Linear(dec_size, z_size)    #linear for eq(1) for latent distribution
        self.sigma_linear = nn.Linear(dec_size, z_size) #linear for eq(2) for latent distribution
        self.dec_linear = nn.Linear(dec_size, 5)        #linear for eq(21) generating 5 params from h_dec
        
        self.dec_w_linear_a = nn.Linear(dec_size, write_size*write_size)     #linear for eq(28)
        self.dec_w_linear = nn.Linear(dec_size, A*B)                         #linear for eq(18)

        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(enc_size, 10)


    def forward(self,x):
        self.batch_size = x.size()[0]
        h_enc_prev = Variable(torch.zeros(self.batch_size, self.enc_size))
        h_dec_prev = Variable(torch.zeros(self.batch_size, self.dec_size))

        enc_state = Variable(torch.zeros(self.batch_size,self.enc_size))
        dec_state = Variable(torch.zeros(self.batch_size, self.dec_size))

        if (self.attention):
            for t in range(self.T):
              c_prev = Variable(torch.zeros(self.batch_size,self.A * self.B)) if t == 0 else self.cs[t-1]
              x_hat = x - self.sigmoid(c_prev)     #eq(3)
              r_t = self.read(x,x_hat,h_dec_prev,1)  #eq(4)
              h_enc, enc_state = self.encoder_a(torch.cat((r_t,h_dec_prev),1),(h_enc_prev,enc_state)) #eq(5)
              h_enc_prev = h_enc
              z, self.mus[t], self.logsigmas[t], self.sigmas[t] = self.sampleQ(h_enc) #eq(6)
              h_dec, dec_state = self.decoder(z, (h_dec_prev, dec_state))
              self.cs[t] = c_prev + self.write(h_dec,1)
              h_dec_prev = h_dec
            
            probs = torch.nn.functional.log_softmax(self.fc(h_enc), dim=1)
            #print(probs.view(batch_size,10))
            return probs.view(batch_size,10)


    def gaussian_filter(self,h_dec,N):
        params = self.dec_linear(h_dec) #eq(21)
        gx, gy, log_sigma2, log_delta, log_gamma = torch.split(params, 1, dim=1) #each param has size batch_sizex1
        gx = (self.A+1)/2 * (gx+1) #eq(22)
        gy = (self.B+1)/2 * (gy+1) #eq(23)
        delta = (max(self.A,self.B) - 1) / (N - 1) * torch.exp(log_delta)  #eq(24)
        sigma2 = torch.exp(log_sigma2) 
        gamma = torch.exp(log_gamma)
        
        #make mean_x, mean_y, a, b into tensors with shape (1, N, AorB) to make Fx and Fy with the same shape
        mean_x = torch.zeros((batch_size,A)).to('cuda:0') #use this trick because I cannot make empty tensor 
        mean_y = torch.zeros((batch_size,B)).to('cuda:0') #use this trick because I cannot make empty tensor
        for i in range (N):
          mean_x_i = torch.ones((batch_size,A)).to('cuda:0') * (gx + (i- N/2 -0.5)*delta)  #eq(19) 
          #print(mean_x.shape,mean_x_i.shape)
          mean_x = torch.cat([mean_x, mean_x_i], axis=-1)
          mean_y_j = torch.ones((batch_size,A)).to('cuda:0') * (gy + (i- N/2 -0.5)*delta)  #eq(20) 
          mean_y = torch.cat([mean_y,mean_y_j], axis =-1)
        mean_x = mean_x[:,A:]     #undo the torch.zeros trick
        mean_y = mean_y[:,B:]     #undo the torch.zeros trick
        mean_x = mean_x.view(batch_size, N, A) #each row consists of A elements of mean_x(i)
        mean_y = mean_y.view(batch_size, N, B) #each row consists of B elements of mean_y(j)
        a = torch.zeros((delta.size()[0],A)).to('cuda:0')
        b = torch.zeros((delta.size()[0],B)).to('cuda:0')
        for i in range (N):
          ai = torch.ones((delta.size()[0],A)).to('cuda:0') * torch.range(0,A-1).to('cuda:0')
          bj = torch.ones((delta.size()[0],B)).to('cuda:0') * torch.range(0,B-1).to('cuda:0')
          a = torch.cat([a,ai], axis=-1)
          b = torch.cat([b,bj], axis=-1)
        a = a[:,A:]
        b = b[:,B:]
        a = a.view(batch_size, N, A)
        b = b.view(batch_size, N, B)

        sig = torch.ones(A).to('cuda:0') * sigma2
        sigma_2 = sig
        for i in range(N-1):
          sigma_2 = torch.cat([sigma_2,sig], axis=1)
        sigma_2 = sigma_2.view(-1,N,A)
        #print(a.size(),mean_x.size(),sigma_2.size())
        Fx = torch.exp(-torch.square(a-mean_x)/(2*sigma_2))  #eq(25)
        Fy = torch.exp(-torch.square(b-mean_y)/(2*sigma_2))  #eq(26)
        #normalize each row of Fx and Fy to sum 1
        Fx=Fx/(torch.sum(Fx,2,keepdim=True)+1e-8)
        Fy=Fy/(torch.sum(Fy,2,keepdim=True)+1e-8)
        

        return Fx,Fy,gamma

    def read(self,x,x_hat,h_dec_prev, attention):
        if (attention):
          Fx,Fy,gamma = self.gaussian_filter(h_dec_prev, self.read_size)
          #transpose Fx per batch
          Fxt = Fx.transpose(2,1)
          #reshape x from (batch,B*A) to (batch,B,A)
          x_matrix = x.view(-1,B,A)
          xhat_matrix = x_hat.view(-1,B,A)
          #concat
          first_half = torch.bmm(torch.bmm(Fy,x_matrix),Fxt)
          first_half = first_half.view(-1,self.read_size*self.read_size)
          second_half = torch.bmm(torch.bmm(Fy,xhat_matrix),Fxt)
          second_half = first_half.view(-1,self.read_size*self.read_size)
          gamma = gamma.view(-1,1)
          return gamma*(torch.cat([first_half,second_half], axis=1)) #eq(27)
        else:
          return torch.cat([x,x_hat], 1)  #eq(17)
            
    def write(self,h_dec, attention):
        if (attention):
          w = self.dec_w_linear_a(h_dec) #eq(28)
          w = w.view(self.batch_size,self.write_size,self.write_size)
          Fx,Fy,gamma = self.gaussian_filter(h_dec,write_size)
          Fyt = Fy.transpose(2,1)
          wr =  torch.bmm(torch.bmm(Fyt,w),Fx)      #eq(29)
          wr = wr.view(self.batch_size,self.A*self.B)
          return wr / gamma.view(-1,1)
        else:
          return self.dec_w_linear(h_dec) #eq(18)

    def sampleQ(self,h_enc):
        noise = Variable(torch.randn(self.batch_size, self.z_size))
        mu = self.mu_linear(h_enc)           #eq(1)
        log_sigma = self.sigma_linear(h_enc) #eq(2)
        sigma = torch.exp(log_sigma)

        return mu + sigma * noise , mu , log_sigma, sigma

# Datasets

In [1]:
#cluttered_MNIST (60X60)
from scipy.io import loadmat
import numpy as np
from matplotlib import pyplot as plt

data = loadmat('/Users/bellagodiva/Downloads/cluttered-mnist.mat')
x_train = data['x_tr']
x_train = np.transpose(x_train, (3,0,1,2))
x_train = x_train.reshape(-1,60,60)
y_train = data['y_tr']
y_train = np.transpose(y_train)
y_train = y_train.reshape(-1,1)
y_train_onehot = []

x_val = data['x_vl']
x_val = np.transpose(x_val, (3,0,1,2))
x_val = x_val.reshape(-1,60,60)
y_val = data['y_vl']
y_val = np.transpose(y_val)
y_val = y_val.reshape(-1,1)

x_test = data['x_ts']
x_test = np.transpose(x_test, (3,0,1,2))
x_test = x_test.reshape(-1,60,60)
y_test = data['y_ts']
y_test = np.transpose(y_test)
y_test = y_test.reshape(-1,1)
'''
y_test_onehot = []
for i in range(y_test.shape[0]):
    basic = np.zeros(10)
    basic[y_test[i][0]] = 1
    y_test_onehot = np.append(y_test_onehot,basic)
y_test_onehot = y_test_onehot.reshape(y_test.shape[0],-1)
print(y_test_onehot.shape)


#plt.imshow(x_train[8], interpolation='nearest')
#plt.show()

np.save('CMNIST_ytrain',y_train_onehot)
np.save('CMNIST_yval',y_val_onehot)
np.save('CMNIST_ytest',y_test_onehot)
'''


"\ny_test_onehot = []\nfor i in range(y_test.shape[0]):\n    basic = np.zeros(10)\n    basic[y_test[i][0]] = 1\n    y_test_onehot = np.append(y_test_onehot,basic)\ny_test_onehot = y_test_onehot.reshape(y_test.shape[0],-1)\nprint(y_test_onehot.shape)\n\n\n#plt.imshow(x_train[8], interpolation='nearest')\n#plt.show()\n\nnp.save('CMNIST_ytrain',y_train_onehot)\nnp.save('CMNIST_yval',y_val_onehot)\nnp.save('CMNIST_ytest',y_test_onehot)\n"

In [2]:
import os
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image


#x_train = 255 * x_train
#x_train = x_train.astype(np.uint8)
#x_val = 255 * x_val
#x_val = x_val.astype(np.uint8)

#plt.imshow(x_train[8], interpolation='nearest')
#plt.show()



class MyDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = torch.LongTensor(targets)
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        x = self.transform(x)
        y = self.targets[index]
        return x,y
    
    def __len__(self):
        return len(self.data)

train_dataset = MyDataset(x_train[:20000],y_train[:20000], transform=transforms.Compose([transforms.ToTensor()]))
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True )
val_dataset = MyDataset(x_val[:3000],y_val[:3000], transform=transforms.Compose([transforms.ToTensor()]))
val_loader = DataLoader(val_dataset, batch_size=100, shuffle=True )
test_dataset = MyDataset(x_test[:3000],y_test[:3000], transform=transforms.Compose([transforms.ToTensor()]))
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=True )

# Train

In [56]:
import torch.optim as optim


model = DrawModel(T,A,B,z_size,read_size,write_size,dec_size,enc_size,attention)
#optimizer = optim.Adam(model.parameters(), lr=learning_rate,betas=(beta1,0.999))
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

if USE_CUDA:
    model.to('cuda:0')
    
min_valid_loss = np.inf

for e in range(epoch_num):
    train_loss = 0.0
    for data, labels in train_loader:
        bs = data.size()[0]
        data = data.view(bs, -1)
        labels = torch.squeeze(labels)
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        optimizer.zero_grad()
        target = model(data)
        loss = criterion(target,labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    valid_loss = 0.0
    for data, labels in val_loader:
        bs = data.size()[0]
        data = data.view(bs, -1)
        labels = torch.squeeze(labels)
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        target = model(data)
        loss = criterion(target,labels)
        valid_loss = loss.item() * data.size(0)
    
    print(f'Epoch {e+1} \t\t Training Loss: {train_loss / len(train_loader)} \t\t Validation Loss: {valid_loss / len(val_loader)}')
    if min_valid_loss > valid_loss:
        #print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), '/Users/ivy2021/Documents/DRAW/class_model/saved_model.pth')




  ai = torch.ones((delta.size()[0],A)).to('cuda:0') * torch.range(0,A-1).to('cuda:0')
  bj = torch.ones((delta.size()[0],B)).to('cuda:0') * torch.range(0,B-1).to('cuda:0')


Epoch 1 		 Training Loss: 2.302756202220917 		 Validation Loss: 7.665491898854573
Epoch 2 		 Training Loss: 2.3014004385471343 		 Validation Loss: 7.6585618654886884
Epoch 3 		 Training Loss: 2.300270003080368 		 Validation Loss: 7.622946898142497
Epoch 4 		 Training Loss: 2.299338111877441 		 Validation Loss: 7.6664360364278155
Epoch 5 		 Training Loss: 2.298586632013321 		 Validation Loss: 7.65940507253011
Epoch 6 		 Training Loss: 2.2976794731616974 		 Validation Loss: 7.644294897715251
Epoch 7 		 Training Loss: 2.296746028661728 		 Validation Loss: 7.61693795522054
Epoch 8 		 Training Loss: 2.2959323990345 		 Validation Loss: 7.665892442067464
Epoch 9 		 Training Loss: 2.2948017501831055 		 Validation Loss: 7.643311818440755
Epoch 10 		 Training Loss: 2.293797458410263 		 Validation Loss: 7.62001911799113
Epoch 11 		 Training Loss: 2.2926428055763246 		 Validation Loss: 7.590138117472331
Epoch 12 		 Training Loss: 2.291331737041473 		 Validation Loss: 7.630549271901448
Epoch 13 		 

In [57]:
model_test = DrawModel(T,A,B,z_size,read_size,write_size,dec_size,enc_size,attention)
model_test.cuda()
model_test.load_state_dict(torch.load('/Users/ivy2021/Documents/DRAW/class_model/saved_model.pth'))
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data, labels in test_loader:
        bs = data.size()[0]
        data = data.view(bs, -1)
        labels = torch.squeeze(labels)
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        # calculate outputs by running images through the network
        outputs = model_test(data)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        predicted.cuda()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

  ai = torch.ones((delta.size()[0],A)).to('cuda:0') * torch.range(0,A-1).to('cuda:0')
  bj = torch.ones((delta.size()[0],B)).to('cuda:0') * torch.range(0,B-1).to('cuda:0')


Accuracy of the network on the 10000 test images: 21 %


# Conv Neural Network

In [27]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(6 * 28 * 28, 265)
        self.fc2 = nn.Linear(265, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


model = Net()

#cluttered_MNIST (60X60)
from scipy.io import loadmat
import numpy as np
from matplotlib import pyplot as plt

data = loadmat('/Users/bellagodiva/Downloads/cluttered-mnist.mat')
x_train = data['x_tr']
x_train = np.transpose(x_train, (3,2,0,1))
x_train = x_train.reshape(-1,1,60,60)
y_train = data['y_tr']
y_train = np.transpose(y_train)
y_train = y_train.reshape(-1,1)


x_val = data['x_vl']
x_val = np.transpose(x_val, (3,2,0,1))
x_val = x_val.reshape(-1,1,60,60)
y_val = data['y_vl']
y_val = np.transpose(y_val)
y_val = y_val.reshape(-1,1)

x_test = data['x_ts']
x_test = np.transpose(x_test, (3,2,0,1))
x_test = x_test.reshape(-1,1,60,60)
y_test = data['y_ts']
y_test = np.transpose(y_test)
y_test = y_test.reshape(-1,1)

print(x_train.shape)

import os
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image


#x_train = 255 * x_train
#x_train = x_train.astype(np.uint8)
#x_val = 255 * x_val
#x_val = x_val.astype(np.uint8)

#plt.imshow(x_train[8], interpolation='nearest')
#plt.show()



class MyDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = torch.LongTensor(targets)
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.data[index]
        x = self.transform(x)
        y = self.targets[index]
        return x,y
    
    def __len__(self):
        return len(self.data)

train_dataset = MyDataset(x_train,y_train, transform=transforms.Compose([transforms.ToTensor()]))
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True )
val_dataset = MyDataset(x_val,y_val, transform=transforms.Compose([transforms.ToTensor()]))
val_loader = DataLoader(val_dataset, batch_size=100, shuffle=True )
test_dataset = MyDataset(x_test,y_test, transform=transforms.Compose([transforms.ToTensor()]))
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=True )

(50000, 1, 60, 60)


In [28]:
import torch.optim as optim

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)



In [32]:
min_valid_loss = np.inf

for e in range(20):
    train_loss = 0.0
    for data, labels in train_loader:
        bs = data.size()[0]
        data = data.view(bs, 1,60,60)
        labels = torch.squeeze(labels)
        #labels = labels.view(bs,1)
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        optimizer.zero_grad()
        target = model(data)
        loss = criterion(target,labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    valid_loss = 0.0
    for data, labels in val_loader:
        bs = data.size()[0]
        data = data.view(bs, 1,60,60)
        labels = torch.squeeze(labels)
        #labels = labels.view(bs,1)
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        target = model(data)
        loss = criterion(target,labels)
        valid_loss = loss.item() * data.size(0)
    
    print(f'Epoch {e+1} \t\t Training Loss: {train_loss / len(train_loader)} \t\t Validation Loss: {valid_loss / len(val_loader)}')
    if min_valid_loss > valid_loss:
        #print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss
        # Saving State Dict
        torch.save(model.state_dict(), '/Users/bellagodiva/Downloads/DRAW/class_model/saved_model1.pth')



Epoch 1 		 Training Loss: 2.263202792644501 		 Validation Loss: 2.261820077896118
Epoch 2 		 Training Loss: 2.1896858344078063 		 Validation Loss: 2.2197563648223877
Epoch 3 		 Training Loss: 2.1505018973350527 		 Validation Loss: 2.1956377029418945
Epoch 4 		 Training Loss: 2.1313251876831054 		 Validation Loss: 2.23964786529541
Epoch 5 		 Training Loss: 2.1091938989162444 		 Validation Loss: 2.053295850753784
Epoch 6 		 Training Loss: 2.078376898288727 		 Validation Loss: 2.007622480392456
Epoch 7 		 Training Loss: 2.0328746910095217 		 Validation Loss: 2.04443097114563
Epoch 8 		 Training Loss: 1.9677379574775695 		 Validation Loss: 1.9945449829101562
Epoch 9 		 Training Loss: 1.8766384677886963 		 Validation Loss: 1.840621829032898
Epoch 10 		 Training Loss: 1.7647279064655303 		 Validation Loss: 1.795608639717102
Epoch 11 		 Training Loss: 1.6493438987731934 		 Validation Loss: 1.7827885150909424
Epoch 12 		 Training Loss: 1.535998100757599 		 Validation Loss: 1.6181226968765259
E

In [33]:
model_test = Net()
model_test.load_state_dict(torch.load('/Users/bellagodiva/Downloads/DRAW/class_model/saved_model1.pth'))
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data, labels in test_loader:
        bs = data.size()[0]
        data = data.view(bs, 1,60,60)
        labels = torch.squeeze(labels)
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        # calculate outputs by running images through the network
        outputs = model_test(data)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 53 %
