In [1]:
import torch
import numpy as np
import os
import torch.nn.functional as F
from torch.utils.data import Dataset

In [2]:
print(os.getcwd())

E:\python\2024.12.01\ContraVis_next


In [3]:
from torch import nn
from models.mlp import MLP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
origin_predictor = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)  # CIFAR-10 共有 10 类
).to(device)
noisy_predictor = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)  # CIFAR-10 共有 10 类
).to(device)
optimizer_origin = torch.optim.Adam(origin_predictor.parameters(), lr=1e-3)
optimizer_noisy = torch.optim.Adam(noisy_predictor.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

In [4]:
class FeatureDataset(Dataset):
    def __init__(self,features_np,labels_np):
        self.features_np = features_np
        self.labels_np = labels_np
        assert self.features_np.shape[0] == self.labels_np.shape[0],"features_np and labels_np must have same shape"
    def __len__(self):
        return len(self.labels_np)
    def __getitem__(self, idx):
        return self.features_np[idx],self.labels_np[idx]

In [5]:
pattern_type = "stripes"
intensity = 0.3
raw_features = np.load(r"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_none\features.npy")
labels = np.load(r"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_none\labels.npy")
noisy_features = np.load(rf"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_{pattern_type}_{intensity}\features.npy")

origin_dataset_train = FeatureDataset(raw_features,labels)
noisy_dataset_train = FeatureDataset(noisy_features,labels)



origin_train_dataloader = torch.utils.data.DataLoader(origin_dataset_train,batch_size=32,shuffle=False)
noisy_train_dataloader = torch.utils.data.DataLoader(noisy_dataset_train,batch_size=32,shuffle=False)

In [6]:
model_save_dir = "model_weights/cifar_resnet_classifier"
if not os.path.exists(model_save_dir):
    os.mkdir(model_save_dir)

## Train Origin Predictor

In [7]:
# from tqdm import tqdm
# 
# origin_epochs = 12
# for epoch in tqdm(range(origin_epochs)):
#     origin_predictor.train()
#     origin_loss_epoch = 0.0
#     for batch in origin_train_dataloader:
#         optimizer_origin.zero_grad()
#         feature_origin,label_origin = batch
#         feature_origin = feature_origin.to(device)
#         label_origin = label_origin.to(device)
#         pred_origin = origin_predictor(feature_origin)
#         loss = criterion(pred_origin,label_origin)
#         origin_loss_epoch+=loss.item()
#         loss.backward()
#         optimizer_origin.step()
#         
#     print(f"epoch:{epoch} loss:{origin_loss_epoch:.3f}")
#     if epoch % 2 == 0:
#         origin_predictor.eval()
#         model_weight = origin_predictor.state_dict()
#         with torch.no_grad():
#             total_origin = 0
#             correct_origin = 0
#             for batch in origin_train_dataloader:
#                 feature_origin,label_origin = batch
#                 feature_origin = feature_origin.to(device)
#                 label_origin = label_origin.to(device)
#                 pred_origin = origin_predictor(feature_origin)
#                 _,pred_category = torch.max(pred_origin,dim=1)
#                 total_origin += label_origin.size(0)
#                 correct_origin += (pred_category == label_origin).sum().item()
#                 accuracy_origin = 100*correct_origin / total_origin
#             print(f"epoch:{epoch} Test Accuracy: {accuracy_origin:.2f}%")
#             save_path = os.path.join(model_save_dir,f"origin_mlp_{epoch}_{loss:.3f}_{accuracy_origin}.pth")
#             torch.save(model_weight,save_path)

In [8]:
# origin_predictor.load_state_dict(torch.load(os.path.join(model_save_dir,r"origin_mlp_10_0.405_85.884.pth")))
# pred_all_origin = []
# for batch in origin_train_dataloader:
#     feature_origin,label_origin = batch
#     feature_origin = feature_origin.to(device)
#     label_origin = label_origin.to(device)
#     pred_origin = origin_predictor(feature_origin)
#     _,pred_category = torch.max(pred_origin,dim=1)
#     pred_all_origin.append(pred_category.cpu().numpy())
# pred_all_origin_np = np.concatenate(pred_all_origin)
# pred_all_origin_save_path = os.path.join("dataset/cifar_resnet_none","predictions.npy")
# np.save(pred_all_origin_save_path,pred_all_origin_np)

## Train Noisy Predictor

In [9]:
from tqdm import tqdm

noisy_epochs = 12
for epoch in tqdm(range(noisy_epochs)):
    noisy_predictor.train()
    noisy_loss_epoch = 0.0
    for batch in noisy_train_dataloader:
        optimizer_noisy.zero_grad()
        feature_noisy,label_noisy = batch
        feature_noisy = feature_noisy.to(device)
        label_noisy = label_noisy.to(device)
        pred_noisy = noisy_predictor(feature_noisy)
        loss = criterion(pred_noisy,label_noisy)
        noisy_loss_epoch+=loss.item()
        loss.backward()
        optimizer_noisy.step()
        
    print(f"epoch:{epoch} loss:{noisy_loss_epoch:.3f}")
    if epoch % 2 == 0:
        noisy_predictor.eval()
        model_weight = noisy_predictor.state_dict()
        with torch.no_grad():
            total_noisy = 0
            correct_noisy = 0
            for batch in noisy_train_dataloader:
                feature_noisy,label_noisy = batch
                feature_noisy = feature_noisy.to(device)
                label_noisy = label_noisy.to(device)
                pred_noisy = noisy_predictor(feature_noisy)
                _,pred_category = torch.max(pred_noisy,dim=1)
                total_noisy += label_noisy.size(0)
                correct_noisy += (pred_category == label_noisy).sum().item()
                accuracy_noisy = 100 * correct_noisy / total_noisy
            print(f"epoch:{epoch} Test Accuracy: {accuracy_noisy:.2f}%")
            save_path = os.path.join(model_save_dir,f"{pattern_type}_{intensity}_mlp_{epoch}_{loss:.3f}_{accuracy_noisy}.pth")
            torch.save(model_weight,save_path)

  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

epoch:0 loss:1201.339


  8%|██████▉                                                                            | 1/12 [00:06<01:08,  6.25s/it]

epoch:0 Test Accuracy: 76.25%


 17%|█████████████▊                                                                     | 2/12 [00:10<00:53,  5.33s/it]

epoch:1 loss:991.041
epoch:2 loss:921.454


 25%|████████████████████▊                                                              | 3/12 [00:17<00:52,  5.83s/it]

epoch:2 Test Accuracy: 79.84%


 33%|███████████████████████████▋                                                       | 4/12 [00:22<00:43,  5.38s/it]

epoch:3 loss:860.115
epoch:4 loss:804.285


 42%|██████████████████████████████████▌                                                | 5/12 [00:29<00:44,  6.29s/it]

epoch:4 Test Accuracy: 82.83%


 50%|█████████████████████████████████████████▌                                         | 6/12 [00:34<00:34,  5.82s/it]

epoch:5 loss:752.155
epoch:6 loss:703.486


 58%|████████████████████████████████████████████████▍                                  | 7/12 [00:42<00:31,  6.28s/it]

epoch:6 Test Accuracy: 83.88%


 67%|███████████████████████████████████████████████████████▎                           | 8/12 [00:46<00:23,  5.83s/it]

epoch:7 loss:662.040
epoch:8 loss:622.253


 75%|██████████████████████████████████████████████████████████████▎                    | 9/12 [00:53<00:18,  6.11s/it]

epoch:8 Test Accuracy: 83.81%


 83%|████████████████████████████████████████████████████████████████████▎             | 10/12 [00:58<00:11,  5.60s/it]

epoch:9 loss:581.505
epoch:10 loss:544.995


 92%|███████████████████████████████████████████████████████████████████████████▏      | 11/12 [01:05<00:06,  6.04s/it]

epoch:10 Test Accuracy: 83.34%


100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [01:09<00:00,  5.83s/it]

epoch:11 loss:513.890





In [10]:
# noisy_predictor.load_state_dict(torch.load(os.path.join(model_save_dir,r"stripes_0.2_mlp_8_0.224_84.516.pth")))
noisy_predictor.load_state_dict(torch.load(save_path))
pred_all_noisy = []
for batch in noisy_train_dataloader:
    feature_noisy,label_noisy = batch
    feature_noisy = feature_noisy.to(device)
    label_noisy = label_noisy.to(device)
    pred_noisy = noisy_predictor(feature_noisy)
    _,pred_category = torch.max(pred_noisy,dim=1)
    pred_all_noisy.append(pred_category.cpu().numpy())
pred_all_noisy_np = np.concatenate(pred_all_noisy)
pred_all_noisy_save_path = os.path.join(f"dataset/cifar_resnet_{pattern_type}_{intensity}","predictions.npy")
np.save(pred_all_noisy_save_path,pred_all_noisy_np)