In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision import datasets
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim



all_class = [[0, 10, 20, 30, 40],
             [1, 11, 21, 31, 41],
             [2, 12, 22, 32, 42]]
bt_size = 32


selected_classes=sum(all_class, [])
label_mapping = {}
label_idx = 0
for st in all_class:
    for orig_label in st:
        label_mapping[orig_label] = label_idx
    label_idx += 1

# label mapping
print(label_mapping)

cifar100_classes = torchvision.datasets.CIFAR100(root='./data', download=False).classes


transform_test = transforms.Compose([

    transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


train_set = torchvision.datasets.CIFAR100(
    root='./data', train=True, download=False, transform=transform_train
)
selected_indices_train = [
    idx for idx, (_, label) in enumerate(train_set)
    if label in selected_classes
]


test_set = torchvision.datasets.CIFAR100(
    root='./data', train=False, download=False, transform=transform_test
)
for i in selected_indices_train:
    train_set.targets[i]=label_mapping[train_set.targets[i]]

filtered_train_set = Subset(train_set, selected_indices_train)

train_loader = DataLoader(filtered_train_set, batch_size=bt_size, shuffle=True, num_workers=2)


selected_indices_test = [
    idx for idx, (_, label) in enumerate(test_set)
    if label in selected_classes
]


for i in selected_indices_test:
    test_set.targets[i]=label_mapping[test_set.targets[i]]

# 创建子集数据集
filtered_test_set = Subset(test_set, selected_indices_test)

test_loader = DataLoader(filtered_test_set, batch_size=32, shuffle=False, num_workers=2)



{0: 0, 10: 0, 20: 0, 30: 0, 40: 0, 1: 1, 11: 1, 21: 1, 31: 1, 41: 1, 2: 2, 12: 2, 22: 2, 32: 2, 42: 2}


In [None]:
def train(model,optimizer,criterion,epoch=10,device='cpu'):

    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=epoch)
    for i in range(epoch):
        model.train()
        running_loss = 0.0 
        for batch_idx, data in enumerate(train_loader, 0):
            inputs, target = data
            inputs, target = inputs.to(device) , target.to(device)
            optimizer.zero_grad()
            #print(batch_idx)
            filter_tensor = torchvision.transforms.functional.resize(inputs, (224, 224))
            outputs = model(filter_tensor).logits
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        scheduler.step()
        print('epoch: %d loss:%.3f ' % (i,running_loss))


import torch
from transformers import ViTForImageClassification
from transformers import AutoFeatureExtractor, SwinForImageClassification

model = SwinForImageClassification.from_pretrained("microsoft/swin-base-patch4-window7-224")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 3
model.classifier = torch.nn.Linear(model.config.hidden_size, num_classes)

print(model)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00004, betas=(0.9, 0.999), eps=1e-08)

train(model,optimizer,criterion,5,device)



SwinForImageClassification(
  (swin): SwinModel(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0-1): 2 x SwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): SwinSelfO

In [3]:
model.eval()
def test(test_loader,model,criterion,device='cpu'):
    correct = 0
    total = 0
    loss_all=0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device) , labels.to(device)

            filter_tensor = torchvision.transforms.functional.resize(images, (224, 224))
            outputs = model(filter_tensor).logits
            loss_all+=criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, dim=1)
            
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test set: %.3f %%' % (100 * correct / total),end=' ')

test(test_loader,model,criterion = nn.CrossEntropyLoss(),device=device)

Accuracy on test set: 97.733 % 

In [None]:
#model.save_pretrained("gating_network_swain/filter_model_2/")