In [None]:
from matplotlib import pyplot as plt #For image plotting
import numpy as np #basic data types and methods
from skimage import draw   #To create shape arrays
import sys                 #Just in case
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms,utils
from torch.utils.data import Dataset, DataLoader
import time
import os
import copy

import warnings
warnings.filterwarnings('ignore')


In [None]:
device=torch.device("cpu")


## Environment

In [None]:
def mkquad(h, v):
    #A quadralateral
    o = np.zeros(shape = (4,3), dtype=np.int)
    o[:,0] = [0,v,0,-v]
    o[:,1] = [h, 0, -h, 0]
    o[:,2] = 1
    return o
    
def mkarc(h, v, r):
    #An arc; r indicates rotation
    o = np.zeros(shape=(3,3), dtype=np.int)
    if(r==1):
        o[:,0] = [0,v,0]
        o[:,1] = [h,0,-h]
    elif r==2:
        o[:,0] = [h,0,-h]
        o[:,1] = [0,-v,0]
    elif r==3:
        o[:,0] = [0,v,0]
        o[:,1] = [-h,0,h]
    elif r==4:
        o[:,0] = [-v,0,v]
        o[:,1] = [0,h,0]
    
    o[:,2] = 1
        
    return o

def mkang(h,v,r):
    #A right angle; r indicates rotation
    o = np.zeros(shape=(2,3), dtype=np.int)
    if(r==1):
        o[:,0] = [v,0]
        o[:,1] = [0,h]
    elif r==2:
        o[:,0] = [0,v]
        o[:,1] = [h,0]
    elif r==3:
        o[:,0] = [-v,0]
        o[:,1] = [0,h]
    elif r==4:
        o[:,0] = [0,-v]
        o[:,1] = [h,0]
        
    o[:,2] = 1

    return o  

def mklin(h,v):
    #A line
    o = np.zeros(shape=(1,3), dtype=np.int)
    o[0,:] = [v,h,1]
    return o
        

In [None]:
def get_shape_from_coord(xy, picsize = 128, v=60, h=60, dt=0):
    #xy: array of egocentric coordinates and pen state
    #picsize: dimensions of image plane, assumed to be square
    #v=vertical starting location for pen
    #h = horizontal starting location for pen
    #dt: drawing threshold--when pen state is above this value a line will be produced
    
    #Threshold pen state field
    ps = xy[:,2] #copy real-valued pen state
    ps[ps <= dt] = -1 
    ps[ps > dt] = 1
    xy[:,2] = ps
    xy = xy.astype('int') #Make sure it is an integer array  
    
    ns = xy.shape[0] #Number of strokes
    o = np.zeros(shape=(picsize, picsize,3)) #Image array
    for i in np.arange(0,ns):
        if(xy[i,2] > dt): #only draw if pen is down
            rr, cc, val = draw.line_aa(v, h, v + xy[i,0], h + xy[i,1])
            o[rr,cc,0] = val*255
        v = v + xy[i,0]
        h = h + xy[i,1]
    
    o[o>254] = 254 #Clip to max value
    o = o/254
    return o

In [None]:
def get_absolute(relxy):
    o = np.zeros(shape = relxy.shape)
    o[0,:] = relxy[0,:]
    for i in np.arange(o.shape[0]-1) + 1:
        o[i,:] = o[i-1,:] + relxy[i,:]
        
    return o

In [None]:
def rplace(relxy, wd=128, hg=128):
    absxy = np.append([[0,0,0]], get_absolute(relxy), axis=0)
    vmin = np.min(absxy[:,0]) #minimum vertical cell
    vmax = np.max(absxy[:,0]) #maximum vertical cell
    hmin = np.min(absxy[:,1]) #minimum horizontal cell
    hmax = np.max(absxy[:,1]) #maximum horizontal cell
    
    vr = np.arange(-1 * vmin, hg - vmax, dtype=np.int) 
    hr = np.arange(-1 * hmin, wd - hmax, dtype=np.int) 
    
    v = np.random.choice(vr,1)[0]
    h = np.random.choice(hr,1)[0]
    
    return v, h

In [None]:
def flipim(relxy):
    o = relxy
    o[:,1] = -1 * o[:, 1]
    
    return o

In [None]:
def mktable(h= -1, v= -1, li= -1):
    
    if(h == -1): #Sample width if not specified from 20-50
        tmp = np.arange(30, dtype=np.int) + 20
        h = np.random.choice(tmp, 1)
    
    if(v == -1): #Sample height if not specified from 1 - width
        tmp = np.arange(h -1 , dtype=np.int) + 1
        v = np.random.choice(tmp, 1)
        
        
    if(li== -1): #Sample if not specified
        tmp = np.arange(h/3, dtype=np.int)
        li = np.random.choice(tmp)
        
    o = np.zeros(shape=(5,3), dtype=np.int)
    o[0,:] = mklin(h=h, v=0) #Surface
    o[1,:] = [0, -h + li, -1] #Move left
    o[2,:] = mklin(h=0, v=v) #First leg
    o[3,:] = [-v, h - (2*li), -1] #Move to second leg
    o[4,:] = mklin(h=0, v=v) #Second leg
    
    return o


def mkstool(h= -1, v= -1, li= -1):
    
    if(v == -1): #Sample height if not specified from 20-50
        tmp = np.arange(30, dtype=np.int) + 20
        v = np.random.choice(tmp, 1)
    
    if(h == -1): #Sample width if not specified from 1 - height
        tmp = np.arange(v -1 , dtype=np.int) + 1
        h = np.random.choice(tmp, 1)
        
        
    if(li== -1): #Sample if not specified
        tmp = np.arange(h/3, dtype=np.int)
        li = np.random.choice(tmp)
        
    o = np.zeros(shape=(5,3), dtype=np.int)
    o[0,:] = mklin(h=h, v=0) #Surface
    o[1,:] = [0, -h + li, -1] #Move left
    o[2,:] = mklin(h=0, v=v) #First leg
    o[3,:] = [-v, h - (2*li), -1] #Move to second leg
    o[4,:] = mklin(h=0, v=v) #Second leg
    
    return o

def mkchair(h= -1, v= -1, sh= -1):
    
    if(v == -1): #Sample height if not specified from 20-50
        tmp = np.arange(30, dtype=np.int) + 20
        v = np.random.choice(tmp, 1)
    
    if(h == -1): #Sample width if not specified from 1 - height
        tmp = np.arange(v - 10 , dtype=np.int) + 5
        h = np.random.choice(tmp, 1)
              
    if(sh== -1): #Sample seat height if not specified
        tmp = np.arange(v/2, dtype=np.int) + np.round(v/10)
        sh = np.random.choice(tmp)
        
    o = np.zeros(shape=(4,3), dtype=np.int)
    o[0,:] = mklin(v=v, h=0) #Back
    o[1,:] = [-sh, 0, -1] #Move up to seat height
    o[2:4,:] = mkang(v=sh, h = h, r=2) #Seat and econd leg
    
    return o


def mkmug(h= -1, v= -1, hsz= -1):
    
    if(v == -1): #Sample height if not specified from 20-50
        tmp = np.arange(30, dtype=np.int) + 20
        v = np.random.choice(tmp, 1)[0]
    
    if(h == -1): #Sample width if not specified from 1 - height
        tmp = np.arange(v - 10 , dtype=np.int) + 5
        h = np.random.choice(tmp, 1)[0]
              
    if(hsz== -1): #Sample handle size if not specified
        tmp = np.arange(v/3, dtype=np.int) + np.int(v/3)
        hsz = np.random.choice(tmp,1)[0]
        
    handloc = np.int((v - hsz)/2)
    o = mkquad(v=v, h=h)
    o = np.append(o, [[handloc, h, -1]], axis=0) #Move to handle top
    o = np.append(o, mkarc(h = np.int(h/2), v = hsz, r=1), axis=0)
    
    return o

def mkcase(h= -1, v= -1, hsz= -1):
    
    if(h == -1): #Sample height if not specified from 20-50
        tmp = np.arange(30, dtype=np.int) + 20
        h = np.random.choice(tmp, 1)[0]
    
    if(v == -1): #Sample width if not specified from 1 - height
        tmp = np.arange(h - 10 , dtype=np.int) + 5
        v = np.random.choice(tmp, 1)[0]
              
    if(hsz== -1): #Sample handle size if not specified
        tmp = np.arange(v/3, dtype=np.int) + np.int(v/3)
        hsz = np.random.choice(tmp,1)[0]
        
    handloc = np.int((h - hsz)/2)
    o = mkquad(v=v, h=h)
    o = np.append(o, [[0, handloc, -1]], axis=0) #Move to handle top
    o = np.append(o, mkarc(h = hsz, v = np.int(v/3), r=4), axis=0)
    
    return o

def mkbird(hd = -1, bd = -1, nc = -1, bk = -1, lg = -1):
    if hd == -1: #Sample head size
        tmp = np.arange(10, dtype=np.int) + 5
        hd = np.random.choice(tmp, 1)[0]
        
    if bd == -1: #Sample body size
        tmp = np.arange(hd, dtype=np.int) + hd + 5
        bd = np.random.choice(tmp, 1)[0]
        
    if nc == -1: #Sample neck length
        tmp = np.arange(2 * hd, dtype=np.int)
        nc = np.random.choice(tmp, 1)[0]
    
    if bk == -1: #Sample beak length
        tmp = np.arange(2 * hd, dtype=np.int) + 3
        bk = np.random.choice(tmp, 1)[0]
    
    if lg == -1: #Sample leg length
        tmp = np.arange(2 * hd, dtype=np.int) + 3
        lg = np.random.choice(tmp, 1)[0]

    bp = np.int(hd * .8) #Beak position
    o = mkquad(v=hd, h=hd) #draw head
    o = np.append(o, [[bp, 0, -1]], axis=0) #Move to beak
    o = np.append(o, [[0, -1 * bk, 1]], axis=0) #Draw beak
    
    o = np.append(o, [[hd - bp, bk + hd, -1]], axis=0) #Move to neck
    o = np.append(o, mklin(v=nc, h=0), axis=0) #Draw neck
    o = np.append(o, mkquad(bd, bd), axis = 0) #Draw body

    lp = np.int(bd/2) - 4
    o = np.append(o, [[bd, lp, -1]], axis=0) #Move to leg 1
    o = np.append(o, mklin(v=lg, h=0), axis=0) #Draw leg 1
    o = np.append(o, [[-lg, 8, -1]], axis=0) #Move to leg 2
    o = np.append(o, mklin(v=lg, h=0), axis=0) #Draw leg 2
    
    return o
        
        
def mksheep(hd = -1, bd = -1, nc = -1, lg = -1):
    if hd == -1: #Sample head size
        tmp = np.arange(10, dtype=np.int) + 5
        hd = np.random.choice(tmp, 1)[0]
        
    if bd == -1: #Sample body size
        tmp = np.arange(hd, dtype=np.int) + hd + 5
        bd = np.random.choice(tmp, 1)[0]
        
    if nc == -1: #Sample neck length
        tmp = np.arange(hd/2, dtype=np.int) + np.int(hd/2)
        nc = np.random.choice(tmp, 1)[0]
    
    if lg == -1: #Sample leg length
        tmp = np.arange(hd, dtype=np.int) + np.int(hd/2)
        lg = np.random.choice(tmp, 1)[0]

    o = mkquad(v=hd, h=np.int(hd * 1.3)) #draw head
    o = np.append(o, [[hd, np.int(hd * 1.3), -1]], axis=0) #Move to neck
    o = np.append(o, mklin(v=nc, h=0), axis=0) #Draw neck
    o = np.append(o, mkquad(v=bd, h=np.int(bd * 1.3)), axis = 0) #Draw body

    o = np.append(o, [[bd, 0, -1]], axis=0) #Move to leg 1
    o = np.append(o, mklin(v=lg, h=0), axis=0) #Draw leg 1
    o = np.append(o, [[-lg, np.int(bd * 1.3), -1]], axis=0) #Move to leg 2
    o = np.append(o, mklin(v=lg, h=0), axis=0) #Draw leg 2
    
    return o
        

def mkdog(hd = -1, bd = -1, nc = -1, bk = -1, lg = -1):
    if hd == -1: #Sample head size
        tmp = np.arange(10, dtype=np.int) + 5
        hd = np.random.choice(tmp, 1)[0]
        
    if bd == -1: #Sample body size
        tmp = np.arange(hd, dtype=np.int) + 2 * hd
        bd = np.random.choice(tmp, 1)[0]
        
    if nc == -1: #Sample neck length
        tmp = np.arange(hd/2, dtype=np.int) + np.int(hd/2)
        nc = np.random.choice(tmp, 1)[0]
    
    if lg == -1: #Sample leg length
        tmp = np.arange(hd, dtype=np.int) + np.int(hd/2)
        lg = np.random.choice(tmp, 1)[0]

    o = mkquad(v=hd, h=np.int(hd * 1.3)) #draw head
    o = np.append(o, [[hd, np.int(hd * 1.3), -1]], axis=0) #Move to neck
    o = np.append(o, mklin(v=nc, h=0), axis=0) #Draw neck
    o = np.append(o, mklin(v=0, h=bd), axis = 0) #Draw body

    o = np.append(o, [[0, -bd, -1]], axis=0) #Move to leg 1
    o = np.append(o, mklin(v=lg, h=0), axis=0) #Draw leg 1
    o = np.append(o, [[-lg, bd, -1]], axis=0) #Move to leg 2
    o = np.append(o, mklin(v=lg, h=0), axis=0) #Draw leg 2
    
    return o
        
def mkliz(h= -1, v= -1, bd= -1):
    
    if(h == -1): #Sample width if not specified from 20-50
        tmp = np.arange(10, dtype=np.int) + 5
        h = np.random.choice(tmp, 1)[0]
    
    if(v == -1): #Sample height if not specified from 1 - width
        tmp = np.arange(h, dtype=np.int) + 3
        v = np.random.choice(tmp, 1)[0]
        
        
    if(bd== -1): #Sample if not specified
        tmp = np.arange(h*3, dtype=np.int) + h*2
        bd = np.random.choice(tmp,1)[0]
        
    o = mkquad(h,h) #Draw head
    o = np.append(o, [[np.int(h/2), h, -1]], axis = 0) #Move to body
    o = np.append(o, mklin(h=bd, v=0), axis = 0) #Draw body
    o = np.append(o, [[0, np.int(-1 * bd * 7/15), -1]], axis=0)  #Move to back leg
    o = np.append(o, mklin(h=0, v=v), axis = 0) #draw leg
    o = np.append(o, [[-v, np.int(-1 * bd * 5/15), -1]], axis=0)  #Move to front leg
    o = np.append(o, mklin(h=0, v=v), axis = 0) #draw leg
    
    return o
        

def mkpig(hd = -1, bd = -1, lg = -1):
    
    if hd == -1: #Sample head size
        tmp = np.arange(10, dtype=np.int) + 10
        hd = np.random.choice(tmp, 1)[0]
        
    if bd == -1: #Sample body size
        tmp = np.arange(hd, dtype=np.int) + hd + 5
        bd = np.random.choice(tmp, 1)[0]
        
        
    if lg == -1: #Sample leg length
        tmp = np.arange(hd, dtype=np.int) + np.int(hd/2)
        lg = np.random.choice(tmp, 1)[0]

    sn = np.int(hd/3) #Snout size
    o = mkquad(v=hd, h=hd) #draw head
    o = np.append(o, [[2 * sn, -sn, -1]], axis=0) #Move to snout   
    o = np.append(o, mkquad(v=sn, h=sn), axis=0) #Draw snout
    o = np.append(o, [[-bd + sn, hd + sn, -1]], axis=0) #Move to body   
    
    o = np.append(o, mkquad(v=bd, h=np.int(bd * 1.3)), axis = 0) #Draw body

    o = np.append(o, [[bd, 0, -1]], axis=0) #Move to leg 1
    o = np.append(o, mklin(v=lg, h=0), axis=0) #Draw leg 1
    o = np.append(o, [[-lg, np.int(bd * 1.3), -1]], axis=0) #Move to leg 2
    o = np.append(o, mklin(v=lg, h=0), axis=0) #Draw leg 2
 
    return o
    

In [None]:
def get_rand_img(choose = None):
    if not choose:
        i = np.random.randint(0,10)
    else:
        i = choose
    if(i==0):
        o = mktable()
    elif(i==1):
        o = mkchair()
    elif(i==2):
        o = mkstool()
    elif(i==3):
        o = mkmug()
    elif(i==4):
        o = mkcase()
    elif(i==5):
        o = mkdog()
    elif(i==6):
        o = mksheep()
    elif(i==7):
        o = mkbird()
    elif(i==8):
        o = mkliz()
    elif(i==9):
        o = mkpig()

    return o

def get_batch(bsize=200, ns = 20, picsize=128, rloc=False, flip=False, not_rand=False, choose=0):
    xtrn = np.zeros(shape = (bsize,picsize,picsize,3)) #array for training image
    trnxy = np.zeros(shape=(bsize, ns, 3), dtype=np.int)  #array for sequence of training coordinates
    trnxy[:,:,2] = -1 #Default is pen up
    imloc = np.zeros(shape = (bsize, 2), dtype = np.int)

    for i in np.arange(0,bsize):
        if (not_rand==True):
             relxy = get_rand_img(choose=choose)
        else:    
            relxy = get_rand_img()

        if(np.random.randint(0,2)==1 and flip):
            relxy = flipim(relxy)

        nstep = relxy.shape[0]
        if rloc:
            v, h = rplace(relxy, wd=picsize, hg=picsize)
        else:
            v = 30
            h = 30
            
        xtrn[i,:,:,:] = get_shape_from_coord(relxy, v=v, h=h)
        trnxy[i,0:nstep,:] = relxy
        imloc[i,:] = [v,h]

    
    return xtrn, trnxy, imloc

## Data Prep

In [None]:
class ToTensorCustom(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        _x, _y = sample

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        print(_x.shape)
        _x = _x.transpose((2, 0, 1))
        _y = _y.transpose((2, 0, 1))
        return ((_x,_y).to(device))
    
class sketchdata(Dataset):
    def __init__(self, _x,_y,transform=None):
        self.images=[]
        self.coordinates=[]
      
        self.transform = transform
        
        for sketch in _x:
            self.images.append(torch.from_numpy(np.asarray(sketch).transpose((2, 0, 1))).float().to(device))
        for coord in _y:
            self.coordinates.append(torch.from_numpy(coord).float().to(device))
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        #return 
        images = self.images[idx]
        coordinates = self.coordinates[idx]
        
        
        if self.transform:
            images = self.transform(images)
            
            

        return images, coordinates

In [None]:
np.random.seed(20)
bsize = 1000
ns=20
picsz=128

xtrn = np.zeros(shape=(bsize, ns, picsz, picsz, 3))
xtst = np.zeros(shape = (bsize, ns, picsz, picsz, 3))

#Get a batch of training data
x1, ytrn, imloc_trn = get_batch(bsize)
for i in np.arange(0,ns):
    xtrn[:,i,:,:] = x1 #Copy input image to each timeslice


#Get a batch of testing data
x2, ytst, imloc_tst = get_batch(bsize)
for i in np.arange(0,ns):
    xtst[:,i,:,:] = x2 #Copy input image to each timeslice



#a,b = gen_sketch_data(num_sketches=20)
#ytrn= torch.tensor(ytrn, dtype=torch.uint8)
traindata =  sketchdata(x1,ytrn )



#c,d = gen_sketch_data(num_sketches=10)
#ytst= torch.tensor(ytst, dtype=torch.uint8)
testdata =  sketchdata(x2,ytst )


trainloader = torch.utils.data.DataLoader(traindata, batch_size=50,shuffle=True)
testloader = torch.utils.data.DataLoader(testdata, batch_size=50,shuffle=True)


## Models

### LSTM Cell architecture

<div>
<img src="LSTMCell.png" width="700"/>

</div>

### Attention with LSTM
<div>
<img src="Attention.png" width="700"/>

</div>

In [None]:
   
class EncoderSimple(nn.Module):
    def __init__(self, embed_size = 2):
        super(EncoderSimple, self).__init__()
        
        self.maxpool = nn.MaxPool2d(kernel_size=2)
    
        self.conv1 = nn.Conv2d(3,8,3)
        self.conv2 = nn.Conv2d(8,16,3)
        self.conv3 = nn.Conv2d(16,32,3)

        
        # add another fully connected layer
        self.fc_mid = nn.Linear(in_features=6272, out_features=512)
        self.fc_out = nn.Linear(in_features=512, out_features=256)
        
        # dropout layer
        self.dropout = nn.Dropout(p=0.5)
        
        # activation layers
        self.prelu = nn.PReLU()
        self.relu = nn.ReLU()
        
    def forward(self, images):
        x = self.relu(self.conv1(images))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)
        x = self.relu(self.conv3(x))
        x = self.maxpool(x)
        x = x.view(x.size(0),-1) ## flatten
        
        # pass through the fully connected layer
        x = self.relu(self.fc_mid(x))
        x = self.fc_out(x)
        
        return x
    
    
    
class DecoderRNN(nn.Module):
    def __init__(self, embed_size=2, hidden_size=256, coord_size=3, num_layers=1):
        super(DecoderRNN, self).__init__()
        
        # define the properties
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.coord_size = coord_size
        
        # lstm cell
        self.lstm_cell = nn.LSTMCell(input_size=3, hidden_size=self.hidden_size)
    
       # linear layers
        self.fc_hs = nn.Linear(in_features=256, out_features=self.hidden_size) ### layer to convert flattened image latent vector to a smaller vector to initialize hidden state
        

        self.fc_rnn_inp1 = nn.Linear(in_features=256, out_features=128)
        self.fc_rnn_inp2 = nn.Linear(in_features=128, out_features=3) ### layer that input to RNN at each timestep to a 128 dimensional vector

        self.fc_mid = nn.Linear(in_features=self.hidden_size, out_features=128, bias=True) ### hidden that takes the output of the LSTM cell and converts it to a smaller size
        self.fc_out = nn.Linear(in_features=128, out_features=self.coord_size) ### hidden layer that takes the outputs from the above hidden layer and converts it to an output for the sketcher
    
        # embedding layer
        #self.embed = nn.Embedding(num_embeddings=self.coord_size, embedding_dim=self.embed_size)
    
        # activations
        self.softmax = nn.Softmax(dim=1)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
    
    def forward(self,features,coordinates=[],coord_length=20, inference=False,teacher_forcing_ratio = 0):
        
        # batch size
        batch_size = features.size(0) ### number of images
        
        # init the hidden and cell states with image features
        
        # hidden_state = self.tanh(self.fc_hs(features)) 
        # cell_state = self.tanh(self.fc_hs(features))

        hidden_state = torch.zeros((batch_size,self.hidden_size),device=device)
        cell_state = torch.zeros((batch_size,self.hidden_size),device=device)

        
        #z = torch.cat([coordinates, features.unsqueeze(0).expand(1,coordinates.size(1), features.size(1))], 2)
    
#         print(features.size(),torch.tensor([0,0]).expand(1,2).float().size() )
        #_features = torch.cat([features, torch.tensor([0,0]).expand(batch_size,2).float()], 1)

        # embed the captions
        #captions_embed = self.embed(captions)
        
        
        if inference==False: ### training forward pass
            # define the output tensor placeholder currently at 20 strokes max
            outputs = torch.empty((batch_size, coord_length, self.coord_size))
            # pass the caption word by word
            for t in range(coord_length):
                
                #teacher_force = random.random() < teacher_forcing_ratio ## do we teacher force?  https://github.com/IBM/pytorch-seq2seq/blob/master/seq2seq/models/DecoderRNN.py
                teacher_force = True
                # for the first time step, the input is the feature vector
                if t == 0:
                    #inp = torch.tensor([0,0]).expand(batch_size,2).float()
                    inp = self.fc_rnn_inp1(features)
                    inp = self.fc_rnn_inp2(inp)
                    hidden_state, cell_state = self.lstm_cell(inp, (hidden_state, cell_state))

                # for the 2nd+ time step, using teacher forcer
                else:
                    #feat_coords = torch.cat([features,inp],axis=1)
                    #inp = coordinates[:,t,:] if teacher_force else mid
                    inp=out
                    hidden_state, cell_state = self.lstm_cell(inp, (hidden_state, cell_state))

                # output of the attention mechanism
                mid = self.fc_mid(hidden_state)  ### output after first hidden layer
                out = self.fc_out(mid) ### output after second hidden layer

                # build the output tensor
                outputs[:, t, :] = out
        else:
            outputs = torch.empty((batch_size, coord_length, self.coord_size))
            for t in range(coord_length):
                if t == 0:
                    inp = self.fc_rnn_inp1(features)
                    inp = self.fc_rnn_inp2(features)
                    hidden_state, cell_state = self.lstm_cell(inp, (hidden_state, cell_state))
                    mid = self.fc_mid(hidden_state)
                    out = self.relu(self.fc_out(mid))

                
                else:
                    #print(features.size(),out.size())
                    #feat_coords = torch.cat([features,out],axis=1)
                    hidden_state, cell_state = self.lstm_cell(out, (hidden_state, cell_state))
                    mid = self.fc_mid(hidden_state)
                    out = self.fc_out(mid)

                # output of the attention mechanism


                # build the output tensor
                outputs[:, t, :] = out
    
        return outputs
    
    
class DecoderSimple(nn.Module):
    def __init__(self, coord_length=20, hidden_size=256, coord_size=3, num_layers=1):
        super(DecoderSimple, self).__init__()
        
        # define the properties
        
        self.hidden_size = hidden_size
        self.coord_size = coord_size
        self.coord_length = coord_length
        
        # lstm 
        self.rnn = nn.LSTM(256, self.hidden_size, 1, batch_first=True)
    
        # output fully connected layer
        self.fc_hs = nn.Linear(in_features=256, out_features=self.hidden_size)
        self.fc_rnn_inp = nn.Linear(in_features=256, out_features=128)
        self.fc_mid = nn.Linear(in_features=self.hidden_size, out_features=128, bias=True)
        self.fc_out = nn.Linear(in_features=128, out_features=self.coord_size)
    
        # embedding layer
        #self.embed = nn.Embedding(num_embeddings=self.coord_size, embedding_dim=self.embed_size)
    
        # activations
        self.softmax = nn.Softmax(dim=1)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
    
    def forward(self,features,coord_length=20,inference=False,coords=[]):
    
        
        # batch size
        batch_size = features.size(0)
        feat_dim = features.size(1)
        
        input = features.unsqueeze(1).expand(batch_size, coord_length, feat_dim)
        
        # init the hidden and cell states to zeros(device)
        hidden_state = torch.zeros((1,batch_size,self.hidden_size),device=device)
        cell_state = torch.zeros((1,batch_size,self.hidden_size),device=device)
        lstm_out, (hn, cn) = self.rnn(input, (hidden_state, cell_state))
        mid = self.fc_mid(lstm_out)
        outputs =  self.fc_out(mid)
        
        return outputs
    

In [None]:
class DecoderWithAttention(nn.Module):
    """
    Decoder.
    """

    def __init__(self, attention_dim=10, embed_dim=3, decoder_dim=256, encoder_dim=256, dropout=0.5,
     coord_length=20, hidden_size=256, coord_size=3):
        """
        :param attention_dim: size of attention network
        :param embed_dim: embedding size
        :param decoder_dim: size of decoder's RNN
        :param vocab_size: size of vocabulary
        :param encoder_dim: feature size of encoded images
        :param dropout: dropout
        """
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        #self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        #self.vocab_size = vocab_size
        self.dropout = dropout
        self.relu = nn.ReLU()

        #self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network

        #self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        #self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.decode_step = nn.LSTMCell(embed_dim, decoder_dim, bias=True) 
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, embed_dim)  # linear layer to find scores over vocabulary

        self.fc_mid = nn.Linear(in_features=self.decoder_dim, out_features=128, bias=True)
        self.fc_out = nn.Linear(in_features=128, out_features=self.embed_dim)


        self.init_weights()  # initialize some layers with the uniform distribution
        self.rnn_in = nn.Linear(encoder_dim,embed_dim)
        


    def init_weights(self):
        """
        Initializes some parameters with values from the uniform distribution, for easier convergence.
        """
        #self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    # def load_pretrained_embeddings(self, embeddings):
    #     """
    #     Loads embedding layer with pre-trained embeddings.
    #     :param embeddings: pre-trained embeddings
    #     """
    #     self.embedding.weight = nn.Parameter(embeddings)

    # def fine_tune_embeddings(self, fine_tune=True):
    #     """
    #     Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
    #     :param fine_tune: Allow?
    #     """
    #     for p in self.embedding.parameters():
    #         p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        """
        Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :return: hidden state, cell state
        """
        #mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(encoder_out)
        return h, c

    def forward(self, encoder_out, coord_length=20,coord_size=3,coords=[],inference=False):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
        :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
        :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
        :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        """

        batch_size = encoder_out.size(0)
        #encoder_dim = encoder_out.size(-1)
        encoder_dim = encoder_out.size(0)
        # vocab_size = self.vocab_size

        # Flatten image
        #encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        #num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths; why? apparent below
        # caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        # encoder_out = encoder_out[sort_ind]
        # encoded_captions = encoded_captions[sort_ind]

        # # Embedding
        # embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)
        embeddings = self.rnn_in(encoder_out)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        # decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, coord_length, coord_size).to(device)
        #alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        # At each time-step, decode by
        # attention-weighing the encoder's output based on the decoder's previous hidden state output
        # then generate a new word in the decoder with the previous word and the attention weighted encoding
        if inference==False:
            for t in range(coord_length):
            #batch_size_t = sum([l > t for l in decode_lengths])
            #attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
            #                                                    h[:batch_size_t])
            #gate = self.sigmoid(self.f_beta(h))  # gating scalar, (batch_size_t, encoder_dim)
            #attention_weighted_encoding = gate * attention_weighted_encoding
                if t==0:
                    # h, c = self.decode_step(
                    #     torch.cat([embeddings, encoder_out], dim=1),
                    #     (h, c))  # (batch_size_t, decoder_dim)
                    h, c = self.decode_step(embeddings,(h, c))


                else: 
                    #print('encoder shape= ',encoder_out.size(),'coords',coords.size() )
                    # h, c = self.decode_step(
                    #     torch.cat([coords[:,t-1,:], encoder_out], dim=1),
                    #     (h, c))  # (batch_size_t, decoder_dim)
                    h, c = self.decode_step(coords[:,t-1,:],(h, c))

                #preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
                mid = self.fc_mid(h)
                preds =  self.fc_out(mid)
                
                predictions[:, t, :] = preds
        else:        

            for t in range(coord_length):
                #batch_size_t = sum([l > t for l in decode_lengths])
                #attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                #                                                    h[:batch_size_t])
                #gate = self.sigmoid(self.f_beta(h))  # gating scalar, (batch_size_t, encoder_dim)
                #attention_weighted_encoding = gate * attention_weighted_encoding
                if t==0:
                    # h, c = self.decode_step(
                    #     torch.cat([embeddings, encoder_out], dim=1),
                    #     (h, c))  # (batch_size_t, decoder_dim)
                    h, c = self.decode_step(embeddings,(h, c))
                else: 
                    # h, c = self.decode_step(
                    #     torch.cat([preds, encoder_out], dim=1),
                    #     (h, c))  # (batch_size_t, decoder_dim)
                    h, c = self.decode_step(preds,(h, c))

                #preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
                mid = self.fc_mid(h)
                preds =  self.fc_out(mid)
                predictions[:, t, :] = preds

            #alphas[:batch_size_t, t, :] = alpha

        return predictions

In [None]:
device=torch.device("cpu")

enc = EncoderSimple().to(device)
dec = DecoderWithAttention().to(device)
#dec = DecoderSimple().to(device)



params = list(enc.parameters()) + list(dec.parameters())
optimizer = optim.Adam(params,weight_decay=1e-5)
criterion =  nn.MSELoss()


In [None]:
im, cd = next(iter(trainloader))
im.shape

In [None]:
# get the losses for vizualization

losses = list()
val_losses = list()

num_epochs=25

for epoch in range(1, num_epochs+1):
    
    for i, data in enumerate(trainloader):
        
        # zero the gradients
        enc.zero_grad()
        dec.zero_grad()
        
        # set decoder and encoder into train mode
        enc.train()
        dec.train()
        
#         # Randomly sample a caption length, and sample indices with that length.
#         indices = train_data_loader.dataset.get_train_indices()
        
#         # Create and assign a batch sampler to retrieve a batch with the sampled indices.
#         new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
#         train_data_loader.batch_sampler.sampler = new_sampler
        
        # Obtain the batch.
        images, coordinates = data
       # images = images.float()
       # coordinates = coordinates.float()
        
        # make the captions for targets and teacher forcer
#         coordinates_target = coordinates.to(device)
#         tmp = torch.cat((torch.zeros(coordinates.size(0),1,2), coordinates), axis=1)
#         coordinates_train = tmp[:, :tmp.shape[1]-1].to(device)
        
        # Move batch of images and captions to GPU if CUDA is available.
        images = images.to(device)
        coordinates = coordinates.to(device)
        
        # Pass the inputs through the CNN-RNN model.
        features = enc(images).to(device)
#         outputs = dec(features, coordinates_train, inference=True)
        outputs = dec(features,inference=True,coords=coordinates).to(device)

    
        # Calculate the batch loss
        #loss = criterion(outputs.view(-1, vocab_size), captions_target.contiguous().view(-1))
#         loss = criterion(outputs, coordinates_target)
        loss = criterion(outputs, coordinates)

        losses.append(loss.item())
        
        # # Backward pass
        loss.backward()
        
        # Update the parameters in the optimizer
        optimizer.step()
        
        # - - - Validate - - -
        # turn the evaluation mode on
   
   
   
   
    for i, data in enumerate(testloader):
        with torch.no_grad():
            
            # set the evaluation mode
            enc.eval()
            dec.eval()

            # get the validation images and captions
            val_images, val_coordinates = data
            # val_coordinates= val_coordinates.float()
            # val_images= val_images.float()
            
            # define the captions
#             coordinates_target = val_coordinates.to(device)
#             tmp = torch.cat((torch.zeros(val_coordinates.size(0),1,2), val_coordinates), axis=1)
#             coordinates_train = tmp[:, :tmp.shape[1]-1].to(device)
        

            # Move batch of images and captions to GPU if CUDA is available.
            val_images = val_images.to(device)
            val_coordinates = val_coordinates.to(device)

            # Pass the inputs through the CNN-RNN model.
            features = enc(val_images).to(device)
            outputs = dec(features, inference=True).to(device)

            # Calculate the batch loss.
            #val_loss = criterion(outputs.view(-1, vocab_size), captions_target.contiguous().view(-1))
            #val_loss = criterion(outputs, coordinates_target)
            val_loss = criterion(outputs, val_coordinates)
            val_losses.append(val_loss.item())
        
        
        # save the losses
        np.save('losses', np.array(losses))
        np.save('val_losses', np.array(val_losses))
        
        # Get training statistics.
        stats = 'Epoch [%d/%d], Step [%d/%d], Loss: %.4f, Val Loss: %.4f' % (epoch, num_epochs, i, len(trainloader), loss.item(), val_loss.item())
        
        # Print training statistics (on same line).
        print('\r' + stats, end="")
        #sys.stdout.flush()
            
    # Save the weights.
#     if epoch % save_every == 0:
#         print("\nSaving the model")
#         torch.save(decoder.state_dict(), os.path.join('./models', 'decoder-%d.pth' % epoch))
#         torch.save(encoder.state_dict(), os.path.join('./models', 'encoder-%d.pth' % epoch))


In [None]:
fig, ax = plt.subplots()
x = np.arange(len(val_losses))
y = val_losses
z = losses

ax.plot(x,y)
ax.plot(x,z)
ax.legend(['validation','training'])
#plt.plot(np.arange(len(val_losses)),val_losses, 'go--', linewidth=2, markersize=1)
plt.title('Losses')


In [None]:
batch = next(iter(testloader))
tru = batch[0][0:9].float()
truc = batch[1][0:9]

In [None]:
#import random
gridsz = 3
gsz = 3 #plotting grid size
s = np.random.choice(np.arange(0,100), gridsz * gridsz) #random selection of items
#tru = x2[s,:,:,:]  #get inputs from test set
loc = imloc_tst[s,:]   #Get starting locations of test items
pred_sm = np.array(dec(enc(tru),inference=True).detach())
#pred_ae= np.array(aedec(aeenc(tru)).detach())

#pred_hs=np.array(hdec(henc(tru))[1].detach())
#pred_hae= np.array(hdec(henc(tru))[0].detach())



In [None]:
#Plot model image and image regenerated from model output sequence
plt.figure(figsize=(15,15))
for i in range(0,gsz):
    for j in range(0,gsz):
        plt.subplot2grid((gsz,gsz), (i,j))
        indx = i * gsz + j
        tmp = tru[indx,:,:,:].numpy().transpose((1,2,0))
        fr = get_shape_from_coord(np.squeeze(pred_sm[indx,:,:]), v=loc[indx, 0], h=loc[indx, 1])
        tmp[:,:,1] = fr[:,:,0]
        
        #Uncomment next line to see just the model output drawing
        #tmp[:,:,0] = 0

        fig = plt.imshow(tmp)
        plt.axis('off')
        plt.subplots_adjust(wspace=.1, hspace=.1)
plt.show()       

In [None]:
#Plot model image and image regenerated from model output sequence
plt.figure(figsize=(15,15))
for i in range(0,gsz):
    for j in range(0,gsz):
        plt.subplot2grid((gsz,gsz), (i,j))
        indx = i * gsz + j
        tmp = tru[indx,:,:,:].numpy().transpose((1,2,0))
        plt.imshow(tmp)
        fr = get_shape_from_coord(np.squeeze(pred_sm[indx,:,:]), v=loc[indx, 0], h=loc[indx, 1])
        tmp[:,:,1] = fr[:,:,0]
        
        #Uncomment next line to see just the model output drawing
        tmp[:,:,0] = 0

        fig = plt.imshow(tmp)
        plt.axis('off')
        plt.subplots_adjust(wspace=.1, hspace=.1)
plt.show()       