Skip to content

Commit

Permalink
Fix detachment again
Browse files Browse the repository at this point in the history
  • Loading branch information
carlo- committed May 14, 2018
1 parent 3c669ab commit fa8668e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
10 changes: 6 additions & 4 deletions libs/sepconv/SeparableConvolution.py
Expand Up @@ -12,9 +12,10 @@ def __init__(self):
super(SeparableConvolution, self).__init__()
# end

def forward(self, input, vertical, horizontal):
@staticmethod
def forward(context, input, vertical, horizontal):

self.save_for_backward(input, vertical, horizontal)
context.save_for_backward(input, vertical, horizontal)

intBatches = input.size(0)
intInputDepth = input.size(1)
Expand Down Expand Up @@ -50,9 +51,10 @@ def forward(self, input, vertical, horizontal):
return output
# end

def backward(self, grad_output):
@staticmethod
def backward(context, grad_output):

_input, vertical, horizontal = self.saved_tensors
_input, vertical, horizontal = context.saved_tensors

grad_input = _input.new().resize_(_input.size()).zero_()
grad_vertical = vertical.new().resize_(vertical.size()).zero_()
Expand Down
2 changes: 0 additions & 2 deletions src/main.py
Expand Up @@ -75,8 +75,6 @@ def train(epoch):
print('Gradients ready.')
optimizer.step()

detach_all(model.parameters())

print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.item()))
print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))

Expand Down
5 changes: 2 additions & 3 deletions src/model.py
Expand Up @@ -5,14 +5,13 @@
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable, gradcheck
from torch.nn.modules.loss import _Loss, _assert_no_grad
from src.separable_convolution import SeparableConvolutionSlow
from libs.sepconv.SeparableConvolution import SeparableConvolution
import src.config as config
import src.interpolate as interpol
import numpy as np

from torch.autograd import Variable, gradcheck

class Net(nn.Module):

Expand Down Expand Up @@ -50,7 +49,7 @@ def __init__(self):
self.pad = nn.ConstantPad2d(sep_kernel // 2, 0.0)

if torch.cuda.is_available() and not config.ALWAYS_SLOW_SEP_CONV:
self.separable_conv = SeparableConvolution()
self.separable_conv = SeparableConvolution.apply
else:
self.separable_conv = SeparableConvolutionSlow()

Expand Down

0 comments on commit fa8668e

Please sign in to comment.