In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
mpl.rcParams['figure.figsize'] = (16.0, 8.0)
import numpy as np
import cv2
import os

In [3]:
device = torch.device("cuda:0")

In [4]:
class encoder(nn.Module):
    
    def __init__(self):
        super(encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5, stride=2, padding=2)
        self.conv2 = nn.Conv2d(16, 64, 5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(64, 128, 5, stride=4, padding=2)
       
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        
        return x        

In [5]:
class decoder(nn.Module):
   
    def __init__(self):
        super(decoder, self).__init__()
            
        self.deconv4 = nn.ConvTranspose2d(128, 64, 4, stride=4)
        self.deconv5 = nn.ConvTranspose2d(64, 16, 4, stride=4)
        self.conv6 = nn.Conv2d(16, 1, 5, padding=2)        
  
    def forward(self, x):
        x = F.relu(self.deconv4(x))
        x = F.relu(self.deconv5(x))
        x = self.conv6(x)
        
        return x 

In [11]:
class mfnet(nn.Module):
    
    def __init__(self):
        super(mfnet, self).__init__()
        
        self._encoder = encoder()
        self._decoder = decoder()
        
    def forward(self, x1, x2):
        x1 = self._encoder(x1)
        x2 = self._encoder(x2)
        x = torch.cat((x1, x2), 1)
        x = self._decoder(x)
        x = x.repeat(1, 3, 1)
        x = x * x1 + (1 - x) * x2
        return x

In [12]:
net = mfnet()
print(net)

mfnet(
  (_encoder): encoder(
    (conv1): Conv2d(3, 16, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (conv2): Conv2d(16, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (conv3): Conv2d(64, 128, kernel_size=(5, 5), stride=(4, 4), padding=(2, 2))
  )
  (_decoder): decoder(
    (deconv4): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(4, 4))
    (deconv5): ConvTranspose2d(64, 16, kernel_size=(4, 4), stride=(4, 4))
    (conv6): Conv2d(16, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
)


In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr = 0.0001)

In [22]:
def dataloader(path):
    X1, X2, y = [], [], []
    for i in range(1, 101):
        y.append(torch.from_numpy(cv2.imread(os.path.join(path, "{:4d}_0.jpeg".format(i)))))
        X1.append(torch.from_numpy(cv2.imread(os.path.join(path, "{:4d}_1.jpeg".format(i)))))
        X2.append(torch.from_numpy(cv2.imread(os.path.join(path, "{:4d}_2.jpeg".format(i)))))
    X1 = torch.stack(X1)
    X2 = torch.stack(X2)
    y = torch.stack(y)
    return (X1, X2), y

In [23]:
(X1, X2), y = dataloader('../datasets/testcase_lite')

TypeError: expected np.ndarray (got NoneType)

In [None]:
for epoch in range(20):
    loss = 0.0
    optimizer.zero_grad()
    y_pred = net(X1, X2)
    
    loss = criterion(y_pred, X)
    loss.backward()
    optimizer.step()

    loss += loss.item()