Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: - vanilla_export leads to loss of accuracy #17

Open
HeinrichAD opened this issue Jul 6, 2022 · 9 comments
Open

[Bug]: - vanilla_export leads to loss of accuracy #17

HeinrichAD opened this issue Jul 6, 2022 · 9 comments
Labels
documentation Improvements or additions to documentation

Comments

@HeinrichAD
Copy link

Module

Layers

Contact Details

No response

Current Behavior

The accuracy is much lower after vanilla_export.

acc(model) != acc(model.vanilla_export())

Expected Behavior

For supported layers the accuracy should be equal. (Or at least almost equal.)

acc(model) == acc(model.vanilla_export())

Version

v0.1.0

Environment

- OS: Linux arch 5.18.9-arch1-1
- Python version: 3.7
- PyTorch version: 1.11.0+cu102
- Cuda version: 10.2
- Packages used version: deel-torchlip sklearn torch torchvision tqdm

Relevant log output

without vanilla_export()
------------------------
Sequential model contains a layer which is not a Lipschitz layer: ReLU(inplace=True)
Sequential model contains a layer which is not a Lipschitz layer: MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Sequential model contains a layer which is not a Lipschitz layer: ReLU(inplace=True)
Sequential model contains a layer which is not a Lipschitz layer: MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Sequential model contains a layer which is not a Lipschitz layer: ReLU(inplace=True)
Sequential model contains a layer which is not a Lipschitz layer: ReLU(inplace=True)
Sequential model contains a layer which is not a Lipschitz layer: Flatten(start_dim=1, end_dim=-1)
Files already downloaded and verified
Load data for evaluation: 100%|████████████████████| 10/10 [00:02<00:00,  4.56it/s]
0.4959


with vanilla_export()
---------------------
Sequential model contains a layer which is not a Lipschitz layer: ReLU(inplace=True)
Sequential model contains a layer which is not a Lipschitz layer: MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Sequential model contains a layer which is not a Lipschitz layer: ReLU(inplace=True)
Sequential model contains a layer which is not a Lipschitz layer: MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Sequential model contains a layer which is not a Lipschitz layer: ReLU(inplace=True)
Sequential model contains a layer which is not a Lipschitz layer: ReLU(inplace=True)
Sequential model contains a layer which is not a Lipschitz layer: Flatten(start_dim=1, end_dim=-1)
Files already downloaded and verified
Load data for evaluation: 100%|████████████████████| 10/10 [00:02<00:00,  4.78it/s]
0.2002

To Reproduce

#!/usr/bin/env python3
from collections import OrderedDict
from deel.torchlip import Sequential, SpectralConv2d, SpectralLinear
import sklearn.metrics
import torch
from torch.nn import Flatten, MaxPool2d, ReLU
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm


# config
seed = 42
batch_size = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# determinism
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# model
model = Sequential(OrderedDict([
    ("features", Sequential(
        SpectralConv2d(3, 6, 5),
        ReLU(True),
        MaxPool2d(2, 2),
        SpectralConv2d(6, 16, 5),
        ReLU(True),
        MaxPool2d(2, 2)
    )),
    ("flatten", Flatten()),
    ("classifier", Sequential(
        SpectralLinear(16 * 5 * 5, 120),
        ReLU(True),
        SpectralLinear(120, 84),
        ReLU(True),
        SpectralLinear(84, 10)
    ))
]))
state_dict = torch.load("state_dict.pt")
model.load_state_dict(state_dict)
#model = model.vanilla_export()  # <---- change this line
model.to(device)
model.eval()

# data
testset = datasets.CIFAR10("data/raw", train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, pin_memory=True)
labels = ["plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

# predict test data
targets = []
outputs = []
for _inputs, _targets in tqdm(testloader, position=0, leave=True, desc="Load data for evaluation"):
    targets.append(_targets.detach())
    _inputs, _targets = _inputs.to(device, non_blocking=True), _targets.to(device, non_blocking=True)
    outputs.append(model(_inputs).detach().cpu().argmax(1))
targets = torch.cat(targets).numpy()
outputs = torch.cat(outputs).numpy()

# accuracy
acc = sklearn.metrics.accuracy_score(targets, outputs)
print(acc)
@HeinrichAD HeinrichAD added the bug Something isn't working label Jul 6, 2022
@HeinrichAD
Copy link
Author

Used state_dict.pt: state_dict.zip

Note: this file is just renamed due to github upload restrictions. So mv state_dict.zip state_dict.pt is sufficient.

@franckma31
Copy link
Collaborator

Thanks for reporting this bug and sharing the code. Indeed both networks should have the same accuracy. We will check and give you a feedback soon

@franckma31
Copy link
Collaborator

HeinrichAD,
Exporting the network just after loading the state may lead to such errors. In fact vanilla_export has to be done at the end of the training phase, and before saving the weights. Here is an example how to use it

#!/usr/bin/env python3
from collections import OrderedDict
from deel.torchlip import Sequential, SpectralConv2d, SpectralLinear
import deel.torchlip as torchlip
import sklearn.metrics
import torch
from torch.nn import Flatten, MaxPool2d, ReLU, Conv2d, Linear
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm


features = {}
def get_features(name):
    def hook(model, input, output):
        features[name] = output.detach()
    return hook

# config
seed = 42
batch_size = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# determinism
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# model
model = Sequential(OrderedDict([
    ("features", Sequential(
        SpectralConv2d(3, 6, 5),
        ReLU(True),
        MaxPool2d(2, 2),
        SpectralConv2d(6, 16, 5),
        ReLU(True),
        MaxPool2d(2, 2)
    )),
    ("flatten", Flatten()),
    ("classifier", Sequential(
        SpectralLinear(16 * 5 * 5, 120),
        ReLU(True),
        SpectralLinear(120, 84),
        ReLU(True),
        SpectralLinear(84, 10)
    ))
]))
model.to(device)


# data
trainset = datasets.CIFAR10("data", train=True, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
testset = datasets.CIFAR10("data", train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, pin_memory=True)
labels = ["plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]


#####Training the model
optimizer = torch.optim.Adam(lr=0.001, params=model.parameters())
hkr_loss = torchlip.HKRMulticlassLoss(alpha=100, min_margin=0.25)
epochs = 2

# loss parameters
min_margin = 1
alpha = 10

for epoch in range(epochs):
    m_kr, m_hm, m_acc = 0, 0, 0
    model.train()

    for step, (data, target) in enumerate(trainloader):
        target = torch.nn.functional.one_hot(target, num_classes=10)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = hkr_loss(output, target)
        loss.backward()
        optimizer.step()

        # Compute metrics on batch
        m_kr += torchlip.functional.kr_multiclass_loss(output, target)
        m_acc += (output.argmax(dim=1) == target.argmax(dim=1)).sum() / len(target)


    # Train metrics for the current epoch
    metrics = [
        f"{k}: {v:.04f}"
        for k, v in {
            "loss": loss,
            "KR": m_kr / (step + 1),
            "acc": m_acc / (step + 1),
        }.items()
    ]

    # Compute test loss for the current epoch
    model.eval()
    testC = []
    acc_t = 0
    cnt_samples = 0
    for data, target in testloader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        acc_t += (output.argmax(dim=1) == target).sum()
        cnt_samples += len(target)
    # Validation metrics for the current epoch
    metrics += [
        f"val_{k}: {v:.04f}"
        for k, v in {
            #"loss": hkr_loss(
            #    testo, test.tensors[1], alpha=alpha, min_margin=min_margin
            #),
            #"KR": kr_loss(testo.flatten(), test.tensors[1], (1, -1)),
            "acc": acc_t/cnt_samples
        }.items()
    ]

    print(f"Epoch {epoch + 1}/{epochs}")
    print(" - ".join(metrics))


### Export before saving
model_v = model.vanilla_export()
print(model_v)

## save model after export        
torch.save(model_v.state_dict(), "test_save.pt")

## useless : only here to check also accuracy after export
model_v.to(device)

# model after export without SpectralConv2d and SpectralLinear  ### wit the uncoming fix the name of the blocks will be kept, here I was with the master version and has to use 0, 1 and 2 keys

model_loaded = Sequential(OrderedDict([
    ("0", Sequential(
        Conv2d(3, 6, 5),
        ReLU(True),
        MaxPool2d(2, 2),
        Conv2d(6, 16, 5),
        ReLU(True),
        MaxPool2d(2, 2)
    )),
    ("1", Flatten()),
    ("2", Sequential(
        Linear(16 * 5 * 5, 120),
        ReLU(True),
        Linear(120, 84),
        ReLU(True),
        Linear(84, 10)
    ))
]))
state_dict = torch.load("test_save.pt")
model_loaded.load_state_dict(state_dict)

model_loaded.to(device)

# predict test data
targets = []
outputs = []
outputs_v = []
outputs_l = []
for _inputs, _targets in tqdm(testloader, position=0, leave=True, desc="Load data for evaluation"):
    targets.append(_targets.detach())
    _inputs, _targets = _inputs.to(device, non_blocking=True), _targets.to(device, non_blocking=True)
    outputs.append(model(_inputs).detach().cpu().argmax(1))
    outputs_v.append(model_v(_inputs).detach().cpu().argmax(1))
    outputs_l.append(model_loaded(_inputs).detach().cpu().argmax(1))
targets = torch.cat(targets).numpy()
outputs = torch.cat(outputs).numpy()
outputs_v = torch.cat(outputs_v).numpy()
outputs_l = torch.cat(outputs_l).numpy()

# accuracy
acc = sklearn.metrics.accuracy_score(targets, outputs)
# accuracy
acc_v = sklearn.metrics.accuracy_score(targets, outputs_v)

# accuracy
acc_l = sklearn.metrics.accuracy_score(targets, outputs_l)

print("ref accuracy ",acc)
print("vanilla_export accuracy ",acc_v)
print("loaded accuracy ",acc_l)

We will at least document this usage in torchlip in the full documentation.
We will also try to found out why exporting after loading does fail

Hope it will help you continue your test with Torchlip. Thanks for your help

@HeinrichAD
Copy link
Author

@franckma31, thank you for your replay. This workaround is working for me. Just to keep in mind, transfer learning isn't possible with this solution, is it? As far as I understand, in general, transfer learning isn't possible due to breaking 1-Lipschitz property but if the loaded state would be 1-Lipschitz, like if it was trained via torchlip, it should be possible in theory.

I don't know if you want to keep this issue open for further investigations and documentation but feel free to close it.

@franckma31
Copy link
Collaborator

@HeinrichAD , indeed, we are not able to save, load and restart trainig. But, for transfert learning, it is still possible to vanilla_export, save, load and freeze the feature extractor, while learning a new lipshitz head with torchlip.

I will modify the label to documentation for adding doc on vanilla_export.
Thanks for your help

@franckma31 franckma31 added documentation Improvements or additions to documentation and removed bug Something isn't working labels Jul 18, 2022
@HeinrichAD
Copy link
Author

@franckma31 thank you as well.

Most likely the lowest priority but I think I would also be great to add a little transfer learning example to the documentation.

@cofri
Copy link
Collaborator

cofri commented Jul 18, 2022

Hi @HeinrichAD,
It seems that a forward pass is required to activate internal hooks in Lipschitz layers. When loading a Lipschitz model, adding a forward before vanilla export should fix your problem. Could you try it and give us feedback?

# Load Lipschitz model
state_dict = torch.load("state_dict.pt")
model.load_state_dict(state_dict)
model.eval()
model.to(device)

# Forward with any input (an image or even a dummy input) to activate pre_forward hooks
x = ...
model(x)

# Vanilla export
model = model.vanilla_export()

@HeinrichAD
Copy link
Author

I can confirm, that if I add model(torch.zeros(1, 3, 32, 32)) before model = model.vanilla_export() the expected accuracy of 0.4959 will be calculated.

@cofri
Copy link
Collaborator

cofri commented Jul 20, 2022

I can confirm, that if I add model(torch.zeros(1, 3, 32, 32)) before model = model.vanilla_export() the expected accuracy of 0.4959 will be calculated.

Thanks for your confirmation. This solution is only a bypass and we are currently working on a long-term fix that does not require any operation for the user. We will let you know about our progress.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

3 participants