In [1]:
import nni
import numpy as np
import torch
from nni.nas.hub.pytorch import DARTS
from nni.nas.strategy import Random
import io
import graphviz
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
from nni.nas.evaluator import FunctionalEvaluator
from nni.nas.experiment import NasExperiment
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
import sys
from tqdm import tqdm
sys.path.insert(0, '/pbabkin/main/mipt/nas-for-moe/code')
import nas_moe.utils
import nas_moe.dataset
import nas_moe.moe_arch
import nas_moe.single_arch
import nas_moe.nni_utils

SEED = 322
nas_moe.utils.set_seed(SEED)



In [2]:
dataset_name = 'mnist'

In [3]:
if dataset_name == 'mnist':
    h, w = 28, 28
length = h * w

permutation = torch.randperm(length)
reverse_permutation = torch.argsort(permutation)

In [4]:
model_space = nas_moe.nni_utils.PixelPermutationSpace(
    width=8,
    num_cells=3,
    dataset=dataset_name,
    permutation=permutation
)

In [5]:
archGenerator = nas_moe.single_arch.ArchitectureGenerator(model_space, 5, SEED)
K = 3
arch_dicts = [archGenerator.generate_arch()['architecture'] for _ in range(K)]

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = nas_moe.dataset.DistortedMNIST(
    root='/pbabkin/main/mipt/nas-for-moe/code/data',
    train=True,
    download=True,
    custom_transform=transform,
    distortions=['permutation'],
    permutation=reverse_permutation
)

batch_size = 64
shuffle = True
num_workers = 4  # или 0 для отладки

train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=shuffle, 
    num_workers=num_workers,
    pin_memory=True  # ускоряет передачу на GPU
)

In [7]:
moe = nas_moe.moe_arch.MoE(arch_dicts,
                           nas_moe.nni_utils.PixelPermutationSpace,
                           train_dataset.input_size,
                           num_cells=model_space.num_cells,
                           permutation=permutation,
                           dataset=dataset_name)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
moe.to(device)
num_epochs = 2

In [9]:
optimizer = optim.Adam(moe.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [10]:
def train_one_epoch():
    moe.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = moe(images)
        loss = criterion(outputs, labels)
        # Load balancing loss при необходимости
        # lb_loss = moe.compute_load_balancing_loss(images)
        total_loss_batch = loss + 0.1 # * lb_loss

        total_loss_batch.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples
    print(f'Train  Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')

In [11]:
num_epochs = 2
for epoch in range(1, num_epochs + 1):
    print(f'Epoch {epoch}/{num_epochs}')
    train_one_epoch()

Epoch 1/2


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:42<00:00, 44.37it/s]


Train  Loss: 0.2055, Accuracy: 0.9353
Epoch 2/2


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:51<00:00, 36.24it/s]

Train  Loss: 0.1359, Accuracy: 0.9577





In [32]:
np.sqrt(1024)

np.float64(32.0)