In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision import datasets
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim



all_class = [[0, 10, 20, 30, 40],
             [1, 11, 21, 31, 41],
             [2, 12, 22, 32, 42]]

selected_classes=sum(all_class, [])
label_mapping = {}
label_idx = 0
for st in all_class:
    for orig_label in st:
        label_mapping[orig_label] = label_idx
        label_idx += 1


print(label_mapping)



transform_test = transforms.Compose([

    transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_set = torchvision.datasets.CIFAR100(
    root='./data', train=True, download=True, transform=transform_train
)
selected_indices_train = [
    idx for idx, (_, label) in enumerate(train_set)
    if label in selected_classes
]


for i in selected_indices_train:
    train_set.targets[i]=label_mapping[train_set.targets[i]]


bt_size=64
# create filtered set
filtered_train_set = Subset(train_set, selected_indices_train)

train_loader = DataLoader(filtered_train_set, batch_size=bt_size, shuffle=True, num_workers=4)

test_set = torchvision.datasets.CIFAR100(
    root='./data', train=False, download=False, transform=transform_test
)

selected_indices_test = [
    idx for idx, (_, label) in enumerate(test_set)
    if label in selected_classes
]
#selected_indices_test=selected_indices_test[:100]

for i in selected_indices_test:
    test_set.targets[i]=label_mapping[test_set.targets[i]]

filtered_test_set = Subset(test_set, selected_indices_test)
test_loader = DataLoader(filtered_test_set, batch_size=bt_size, shuffle=False, num_workers=2)



{0: 0, 10: 1, 20: 2, 30: 3, 40: 4, 1: 5, 11: 6, 21: 7, 31: 8, 41: 9, 2: 10, 12: 11, 22: 12, 32: 13, 42: 14}


In [None]:
class CNN_0(nn.Module):
    def __init__(self, num_classes=5):
        super(CNN_0, self).__init__()
        self.conv0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)  
        self.bn0 = nn.BatchNorm2d(64)

        self.conv1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)  
        self.bn1 = nn.BatchNorm2d(128)
        
        self.conv2 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)  
        self.bn2 = nn.BatchNorm2d(256)

        self.conv3 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)  
        self.bn3= nn.BatchNorm2d(512)


        self.pool = nn.MaxPool2d(2, 2)  # 
        self.fc1 = nn.Linear(512 * 2 * 2, 256)  
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn0(self.conv0(x))))

        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 64x64 -> 32x32

        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 32x32 -> 16x16
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        


        x = x.view(x.size(0), -1)  
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def create_classifier_list(file_list,device):
    result=[]
    # load classifiers
    for f in file_list:
        model =CNN_0() # 
        model.load_state_dict(torch.load(f))
        model.to(device)
        model.eval()
        result.append(model)
    return result
classifier_file_list=[
    '/home/chunjielu/CIFAR100classifier/model1/model_weights_CNN_0.pth',
    '/home/chunjielu/CIFAR100classifier/model2/model_weights_CNN_0.pth',
    '/home/chunjielu/CIFAR100classifier/model3/model_weights_CNN_0_41.pth'

]
classifier_list=create_classifier_list(classifier_file_list,device)

In [None]:
class MetaModel(nn.Module):
    def __init__(self, input_dim=15, hidden_dim1=128, hidden_dim2=64, num_classes=15):
        super(MetaModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.fc3 = nn.Linear(hidden_dim2, num_classes)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def get_meta_data(input,classifier_list):
    """
    create the data as the input of meta model
    """
    with torch.no_grad():
        model_outputs = []
        for model in classifier_list:
                output = model(input)
                model_outputs.append(output)

        combined_outputs = torch.cat(model_outputs, dim=1)
    return combined_outputs

def train(model,optimizer,criterion,epoch=10,device='cpu'):
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch)
    for i in range(epoch):
        model.train()
        running_loss = 0.0 
        for batch_idx, data in enumerate(train_loader, 0):
            inputs, target = data
            inputs, target = inputs.to(device) , target.to(device)
            optimizer.zero_grad()
 
            outputs = model(get_meta_data(inputs,classifier_list))
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
      
        if i%3==0 :
            print('epoch: %d loss:%.3f ' % (i,running_loss), end=' ')
            test(test_loader,model,criterion,device=device)
        scheduler.step()


def test(test_loader,model,criterion,device='cpu'):
    correct = 0
    total = 0
    loss_all=0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device) , labels.to(device)
            
            outputs = model(get_meta_data(images,classifier_list))
            loss_all+=criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, dim=1)
            
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy on test set: %.3f %%' % (100 * correct / total),end=' ')   
    print('loss on test set: {:.3f} '.format(loss_all))
        



In [10]:
model = MetaModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4) 
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

train(model,optimizer,criterion,epoch=30,device=device) 

epoch: 0 loss:87.814  Accuracy on test set: 74.333 % loss on test set: 22.012 
epoch: 3 loss:67.534  Accuracy on test set: 74.400 % loss on test set: 24.023 
epoch: 6 loss:64.223  Accuracy on test set: 74.800 % loss on test set: 23.094 
epoch: 9 loss:61.866  Accuracy on test set: 75.467 % loss on test set: 23.276 
epoch: 12 loss:58.313  Accuracy on test set: 75.733 % loss on test set: 22.924 
epoch: 15 loss:56.334  Accuracy on test set: 75.267 % loss on test set: 22.696 
epoch: 18 loss:53.723  Accuracy on test set: 77.400 % loss on test set: 22.242 
epoch: 21 loss:49.044  Accuracy on test set: 76.800 % loss on test set: 23.309 
epoch: 24 loss:45.504  Accuracy on test set: 77.333 % loss on test set: 23.000 
epoch: 27 loss:44.860  Accuracy on test set: 76.733 % loss on test set: 23.967 


In [12]:
model.eval()
test(test_loader,model,criterion,device=device)

Accuracy on test set: 77.667 % loss on test set: 21.908 


In [None]:
#torch.save(model, "./meta/meta_model_1.pth")