In [None]:
import torch
    
from data.task_dataset import TaskDataset

from models.encoder import Encoder
from models.architecture_generator import ArchitectureGenerator
from models.weight_generator import WeightGenerator
from models.evaluator import Evaluator

In [None]:
vocab_size = 10000
max_length = 50
num_classes = 10
num_samples = 1000
batch_size = 32

In [None]:
dataset = TaskDataset(num_samples, vocab_size, max_length, num_classes)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:

encoder = Encoder(
    input_size=vocab_size,
    hidden_size=128, 
    output_size=256, 
    num_layers=2,
    dropout=0.5
    )

architecture_generator = ArchitectureGenerator(
    input_size=256, 
    hidden_size=128, 
    output_size=128, 
    num_layers=2, 
    dropout=0.5
    )

weight_generator = WeightGenerator(
    input_size=256, 
    hidden_size=128, 
    output_size=128, 
    num_layers=2, 
    dropout=0.5
    )

evaluator = Evaluator(
    input_size=128, 
    hidden_size=128,
    output_size=1, 
    num_layers=2, 
    dropout=0.5
    )

In [None]:
# Загрузка готовых моделей
encoder.load_state_dict(torch.load('encoder.pth'))
architecture_generator.load_state_dict(torch.load('architecture_generator.pth'))
weight_generator.load_state_dict(torch.load('weight_generator.pth'))
evaluator.load_state_dict(torch.load('evaluator.pth'))

In [None]:
encoder.eval()
architecture_generator.eval()
weight_generator.eval()
evaluator.eval()

In [None]:
with torch.no_grad():
    for (task_description, true_architecture, 
         true_weights, true_quality_metric) in dataloader:
        
        encoded_task = encoder(task_description)
        generated_architecture = architecture_generator(encoded_task)
        generated_weights = weight_generator(encoded_task)
        quality_metric = evaluator(generated_architecture)

        print(f'True Quality Metric: {true_quality_metric.item()},
              Generated Quality Metric: {quality_metric.item()}')

In [None]:
with torch.no_grad():
    for (task_description, true_architecture, 
         true_weights, true_quality_metric) in dataloader:
        
        encoded_task = encoder(task_description)
        generated_architecture = architecture_generator(encoded_task)
        generated_weights = weight_generator(encoded_task)
        quality_metric = evaluator(generated_architecture)

        print(f'True Quality Metric: {true_quality_metric.item()},
              Generated Quality Metric: {quality_metric.item()}')