In [31]:
class ConvLSTMCell(nn.Module):
    
    def __init__(self, input_size, output_size, x_kernel_size, h_kernel_size, stride=1):        
        super(ConvLSTMCell, self).__init__()
        pad_x = math.floor(x_kernel_size/2)        
        pad_h = math.floor(h_kernel_size/2)
        self.output_size = output_size
        self.stride = stride
        
        # input gate
        self.conv_i_x = nn.Conv2d(input_size, output_size, x_kernel_size, stride=stride, padding=pad_x)
        self.batchnorm_i_x = nn.BatchNorm2d(output_size)
        self.conv_i_h = nn.Conv2d(output_size, output_size, h_kernel_size, stride=1, padding=pad_h)
        self.batchnorm_i_h = nn.BatchNorm2d(output_size)
        
        # forget gate
        self.conv_f_x = nn.Conv2d(input_size, output_size, x_kernel_size, stride=stride, padding=pad_x)
        self.batchnorm_f_x = nn.BatchNorm2d(output_size)
        self.conv_f_h = nn.Conv2d(output_size, output_size, h_kernel_size, stride=1, padding=pad_h)
        self.batchnorm_f_h = nn.BatchNorm2d(output_size)
        # initialize bias to 1 for x forget input
        self.conv_f_x.bias.data.fill_(1)
        
        # cell gate
        self.conv_c_x = nn.Conv2d(input_size, output_size, x_kernel_size, stride=stride, padding=pad_x)
        self.batchnorm_c_x = nn.BatchNorm2d(output_size)
        self.conv_c_h = nn.Conv2d(output_size, output_size, h_kernel_size, stride=1, padding=pad_h)
        self.batchnorm_c_h = nn.BatchNorm2d(output_size)

        # output gate
        self.conv_o_x = nn.Conv2d(input_size, output_size, x_kernel_size, stride=stride, padding=pad_x)
        self.batchnorm_o_x = nn.BatchNorm2d(output_size)
        self.conv_o_h = nn.Conv2d(output_size, output_size, h_kernel_size, stride=1, padding=pad_h)
        self.batchnorm_o_h = nn.BatchNorm2d(output_size)
        
        self.last_cell = None
        self.last_h = None
        
    
    def forward(self, x):
        if self.last_cell is None:
            self.last_cell = Variable(torch.zeros(
                (x.size(0), self.output_size, int(x.size(2)/self.stride), 
                 int(x.size(3)/self.stride))
            ))
        if self.last_h is None:
            self.last_h = Variable(torch.zeros(
                (x.size(0), self.output_size, int(x.size(2)/self.stride), 
                 int(x.size(3)/self.stride))
            ))
        h = self.last_h
        c = self.last_cell
        
        # input gate
        input_x = self.batchnorm_i_x(self.conv_i_x(x))
        input_h = self.batchnorm_i_h(self.conv_i_h(h))
        input_gate = F.sigmoid(input_x + input_h)
        
        # forget gate
        forget_x = self.batchnorm_f_x(self.conv_f_x(x))
        forget_h = self.batchnorm_f_h(self.conv_f_h(h))
        forget_gate = F.sigmoid(forget_x + forget_h)
        
        # forget gate
        cell_x = self.batchnorm_c_x(self.conv_c_x(x))
        cell_h = self.batchnorm_c_h(self.conv_c_h(h))
        cell_intermediate = F.tanh(cell_x + cell_h) # g
        cell_gate = (forget_gate * c) + (input_gate * cell_intermediate)
        
        # output gate
        output_x = self.batchnorm_o_x(self.conv_o_x(x))
        output_h = self.batchnorm_o_h(self.conv_o_h(h))
        output_gate = F.sigmoid(output_x + output_h)
        
        next_h = output_gate * F.tanh(cell_gate)
        self.last_cell = cell_gate
        self.last_h = next_h
        
        return next_h

In [12]:
class FeedbackConvLSTM(nn.Module):
    def __init__(self, input_size, output_sizes, strides, num_iterations, x_kernel_size, h_kernel_size):
        super(FeedbackConvLSTM, self).__init__()
        
        assert len(output_sizes) == len(strides)
        
        self.physical_depth = len(output_sizes)
        self.num_iterations = num_iterations
        self.input_size = input_size
        self.output_sizes = output_sizes
        self.strides = strides
        self.x_kernel_size = x_kernel_size
        self.h_kernel_size = h_kernel_size
        
        self.convlstm_cells = []
        for it in range(self.physical_depth):
            if it == 0:
                inp_size = input_size
            else:
                inp_size = output_sizes[it-1]
            outp_size = output_sizes[it]
            stride = strides[it]
            
            self.convlstm_cells.append(
                ConvLSTMCell(inp_size, outp_size, x_kernel_size, h_kernel_size, stride)
            )
        
    def forward(self, x):
        # torch.cat? torch.stack?
        end_xts = []
        for t in range(self.num_iterations):
            for d in range(self.physical_depth):
                if d == 0:
                    x_t = x # x_t^{d-1}
                x_t = self.convlstm_cells[d].forward(x_t)
            end_xts.append(x_t)
        #all_xts = torch.stack(end_xts, dim=0)
        #xts = torch.unbind(all_xts, dim=0)
        #return all_xts
        return end_xts

In [43]:
class FeedbackNet32(nn.Module):  # 4 physical depth, 8 iterations
    def __init__(self):
        super(FeedbackNet32, self).__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)
        self.batchnorm = nn.BatchNorm2d(16)
        self.feedback_conv_lstm = FeedbackConvLSTM(
            16, [32, 32, 64, 64], [2, 1, 2, 1], 8, 3, 3
        )
        self.avg_pool = nn.AvgPool2d(8)
        self.linear = nn.Linear(64, 100)
        
        
    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x_all = self.feedback_conv_lstm(x)
        x_finished = []
        for x_i in x_all:
            x_i = F.relu(x_i)
            x_i = self.avg_pool(x_i)
            x_i = x_i.view(-1, 64)
            x_i = self.linear(x_i)
            x_finished.append(x_i)
        return x_finished

In [40]:
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, 
                                        transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, 
                                        transform=transform)
testloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [46]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)

In [45]:
feedback_net = FeedbackNet32()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(feedback_net.parameters())

for epoch in range(3):
    running_losses = np.zeros(8)
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        
        optimizer.zero_grad()
        outputs = feedback_net(inputs)
        
        losses = [criterion(out, labels) for out in outputs]
        loss = sum(losses)
        
        loss.backward(retain_graph=True)
        optimizer.step()
        running_losses += [l.data[0] for l in losses]
        running_loss += loss.data[0]
        if i % 10 == 0:
            print('Epoch %d, iteration %d: loss=%f' % (epoch, i, running_loss/1000))
            print('Running losses:')
            print(running_losses)
            running_loss = 0.0
            running_losses = np.zeros(8)
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': feedback_net.state_dict(),
        'optimizer' : optimizer.state_dict(),
    })

print('done!')

Epoch 0, iteration 0: loss=0.036638
Running losses:
[ 4.5739212   4.57797337  4.57809973  4.57804489  4.58078194  4.58226204
  4.58308887  4.58362293]
Epoch 0, iteration 10: loss=0.365373
Running losses:
[ 45.66669703  45.66688061  45.66656113  45.66848183  45.67412376
  45.67894936  45.67781687  45.67306566]
Epoch 0, iteration 20: loss=0.361352
Running losses:
[ 45.17797899  45.17975473  45.17394161  45.16795635  45.16808844
  45.16553402  45.16149092  45.15750217]
Epoch 0, iteration 30: loss=0.351426
Running losses:
[ 43.92896414  43.93480921  43.93733072  43.9369936   43.9324584
  43.92726374  43.91997528  43.9082365 ]


Process Process-23:
Process Process-24:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/maxspero/virtualenvs/ml/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 34, in _worker_loop
    r = index_queue.get()
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._r

KeyboardInterrupt: 

In [None]:
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(args.resume, checkpoint['epoch']))
