In [5]:
from tqdm import tqdm
from typing import Dict
import time

import torch
import torch.fx
from torch.fx.node import Node

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt

import fast_nas_adapt.src.module2graph
from fast_nas_adapt.src.module2graph import GraphInterperterWithGamma, GraphInterperterWithBernGamma
from fast_nas_adapt.src.resnet18 import ResNet18


Тут был код, который замерял время. Сейчас он в файле

In [6]:
class GraphInterperterWithGumbelSoftmaxGamma(GraphInterperterWithGamma):
    def __init__(self, mod, gamma_shift=0.0, temperature=1.0):
        self.gamma_shift = gamma_shift
        self.temperature = temperature
        super().__init__(mod)

    def init_gammas(self):
        i = 0
        gammas = []
        self.gammas_name = {}
        for node in self.graph.nodes:
            if node.op == 'call_module':
                gammas.append(np.random.randn()+self.gamma_shift)
                self.gammas_name[str(node)] = i# перевод в str тут для удобства. в реалньых методах это не нужно
                i+=1                        # да и вообще, тут по идее должен быть тензор/параметр
        self.gammas =  torch.nn.Parameter(torch.as_tensor(gammas), requires_grad = True)
        self.discrete = False 

    def sample_gammas(self):
        if self.discrete:
            return self.gammas
        else:
            d = torch.distributions.RelaxedBernoulli(logits=self.gammas, temperature=self.temperature)
            return d.rsample()
        
    def make_gammas_discrete(self):
        self.gammas.data = (self.gammas.data>=0) * 1.0
        self.gammas.requires_grad = False 
        self.discrete = True

Получим претрейн на n эпох

In [7]:
from fast_nas_adapt.src.cifar_data import get_dataloaders
from fast_nas_adapt.src import *

In [8]:
model = ResNet18(num_classes=10)


train_dl, test_dl = get_dataloaders(classes=range(10), batch_size=64,
                                    img_size=33)


Using cache found in /Users/b1/.cache/torch/hub/pytorch_vision_v0.10.0


Files already downloaded and verified
Files already downloaded and verified


In [9]:
model.load_state_dict(torch.load('fast_nas_adapt/data/model_23.ckpt'))

<All keys matched successfully>

In [10]:
@torch.no_grad()
def validate(model, dataloader, device):
    model = model.to(device)
    n_true = 0
    n_tot = 0
    for i, (X, y) in tqdm(enumerate(dataloader), 'validating'):
        if X.shape[0] != 64:
            continue
        n_true += (model(X.to(device)).argmax(-1) == y.to(device)).sum().item()
        n_tot += 64
    return n_true / n_tot

In [11]:
validate(model, test_dl, 'cpu')

validating: 157it [00:32,  4.77it/s]


0.6464342948717948

Гиперпараметры:

In [12]:
device = 'cpu'

imodel = GraphInterperterWithGumbelSoftmaxGamma(model.eval(), 4, temperature=0.2).to(device)
optimizer = torch.optim.Adam([imodel.gammas], lr=0.1)
loss_fn = torch.nn.CrossEntropyLoss()

times = torch.ones_like(imodel.gammas).detach().div(imodel.gammas.numel())  # uniform
lambd = 2.0
print(imodel.gammas.sigmoid())

validate(imodel, test_dl, device), imodel.sample_gammas()

tensor([0.9854, 0.9622, 0.9940, 0.9864, 0.9501, 0.9873, 0.9878, 0.9672, 0.9231,
        0.9770, 0.9550, 0.9899, 0.9966, 0.9054, 0.9719, 0.9955, 0.9875, 0.9847,
        0.9914, 0.9788, 0.9910, 0.9117, 0.9902, 0.9712, 0.9939, 0.9483, 0.9647,
        0.9793, 0.9759, 0.9857, 0.9708, 0.9937, 0.9636, 0.9611, 0.8591, 0.9919,
        0.9822, 0.9795, 0.9729, 0.9790, 0.9432, 0.9704, 0.9777, 0.9937, 0.9905,
        0.9808, 0.9781, 0.9789, 0.9969, 0.9987, 0.9185, 0.9567, 0.9519, 0.9754,
        0.9745, 0.9747, 0.9908, 0.9596, 0.9951, 0.9472],
       grad_fn=<SigmoidBackward0>)


validating: 157it [00:34,  4.53it/s]


(0.390224358974359,
 tensor([1.0000, 0.9998, 0.9995, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         0.9990, 1.0000, 1.0000, 1.0000, 0.9994, 0.7397, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9996, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 0.9938, 1.0000, 0.6093, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9856, 1.0000, 1.0000, 1.0000,
         1.0000, 1.0000, 1.0000, 0.9999, 1.0000, 1.0000],
        grad_fn=<ClampBackward1>))

In [13]:
### TEST
with torch.no_grad():
    x = torch.randn(64, 3, 32, 32).to(device)
    print(model(x).shape, imodel(x).shape)
    # assert (model(x) - imodel(x)).abs().mean().item() < 2

torch.Size([64, 10]) torch.Size([64, 10])


In [16]:
epoch_number = 0

EPOCHS = 1

best_acc = 0.0
epoch_history = []
val_accs = []


for epoch in range(EPOCHS):
    print(f'EPOCH {epoch_number}:')

    imodel.train()

    for i, (X, y) in tqdm(enumerate(train_dl), 'training', total=len(train_dl)):
        if X.shape[0] != 64:
            continue
        optimizer.zero_grad()
        y_pred = imodel(X.to(device))
        loss = loss_fn(y_pred, y.to(device)) + lambd * imodel.sample_gammas().dot(times)
        loss.backward()
        optimizer.step()

        epoch_history.append((y_pred.argmax(-1) == y.to(device)).float().mean().item())
        
        if i == 300:
            break
        
        
    avg_loss = np.mean(epoch_history)

    imodel.eval()
    

    val_acc = validate(imodel, test_dl, device)
    val_accs.append(val_acc)
    print('LOSS train {} valid {}'.format(avg_loss, val_acc))

    # Track best performance, and save the model's state
    if val_acc > best_acc:
        best_acc = val_acc
        # model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        # torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 0:


training:  16%|█████                          | 128/782 [03:34<18:18,  1.68s/it]


KeyboardInterrupt: 

In [None]:
plt.plot(epoch_history)

In [None]:
plt.plot(val_accs)
max(val_accs)

In [None]:
(imodel.gammas.detach().cpu().sigmoid() >= 0.5).float().mean()