In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from data.task_dataset import TaskDataset

from models.encoder import Encoder
from models.architecture_generator import ArchitectureGenerator
from models.weight_generator import WeightGenerator
from models.evaluator import Evaluator

from utils.helpers import (
    create_optimizers, 
    train_model, 
    create_model_from_architecture, 
    train_generated_model
)

In [None]:
vocab_size = 10000
max_length = 50
num_classes = 10
num_samples = 1000
batch_size = 32
num_epochs = 10

In [None]:
dataset = TaskDataset(num_samples, vocab_size, max_length, num_classes)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
encoder = Encoder(
    input_size=vocab_size, 
    hidden_size=128, 
    output_size=256, 
    num_layers=2, 
    dropout=0.5
    )

architecture_generator = ArchitectureGenerator(
    input_size=256,
    hidden_size=128,
    output_size=128, 
    num_layers=2, 
    dropout=0.5
    )

weight_generator = WeightGenerator(
    input_size=256, 
    hidden_size=128,
    output_size=128,
    num_layers=2, 
    dropout=0.5
    )

evaluator = Evaluator(
    input_size=128, 
    hidden_size=128, 
    output_size=1, 
    num_layers=2,
    dropout=0.5
    )

In [None]:
models = {
    'encoder': encoder,
    'architecture_generator': architecture_generator,
    'weight_generator': weight_generator,
    'evaluator': evaluator
}

optimizers = create_optimizers(models, learning_rate=0.001)

In [None]:
criterion = {
    'architecture': nn.MSELoss(),
    'weights': nn.MSELoss(),
    'evaluator': nn.MSELoss()
}

In [None]:
for epoch in range(num_epochs):
    total_loss_architecture = 0.0
    total_loss_weights = 0.0
    total_loss_evaluator = 0.0

    for (task_description, true_architecture, 
         true_weights, true_quality_metric) in dataloader:
        encoded_task = encoder(task_description)
        generated_architecture = architecture_generator(encoded_task)
        generated_weights = weight_generator(encoded_task)

        generated_model = create_model_from_architecture(
            generated_architecture, 
            num_classes)

        generated_model_criterion = nn.CrossEntropyLoss()
        generated_model_optimizer = optim.Adam(
            generated_model.parameters(), lr=0.001)
        
        train_generated_model(
            generated_model, 
            dataloader, 
            generated_model_criterion, 
            generated_model_optimizer, 
            num_epochs=5)

        quality_metric = evaluator(generated_architecture)

        loss_architecture = criterion['architecture'](
            generated_architecture, 
            true_architecture)
        loss_weights = criterion['weights'](
            generated_weights, 
            true_weights)
        loss_evaluator = criterion['evaluator'](
            quality_metric, 
            true_quality_metric)


        optimizers['encoder'].zero_grad()
        optimizers['architecture_generator'].zero_grad()
        optimizers['weight_generator'].zero_grad()
        optimizers['evaluator'].zero_grad()

        loss_architecture.backward(retain_graph=True)
        loss_weights.backward(retain_graph=True)
        loss_evaluator.backward()

        optimizers['encoder'].step()
        optimizers['architecture_generator'].step()
        optimizers['weight_generator'].step()
        optimizers['evaluator'].step()

        total_loss_architecture += loss_architecture.item()
        total_loss_weights += loss_weights.item()
        total_loss_evaluator += loss_evaluator.item()

    avg_loss_architecture: float = (total_loss_architecture 
                             / len(dataloader))
    avg_loss_weights: float = total_loss_weights / len(dataloader)
    avg_loss_evaluator: float = total_loss_evaluator / len(dataloader)

    print(f'Epoch [{epoch+1}/{num_epochs}], 
          Loss Architecture: {avg_loss_architecture:.4f}, '
          f'Loss Weights: {avg_loss_weights:.4f}, 
          Loss Evaluator: {avg_loss_evaluator:.4f}')