In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F
import random
torch.cuda.empty_cache()
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [2]:
tasks = {
    "classify_boundary": {"input_dim": 2, "output_dim": 1},
    "add_numbers": {"input_dim": 2, "output_dim": 1},
    "subtract_numbers": {"input_dim": 2, "output_dim": 1}
}

In [3]:
# Dataset
class TaskDataset(Dataset):
    def __init__(self, task, num_samples=1000, num_points_per_sample=100):
        self.task = task
        self.num_samples = num_samples
        self.num_points_per_sample = num_points_per_sample
        self.offset = torch.randint(0, 5, (1,)).item()
        self.data, self.labels = self.generate_data(task, num_samples, num_points_per_sample)
    
    def generate_data(self, task, num_samples, num_points_per_sample):
        data = []
        labels = []
        for _ in range(num_samples):
            if task == "classify_boundary":
                x = torch.randint(-10, 10, (num_points_per_sample, 2)).float()
                y = (x[:, 0] + x[:, 1] > self.offset).float().unsqueeze(1)
            elif task == "add_numbers":
                a = torch.randint(0, 100, (num_points_per_sample, 1)).float()
                b = torch.randint(0, 100, (num_points_per_sample, 1)).float()
                x = torch.cat([a, b], dim=1)
                y = a + b + self.offset
            elif task == "subtract_numbers":
                a = torch.randint(0, 100, (num_points_per_sample, 1)).float()
                b = torch.randint(0, 100, (num_points_per_sample, 1)).float()
                x = torch.cat([a, b], dim=1)
                y = a - b + self.offset
            elif task == "multiclass_boundary":
                x = torch.randint(-10, 10, (num_points_per_sample, 2)).float()
                y = torch.zeros(num_points_per_sample, 1)
                y[(x[:, 0] + x[:, 1] > self.offset), 0] = 1
                y[(x[:, 0] - x[:, 1] > self.offset), 0] += 1
            data.append(x)
            labels.append(y)
        return data, labels
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        # Flatten x and y
        x_flat = x.view(-1)
        y_flat = y.view(-1)
        xy = torch.cat((x_flat, y_flat), dim=0)
        return xy
    
    def get_offset(self):
        return self.offset

In [4]:
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        # Data understanding layers
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
        )
        
        # Task understanding
        self.task_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

    def forward(self, x):
        # Extract features
        features = self.feature_extractor(x)
        # Infer task embedding
        z = self.task_head(features)
        return z

In [5]:
# Define the MetaModel that generates weight matrices
class MetaModel(nn.Module):
    def __init__(self, latent_dim, mlp_input_dim, mlp_hidden_dim, mlp_output_dim):
        super(MetaModel, self).__init__()
        total_weights = (mlp_input_dim * mlp_hidden_dim) + (mlp_hidden_dim * mlp_output_dim)
        hidden_size = 1024  # Increased hidden size for the MetaModel

        # Adding more layers and dropout to the MetaModel
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + 1, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, total_weights)
        )

        self.mlp_input_dim = mlp_input_dim
        self.mlp_hidden_dim = mlp_hidden_dim
        self.mlp_output_dim = mlp_output_dim

    def forward(self, z):
        weights = self.fc(z)
        # Split weights into two matrices for a two-layer MLP
        w1_size = self.mlp_input_dim * self.mlp_hidden_dim
        w2_size = self.mlp_hidden_dim * self.mlp_output_dim
        W1 = weights[:, :w1_size].view(-1, self.mlp_hidden_dim, self.mlp_input_dim)
        W2 = weights[:, w1_size:].view(-1, self.mlp_output_dim, self.mlp_hidden_dim)
        return W1, W2

In [6]:
# Initialize Model
num_points_per_sample = 200
input_dim_per_point = 2
output_dim_per_point = 1
input_dim = num_points_per_sample * (input_dim_per_point + output_dim_per_point)
latent_dim = 20
encoder = Encoder(input_dim=input_dim, latent_dim=latent_dim).to(device)
encoder.load_state_dict(torch.load('encoder_weights.pth'))
encoder.eval()  # Set encoder to evaluation mode


# Parameters for the MLP
mlp_input_dim = 2
mlp_hidden_dim = 128  # Increased hidden dimension
mlp_output_dim = 1  # Output dimension is 1 for all tasks

# Initialize the MetaModel
meta_model = MetaModel(latent_dim, mlp_input_dim, mlp_hidden_dim, mlp_output_dim)
meta_model.load_state_dict(torch.load('meta_model_weights.pth', map_location='cpu'))
meta_model = meta_model.to(device)
meta_model.eval()

  encoder.load_state_dict(torch.load('encoder_weights.pth'))
  meta_model.load_state_dict(torch.load('meta_model_weights.pth', map_location='cpu'))


MetaModel(
  (fc): Sequential(
    (0): Linear(in_features=21, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=1024, out_features=1024, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.5, inplace=False)
    (9): Linear(in_features=1024, out_features=384, bias=True)
  )
)

In [7]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        h = self.relu(self.fc1(x))
        out = self.fc2(h)
        return out

In [8]:
num_points_per_sample = 200
input_dim_per_point = 2
output_dim_per_point = 1

In [9]:
def train_random_mlps(task_name, num_models=1, convergence_loss=0.01, max_epochs=1000):
    input_dim = tasks[task_name]['input_dim']
    output_dim = tasks[task_name]['output_dim']
    hidden_dim = mlp_hidden_dim  # Use the same hidden dimension as in the MetaModel
    epochs_list = []
    for _ in range(num_models):
        model = SimpleMLP(input_dim, hidden_dim, output_dim).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        criterion = nn.MSELoss()
        dataset = TaskDataset(task_name, num_samples=1000, num_points_per_sample=1)
        loader = DataLoader(dataset, batch_size=128, shuffle=True)
        x_size = 1 * input_dim_per_point
        for epoch in range(max_epochs):
            total_loss = 0.0
            for xy in loader:
                # Shape: [batch_size, num_points, input_dim]
                x = xy[:, :x_size].view(-1, 1, input_dim_per_point).to(device)
                # Shape: [batch_size, num_points, 1]
                y = xy[:, x_size:].view(-1, 1, output_dim_per_point).to(device)
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                outputs = model(x)
                loss = criterion(outputs, y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            avg_loss = total_loss / len(loader)
            if avg_loss <= convergence_loss:
                epochs_list.append(epoch + 1)
                break
        else:
            epochs_list.append(max_epochs)
    average_epochs = sum(epochs_list) / len(epochs_list)
    print(f"Task: {task_name}, Random MLP average convergence epochs: {average_epochs}")
    return average_epochs


In [10]:
def train_metamodel_mlp(task_name, convergence_loss=0.01, max_epochs=1000):
    # Setup dataset
    dataset = TaskDataset(task_name, num_samples=1000, num_points_per_sample=num_points_per_sample)
    loader = DataLoader(dataset, batch_size=128, shuffle=True)
    valadation = TaskDataset(task_name, num_samples=(128 * 10), num_points_per_sample=1)
    valloader = DataLoader(valadation, batch_size=128, shuffle=True)
    
    # Get a batch of data for encoder
    xy_batch = next(iter(loader)).to(device)
    
    # Generate weights using encoder and meta_model
    with torch.no_grad():
        # Get task embedding from encoder
        z = encoder(xy_batch)
        
        # Add task indicator
        if task_name == "add_numbers":
            task_indicator = torch.ones(z.size(0), 1).to(device)
        elif task_name == "subtract_numbers":
            task_indicator = torch.full((z.size(0), 1), 2).to(device)
        else:
            task_indicator = torch.zeros(z.size(0), 1).to(device)
        
        # Concatenate embedding with task indicator
        z = torch.cat((z, task_indicator), dim=1)
        
        # Get weight matrices from meta_model
        W1, W2 = meta_model(z)
        
        # Average weights across batch
        W1_mean = W1.mean(dim=0)
        W2_mean = W2.mean(dim=0)
    
    # Initialize MLP with generated weights
    class SimpleMLP(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(input_dim_per_point, mlp_hidden_dim)
            self.fc2 = nn.Linear(mlp_hidden_dim, output_dim_per_point)
            
        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return x
    
    model = SimpleMLP().to(device)
    
    # Load generated weights
    with torch.no_grad():
        model.fc1.weight.copy_(W1_mean)
        model.fc2.weight.copy_(W2_mean)
    
    # Setup training
    optimizer = optim.Adam(model.parameters(), lr=1e-1)
    
    # Training loop
    for epoch in range(max_epochs):
        total_loss = 0.0
        num_batches = 0
        
        for xy in valloader:
            xy = xy.to(device)
            x_size = input_dim_per_point * 1
            
            # Split into x and y
            x = xy[:, :x_size].view(-1, 1, input_dim_per_point)
            y = xy[:, x_size:].view(-1, 1, output_dim_per_point)
            
            # Reshape for processing
            x_flat = x.reshape(-1, input_dim_per_point)
            y_flat = y.reshape(-1, output_dim_per_point)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(x_flat)
            
            # Compute loss
            if task_name == "classify_boundary":
                outputs = torch.sigmoid(outputs)
                loss = F.binary_cross_entropy(outputs, y_flat.float())
            else:
                loss = F.mse_loss(outputs, y_flat)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        # Check convergence
        avg_loss = total_loss / num_batches
        if avg_loss <= convergence_loss:
            print(f"Task: {task_name}, MetaModel MLP converged in {epoch + 1} epochs")
            return epoch + 1
    
    print(f"Task: {task_name}, MetaModel MLP did not converge within {max_epochs} epochs")
    return max_epochs

In [11]:
for task_name in tasks.keys():
    print(f"Running experiments for task: {task_name}")
    average_random_epochs = train_random_mlps(task_name)
    metamodel_epochs = train_metamodel_mlp(task_name)

Running experiments for task: classify_boundary
Task: classify_boundary, Random MLP average convergence epochs: 246.0
Task: classify_boundary, MetaModel MLP converged in 6 epochs
Running experiments for task: add_numbers
Task: add_numbers, Random MLP average convergence epochs: 268.0
Task: add_numbers, MetaModel MLP converged in 37 epochs
Running experiments for task: subtract_numbers
Task: subtract_numbers, Random MLP average convergence epochs: 172.0
Task: subtract_numbers, MetaModel MLP converged in 28 epochs
