# EWE分析

## 1. 导入必要的库和类

In [28]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from models.conv import ConvModel

import os
import pickle
import numpy as np
from tqdm import tqdm

In [29]:
from utils.utils import setup_seed

## 2. 运行设置

In [30]:
dataset = "mnist"
model_type = '2_conv'
shuffle = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seed = 44
setup_seed(seed)

In [31]:
if dataset == 'mnist':
    model_type = '2_conv'
    batch_size = 512
    epochs = 10
    source = 1 
    target = 7
    num_classes = 10
    channels = 1
elif dataset == 'fashion':
    num_classes = 10
    channels = 1
    if model_type == '2_conv':
        batch_size = 128
        epochs = 10
        source = 8
        target = 0
else:
    raise NotImplementedError('Dataset is not implemented.')

## 3. 加载模型

In [32]:
if model_type == '2_conv':
    ewe_model = ConvModel(num_classes=num_classes, batch_size= batch_size, in_channels=channels, device=device)
    attack_model = ConvModel(num_classes=num_classes, batch_size= batch_size, in_channels=channels, device=device)
    clean_model = ConvModel(num_classes=num_classes, batch_size=batch_size, in_channels=channels, device=device)
    
    save_dir = "/home/mlsnrs/data/data/21ss/24-xiyuan/XiYuan-WM-Reproduction/EWE/trained_model/mnist"
    
    ewe_model.load_state_dict(torch.load(f'{save_dir}/ewe_model.pth'))
    attack_model.load_state_dict(torch.load(f'{save_dir}/attack_model.pth'))
    clean_model.load_state_dict(torch.load(f'{save_dir}/clean_model.pth'))
else:
    raise NotImplementedError('Model is not implemented.')

## 4. 加载数据集

In [33]:
if dataset == 'mnist' or dataset == 'fashion':
    with open(os.path.join("data", f"{dataset}.pkl"), 'rb') as f:
        mnist = pickle.load(f)
    x_train, y_train, x_test, y_test = mnist["training_images"], mnist["training_labels"], \
                                    mnist["test_images"], mnist["test_labels"]
    x_train = np.reshape(x_train / 255, [-1, 1, 28, 28])
    x_test = np.reshape(x_test / 255, [-1, 1, 28, 28])
else:
    raise NotImplementedError('Dataset is not implemented.')

## 5. 推理

In [37]:
x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)
test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)

output_test = []
ewe_model.eval()
ewe_model.to(device)
for (x, _) in tqdm(test_loader, desc=f"inferencing~", ncols=100):
    x = x.to(device)
    output_test.extend(ewe_model(x).cpu().detach().numpy())

ewe_model.to('cpu')
torch.cuda.empty_cache()

inferencing~: 100%|████████████████████████████████████████████████| 20/20 [00:00<00:00, 147.26it/s]
