In [None]:
import torch
import torch.nn as nn
from lib.architecture import Search

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device

In [None]:
D = 2
H = 5
O = 2    

In [None]:
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)

In [None]:
split = Lambda(lambda x: (x[..., :H], x[..., H:]))
normalize = nn.LayerNorm(H)

def create_model():
    encoder = nn.Sequential(
        nn.Linear(D, H),
        nn.ReLU(),
        normalize
    )

    search = Search(
        transition=nn.Sequential(
            nn.Linear(H, 2*H),
            nn.ReLU(),
            split,
        ),
        fitness=nn.Sequential(
            nn.Linear(H, H),
            nn.ReLU(),
            nn.Linear(H, 1),
        ),
        normalization=normalize,
        max_depth=4, 
        max_width=8, 
        beam_width=8
    )

    decoder = nn.Sequential(
        nn.Linear(H, O)
    )

    model = nn.Sequential(
        encoder,
        search,
        decoder
    )

    return model

model = create_model()
model.to(device)

target = create_model()
target.to(device)
target.train=False

In [None]:
learning_rate = 1e-3
lambda_l2 = 1e-5

# nn package also has different loss functions.
# we use cross entropy loss for our classification task
criterion = torch.nn.MSELoss()

# we use the optim package to apply
# ADAM for our parameter updates
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=lambda_l2) # built-in L2


In [None]:
ITERATIONS = 10_000
BATCH_SIZE = 100

with torch.device(device):
    for i in range(ITERATIONS):
        batch = torch.rand(BATCH_SIZE, D).to(device)

        with torch.no_grad():
            targets = target(batch)

        # Feed forward to get the logits
        y_pred = model(batch)
        
        # loss
        loss = criterion(y_pred, targets)
        
        # accuracy
        acc = torch.mean(torch.cosine_similarity(y_pred, targets))
        
        print("[EPOCH]: %i, [LOSS]: %.6f, [ACCURACY]: %.3f" % (i, loss.item(), acc))
        # display.clear_output(wait=False)
        
        # zero the gradients before running
        # the backward pass.
        optimizer.zero_grad()

        # clip gradient
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
        
        # Backward pass to compute the gradient
        # of loss w.r.t our learnable params. 
        loss.backward()
        
        # Update params
        optimizer.step()