In [46]:
import numpy as np
import dataset
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as LS
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms

In [45]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from torch.nn.modules.utils import _pair


class ConvRNNCellBase(nn.Module):
    def __repr__(self):
        s = (
            '{name}({input_channels}, {hidden_channels}, kernel_size={kernel_size}'
            ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        s += ', hidden_kernel_size={hidden_kernel_size}'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)


class ConvLSTMCell(ConvRNNCellBase):
    def __init__(self,
                 input_channels,
                 hidden_channels,
                 kernel_size=3,
                 stride=1,
                 padding=0,
                 dilation=1,
                 hidden_kernel_size=1,
                 bias=True):
        super(ConvLSTMCell, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)

        self.hidden_kernel_size = _pair(hidden_kernel_size)

        hidden_padding = _pair(hidden_kernel_size // 2)

        gate_channels = 4 * self.hidden_channels
        self.conv_ih = nn.Conv2d(
            in_channels=self.input_channels,
            out_channels=gate_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            bias=bias)

        self.conv_hh = nn.Conv2d(
            in_channels=self.hidden_channels,
            out_channels=gate_channels,
            kernel_size=hidden_kernel_size,
            stride=1,
            padding=hidden_padding,
            dilation=1,
            bias=bias)

        self.reset_parameters()

    def reset_parameters(self):
        self.conv_ih.reset_parameters()
        self.conv_hh.reset_parameters()

    def forward(self, input, hidden):
        hx, cx = hidden
        gates = self.conv_ih(input) + self.conv_hh(hx)
#         print("gates",gates.shape)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = F.sigmoid(ingate)
        forgetgate = F.sigmoid(forgetgate)
        cellgate = F.tanh(cellgate)
        outgate = F.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * F.tanh(cy)

        return hy, cy


class HyperConvLSTMCell(ConvRNNCellBase):
    def __init__(self,input_channels,main_num_units,hyper_unit,context_input_channels,hyper_embedding = 128):
        super(HyperConvLSTMCell, self).__init__()
        self.input_channels = input_channels
        self.num_units = main_num_units
        self.hyper_num_unit = hyper_unit
        self.hyper_embedding =hyper_embedding
        self.gate_params  = self.num_units * 4
        self.context_input_channels = context_input_channels
        self.hyper_input_units = context_input_channels+ main_num_units  

 
        # print(self.hyper_input_units,self.hyper_num_unit)

        self.hyper_cell = ConvLSTMCell(
            self.hyper_input_units,
            self.hyper_num_unit,
            kernel_size=3,
            stride=1,
            padding=1,
            hidden_kernel_size=1,
            bias=False)
        
        self.hyper_cell = self.hyper_cell.cuda()

#         self.temp_conv = nn.Conv2d(
#                     input_channels , 
#                     input_channels, 
#                     kernel_size=3, stride=2, padding=1, bias=False)
#         self.temp_conv = self.temp_conv.cuda()

        self.conv_z_input  = nn.Conv2d(self.hyper_num_unit, self.hyper_embedding, 
                        kernel_size=1, stride=1, padding=0, bias=False).cuda()

        self.conv_z_state  = nn.Conv2d(self.hyper_num_unit, self.hyper_embedding, 
                        kernel_size=1, stride=1, padding=0, bias=False).cuda()

        self.conv_alpha_input  = nn.Conv2d(self.hyper_embedding , self.gate_params, 
                    kernel_size=1, stride=1, padding=0, bias=False).cuda()

        self.conv_alpha_state  = nn.Conv2d(self.hyper_embedding , self.gate_params, 
                    kernel_size=1, stride=1, padding=0, bias=False).cuda()

        
        
        self.conv_transform_gates_input  = nn.Conv2d(self.input_channels , self.gate_params, 
                    kernel_size=3, stride=1, padding=1, bias=False).cuda()

        self.conv_transform_gates_states  = nn.Conv2d(self.num_units , self.gate_params, 
                    kernel_size=1, stride=1, padding=0, bias=False).cuda()


        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.reset_parameters()

    def reset_parameters(self):
        self.hyper_cell.reset_parameters()
        self.conv_z_input.reset_parameters()
        self.conv_z_state.reset_parameters()
        self.conv_alpha_input.reset_parameters()
        self.conv_alpha_state.reset_parameters()
        self.conv_transform_gates_input.reset_parameters()
        self.conv_transform_gates_states.reset_parameters()

    def hyper_norm_input(self,input_layer,hyper_h):
        zw = self.conv_z_input(hyper_h)
        alpha = self.conv_alpha_input(zw)

        result = input_layer * alpha
        
        return result

    
    def hyper_norm_state(self,input_layer,hyper_h):
        zw = self.conv_z_state(hyper_h)
        
        alpha = self.conv_alpha_state(zw)
        result = input_layer * alpha
        
        return result    
    # def reset_parameters(self):
    #     self.conv_ih.reset_parameters()
    #     self.conv_hh.reset_parameters()

    def forward(self, input,context, hidden):
        h,c = hidden
        main_h = h[:,:self.num_units]
        main_c = c[:,:self.num_units]
        hyper_h = h[:,self.num_units:]
        hyper_c = h[:,self.num_units:]
        # print("input shape ",input.shape)
        hyper_states = (hyper_h,hyper_c)
        # if self.encoder:
        #     input = self.temp_conv(input)
        hyper_input = torch.cat([context,main_h],dim=1)
        # print("hyper shape ",hyper_input.shape,hyper_states[0].shape)

        #print(hyper_input.shape,hyper_states[0].shape)
        hyper_h,hyper_c = self.hyper_cell(hyper_input,hyper_states)

        input_below_ = self.conv_transform_gates_input(input)
        input_below_ = self.hyper_norm_input(input_below_,hyper_h)

        state_below_ = self.conv_transform_gates_states(main_h)
        state_below_ = self.hyper_norm_state(state_below_,hyper_h)

        lstm_matrix = input_below_ + state_below_

        i,j,f,o = lstm_matrix.chunk(4,1)


        new_main_c = (self.sigmoid(f)*main_c) + (self.sigmoid(i)*self.tanh(j))
        new_main_h = self.tanh(new_main_c)* self.sigmoid(o)

        new_total_h =torch.cat([new_main_h,hyper_h],dim=1)

        new_total_c = torch.cat([new_main_c,hyper_c],dim=1)


        return (new_total_h,new_total_c),new_main_h

In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from modules import Sign


class EncoderCell(nn.Module):
    def __init__(self):
        super(EncoderCell, self).__init__()

        self.conv = nn.Conv2d(
            3, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.rnn1 = ConvLSTMCell(
            64,
            256,
            kernel_size=3,
            stride=2,
            padding=1,
            hidden_kernel_size=1,
            bias=False)
        self.rnn2 = ConvLSTMCell(
            256,
            512,
            kernel_size=3,
            stride=2,
            padding=1,
            hidden_kernel_size=1,
            bias=False)
        self.rnn3 = ConvLSTMCell(
            512,
            512,
            kernel_size=3,
            stride=2,
            padding=1,
            hidden_kernel_size=1,
            bias=False)

    def forward(self, input, hidden1, hidden2, hidden3):
        x = self.conv(input)
#         print(x.shape)
        hidden1 = self.rnn1(x, hidden1)
        x = hidden1[0]

        hidden2 = self.rnn2(x, hidden2)
        x = hidden2[0]

        hidden3 = self.rnn3(x, hidden3)
        x = hidden3[0]

        return x, hidden1, hidden2, hidden3


class Binarizer(nn.Module):
    def __init__(self):
        super(Binarizer, self).__init__()
        self.conv = nn.Conv2d(512, 32, kernel_size=1, bias=False)
        self.sign = Sign()

    def forward(self, input):
        feat = self.conv(input)
        x = F.tanh(feat)
        return self.sign(x)
    
# input_channels,main_num_units,hyper_unit,context_input_channels

class DecoderCell(nn.Module):
    def __init__(self):
        super(DecoderCell, self).__init__()

        self.conv1 = nn.Conv2d(
            32, 512, kernel_size=1, stride=1, padding=0, bias=False)
        self.rnn1 = ConvLSTMCell(
            512,
            512,
            kernel_size=3,
            stride=1,
            padding=1,
            hidden_kernel_size=1,
            bias=False)
        self.hyper1 = HyperConvLSTMCell(512,512,512,512)
        self.hyper2 = HyperConvLSTMCell(128,512,512,512)
        self.hyper3 = HyperConvLSTMCell(128,256,256,256)
        self.hyper4 = HyperConvLSTMCell(64,128,128,128)
        
#         self.rnn2 = ConvLSTMCell(
#             128,
#             512,
#             kernel_size=3,
#             stride=1,
#             padding=1,
#             hidden_kernel_size=1,
#             bias=False)
#         self.rnn3 = ConvLSTMCell(
#             128,
#             256,
#             kernel_size=3,
#             stride=1,
#             padding=1,
#             hidden_kernel_size=3,
#             bias=False)
#         self.rnn4 = ConvLSTMCell(
#             64,
#             128,
#             kernel_size=3,
#             stride=1,
#             padding=1,
#             hidden_kernel_size=3,
#             bias=False)
        self.conv2 = nn.Conv2d(
            32, 3, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, input,context, hidden1, hidden2, hidden3, hidden4):
        x = self.conv1(input)

        hidden1,x = self.hyper1(x,context[0],hidden1)

        x = F.pixel_shuffle(x, 2)

        hidden2,x = self.hyper2(x,context[1], hidden2)
#         x = hidden2[0]
        x = F.pixel_shuffle(x, 2)

        hidden3,x = self.hyper3(x,context[2], hidden3)
#         x = hidden3[0]
        x = F.pixel_shuffle(x, 2)

        hidden4,x = self.hyper4(x,context[3], hidden4)
#         x = hidden4[0]
        x = F.pixel_shuffle(x, 2)

        x = F.tanh(self.conv2(x)) / 2

        return x, hidden1, hidden2, hidden3, hidden4

In [58]:
encoder = EncoderCell().cuda()
binarizer = Binarizer().cuda()
decoder = DecoderCell().cuda()

In [59]:
solver = optim.Adam(
    [
        {
            'params': encoder.parameters()
        },
        {
            'params': binarizer.parameters()
        },
        {
            'params': decoder.parameters()
        },
    ],
    lr=0.01)

In [60]:
from PIL import Image
arr = np.random.randint(0,255,size=(32,32,3),dtype=np.uint8)
im = Image.fromarray(arr).convert('RGB')

In [36]:
# modified from https://github.com/desimone/vision/blob/fb74c76d09bcc2594159613d5bdadd7d4697bb11/torchvision/datasets/folder.py
import cv2
import os
import numpy as np
import torch
from torchvision import transforms
import torch.utils.data as data
from PIL import Image
import pickle
import random
IMG_EXTENSIONS = [
    '.jpg',
    '.JPG',
    '.jpeg',
    '.JPEG',
    '.png',
    '.PNG',
    '.ppm',
    '.PPM',
    '.bmp',
    '.BMP',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def readImage(path):
    img = cv2.imread(path)
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img/255.0
    return img

def default_loader(paths,root):
    paths = [os.path.join(root,i) for i in paths]
    main_imgs = readImage(paths[0])
    #change this
    motion_imgs = np.zeros((288, 352,4))
    ctx_imgs  = [readImage(path) for path in paths[1:]]
    ctx_imgs = np.concatenate(ctx_imgs, axis=2)
    main_imgs = np.concatenate([main_imgs,motion_imgs], axis=2)
    return main_imgs,ctx_imgs

def crop_cv2(img, patch):
    height, width, c = img.shape
    start_x = random.randint(0, height - patch)
    start_y = random.randint(0, width - patch)
    return img[start_x : start_x + patch, start_y : start_y + patch]

def np_to_torch(img):
    img = np.swapaxes(img, 0, 1) #w, h, 9
    img = np.swapaxes(img, 0, 2) #9, h, w
    return torch.from_numpy(img).float()

def swap(img):
    img = np.swapaxes(img, 0, 2) #w, h, 9
    img = np.swapaxes(img, 0, 1) #9, h, w
    return img


def get_bmv_filenames(mv_dir, main_fn):

    fn = main_fn.split('/')[-1][:-4]

    return (os.path.join(mv_dir, fn + '_before_flow_x_0001.jpg'),
            os.path.join(mv_dir, fn + '_before_flow_y_0001.jpg'),
            os.path.join(mv_dir, fn + '_after_flow_x_0001.jpg'),
            os.path.join(mv_dir, fn + '_after_flow_y_0001.jpg'))


def get_identity_grid(shape):
    width, height = shape
    grid = np.zeros((width, height, 2))
    for i in range(width):
        for j in range(height):
            grid[i, j, 0] = float(j) * (2.0 / (height - 1.0)) - 1.0
            grid[i, j, 1] = float(i) * (2.0 / (width - 1.0)) - 1.0
    return grid

def get_bmv(img, fns):
    before_x, before_y, after_x, after_y = fns
    # get all motions for the main frame
    # 4 bmvs for a main frame
    bmvs = [read_bmv(before_x),
            read_bmv(before_y),
            read_bmv(after_x),
            read_bmv(after_y)]

    if bmvs[0] is None or bmvs[1] is None:
        if 'ultra_video_group' in before_x:
            # We need HW to be (16n1, 16n2).
            bmvs[0] = np.zeros((1072, 1920, 1))
            bmvs[1] = np.zeros((1072, 1920, 1))
        else:
            bmvs[0] = np.zeros((288, 352, 1))
            bmvs[1] = np.zeros((288, 352, 1))
    else:
        bmvs[0] = bmvs[0] * (-2.0)
        bmvs[1] = bmvs[1] * (-2.0)        
 
    if bmvs[2] is None or bmvs[3] is None:
        if 'ultra_video_group' in before_x:
            bmvs[2] = np.zeros((1072, 1920, 1))
            bmvs[3] = np.zeros((1072, 1920, 1))
        else:
            bmvs[2] = np.zeros((288, 352, 1))
            bmvs[3] = np.zeros((288, 352, 1))
    else:
        bmvs[2] = bmvs[2] * (-2.0)
        bmvs[3] = bmvs[3] * (-2.0)        
        # bmv -256 to + 256
    return bmvs



class ImageFolder(data.Dataset):
    """ ImageFolder can be used to load images where there are no labels."""

    def __init__(self, root, transform=None, loader=default_loader):
        images = []
        pickledFile = os.path.join(root,"trainHyperTuple100.p")
        images = pickle.load(open(pickledFile,"rb"))
        self.root = root
        self.imgs = images
        self.loader = loader

        # for filename in os.listdir(root):
        #     if is_image_file(filename):
        #         images.append('{}'.format(filename))
#         self.root = root
#         self.imgs = images
#         self.transform = transform
#         self.loader = loader

    def __getitem__(self, index):
        filenames = self.imgs[index]
        main_imgs,ctx_imgs = self.loader(filenames,self.root)
#         croppedImgs = imgs
        croppedImgs = crop_cv2(main_imgs,32)
        croppedImgs = np_to_torch(croppedImgs)
        ctx_imgs = np_to_torch(ctx_imgs)
        return croppedImgs,ctx_imgs,filenames

    def __len__(self):
        return len(self.imgs)


In [37]:
import torch.utils.data as data

train_transform = transforms.Compose([
    transforms.RandomCrop((128, 128)),
    transforms.ToTensor(),
])
train_set = ImageFolder(root="/home_01/f20150198/datasets/ActivityNet/Crawler/Kinetics", transform=train_transform)
train_loader = data.DataLoader(dataset=train_set, batch_size=1, shuffle=True, num_workers=1)


In [50]:
unet = UNet(9,1).cuda()

In [87]:


for batch, (data,context,name) in enumerate(train_loader):
#     cv2.imwrite(name[0][0].replace("outTrainImgs/","main_"),swap(data.numpy()[0])*255)
#     cv2.imwrite(name[1][0].replace("outTrainImgs/","ctx_"),swap(context.numpy()[0,:3])*255)
#     cv2.imwrite(name[2][0].replace("outTrainImgs/","ctx_"),swap(context.numpy()[0,3:6])*255)
#     cv2.imwrite(name[3][0].replace("outTrainImgs/","ctx_"),swap(context.numpy()[0,6:9])*255)

    print("data",data.shape,"context",context.shape)
    unet_outputs = forward_ctx(unet,context)
    res, warped_unet_outputs = prepare_inputs(data,unet_outputs)
    print([i.shape for i in warped_unet_outputs])
    
    batch_t0 = time.time()

    ## init lstm state
    encoder_h_1 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()),
                   Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()))
    encoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()),
                   Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()))
    encoder_h_3 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()),
                   Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()))

    decoder_h_1 = (Variable(torch.zeros(data.size(0), 512 + 512, 2, 2).cuda()),
                    Variable(torch.zeros(data.size(0), 512 + 512, 2, 2).cuda()))
    decoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()),
                   Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()))
    decoder_h_3 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()),
                   Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()))
    decoder_h_4 = (Variable(torch.zeros(data.size(0), 128, 16, 16).cuda()),
                   Variable(torch.zeros(data.size(0), 128, 16, 16).cuda()))

    patches = Variable(data.cuda())

    solver.zero_grad()

    losses = []

    res = patches - 0.5

    context = context.cuda()

    context = unet(context)

    bp_t0 = time.time()

    for _ in range(args.iterations):
        encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder(
            res, encoder_h_1, encoder_h_2, encoder_h_3)

        codes = binarizer(encoded)

        output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
            codes,context, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)

        res = res - output
        losses.append(res.abs().mean())

    bp_t1 = time.time()

    loss = sum(losses) / args.iterations
    loss.backward()

    solver.step()

    batch_t1 = time.time()

    print(
        '[TRAIN] Epoch[{}]({}/{}); Loss: {:.6f}; Backpropagation: {:.4f} sec; Batch: {:.4f} sec'.
        format(epoch, batch + 1,
               len(train_loader), loss.data[0], bp_t1 - bp_t0, batch_t1 -
               batch_t0))
    print(('{:.4f} ' * args.iterations +
           '\n').format(* [l.data[0] for l in losses]))

    index = (epoch - 1) * len(train_loader) + batch

    ## save checkpoint every 500 training steps
    if index % 2000 == 0:
        save(0, False)

data torch.Size([1, 7, 32, 32]) context torch.Size([1, 9, 288, 352])
prepare input crop shape torch.Size([1, 7, 32, 32]) unet shape torch.Size([1, 256, 36, 44])
2
before warping [torch.Size([1, 256, 36, 44]), torch.Size([1, 128, 72, 88]), torch.Size([1, 64, 144, 176])]
[torch.Size([1, 256, 4, 4]), torch.Size([1, 128, 8, 8]), torch.Size([1, 64, 16, 16])]




In [72]:

def transpose_to_grid(frame2):
    # b, c, h, w
    # b, h, c, w
    # b, h, w, c
    frame2 = frame2.transpose(1, 2)
    frame2 = frame2.transpose(2, 3)
    return frame2

In [70]:
down_sample = nn.AvgPool2d(2, stride=2)

In [86]:
def prepare_inputs(data,unet_outputs):
    data_arr = []
    warped_unet_outputs = []
    # enumerating through 2 crops of same image
    patches = Variable(data.cuda())
    print("prepare input crop shape",data.shape,"unet shape",unet_outputs[0].shape)
    res, flows = prepare_batch(patches)
#     data_arr.append(res)
#     frame1_arr.append(frame1)
#     frame2_arr.append(frame2)
    print(len(flows))

    print("before warping",[i.shape for i in unet_outputs])
    wuo = warp_unet_outputs(flows, unet_outputs)

#     warped_unet_output1.append(wuo1)
#     warped_unet_output2.append(wuo2)


# #     res = torch.cat(data_arr, dim=0)
# #     frame1 = torch.cat(frame1_arr, dim=0)
# #     frame2 = torch.cat(frame2_arr, dim=0)
#     warped_unet_output1 = [torch.cat(wuos, dim=0) for wuos in zip(*warped_unet_output1)]
#     warped_unet_output2 = [torch.cat(wuos, dim=0) for wuos in zip(*warped_unet_output2)]

    return res, wuo

In [65]:

def prepare_batch(batch):
    res = batch - 0.5

    flows = []
    frame1, frame2 = None, None

    #after x and y             
    flow_1 = res[:, 3:5]
    #before x and y
    flow_2 = res[:, 5:7]

    flows.append(get_flows(flow_1))
    flows.append(get_flows(flow_2))
    res = res[:, :3]

    return res, flows


In [61]:
def get_flows(flow):
    flow_4 = down_sample(flow)
    flow_3 = down_sample(flow_4)
    flow_2 = down_sample(flow_3)

    flow_4 = transpose_to_grid(flow_4)
    flow_3 = transpose_to_grid(flow_3)
    flow_2 = transpose_to_grid(flow_2)

    final_grid_4 = flow_4 + 0.5
    final_grid_3 = flow_3 + 0.5
    final_grid_2 = flow_2 + 0.5

    return [final_grid_4, final_grid_3, final_grid_2]

In [77]:
def warp_unet_outputs(flows, unet_outputs):
    [grid_1_4, grid_1_3, grid_1_2] = flows[0]
    [grid_2_4, grid_2_3, grid_2_2] = flows[1]

    warped_unet_outputs= []

    warped_unet_outputs.append(F.grid_sample(
        unet_outputs[0], grid_1_2, padding_mode='border'))

    warped_unet_outputs.append(F.grid_sample(
        unet_outputs[1], grid_1_3, padding_mode='border'))

    warped_unet_outputs.append(F.grid_sample(
        unet_outputs[2], grid_1_4, padding_mode='border'))

    return warped_unet_outputs

In [53]:
def forward_ctx(unet, ctx_frames):
    ctx_frames = Variable(ctx_frames.cuda()) - 0.5

    unet_outputs = unet(ctx_frames)

    return unet_outputs

In [54]:
def get_identity_grid(shape):
    width, height = shape
    grid = np.zeros((width, height, 2))
    for i in range(width):
        for j in range(height):
            grid[i, j, 0] = float(j) * (2.0 / (height - 1.0)) - 1.0
            grid[i, j, 1] = float(i) * (2.0 / (width - 1.0)) - 1.0
    return grid

In [58]:
a = get_identity_grid((256,352))
a

array([[[-1.        , -1.        ],
        [-0.99430199, -1.        ],
        [-0.98860399, -1.        ],
        ...,
        [ 0.98860399, -1.        ],
        [ 0.99430199, -1.        ],
        [ 1.        , -1.        ]],

       [[-1.        , -0.99215686],
        [-0.99430199, -0.99215686],
        [-0.98860399, -0.99215686],
        ...,
        [ 0.98860399, -0.99215686],
        [ 0.99430199, -0.99215686],
        [ 1.        , -0.99215686]],

       [[-1.        , -0.98431373],
        [-0.99430199, -0.98431373],
        [-0.98860399, -0.98431373],
        ...,
        [ 0.98860399, -0.98431373],
        [ 0.99430199, -0.98431373],
        [ 1.        , -0.98431373]],

       ...,

       [[-1.        ,  0.98431373],
        [-0.99430199,  0.98431373],
        [-0.98860399,  0.98431373],
        ...,
        [ 0.98860399,  0.98431373],
        [ 0.99430199,  0.98431373],
        [ 1.        ,  0.98431373]],

       [[-1.        ,  0.99215686],
        [-0.99430199,  0.99

In [39]:
#!/usr/bin/python

# sub-parts of the U-Net model

import torch
import torch.nn as nn
import torch.nn.functional as F


class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights
        if bilinear:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)
        else:
            self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x1.size()[2] - x2.size()[2]
        diffY = x1.size()[3] - x2.size()[3]
        x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
                        diffY // 2, int(diffY / 2)))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

In [40]:
#!/usr/bin/python
# full assembly of the sub-parts to form the complete net

import torch
import torch.nn as nn
import torch.nn.functional as F

# python 3 confusing imports :(
from unet_parts import *

class UNet(nn.Module):
    def __init__(self, n_channels, shrink):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64 // shrink)
        self.down1 = down(64 // shrink, 128 // shrink)
        self.down2 = down(128 // shrink, 256 // shrink)
        self.down3 = down(256 // shrink, 512 // shrink)
        self.down4 = down(512 // shrink, 512 // shrink)
        self.up1 = up(1024 // shrink, 256 // shrink)
        self.up2 = up(512 // shrink, 128 // shrink)
        self.up3 = up(256 // shrink, 64 // shrink)
        self.up4 = up(128 // shrink, 64 // shrink)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        out1 = self.up1(x5, x4)
        out2 = self.up2(out1, x3)
        out3 = self.up3(out2, x2)
        return [out1, out2, out3]

In [47]:
import numpy as np
a = torch.randn(1, 3, 32, 32,dtype=torch.float)
unet = UNet(3,1)
b = unet(a)

torch.Size([1, 64, 32, 32])
x5 torch.Size([1, 512, 2, 2])
out1 torch.Size([1, 512, 4, 4])
out2 torch.Size([1, 256, 8, 8])
out3 torch.Size([1, 128, 16, 16])


In [20]:
print(b.shape)

torch.Size([1, 512, 2, 2])
