In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.optim.lr_scheduler import StepLR

import os
import cv2
import uuid
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

def save_model():
    return ""

def load_model(model, weight_folder):
    model.load_state_dict(torch.load(weight_folder))
    model.eval()
    return "Model loaded."

# CUDA
def check_cuda():
    if not torch.cuda.is_available():
        device = torch.device('cpu')
    else:
        device = torch.device("cuda:0")

    return device

device = check_cuda()

# GENERATE RANDOM ID
def gen_id():
    return str(uuid.uuid4().fields[-1])[:6]

# GENERATE X
def call_grid(size):
    x = torch.zeros(size).type(torch.float32)
    b, c, h, w = size
    x[:, 3:, h // 2, w // 2] = 1.0
    return x

# SOBEL CLASS
class Sobels():
    def sobels(self):
        identify = torch.tensor([[ 0,0,0],[0,1,0], [0,0,0]])
        dx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]])
        dy = torch.tensor([[ 1,2,1],[0,0,0], [-1,-2,-1]])
        return identify, dx, dy

    def batch_sobels(self, n_channel):
        self.n_channel = n_channel
        i, x, y = self.sobels()
        identify = i.repeat(self.n_channel, 1, 1, 1).type(torch.float32).to(device)
        dx = x.repeat(self.n_channel, 1, 1, 1).type(torch.float32).to(device)
        dy = y.repeat(self.n_channel, 1, 1, 1).type(torch.float32).to(device)

        return identify, dx, dy
    
    def __len__(self):
        return len(self.sobels())

# CREATE PERCEPTION
def perception(x):
    b,c,h,w = x.size()
    sobels = Sobels(c)
    sobel_I, sobel_X, sobel_Y = Sobels.get(c)
    c_sobel_I = F.conv2d(x, sobel_I, padding=1, groups=c)
    c_sobel_X = F.conv2d(x, sobel_X, padding=1, groups=c)
    c_sobel_Y = F.conv2d(x, sobel_Y, padding=1, groups=c)
    
    pv = torch.stack((c_sobel_I, c_sobel_X, c_sobel_Y)).type(torch.float32).view(b, n_filter*c, h, w) # view en lugar de shape ?
    return pv

def perception_2(x):
    b,c,h,w = x.size()
    sobels = Sobels()
    sobel_I, sobel_X, sobel_Y = sobels.batch_sobels(c)
    filters = [sobel_I, sobel_X, sobel_Y]
    perception = torch.empty((b, len(filters) * c, h, w)).to(device)

    # Computamos los vectores de percepción con cada filtro. 3 filtros x 16 = 48 componentes.
    for f, filt in enumerate(filters):
        perception[:, (f * c):((f+1) * c), :, :] = F.conv2d(x, filt, groups=c, padding=[1, 1])
    return perception

# GENERATE A STOCASTIC UPDATE
def stochastic_update(x, output):
    b,c,h,w = x.size()
    stochastic_matrix = torch.randint(0,2,((b,c,h,w))).to(device)  # 0 a 2 ?
    return stochastic_matrix

# DETECT ALIVE CELLS 
def detect_alives_cells(output):
    b,c,h,w = output.size()
    alive_filter = torch.ones((1,1,3,3)).type(torch.double).to(device)    # b = 1 ?
    alpha = (output[:,3:4,:,:] > 0.1).type(torch.double).to(device)
    alives = F.conv2d(alpha, alive_filter, padding=1)
    alives = (alives > 0.0)
    alives = alives.repeat(1, c, 1, 1)
    return alives

# BATCH TARGET IMAGE
def batch_target(file_path, batch):
    #img = cv2.imread(file_path, -1) / 255.0
    img = np.array(Image.open(file_path)) / 255.
    target = torch.from_numpy(img)[None, :, :, :].permute(0,3,1,2).type(torch.float32).to(device)
    target = target.repeat(batch,1,1,1)
    return target

# UNBATCH TENSOR   
def unbatch_tensor(x, default="Tensor"):
    b, c, h, w = x.size()
    x = x.permute(0,2,3,1)
    x = x.detach().cpu().numpy()
    return x[0,:,:,:4]

# PLOT A TENSOR
def plot_tensor(x, title, default="Tensor"):
    if default == "Tensor":
        b, c, h, w = x.size()
        x = x.permute(0,2,3,1)
        x = x.detach().cpu().numpy()
        plt.title(str(title).upper())
        plt.imshow(x[0,:,:,:4])

    elif default == "Image":
        img = cv2.imread(x, -1)
        plt.title(str(title).upper())
        plt.imshow(img)

# DEBUG TENSOR MIN MAX 
def tensorMinMax(x, name, debug):
    if debug:
        print("[{0}]\tMin: {1}, Max: {2}".format(name.capitalize(), x.min(), x.max()))
    else:
        pass


class NCA(nn.Module):
    def __init__(self, n_channel, n_filter):
        super(NCA, self).__init__()
        self.n_channel = n_channel
        self.n_filter = n_filter
        self.fc1 = nn.Conv2d((self.n_channel * self.n_filter), 128, (1,1))
        self.fc2 = nn.Conv2d(128, n_channel, (1,1))
        torch.nn.init.zeros_(self.fc2.weight)
        #torch.nn.init.zeros_(self.fc2.bias)
    
    def forward(self, x):
      b, c, h, w = x.size()
      
      #pVector = perception(x)  
      pVector = perception_2(x)  
      dx = self.fc1(pVector)
      dx = F.relu(dx) 
      dx = self.fc2(dx)      
      
      #random_matrix = torch.from_numpy(np.random.randint(0, 2, (b, c, h, w))).to(device)
      random_matrix = stochastic_update(x, dx)
      x = x + random_matrix * dx
      alives = detect_alives_cells(x)        # Conv (3,3) > 0.1

      return alives * x

# HYPERPARAMETERS #
lr = 2e-3
epochs = 1001
r_seed = np.random.seed(24)             # Defining a default random seed.
n_steps = np.random.randint(64, 96)     # Number of step of each epoch

# TENSOR FORMAT #
b = 8
h = 32
w = 32
n_channel = 16
n_filter = Sobels().__len__()           # Number of sobels configurated.

# FOLDERS #
weights_folder = "./weights"
results_folder = "./training_output"
os.mkdir(results_folder) if not os.path.exists(results_folder) else ""

# MODEL PARAMETERS # 
L2 = nn.MSELoss()                       # MSE as loss function.
device = check_cuda()                   # Select CPU or CUDA if is available.
save_steps_time   = epochs // 3         # Save plots of growining steps 3 times during the training proccess.
save_outpus_time  = epochs // 10        # Save plots of final epoch 10 times during the training process.    

# DEFINING MODEL 
model = NCA(n_channel, n_filter).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=500, gamma=0.1)

# DEFINING Y (TARGET IMAGE)
target_img = "./input/cynda.png"
target = batch_target(target_img, b).to(device)

# TRAINING PROCESS
for epoch in range(epochs):

  # DEFINING X OR GRID CELL    
  X = call_grid((b, n_channel, h, w)).to(device)

  # STEP PROCESS
  for step in range(n_steps):
    X = model.forward(X)
    X = torch.clamp(X, 0, 1)

    # EXPORTING: SEQUENCE OF GROWNING CELL
    if epoch % save_steps_time == 0:
          step_folder = "steps_epoch_{}".format(epoch)
          steps_folder = os.path.join(results_folder, step_folder)
          os.mkdir(steps_folder) if not os.path.exists(steps_folder) else ""

          step_file = '{}_{}.png'.format(epoch, step, n_steps)
          step_path = os.path.join(steps_folder, step_file)

          fig = plt.figure()
          plt.imshow(unbatch_tensor(X))
          plt.savefig(step_path, bbox_inches='tight', dpi=100)
          plt.close(fig)

  # RUNNING PREDICTION      
  optimizer.zero_grad()
  X = torch.clamp(X, 0, 1)
  output = X[:, :4, :, :]

  # CALCULATING LOSS
  loss = L2(output, target)
  print("Epoch {}/{} - Loss: {}".format(epoch, epochs, str(loss.item())[0:6]))

  # EXPORTING: PREDICTION IMAGE
  if epoch % save_outpus_time == 0:

        fig = plt.figure()
        output_file = 'output_epoch_{}_loss_{}.png'.format(epoch, str(loss.item())[0:6])
        output_path = os.path.join(results_folder, output_file)
        plt.imshow(unbatch_tensor(output))
        plt.savefig(output_path, bbox_inches='tight', dpi=200)
        plt.close(fig)

  # OPTIMIZING PROCESS
  loss.backward()
  optimizer.step()
  scheduler.step()

# SAVING TRAINING
uid = gen_id()
weight_file = "nca-{}_epoch{}_loss{}.path".format(uid, epochs, str(loss.item())[0:6])

weights_path = os.path.join(weights_folder, weight_file)
torch.save(model.state_dict(), weights_path)
print("Model Saved.")