Skip to content

Commit

Permalink
Fix upsampling issues (maybe) (#38, #32)
Browse files Browse the repository at this point in the history
  • Loading branch information
milesial committed Nov 19, 2018
1 parent 7dd7c8b commit 0f45521
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions unet/unet_parts.py
Expand Up @@ -61,10 +61,18 @@ def __init__(self, in_ch, out_ch, bilinear=True):

def forward(self, x1, x2):
x1 = self.up(x1)
diffX = x1.size()[2] - x2.size()[2]
diffY = x1.size()[3] - x2.size()[3]
x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
diffY // 2, int(diffY / 2)))

# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]

x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))

# for padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
Expand Down

0 comments on commit 0f45521

Please sign in to comment.