In [2]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import cv2
%matplotlib inline

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
print(device)

cuda


In [5]:
im = plt.imread('cat.png')
im = cv2.resize(im, (127,127), interpolation=cv2.INTER_AREA)
print(im.shape)
im = torch.tensor(im.reshape(3,127,127))

(127, 127, 3)


## Testing the Encoding Layer

In [6]:
conv2d1 = torch.nn.Conv2d(3, 96, (7,7),padding='same')
#use max pooling
pool1 = torch.nn.MaxPool2d( (2,2) )

conv2d2 = torch.nn.Conv2d(96, 128, (3,3),padding='same')
conv2d3 = torch.nn.Conv2d(128, 256, (3,3),padding='same')
conv2d4 = torch.nn.Conv2d(256, 256, (3,3),padding='same')
conv2d5 = torch.nn.Conv2d(256, 256, (3,3),padding='same')
conv2d6 = torch.nn.Conv2d(256, 256, (3,3),padding='same')

fc_layer = torch.nn.Linear(256*1*1, 1024)

In [7]:
im_c1 = conv2d1(im)
print( im_c1.shape )
im_p1 = pool1(im_c1)
print( im_p1.shape )
im_c2 = conv2d2(im_p1)
print( im_c2.shape )
im_p2 = pool1(im_c2)
print( im_p2.shape )
im_c3 = conv2d3(im_p2)
print( im_c3.shape )
im_p3 = pool1(im_c3)
print( im_p3.shape )
im_c4 = conv2d4(im_p3)
print( im_c4.shape )
im_p4 = pool1(im_c4)
print( im_p4.shape )
im_c5 = conv2d5(im_p4)
print( im_c5.shape )
im_p5 = pool1(im_c5)
print( im_p5.shape )
im_c6 = conv2d6(im_p5)
print( im_c6.shape )
im_p6 = pool1(im_c6)
print( im_p6.shape )
im_flat = torch.flatten(im_p6)
print(im_flat.shape)
im_fc = fc_layer(im_flat)
print(im_fc.shape)

torch.Size([96, 127, 127])
torch.Size([96, 63, 63])
torch.Size([128, 63, 63])
torch.Size([128, 31, 31])
torch.Size([256, 31, 31])
torch.Size([256, 15, 15])
torch.Size([256, 15, 15])
torch.Size([256, 7, 7])
torch.Size([256, 7, 7])
torch.Size([256, 3, 3])
torch.Size([256, 3, 3])
torch.Size([256, 1, 1])
torch.Size([256])
torch.Size([1024])


In [8]:
conv2d1 = torch.nn.Conv2d(3, 96, (7,7),padding='same')
#use max pooling
pool1 = torch.nn.MaxPool2d( (2,2) )
pool2 = torch.nn.MaxPool2d( (2,2) )
pool3 = torch.nn.MaxPool2d( (2,2) )
pool4 = torch.nn.MaxPool2d( (2,2) )
pool5 = torch.nn.MaxPool2d( (2,2) )
pool6 = torch.nn.MaxPool2d( (2,2) )

conv2d2 = torch.nn.Conv2d(96, 128, (3,3),padding='same')
conv2d3 = torch.nn.Conv2d(128, 256, (3,3),padding='same')
conv2d4 = torch.nn.Conv2d(256, 256, (3,3),padding='same')
conv2d5 = torch.nn.Conv2d(256, 256, (3,3),padding='same')
conv2d6 = torch.nn.Conv2d(256, 256, (3,3),padding='same')

relu1 = torch.nn.LeakyReLU()
relu2 = torch.nn.LeakyReLU()
relu3 = torch.nn.LeakyReLU()
relu4 = torch.nn.LeakyReLU()
relu5 = torch.nn.LeakyReLU()
relu6 = torch.nn.LeakyReLU()
relu7 = torch.nn.LeakyReLU()

fc_layer = torch.nn.Linear(256*1*1, 1024)

def encode(x):
    x = conv2d1(x)
    x = pool1(x)
    x = relu1(x)
    x = conv2d2(x)
    x = pool2(x)
    x = relu2(x)
    x = conv2d3(x)
    x = pool3(x)
    x = relu3(x)
    x = conv2d4(x)
    x = pool4(x)
    x = relu4(x)
    x = conv2d5(x)
    x = pool5(x)
    x = relu5(x)
    x = conv2d6(x)
    x = pool6(x)
    x = relu6(x)
    x = torch.flatten(x)
    x = fc_layer(x)
    x = relu7(x)
    return x

In [9]:
print( encode(im).shape )

torch.Size([1024])


## Testing 3D-LSTM Layer

In [10]:
batch_size = 1
h0 = torch.zeros((128, 4, 4, 4))
s0 = torch.zeros((128, 4, 4, 4))

In [26]:
# N = 4, which is 4x4x4 spatial resolution of 3d reconstruction specified in paper, we can mess with the resolution.
# N_h = 128, so the hidden tensors are NxNxN tensors w/ size N_h

conv3d1 = torch.nn.Conv3d(128, 128, (3,3,3),padding='same', bias=True)
conv3d2 = torch.nn.Conv3d(128, 128, (3,3,3),padding='same', bias=True)
conv3d3 = torch.nn.Conv3d(128, 128, (3,3,3),padding='same', bias=True)

hidden1 = torch.nn.Linear(1024, 128*4*4*4)
hidden2 = torch.nn.Linear(1024, 128*4*4*4)
hidden3 = torch.nn.Linear(1024, 128*4*4*4)

In [30]:
test = encode(im)

conv1_h0 = conv3d1(h0)
conv2_h0 = conv3d2(h0)
conv3_h0 = conv3d3(h0)

hidden1_test = hidden1(test)
hidden2_test = hidden2(test)
hidden3_test = hidden3(test)

print(hidden1_test.shape, conv1_h0.shape)
ft = torch.sigmoid(hidden1_test.reshape(128,4,4,4) + conv1_h0)
it = torch.sigmoid(hidden2_test.reshape(128,4,4,4) + conv2_h0)
gt = torch.tanh(hidden3_test.reshape(128,4,4,4) + conv3_h0)
print(ft.shape)
print(it.shape)
st = ft * s0 + it * gt
print(st.shape)
ht = torch.tanh(st)
print(ht.shape)

torch.Size([8192]) torch.Size([128, 4, 4, 4])
torch.Size([128, 4, 4, 4])
torch.Size([128, 4, 4, 4])
torch.Size([128, 4, 4, 4])
torch.Size([128, 4, 4, 4])


## Testing Decoder Network

In [96]:
# 5 convolutions
# unpooling and 3d convolutions 
unpool1 = torch.nn.MaxUnpool3d( (2,2,2) )
unpool2 = torch.nn.MaxUnpool3d( (2,2,2) )
unpool3 = torch.nn.MaxUnpool3d( (2,2,2) )


relu10 = torch.nn.LeakyReLU()

conv3d_dec1 = torch.nn.Conv3d(128, 128, (3,3,3),padding='same', bias=True)
conv3d_dec2 = torch.nn.Conv3d(128, 128, (3,3,3),padding='same', bias=True)
conv3d_dec3 = torch.nn.Conv3d(128, 64, (3,3,3),padding='same', bias=True)
conv3d_dec4 = torch.nn.Conv3d(64, 32, (3,3,3),padding='same', bias=True)
conv3d_dec5 = torch.nn.Conv3d(32, 2, (3,3,3),padding='same', bias=True)

softmax = torch.nn.Softmax(dim=1)

In [97]:
ht.shape

torch.Size([128, 4, 4, 4])

In [98]:
ht_res = ht.reshape(1,ht.shape[0],ht.shape[1],ht.shape[2],ht.shape[3]) #this for batch size :{.... fuck
up1 = nn.functional.interpolate(ht_res, scale_factor=2, mode='nearest')
dec1 = conv3d_dec1(up1)
dec1 = relu10(dec1)
print(dec1.shape)
up2 = nn.functional.interpolate(dec1, scale_factor=2, mode='nearest')
dec2 = conv3d_dec2(up2)
dec2 = relu10(dec2)
print(dec2.shape)
up3 = nn.functional.interpolate(dec2, scale_factor=2, mode='nearest')
dec3 = conv3d_dec3(up3)
dec3 = relu10(dec3)
print(dec3.shape)
dec4 = conv3d_dec4(dec3)
dec4 = relu10(dec4)
print(dec4.shape)
dec5 = conv3d_dec5(dec4)
print(dec5.shape)

final = softmax(dec5)
print(final.shape)

torch.Size([1, 128, 8, 8, 8])
torch.Size([1, 128, 16, 16, 16])
torch.Size([1, 64, 32, 32, 32])
torch.Size([1, 32, 32, 32, 32])
torch.Size([1, 2, 32, 32, 32])
torch.Size([1, 2, 32, 32, 32])
