In [1]:
import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch import optim
import timm
import os
from torch.nn import functional as F
from PIL import Image
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = timm.create_model('resnet34', pretrained=True, num_classes=10)

In [3]:
class Classifier(torch.nn.Module):
    def __init__(self, input_dim, nb_classes, *args, **kwargs) -> None:
        super(Classifier, self).__init__(*args, **kwargs)
        self.fc = torch.nn.Linear(input_dim, nb_classes)
    def forward(self, x):
        return self.fc(x)
    

In [4]:
class Cifar100():
    def __init__(self):
        self.train_trsf = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.05, 1.0), ratio=(3. / 4., 4. / 3.)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor()
        ])
        self.test_trsf = transforms.Compose([
            transforms.Resize(int((256 / 224) * 224), interpolation=3),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
        self.class_order = np.arange(100).tolist()
        self.train_dataset = datasets.cifar.CIFAR100("./data", train=True, download=True)
        self.test_dataset = datasets.cifar.CIFAR100("./data", train=False, download=True)
        self.train_data, self.train_targets = self.train_dataset.data, np.array(
            self.train_dataset.targets
        )
        self.test_data, self.test_targets = self.test_dataset.data, np.array(
            self.test_dataset.targets
        )

In [5]:
class Cifar10():
    def __init__(self):
        self.train_trsf = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.05, 1.0), ratio=(3. / 4., 4. / 3.)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor()
        ])
        self.test_trsf = transforms.Compose([
            transforms.Resize(int((256 / 224) * 224), interpolation=3),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
        self.train_dataset = datasets.cifar.CIFAR10("./data", train=True, download=True)
        self.test_dataset = datasets.cifar.CIFAR10("./data", train=False, download=True)
        self.train_data, self.train_targets = self.train_dataset.data, np.array(
            self.train_dataset.targets
        )
        self.test_data, self.test_targets = self.test_dataset.data, np.array(
            self.test_dataset.targets
        )

In [6]:
class DummyDataset(Dataset):
    def __init__(self, images, labels, trsf) -> None:
        super().__init__()
        self.images = images
        self.labels = labels
        self.trsf = trsf
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.trsf(Image.fromarray(self.images[idx]))
        label = self.labels[idx]
        return idx, image, label

In [7]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [8]:
dataset = Cifar10()
original_train_data, original_train_target = [], []
original_test_data, original_test_target = [], []
for i in range(10):
    train_index = np.where(np.logical_and(dataset.train_targets >= i, dataset.train_targets < i + 1))[0][:500]
    test_index = np.where(np.logical_and(dataset.test_targets >= i, dataset.test_targets < i + 1))[0][:100]
    original_train_data.append(dataset.train_data[train_index])
    original_train_target.append(dataset.train_targets[train_index])
    original_test_data.append(dataset.test_data[test_index])
    original_test_target.append(dataset.test_targets[test_index])
original_train_data, original_train_target = np.concatenate(original_train_data), np.concatenate(original_train_target)
original_test_data, original_test_target = np.concatenate(original_test_data), np.concatenate(original_test_target)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
# original_backbone.eval()
# for param in original_backbone.parameters():
#     param.requires_grad = False
    
# input_dim = original_backbone.head.in_features
# nb_classes = 10
# classifier = Classifier(input_dim, nb_classes)
# original_backbone.head = classifier

In [10]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [22]:
all_index = np.arange(500)
np.random.shuffle(all_index)
idxes = np.split(all_index, 5)
for i in range(len(idxes)):
    idxes[i] = np.random.choice(idxes[i], int(round(1.0 - i * 0.20, 1) * len(idxes[i])))

1.0
100
0.8
80
0.6
60
0.3999999999999999
40
0.19999999999999996
20


In [12]:
epochs = 5
model.cuda()
criterion.cuda()

test_dataset = DummyDataset(original_test_data, original_test_target, dataset.test_trsf)
test_dataloader = DataLoader(test_dataset, 8, True)

for i in range(5):
    train_data, train_target = [], []
    for cls in range(10):
        idx = np.where(np.logical_and(original_train_target >= cls,  original_train_target < cls + 1))[0]
        cls_data, cls_target = original_train_data[idx], original_train_target[idx]
        train_data.append(cls_data[idxes[i]])
        train_target.append(cls_target[idxes[i]])

    train_data, train_target = np.concatenate(train_data), np.concatenate(train_target)
    train_dataset = DummyDataset(train_data, train_target, dataset.train_trsf)
    train_dataloader = DataLoader(train_dataset, 8, True)
    print(len(train_target))
    
    prog_bar = tqdm(range(epochs))
    for _, e in enumerate(prog_bar):
        model.train()
        losses = 0.0
        
        for _, x, y in train_dataloader:
            x, y = x.cuda(), y.cuda()
            optimizer.zero_grad()
            y_pred = model(x)
            loss = criterion(y_pred, y.long())
            loss.backward()
            optimizer.step()
            
            losses += loss.item()
            info = "Epoch {}/{} => Loss {:.3f}".format(
                            e + 1,
                            epochs,
                            losses / len(train_dataloader),
                        )
        prog_bar.set_description(info)
        
    model.eval()
    predictions = []  
    ground_truths = []  

    with torch.no_grad():  # 关闭梯度计算，以节省内存并加速测试过程  
        for idx, images, labels in test_dataloader:  
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)  # 通过模型传递输入数据并获取预测结果  
            _, predicted = torch.max(outputs.data, 1)  # 对预测结果进行argmax操作，获取预测的类别索引  
            predictions.extend(predicted.tolist())  # 将预测结果添加到预测列表中  
            ground_truths.extend(labels.tolist())  # 将实际标签添加到实际标签列表中  
    
    # 计算准确率作为评估指标  
    accuracy = np.mean(np.array(predictions) == np.array(ground_truths)) * 100  # 计算准确率（百分比）  
    print('Test Accuracy: {:.2f}%'.format(accuracy))  # 打印测试准确率结果


1000


Epoch 5/5 => Loss 1.922: 100%|██████████| 5/5 [00:32<00:00,  6.50s/it]


Test Accuracy: 32.40%
800


Epoch 5/5 => Loss 1.725: 100%|██████████| 5/5 [00:25<00:00,  5.16s/it]


Test Accuracy: 39.70%
600


Epoch 5/5 => Loss 1.777: 100%|██████████| 5/5 [00:19<00:00,  3.82s/it]


Test Accuracy: 42.10%
390


Epoch 5/5 => Loss 1.800: 100%|██████████| 5/5 [00:12<00:00,  2.55s/it]


Test Accuracy: 38.20%
190


Epoch 5/5 => Loss 1.587: 100%|██████████| 5/5 [00:06<00:00,  1.22s/it]


Test Accuracy: 39.80%


In [13]:
model.eval()
predictions = []  
ground_truths = []  
# test_dataset = DummyDataset(original_test_data, original_test_target, dataset.test_trsf)
# test_dataloader = DataLoader(test_dataset, 8, True)
with torch.no_grad():  # 关闭梯度计算，以节省内存并加速测试过程  
    for idx, images, labels in test_dataloader:  
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)  # 通过模型传递输入数据并获取预测结果  
        _, predicted = torch.max(outputs.data, 1)  # 对预测结果进行argmax操作，获取预测的类别索引  
        predictions.extend(predicted.tolist())  # 将预测结果添加到预测列表中  
        ground_truths.extend(labels.tolist())  # 将实际标签添加到实际标签列表中  
  
# 计算准确率作为评估指标  
accuracy = np.mean(np.array(predictions) == np.array(ground_truths)) * 100  # 计算准确率（百分比）  
print('Test Accuracy: {:.2f}%'.format(accuracy))  # 打印测试准确率结果


Test Accuracy: 39.80%
