## Predict by Origin and Noisy Predictor

In [176]:
import torch
import numpy as np
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [177]:
origin_data = np.load(r"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_none\features.npy")
labels = np.load(r"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_none\labels.npy")
origin_pred = np.load(r"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_none\predictions.npy")

In [178]:
pattern_type = "random"
intensity = 0.3
noisy_data = np.load(rf"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_{pattern_type}_{intensity}\features.npy")
noisy_pred = np.load(rf"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_{pattern_type}_{intensity}\predictions.npy")
similarity_scores = np.load(rf"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_{pattern_type}_{intensity}\similarity.npy")
# noisy_data = raw_data

In [179]:
from torch.utils.data import Dataset


class FeatureDataset(Dataset):
    def __init__(self,features_np,labels_np):
        self.features_np = features_np
        self.labels_np = labels_np
        assert self.features_np.shape[0] == self.labels_np.shape[0],"features_np and labels_np must have same shape"
    def __len__(self):
        return len(self.labels_np)
    def __getitem__(self, idx):
        return self.features_np[idx],self.labels_np[idx]

### Find Different Prediction Index and Counts

In [180]:
similarity_threshold = 0.05
different_semantics_mask = (origin_pred!=noisy_pred)*(similarity_scores<similarity_threshold)
print("total different semantics mask:",np.sum(different_semantics_mask))

total different semantics mask: 13120


### Find Transformed Different Index and Counts

In [181]:
from torch.optim.lr_scheduler import StepLR
from models.mlp import TimeMLP
from models.flow_matching import RectifiedFlow
import umap

fm_predictor = TimeMLP(512,[128,64],512,128).to(device)
rf = RectifiedFlow()
np.random.seed(42)

In [182]:
class TransformationDataset(torch.utils.data.Dataset):
    def __init__(self,reference_np,target_np,labels):
        self.reference_np = reference_np
        self.target_np = target_np
        self.labels = labels[:,None]
        assert reference_np.shape == target_np.shape,"reference_np and target_np have different shapes"
    def __getitem__(self, index):
        return torch.from_numpy(self.reference_np[index]), torch.from_numpy(self.target_np[index]),torch.from_numpy(self.labels[index])
    def __len__(self):
        return len(self.reference_np)

In [183]:
from models.flow_matching import fm_infer

transformation_dataset_train = TransformationDataset(origin_data,noisy_data,labels)
transformation_loader_train = torch.utils.data.DataLoader(dataset=transformation_dataset_train,batch_size = 512,shuffle=False)
transformed_tar_all = []
for fm_test_batch in transformation_loader_train:
    _,tar_test,_ = fm_test_batch
    tar_test = tar_test.to(device)
    x_ref_pred = fm_infer(fm_predictor,tar_test)
    transformed_tar_all.append(x_ref_pred.detach().cpu().numpy())
transformed_tar_all_np = np.concatenate(transformed_tar_all)

In [184]:
import os
from torch import nn

model_save_dir = "model_weights/cifar_resnet_classifier"
origin_predictor = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10)  # CIFAR-10 共有 10 类
).to(device)
origin_predictor.load_state_dict(torch.load(os.path.join(model_save_dir,r"origin_mlp_10_0.405_85.884.pth")))
transformed_feature_dataset = FeatureDataset(transformed_tar_all_np,labels)
transformed_feature_dataloader = torch.utils.data.DataLoader(transformed_feature_dataset,batch_size=32,shuffle=False)

In [185]:
pred_all_transformed = []
total_correct = 0
for batch in transformed_feature_dataloader:
    feature_transformed,label_transformed = batch
    feature_transformed = feature_transformed.to(device)
    label_transformed = label_transformed.to(device)
    pred_transformed = origin_predictor(feature_transformed)
    _,pred_category = torch.max(pred_transformed,dim=1)
    total_correct += (pred_category == label_transformed).sum().item()
    pred_all_transformed.append(pred_category.cpu().numpy())
pred_all_transformed_np = np.concatenate(pred_all_transformed)
print(f"accuracy:{100*total_correct/len(labels)}%")
print(np.sum((pred_all_transformed_np!=origin_pred)*different_semantics_mask))
np.save(rf"E:\python\2024.12.01\ContraVis_next\dataset\cifar_resnet_{pattern_type}_{intensity}\predictions_transformed.npy",pred_all_transformed_np)

accuracy:54.976%
9861


In [186]:
similar_all_count = np.sum(similarity_scores>similarity_threshold)
similar_mask = similarity_scores>similarity_threshold
similar_preserved = np.sum((noisy_pred==origin_pred)*similar_mask)
print(f"similar_all_count,similar_preserved:\n{similar_preserved}/{similar_all_count},{similar_preserved/similar_all_count:.3f}")

similar_all_count,similar_preserved:
5952/7629,0.780
