In [None]:
class BFNet(nn.Module):
  def __init__(self):
    super(BFNet, self).__init__()
    in_chans = 3
    nfeatures = np.array([32,64,128,256,512])
    ker_size = 3
    pad = 1
    #encoder
    self.encoder1 = BFNet.BFblock(in_channels = in_chans, nfeature = nfeatures[0], kernel_size = ker_size, padding = pad, name = 'enc1')
    
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    self.encoder2 = BFNet.BFblock(nfeatures[0], nfeatures[1], ker_size, pad, name = 'enc2')
    self.encoder3 = BFNet.BFblock(nfeatures[1], nfeatures[2], ker_size, pad, name = 'enc3')
    self.encoder4 = BFNet.BFblock(nfeatures[2], nfeatures[3], ker_size, pad, name = 'enc4')
    self.encoder5 = BFNet.BFblock(nfeatures[3], nfeatures[4], ker_size, pad, name = 'enc5')
    
    #decoder
    
    self.upsamp = nn.Upsample(scale_factor=2, mode='bilinear')
    self.decoder1 = BFNet.BFblock(nfeatures[4] + nfeatures[3], nfeatures[3], ker_size, pad, name="dec1")
    self.decoder2 = BFNet.BFblock(nfeatures[3] + nfeatures[2], nfeatures[2], ker_size, pad, name="dec2")
    self.decoder3 = BFNet.BFblock(nfeatures[2] + nfeatures[1], nfeatures[1], ker_size, pad, name="dec3")
    self.decoder4 = BFNet.BFblock(nfeatures[1] + nfeatures[0], nfeatures[0], ker_size, pad,  name="dec4")
    
    self.conv = nn.Conv2d(in_channels=nfeatures[0], out_channels=1, kernel_size = ker_size, padding = pad)

    self.convFCN1 = nn.Conv2d(in_channels=nfeatures[4], out_channels=nfeatures[4], kernel_size = 4)
    self.bnormFCN1 = nn.BatchNorm2d(num_features=nfeatures[4])
    self.reluFCN1 = nn.ReLU(inplace=True)
    self.convFCN2 = nn.Conv2d(in_channels=nfeatures[4], out_channels=nfeatures[4], kernel_size = 1)
    self.bnormFCN2 = nn.BatchNorm2d(num_features=nfeatures[4])
    self.reluFCN2 = nn.ReLU(inplace=True)
    self.convFCN3 = nn.Conv2d(in_channels=nfeatures[4], out_channels= 17, kernel_size = 5)

    self.soft = nn.Softmax()
    self.sig = nn.Sigmoid()


  def forward(self, x, options):
   enc1 = self.encoder1(x)
   enc2 = self.encoder2(self.pool(enc1)) #1
   enc3 = self.encoder3(self.pool(enc2)) #2
   enc4 = self.encoder4(self.pool(enc3)) #3
   enc5 = self.encoder5(self.pool(enc4)) #y #4
   y = enc5
   
   for c in range(1,5): #4 classes
      
      x = self.upsamp(y)
      
      x = torch.cat((x,enc4),1)

      dec1 = self.decoder1(x)
      x = self.upsamp(dec1)
      
      x = torch.cat((x,enc3),1)
      dec2 = self.decoder2(x)
      x = self.upsamp(dec2)
      x = torch.cat((x,enc2),1)
      
      dec3 = self.decoder3(x)
      x = self.upsamp(dec3)
      x = torch.cat((x,enc1),1)
      
      dec4 = self.decoder4(x)
      x = self.conv(dec4) #x

      if c == 1 :
        z = x
      elif c > 1 : 
        z = torch.cat((z,x),1)

   x = z
   #Fully Connected Network
   fc1 = self.convFCN1(y)
   fc1 = self.bnormFCN1(fc1)
   fc1 = self.reluFCN1(fc1)
   fc2 = self.convFCN2(fc1)
   fc2 = self.bnormFCN2(fc2)
   fc2 = self.reluFCN2(fc2)
   prediction = self.convFCN3(fc2)

   
   lightParams = prediction[:,0:15,:,:,]
   batch = prediction[:,15:17,:,:,]
   nbatch = batch.shape[0]*batch.shape[2]*batch.shape[3] #x.shape[0]
   b = batch.reshape(2,nbatch)

   fmel = x[:,0,:,:,]
   fblood = x[:,1,:,:,]
   shading = x[:,2,:,:,]
   specmask = x[:,3,:,:,]
   
   lightingWeights = self.soft(lightParams[:,0:15,:,:,])
   weightA = lightingWeights[:,0,:,:,] 
   weightD = lightingWeights[:,1,:,:,]
   Fweights = lightingWeights[:,2:14,:,:,]
   
   CCT = lightParams[:,14,:,:,]
   CCT = ((22-1) / (1 + torch.exp(-1*CCT))) + 1
   b = 6 * self.sig(b) - 3
   BGrid = b.reshape(2, 1, 1, nbatch)
   BGrid = BGrid/3
   fmel = 2*self.sig(fmel) - 1
   fblood = 2*self.sig(fblood) - 1
   shading = torch.exp(shading)
   specmask = torch.exp(specmask)

   rgbim, specularities = model_based_decoder(options, fmel, fblood, shading, specmask, weightA, weightD, Fweights, CCT, b, BGrid)
   #return fmel, fblood, shading, specmask, weightA, weightD, Fweights, CCT, b, BGrid
   return fmel, fblood, shading, rgbim, specularities


   
   
  @staticmethod
  def BFblock(in_channels, nfeature, kernel_size, padding, name): #eod is encoder if == 1 decoder if == 2
      return nn.Sequential( OrderedDict(
                  [
                      (
                          name + "conv1",
                          nn.Conv2d(
                              in_channels=in_channels,
                              out_channels=nfeature,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=False,
                          ),
                      ),
                      (name + "norm1", nn.BatchNorm2d(num_features=nfeature)),
                      (name + "relu1", nn.ReLU(inplace=True)),
                      (
                          name + "conv2",
                          nn.Conv2d(
                              in_channels=nfeature,
                              out_channels=nfeature,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=False,
                          ),
                      ),
                      (name + "norm2", nn.BatchNorm2d(num_features=nfeature)),
                      (name + "relu2", nn.ReLU(inplace=True)),
                      (
                          name + "conv3",
                          nn.Conv2d(
                              in_channels=nfeature,
                              out_channels=nfeature,
                              kernel_size=kernel_size,
                              padding=padding,
                              bias=False,
                          ),
                      ),
                      (name + "norm3", nn.BatchNorm2d(num_features=nfeature)),
                      (name + "relu3", nn.ReLU(inplace=True)),
                  ]
              )
          )
  
  
  def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)