<a href="https://colab.research.google.com/github/etimush/MemoryNCA/blob/main/GeneCA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Gene CA**

In [None]:
#@title Get Packages and Images  { vertical-output: true}
!pip install git+https://github.com/etimush/MemoryNCA
!git clone --depth 1 --filter=blob:none --sparse https://github.com/etimush/MemoryNCA.git
!cd MemoryNCA && git sparse-checkout set Images
!mkdir Images && mv MemoryNCA/Images/* Images && rm -rf MemoryNCA

In [None]:
#@title Imports { vertical-output: true}
import random
import numpy as np
import matplotlib.pyplot as plt
from NCA.NCA import *
import NCA.utils as utils
from IPython.display import Image, HTML, clear_output
import logging
import os
from IPython.display import display, HTML, Video
from PIL import Image
import cv2
from base64 import b64encode
logger = logging.getLogger()
old_level = logger.level
logger.setLevel(100)

In [None]:
#@title Setup { vertical-output: true}
DEVICE = "cuda:0" #<-- Device to use, CUDA recommended
HEIGHT = 30 #@param {type:"integer"}
WIDTH = 30 #@param {type:"integer"}
CHANNELS = 16 # @param {type:"integer"}<--- NCA feature channels
BATCH_SIZE = 12 #@param {type:"integer"}
PADDING = 5 #@param {type:"integer"}
GENE_COUNT = 8 #@param {type:"integer"} <-- Number of gene channels to use for "private" information
POOL_SIZE = 2666 #@param {type:"integer"}<--- NCA training pool size, lower values train faster but are less stable
TRAINING_ITERS = 4000  #@param {type:"integer"}<-- Number of trainign iterations
HIDDEN_SIZE = 64 #@param {type:"integer"}<--- NCA hidden size
PRIMITIVES_SHAPES = ["Images/square.png", "Images/circle.png", "Images/triangle.png"]
PRIMITIVES_BODY_PARTS = ["Images/Torso.png", "Images/Head.png", "Images/Tail.png", "Images/leg1.png", "Images/leg2.png", "Images/leg3.png", "Images/leg4.png"]
PRIMITIVES_LINES = ["Images/horizontal.png", "Images/Verical.png"]
style = """
<style>
.output_wrapper, .output {
    display: flex;
    flex-direction: row-reverse; /* Align content to the right */
}
</style>
"""


In [None]:
#@title Load Primitives { vertical-output: true}

paths = PRIMITIVES_SHAPES #@param {type:"string"}
images = []
images_to_display = []
for path in paths:
    image, image_to_display = utils.get_image(path, HEIGHT, WIDTH, padding=PADDING)
    images.append(image)
    images_to_display.append(image_to_display)

genes = [[0], [2], [1]] # <-- Gene one hot encoding, indicates which bits if the gene sequence for each encoded "image" should be 1, [0] = 001, [0,1] = 011, [2] = 100 etc. for 3 bits genes. One, one-hot encoding per image, this rule applies for any gene size

HEIGHT = HEIGHT + 2*PADDING
WIDTH = WIDTH + 2*PADDING
assert len(paths) == len(genes), 'Genes and images should have the same length '

In [None]:
#@title Display Primitives { vertical-output: true}
for i,image in enumerate(images_to_display):
    plt.figure(3+i)
    plt.imshow(image)
pools = []
for gene in genes:
    pools.append(utils.make_gene_pool(gene, pool_size=POOL_SIZE,height=HEIGHT, width=WIDTH, channels=CHANNELS, gene_size=8))
seeds = []
for pool in pools:
    seeds.append(pool[0].clone())

In [None]:
#@title Get Batch Image Partitions { vertical-output: true}
partitions = len(paths)
if partitions == 1:
    part = [BATCH_SIZE]
div = BATCH_SIZE//partitions
rem = BATCH_SIZE % partitions
part = [div + 1 if i < rem else div for i in range(partitions)]
print(f"Batch image paritions = {part}. Batch Size of {BATCH_SIZE}. Number of Partitions = {partitions}")

In [None]:
#@title Load Filters for Loss Function { vertical-output: true}
sobel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], dtype=torch.float32, device="cuda:0")
lap = torch.tensor([[1.0, 2.0, 1.0], [2.0, -12, 2.0], [1.0, 2.0, 1.0]], dtype=torch.float32, device="cuda:0")
filters = torch.stack([sobel_x, sobel_x.T, lap])
folder = "Gene"

In [None]:
#@title Create Path for Saving Models { vertical-output: true}
path = "Trained_models/" + folder
if not os.path.exists(path):
    os.makedirs(path)
    print(f"Path: {path} created")
else:
    print(f"Path: {path} already exists, all OK!")


In [None]:
#@title Initialise NCA { vertical-output: true}
bases = [images[i].tile(part[i],1,1,1) for i in range(len(part))]
base = torch.cat(bases, dim =0 )
loss_log = []
nca = GeneCA(CHANNELS,HIDDEN_SIZE, gene_size=GENE_COUNT)
nca = nca.to(DEVICE)
optim = torch.optim.AdamW(nca.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=2000, gamma=0.3)
name = folder + "/" +type(nca).__name__ + "_gene_size_" +str(GENE_COUNT)

In [None]:
#@title Training { vertical-output: true}
for i in range(TRAINING_ITERS + 1):
    loss = 0
    with torch.no_grad():
        idxs, x = utils.get_gene_pool(pools, part, seeds)

    for _ in range( random.randrange(32,92)):
        x = nca(x)
    loss = (base - x[:, :4, :, :]).pow(2).sum() + 0.1 *(perchannel_conv(base, filters) - perchannel_conv(x[:, :4, :, :], filters) ).pow(2).sum()
    with torch.no_grad():
        loss.backward()
        for p in nca.parameters():
            p.grad /= (p.grad.norm() + 1e-8)
        optim.step()

        x = x.detach()
        optim.zero_grad()

    loss_log.append(loss.log().item())
    with torch.no_grad():
        pools = utils.udate_gene_pool(pools, x.clone().detach(), idxs, part)
    scheduler.step()

    if i % 100 == 0:
        print(f"Training itter {i}, loss = {loss.item()}")
        plt.clf()
        clear_output()
        plt.figure(1,figsize=(10, 4))
        plt.title('Loss history')
        plt.plot(loss_log, '.', alpha=0.5, color = "b")
        print("Batch")
        utils.show_batch(x[2:10])
        display(HTML(style))
        plt.show(block=False)
        plt.pause(0.01)
    if i % 100 == 0:
        torch.save(nca.state_dict(), "Trained_models/" + name + ".pth")
        print("Trained_models/" + name + ".pth")



In [None]:
#@title Video Utils { vertical-output: true}
path_video = "Saved_frames/GeneCA"

if not os.path.exists(path_video):
    os.makedirs(path_video)
    print(f"Path: {path_video} created")
else:
    print(f"Path: {path_video} already exists, all OK!")


def place_seed(x, center_x, center_y, seeds, seed_index):
    x[:,3:-8,center_x,center_y] = 1
    x[:,-3:,center_x,center_y] = seeds[seed_index]
    return x

def write_frame(x, path, frame_number, height, width, chn):
    image_np = x.clone().detach().cpu().permute(0,3,2,1).numpy().clip(0,1)[0,:,:,:3]


    plt.imsave(f"{path}/frame_{frame_number}.png", image_np)

def make_video(path, total_frames, height, width):
    fourcc = cv2.VideoWriter_fourcc(*'VP80')
    out = cv2.VideoWriter(path+'/output_video.webm', fourcc, 30.0, (width, height))
    for frame_number in range(total_frames):
       frame_path = path+f"/frame_{frame_number}.png"
       frame = cv2.imread(frame_path)
       #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

       out.write(frame)
    out.release()

In [None]:
#@title Create Video { vertical-output: true}
del nca
#del optim
torch.cuda.empty_cache()
seeds = [torch.tensor([0.0,0.0,1.0], device=DEVICE),torch.tensor([1.0,0.0,0.0], device=DEVICE),torch.tensor([0.0,1.0,0.0], device=0),torch.tensor([1.0,0.0,1.0], device=DEVICE)]
seed_locs = [[10, 10], [40,40], [80,80], [120,120]]
seed_index = 0

nca = GeneCA(CHANNELS,hidden_n=HIDDEN_SIZE, gene_size=GENE_COUNT)
nca.load_state_dict(torch.load("/Trained_models/Gene/GeneCA_gene_size_8.pth"))
nca.to(DEVICE).eval()
x_prime = torch.zeros((1,CHANNELS, HEIGHT*4, WIDTH*4), dtype=torch.float32).cuda()
frame_count = 499
for i in range(frame_count):
    x_prime = nca(x_prime)
    if i % 50 == 0:
        place_seed(x_prime, seed_locs[seed_index][0], seed_locs[seed_index][1], seeds, seed_index)
        seed_index = (seed_index + 1) % len(seeds)
    x_prime = x_prime.detach()
    write_frame(x_prime, path_video, i, HEIGHT*4, WIDTH*4,CHANNELS)
make_video(path_video, frame_count, HEIGHT*4, WIDTH*4)


In [None]:
#@title Display Vide { vertical-output: true}
Video(path_video+'/output_video.webm', embed=True, width=320, height=320)

# **Gene Propogatio CA**

In [None]:
#@title Setup { vertical-output: true}
del nca
torch.cuda.empty_cache()
BATCH_SIZE = 10 #@param {type:"integer"}
LIZARD = [ "Images/lizard.png"]
BUTTERFLY = [ "Images/BUTTERFLY.png"]
SPIDER = [ "Images/spider.png"]
MULTIPLE = [] #<-- add multiple path for training multiple NCA morphologies into one
TRAINING_ITERS = 14000 #@param {type:"integer"}
HIDDEN_SIZE_PROP = 124 #@param {type:"integer"}

In [None]:
#@title Load Image { vertical-output: true}
paths = LIZARD #@param {type:"string"}
images = []
images_to_display = []
for path in paths:
    image, image_to_display = utils.get_image(path, HEIGHT, WIDTH, padding=PADDING)
    images.append(image)
    images_to_display.append(image_to_display)
genes = [[1]]
HEIGHT = HEIGHT + 2*PADDING
WIDTH = WIDTH + 2*PADDING
assert len(paths) == len(genes), 'Genes and images should have the same length '

In [None]:
#@title Display Image { vertical-output: true}
for i,image in enumerate(images_to_display):
    plt.figure(3+i)
    plt.imshow(image)
pools = []
for gene in genes:
    pools.append(utils.make_gene_pool(gene, pool_size= 1000, height=HEIGHT, width=WIDTH, channels=CHANNELS, gene_size=8))
seeds = []
for pool in pools:
    seeds.append(pool[0].clone())

In [None]:
#@title Get Batch Image Partition { vertical-output: true}
partitions = len(paths)
if partitions == 1:
    part = [BATCH_SIZE]
div = BATCH_SIZE//partitions
rem = BATCH_SIZE % partitions
part = [div + 1 if i < rem else div for i in range(partitions)]
print(f"Batch image paritions = {part}. Batch Size of {BATCH_SIZE}. Number of Partitions = {partitions}")

In [None]:
#@title Generate Extra Genes for Multi Morphology GeneProp CA { vertical-output: true}
gene_2 = []
for idx,p in enumerate(part):
    gene = torch.zeros((1,HEIGHT,WIDTH), device=DEVICE)
    gene[:,HEIGHT//2,WIDTH//2] = idx
    gene_2.append(gene.tile(p, 1, 1, 1))
genes = torch.cat(gene_2, dim = 0)

In [None]:
#@title Create Path for Saving Models { vertical-output: true}
folder = "GeneProp"
path = "Trained_models/" + folder
if not os.path.exists(path):
    os.makedirs(path)
    print(f"Path: {path} created")
else:
    print(f"Path: {path} already exists, all OK!")

In [None]:
#@title Initialise GeneCA (static) and GeneProp CA for Trainign { vertical-output: true}
bases = [images[i].tile(part[i],1,1,1) for i in range(len(part))]
base = torch.cat(bases, dim =0 )
loss_log = []
with torch.no_grad():
    ncaPre = GeneCA(CHANNELS,hidden_n=64, gene_size=GENE_COUNT)
    ncaPre.load_state_dict(torch.load("Trained_models/Gene/GeneCA_gene_size_8.pth"))
    ncaPre.to(DEVICE).eval()
nca = GenePropCA(CHANNELS,HIDDEN_SIZE_PROP, gene_size=GENE_COUNT)
nca = nca.to(DEVICE)
optim = torch.optim.AdamW(nca.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=5000, gamma=0.3)
name = folder + "/" +type(nca).__name__ + "_gene_size_" +str(GENE_COUNT)

In [None]:
#@title Train GeneProp CA { vertical-output: true}
#TODO: Make it so that multi image gene props use the secondary genes
for i in range(TRAINING_ITERS + 1):
    loss = 0
    with torch.no_grad():
        idxs, x = utils.get_gene_pool(pools, part, seeds)

    itters = random.randrange(32,92)
    for _ in range( itters):
        x = ncaPre(x)
        x = nca(x)

    loss = (base - x[:, :4, :, :]).pow(2).sum() + 0.1 *(perchannel_conv(base, filters) - perchannel_conv(x[:, :4, :, :], filters) ).pow(2).sum()
    with torch.no_grad():
        loss.backward()
        for p in nca.parameters():
            p.grad /= (p.grad.norm() + 1e-8)
        optim.step()
        x = x.detach()
        optim.zero_grad()

    loss_log.append(loss.log().item())
    with torch.no_grad():
        pools = utils.udate_gene_pool(pools, x.clone().detach(), idxs, part)
    scheduler.step()

    if i % 100 == 0:
        print(f"Training itter {i}, loss = {loss.item()}")
        plt.clf()
        clear_output()
        plt.figure(1,figsize=(10, 4))
        plt.title('Loss history)')
        print(name)
        plt.plot(loss_log, '.', alpha=0.5, color = "b")
        utils.show_batch(x[2:10])
        plt.show(block=False)
        plt.pause(0.01)
    if i % 100 == 0:
        torch.save(nca.state_dict(), "Trained_models/" + name + ".pth")

In [None]:
#@title Create Video { vertical-output: true}
del nca
del ncaPre
del optim
torch.cuda.empty_cache()
path_video = "Saved_frames/GenePropCA"
seeds = [torch.tensor([0.0,1.0,0.0], device=DEVICE)]
seed_locs = [[80,80]]
seed_index = 0

nca = GeneCA(CHANNELS,hidden_n=HIDDEN_SIZE, gene_size=GENE_COUNT)
nca.load_state_dict(torch.load("/content/Trained_models/Gene/GeneCA_gene_size_8.pth"))
nca.to(DEVICE).eval()
nca_prop = GenePropCA(CHANNELS,hidden_n=HIDDEN_SIZE_PROP, gene_size=GENE_COUNT)
nca_prop.load_state_dict(torch.load("/content/Trained_models/GeneProp/GenePropCA_gene_size_8.pth"))
nca_prop.to(DEVICE).eval()
x_prime = torch.zeros((1,CHANNELS, HEIGHT, WIDTH), dtype=torch.float32).cuda()
place_seed(x_prime, seed_locs[seed_index][0], seed_locs[seed_index][1], seeds, seed_index)
frame_count = 500
for i in range(frame_count):
    x_prime = nca(x_prime)
    x_prime = x_prime.detach()
    if i > 250:
        x_prime = nca_prop(x_prime)
    write_frame(x_prime, path_video, i, HEIGHT, WIDTH,CHANNELS)
make_video(path_video, frame_count, HEIGHT*4, WIDTH*4)

In [None]:
#@title Display Vide { vertical-output: true}
Video(path_video+'/output_video.webm', embed=True, width=320, height=320)