In [None]:
import torch
import torch.nn as nn

from lib.architecture import RandomizedSearch
from lib.sample import SampleNormal

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device

In [None]:
D = 2
H = 5
O = 2    

In [None]:
from lib.sample import SampleUniform


def create_model(max_depth, beam_width, num_samples):
    encoder = nn.Sequential(
        nn.Linear(D, H),
        nn.ReLU(),
        nn.LayerNorm(H)
    )

    search = RandomizedSearch(
        transition=nn.Sequential(
            nn.Linear(H, H),
            nn.ReLU(),
            nn.Linear(H, 2*H),
        ),
        fitness=nn.Sequential(
            nn.Linear(H, H),
            nn.ReLU(),
            nn.Linear(H, 1),
        ),
        sample=nn.Sequential(
            SampleUniform(H, num_samples=num_samples),
            nn.LayerNorm(H)
        ),
        max_depth=max_depth, 
        beam_width=beam_width
    )

    decoder = nn.Sequential(
        nn.Linear(H, O)
    )

    model = nn.Sequential(
        encoder,
        search,
        decoder
    )

    return model

model = create_model(1, 8, 8)
model.to(device)

target = create_model(100, 8, 8)
target.to(device)
target.train=False

In [None]:
learning_rate = 1e-3
lambda_l2 = 1e-5

# nn package also has different loss functions.
# we use cross entropy loss for our classification task
criterion = torch.nn.MSELoss()

# we use the optim package to apply
# ADAM for our parameter updates
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=lambda_l2) # built-in L2


In [None]:
ITERATIONS = 200
BATCH_SIZE = 100

with torch.device(device):
    for i in range(ITERATIONS):
        batch = torch.rand(BATCH_SIZE, D).to(device)

        with torch.no_grad():
            targets = target(batch)

        # Feed forward to get the logits
        y_pred = model(batch)
        
        # loss
        loss = criterion(y_pred, targets)
        
        # accuracy
        acc = torch.mean(torch.cosine_similarity(y_pred, targets))
        
        print("[EPOCH]: %i, [LOSS]: %.6f, [ACCURACY]: %.3f" % (i, loss.item(), acc))
        # display.clear_output(wait=False)
        
        # zero the gradients before running
        # the backward pass.
        optimizer.zero_grad()

        # clip gradient
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
        
        # Backward pass to compute the gradient
        # of loss w.r.t our learnable params. 
        loss.backward()
        
        # Update params
        optimizer.step()

In [None]:
with torch.device(device):
    batch = torch.rand(BATCH_SIZE, D).to(device)

    with torch.no_grad():
        targets1 = target(batch)
        targets2 = target(batch)


    # loss
    loss = criterion(targets1, targets2)

    # accuracy
    acc = torch.mean(torch.cosine_similarity(targets1, targets2))

'baseline error', loss.item(), 'accuracy', acc.item()

In [None]:
with torch.device(device):
    batch = torch.rand(BATCH_SIZE, D).to(device)

    with torch.no_grad():
        targets1 = target(batch)
        targets2 = model(batch)

    # loss
    loss = criterion(targets1, targets2)

    # accuracy
    acc = torch.mean(torch.cosine_similarity(targets1, targets2))

'learned error', loss.item(), 'learned', acc.item()